医工互联

 找回密码
 注册[Register]

手机动态码快速登录

手机号快速登录

微信登录

微信扫一扫,快速登录

QQ登录

只需一步,快速开始

查看: 178|回复: 0
收起左侧

新冠肺炎CT辅助诊断文献实战-02

[复制链接]

  在线 

发表于 2022-10-23 02:13:56 | 显示全部楼层 |阅读模式 <
接着上一讲的内容,上一讲已经把第一个模型训练完毕。
该模型的主要作用就是将CT图像分为三种类型

  • NiCT:无用CT图片;
  • pCT:潜在的可能与cov19相关的图片;
  • nCT:与cov19无关的图片;
然后我们现在就要将这个模型运用在cohort1和cohort2的模型中
  1. # load model
  2. # the model and initial setting save in VGG_Simple.py
  3. from VGG_Simple import *
  4. %matplotlib inline
  5. %config InlineBackend.figure_format = 'svg'
  6. # load model parameter
  7. version = 'version_14'  # choose best model path
  8. path = os.path.join(os.getcwd(), 'check_point', version)
  9. ckpt = str(os.listdir(path)[2])  # choose best model
  10. print(ckpt)
  11. # Create model
  12. ckpt = os.path.join(path, ckpt)
  13. class_weight = torch.FloatTensor([2, 5, 2]).cuda()
  14. myloss = nn.CrossEntropyLoss(weight=class_weight)
  15. model = VGG_Simple(myloss=myloss)
  16. epoch=30-val_acc=0.99097-val_loss=0.03236.ckpt
复制代码
这里要注意一下,因为我们是将数据批量的下载到patientCT这个文件夹中,所以首先我们需要将 cohort1 和 cohort2 的数据分离开来。
cohort1 patient 1-1170
chohrt2 patient 1171 - 1521
  1. batch_size = 64
  2. # get patient CT-file path
  3. cohort1 = ['Patient {}'.format(i) for i in range(0,1171)]
  4. cohort2 = ['Patient {}'.format(i) for i in range(1171,1522)]
  5. patient_list = os.listdir('../patientCT/')
  6. # attention! not all samples have CT-file
  7. cohort1 = set(cohort1) & set(patient_list)
  8. cohort2 = set(cohort2) & set(patient_list)
  9. cohort1 = [os.path.join('..','patientCT',cohort,cohort) for cohort in cohort1]
  10. cohort2 = [os.path.join('..','patientCT',cohort,cohort) for cohort in cohort2]
  11. cohort1 = sorted(cohort1)
  12. cohort2 = sorted(cohort2)
  13. # make an empty dataframe
  14. result = pd.DataFrame(columns=['NiCT', 'nCT', 'pCT', 'patientId', 'label'])
复制代码
导入模型(一)

  1. for i in range(len(cohort1)):
  2.     print(i)
  3.     cohort_sample_path = cohort1[i] # get patient CT-file path
  4.     cohort_sample = cohort_sample_path.split('\\')[3] # get patientId
  5.     # import patient one by one
  6.     cohort = torchvision.datasets.ImageFolder(cohort_sample_path,
  7.                                             transform=torchvision.transforms.Compose([
  8.                                                 torchvision.transforms.Grayscale(num_output_channels=1),
  9.                                                 torchvision.transforms.Scale(256),
  10.                                                 torchvision.transforms.CenterCrop(200),
  11.                                                 torchvision.transforms.ToTensor()])
  12.                                             )
  13.     # load all ct file
  14.     cohort = torch.utils.data.DataLoader(cohort, batch_size=batch_size,shuffle=False,num_workers =24)
  15.     XRawData = []
  16.     YRawData = []
  17.     for i,(x,y) in enumerate(cohort):
  18.         XRawData.append(x)
  19.         YRawData.append(y)
  20.     XRawData = torch.cat(XRawData)
  21.     YRawData = torch.cat(YRawData)
  22.     cohort_iter = makeDataiter(XRawData, YRawData, batch_size=batch_size,shuffle=False)
  23.     model.test_predict = []
  24.     model.test_sample_label = []
  25.     model.test_decoder = []
  26.     model.test_matrix = []
  27.     model.test_conv = []
  28.     model.test_primary = []
  29.     model.test_digitcaps = []
  30.     trainer = pl.Trainer(resume_from_checkpoint=ckpt, gpus=-1)
  31.     # predict patient ct
  32.     trainer.test(model, cohort_iter)
  33.     predict = torch.cat(model.test_predict)
  34.     predict_label = toLabel(predict).cpu().numpy()
  35.     # get each picture id
  36.     sample = [cohort_sample +' ' + cohort.dataset.samples[i][0].split(
  37.         '\\')[5] for i in range(len(cohort.dataset.samples))]
  38.     predict = pd.DataFrame(predict.numpy())
  39.     predict.columns = ['NiCT','nCT',  'pCT']
  40.     predict.index = sample
  41.     predict['patientId'] = cohort_sample
  42.     predict['label'] = predict_label
  43.     predict['label'] = predict['label'].replace(
  44.        [0,1,2],  # 构建 label
  45.        ['NiCT','nCT',  'pCT']).astype(str)
  46.     # get top 10 pct
  47.     predict = predict.sort_values(by='pCT',ascending = False)
  48.     predict = predict.iloc[0:10,]
  49.     result = pd.concat([result,predict])
复制代码
分别对两套数据进行预测,并将结果保存输出

  1. # result2 = pd.DataFrame(columns=['NiCT', 'nCT', 'pCT', 'patientId', 'label'])
  2. #
  3. # for i in range(len(cohort2)):
  4. #     print(i)
  5. #     cohort_sample_path = cohort2[i]
  6. #     cohort_sample = cohort_sample_path.split('\\')[3]
  7. #     # 依次导入数据
  8. #     cohort = torchvision.datasets.ImageFolder(cohort_sample_path,
  9. #                                             transform=torchvision.transforms.Compose([
  10. #                                                 torchvision.transforms.Grayscale(num_output_channels=1),
  11. #                                                 torchvision.transforms.Scale(256),
  12. #                                                 torchvision.transforms.CenterCrop(200),
  13. #                                                 torchvision.transforms.ToTensor()])
  14. #                                             )
  15. #     cohort = torch.utils.data.DataLoader(cohort, batch_size=batch_size,shuffle=False,num_workers =24)
  16. #
  17. #     # 数据类型转换
  18. #     XRawData = []
  19. #     YRawData = []
  20. #     for i,(x,y) in enumerate(cohort):
  21. #         XRawData.append(x)
  22. #         YRawData.append(y)
  23. #         
  24. #     XRawData = torch.cat(XRawData)
  25. #     YRawData = torch.cat(YRawData)
  26. #     
  27. #     cohort_iter = makeDataiter(XRawData, YRawData, batch_size=batch_size,shuffle=False)
  28. #
  29. #     model.test_predict = []
  30. #     model.test_sample_label = []
  31. #     model.test_decoder = []
  32. #     model.test_matrix = []
  33. #     model.test_conv = []
  34. #     model.test_primary = []
  35. #     model.test_digitcaps = []
  36. #     
  37. #     trainer = pl.Trainer(resume_from_checkpoint=ckpt, gpus=-1)
  38. #     
  39. #     # predict patient ct
  40. #     trainer.test(model, cohort_iter)
  41. #     predict = torch.cat(model.test_predict)
  42. #     predict_label = toLabel(predict).cpu().numpy()
  43. #     
  44. #     # get each picture id
  45. #     sample = [cohort_sample +' ' + cohort.dataset.samples[i][0].split(
  46. #         '\\')[5] for i in range(len(cohort.dataset.samples))]
  47. #     predict = pd.DataFrame(predict.numpy())
  48. #     predict.columns = ['NiCT','nCT',  'pCT']
  49. #     predict.index = sample
  50. #     predict['patientId'] = cohort_sample
  51. #     predict['label'] = predict_label
  52. #     predict['label'] = predict['label'].replace(
  53. #        [0,1,2],  # 构建 label
  54. #        ['NiCT','nCT',  'pCT']).astype(str)
  55. #     
  56. #     # get top 10 pct
  57. #     
  58. #     predict = predict.sort_values(by='pCT',ascending = False)
  59. #     predict = predict.iloc[0:10,]
  60. #
  61. #     result2 = pd.concat([result2,predict])
  62. # save result
  63. # result.to_csv('cohort1_pCT_detection.csv')
  64. # result2.to_csv('cohort2_pCT_detection.csv')
复制代码
导入数据,进行模型可靠性验证

  1. cohort1_pCT = pd.read_csv('cohort1_pCT_detection.csv',index_col=0)
  2. cohort2_pCT = pd.read_csv('cohort2_pCT_detection.csv',index_col=0)
  3. meta = pd.read_csv('metadata.csv',index_col=0)
复制代码
这里我们要稍微对meta稍微处理一下,因为并不是每一个样本都有CT
  1. cohort1 = set(cohort1_pCT.index.map(lambda x: x.split(' IMG')[0]))
  2. cohort2 = set(cohort2_pCT.index.map(lambda x: x.split(' IMG')[0]))
  3. patientIds = list(cohort1)+list(cohort2)
  4. meta = meta.loc[patientIds]
  5. # cohort1 control pictures
  6. cohort1_predict_control = set(cohort1_pCT[cohort1_pCT['label'] != 'pCT'].index.map(
  7.     lambda x: x.split(' IMG')[0]))
  8. # cohort2 control pictures
  9. cohort2_predict_control = set(cohort2_pCT[cohort2_pCT['label'] != 'pCT'].index.map(
  10.     lambda x: x.split(' IMG')[0]))
  11. # cohort1 pCT pictures
  12. cohort1_predict_pCT = set(cohort1_pCT[cohort1_pCT['label'] == 'pCT'].index.map(
  13.     lambda x: x.split(' IMG')[0]))
  14. # cohort2 pCT pictures
  15. cohort2_predict_pCT = set(cohort2_pCT[cohort2_pCT['label'] == 'pCT'].index.map(
  16.     lambda x: x.split(' IMG')[0]))
  17. cohort1_ambiguous = cohort1_predict_control & cohort1_predict_pCT
  18. cohort2_ambiguous = cohort2_predict_control & cohort2_predict_pCT
  19. cohort1_control = cohort1_predict_control - cohort1_ambiguous
  20. cohort2_control = cohort2_predict_control - cohort2_ambiguous
  21. cohort1_pCT = cohort1_predict_pCT - cohort1_ambiguous
  22. cohort2_pCT = cohort2_predict_pCT - cohort2_ambiguous
  23. true_control = set(meta[meta['Morbidity outcome'] == 'Control'].index)
复制代码
这里我们注意到,原始数据中控制组分为两部分,其中一部分为Community-acquired pneumonia即普通肺炎,另一部分为Control,我们就暂且当作完全健康的CT
那么根据文章所给出的定义,按道理来说pCT应该在完全健康组中不存在
  1. Counter(meta['Morbidity outcome'])
  2. Counter({'Regular': 542,
  3.          'Severe': 170,
  4.          'Suspected': 259,
  5.          'Control': 96,
  6.          'Control (Community-acquired pneumonia)': 218,
  7.          'Mild': 22,
  8.          'Critically ill': 35})
  9. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort1_pCT)
  10. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort1_ambiguous)
  11. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort1_control)
  12. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort2_pCT)
  13. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort2_ambiguous)
  14. len(set(meta[meta['Morbidity outcome'] == 'Control'].index) & cohort2_control)
复制代码
  11
15
70
0
0
0
  在健康对照中根据输出的结果,有大约10%左右的假阳性
另外,我们还得关心一下阳性率的情况
  1. disease = list(
  2.     set(meta[(meta['Morbidity outcome'] != 'Control') & (
  3.         meta['Morbidity outcome'] != 'Control (Community-acquired pneumonia)') & (
  4.         meta['Morbidity outcome'] != 'Suspected')
  5.     ].index))
  6. len(set(disease) & cohort1_control)
  7. len(set(disease) & cohort2_control)
复制代码
  212 102
211+102 / 769
  从上面的输出结果来看,阳性率大约为60%左右,还是可以接受的
因此,可以开始训练第二个模型
  1. Best Regards,  
  2. Yuan.SH  
  3. ---------------------------------------
  4. School of Basic Medical Sciences,  
  5. Fujian Medical University,  
  6. Fuzhou, Fujian, China.  
  7. please contact with me via the following ways:  
  8. (a) e-mail :yuansh3354@163.com
复制代码
来源:https://blog.csdn.net/qq_40966210/article/details/114021462
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

提醒:禁止复制他人回复等『恶意灌水』行为,违者重罚!
您需要登录后才可以回帖 登录 | 注册[Register] 手机动态码快速登录 微信登录

本版积分规则

发布主题 快速回复 收藏帖子 返回列表 客服中心 搜索
简体中文 繁體中文 English 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french

QQ|RSS订阅|小黑屋|处罚记录|手机版|联系我们|Archiver|医工互联 |粤ICP备2021178090号 |网站地图

GMT+8, 2024-9-20 06:17 , Processed in 0.288797 second(s), 61 queries .

Powered by Discuz!

Copyright © 2001-2023, Discuz! Team.

快速回复 返回顶部 返回列表