医工互联

 找回密码
 注册[Register]

手机动态码快速登录

手机号快速登录

微信登录

微信扫一扫,快速登录

QQ登录

只需一步,快速开始

查看: 212|回复: 0
收起左侧

新冠肺炎CT辅助诊断文献实战-01

[复制链接]

  离线 

发表于 2022-10-23 02:13:10 | 显示全部楼层 |阅读模式 <
新冠肺炎CT辅助诊断文献实战-01

   https://www.nature.com/articles/s41551-020-00633-5
2020年11月,华中科技大学发表在Nature Biomedical Engineering
如果不想看文献的话可以看我写的文献导读(文献导读001)基于图像识别技术的cov19辅助诊断 - 知乎 (zhihu.com)
建议还是看一下,文章不难,我也只用了30分钟就大概的看了一遍
  由于上一个系列——深度学习与单细胞,我自己在写教程运行的过程中发现了一系列的问题导致目前暂时无法更新。然后最近老师推荐看来一篇关于医学图像识别辅助诊断新冠肺炎的文章,看了一下文章发现特别简单,于是就打算利用这篇文章继续我的深度学习系列教程。
   这个系列教程与深度学习和生物信息学的应用暂时无关,就是非常基础的pytorch实战和一系列的问题。
网络上其实可以搜索到很多相关的教程,只是我这个是针对医学数据而已哈
  数据下载

   iCTCF - CT images and clinical features for COVID-19 (biocuckoo.cn)
  一共有两套数据需要下载:


  • 已经标记好的label数据,用于构建第一个模型
  • cohort1 和cohort2
cohort1,cohort2的下载略微有点麻烦得一个个点,这里的话可以使用迅雷自带的批量下载功能
如果是ubuntu系统的话,可以使用如下代码批量下载:
  1. # 这里本身就是国内源,不要做奇奇怪怪的操作
  2. for i in {1..1521}  
  3. do  
  4.      wget -c http://ictcf.biocuckoo.cn/patient/CT/Patient%20{$i}.zip
  5. done
复制代码
背景

根据文章的神经网络框架,该框架分为三个部分其中,第一部分和第二部分是用图像数据训练,第三部分是用临床特征数据进行训练。
由于CT图像一次性会扫描出很多张图片,但是并不是每一张图片都是有用的。因此,医生在对CT图像做出判断决策时候通常需要以下几个步骤:

  • 从大量的CT图像中挑选出包涵信息的图片
  • 从这些图片中挑选出可能与疾病相关的图片
  • 从可能与疾病相关的图片中做出决策
上面的三个步骤又可以总结成为两个步骤,即:


  • 挑选图片
  • 做出决策
那么深度学习技术的辅助诊断也是从这两个部分进行的。那么,在完全不考虑经费和计算资源的情况下那么辅助诊断的实施就非常的简单。直接整合所有的CT图像数据和clinical outcomes 建模即可,但是,显然这种做法并不现实,因为CT图像通常一个样本产生的数据非常多,并且由于噪声越大,所需要的模型也越大,训练时长也越久。因此,文章的做法是将两个模型分开做,即构建一个挑选图片的模型,将挑选后的图片再次构建一个clinical outcomes 相关的模型,这个构思也是目前辅助诊断中最为常用的模型。此外,将步骤分开的话,由于单个任务的复杂度低,因此构建模型也相对简单。
正文

那么,我们先按照文献给的思路一步步的构建模型,先构建第一个模型
   文章使用的是kears,我用的是pytorch哈,由于文章中有些部分我感觉并不是很合理,因此我打算稍微改动一下
  首先初始化配置
  1. #### no mean but to fold the code
  2. import os
  3. import math
  4. import itertools
  5. import numpy as np
  6. import pandas as pd2
  7. import matplotlib.pyplot as plt
  8. from collections import Counter
  9. from matplotlib import cm
  10. import cv2
  11. import torch
  12. import torchvision
  13. import torch.utils.data as Data
  14. import torch.nn as nn
  15. import torch.nn.functional as F
  16. import pytorch_lightning as pl
  17. from sklearn.preprocessing import StandardScaler,MinMaxScaler
  18. from sklearn.model_selection import train_test_split
  19. import warnings
  20. from myexptorch import *
  21. %matplotlib inline
  22. %config InlineBackend.figure_format = 'svg'
  23. warnings.filterwarnings("ignore")
  24. # if GPU avaliabel, data to
  25. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26. def deviceChange(x):
  27.     if (torch.cuda.is_available() & x == 'cuda'):
  28.         device = torch.device('cuda')
  29.     else:
  30.         device = torch.device('cpu')
  31.     return device
  32. # =============================================================================
  33. # the following code is common pytorch formula
  34. # =============================================================================
  35. # change expr to tensor
  36. def expToTorch(x):
  37.     return torch.from_numpy(np.array(x)).float()
  38. # change label to tensor
  39. def labelTotorch(y):
  40.     return torch.LongTensor(y)
  41. # data to data iter
  42. def makeDataiter(x,y,batch_size,shuffle=True):
  43.     return Data.DataLoader(Data.TensorDataset(x, y), batch_size, shuffle=shuffle)
  44. # scale
  45. def DataFrame_normalize(my_data,std_method):
  46.     if std_method == 'minmax':
  47.         method = MinMaxScaler()
  48.     if std_method == 'std':
  49.         method = StandardScaler()
  50.     my_data = pd.DataFrame(method.fit(my_data).transform(
  51.         my_data),columns=my_data.columns,index=my_data.index)
  52.     return my_data
  53. def toOneHot(ylabels,n_class):
  54.     onehot = torch.zeros(ylabels.shape[0],n_class)
  55.     index = torch.LongTensor(ylabels).view(-1,1)
  56.     onehot.scatter_(dim=1, index=index, value=1)
  57.     return onehot
  58. def toLabel(ylabels):
  59.     return torch.topk(ylabels, 1)[1].squeeze(1)
  60. # show tenor to picture
  61. def show_tensor_pictures(x, y):
  62.     plt.figure(figsize=(12,12))
  63.     fig, axes = plt.subplots(1, len(x))
  64.     for ax, image, label in zip(axes, x, y):
  65.         ax.imshow(image.view((28, 28)).numpy())
  66.         ax.set_title(label)
  67.         ax.axes.get_xaxis().set_visible(False)
  68.         ax.axes.get_yaxis().set_visible(False)
  69.     plt.show()
  70. def show_batch(imgs):
  71.     grid = torchvision.utils.make_grid(imgs,nrow=5)
  72.     plt.imshow(grid.numpy().transpose((1, 2, 0)))
  73.     plt.title('Batch from dataloader')
复制代码
导入数据

pytorch的数据导入方法非常的方便,还会自动添加label,导入的数据仅需要满足如下要求:


  • 创建一个用于储存数据的文件夹 (totallabeldata)
  • 在该文件夹下将同label的数据存放在相同的文件夹下,并命名
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-x33KBSOd-161********44)(https://tva1.sinaimg.cn/large/008eGmZEly1gnylnfyvh2j310r06ggm9.jpg)]
然后使用如下代码进行数据导入
  1. # load images to torch
  2. # plz put same class images to dif-folder
  3. train = torchvision.datasets.ImageFolder('../totallabeldata/',
  4.                                             transform=torchvision.transforms.Compose([
  5.                                                 torchvision.transforms.Grayscale(num_output_channels=1),
  6.                                                 torchvision.transforms.Scale(256),
  7.                                                 torchvision.transforms.CenterCrop(200),
  8.                                                 torchvision.transforms.ToTensor()])
  9.                                             )
  10. train_loader = torch.utils.data.DataLoader(train, batch_size=128,shuffle=True,num_workers =24)
  11. len(train_loader.dataset.targets)
  12. train_loader.dataset.class_to_idx
  13. """
  14. 在train_loader.dataset.class_to_idx的对象中记录了每个类别对应的class number
  15. """
  16. # [out] :
  17. # 19685
  18. # {'NiCT': 0, 'nCT': 1, 'pCT': 2}
复制代码
然后这里的话就面临着一个问题,就是没办法进行data的分割,总所周知,我们在进行建模的时候,通常要将数据分割成3份:training, validation, test,通常情况下比例为8:1:1。 因此要对数据进行如下操作
  1. # 目前我没有找到更好的方法了,如果有知道的朋友请留个言呗
  2. # Create empty list
  3. XRawData = []
  4. YRawData = []
  5. # from iteration extract batch
  6. for i,(x,y) in enumerate(train_loader):
  7.     XRawData.append(x)
  8.     YRawData.append(y)
  9. # concat all batch
  10. XRawData = torch.cat(XRawData)
  11. YRawData = torch.cat(YRawData)
  12. randoms = 42
  13. batch_size =150
  14. # Step.3 data split to train - test
  15. x_train, x_test, y_train, y_test = train_test_split(
  16.     XRawData, YRawData, stratify = YRawData,
  17.     test_size=0.1, random_state=randoms)
  18. # Step.4 train data split validation data
  19. x_train, x_val, y_train, y_val = train_test_split(
  20.     x_train, y_train, stratify = y_train,
  21.     test_size=0.1, random_state=randoms)
  22. # Step.6 transform to iterate object
  23. train_iter = makeDataiter(x_train, y_train, batch_size=batch_size,shuffle=True)
  24. val_iter = makeDataiter(x_val, y_val, batch_size=batch_size,shuffle=True)
  25. # visualization
  26. for i, (batch_x, batch_y) in enumerate(train_iter):
  27.     if(i<1):
  28.         show_batch(batch_x)
  29.         plt.axis('off')
  30.         plt.show()
  31. # class count
  32. Counter(YRawData.numpy())
  33. # [out]:
  34. # Counter({2: 4001, 1: 9979, 0: 5705})
复制代码
构建模型(按文章)

  1. class VGG_Simple(pl.LightningModule):
  2.     train_epoch_loss = []
  3.     train_epoch_acc = []
  4.     train_epoch_aucroc = []
  5.     val_epoch_loss = []
  6.     val_epoch_acc = []
  7.     val_epoch_aucroc = []
  8.     test_predict = []
  9.     test_sample_label = []
  10.     test_decoder = []
  11.     test_digitcaps = []
  12.     test_matrix = []
  13.     test_conv = []
  14.     test_primary = []
  15.     test_decoder = []
  16.     def __init__(self,myloss = None):
  17.         super(VGG_Simple, self).__init__()
  18.         self.myloss = myloss
  19.         self.cov = nn.Sequential(
  20.             nn.Conv2d(in_channels=1,out_channels=64,kernel_size=3,stride=1),
  21.             nn.ReLU(inplace=True),
  22.             nn.Conv2d(in_channels=64,out_channels=64,kernel_size=3,stride=1),
  23.             nn.ReLU(inplace=True),
  24.             nn.MaxPool2d((2, 2)),
  25.             nn.Conv2d(in_channels=64,out_channels=32,kernel_size=3,stride=1),
  26.             nn.ReLU(inplace=True),
  27.             nn.Conv2d(in_channels=32,out_channels=32,kernel_size=3,stride=1),
  28.             nn.ReLU(inplace=True),
  29.             nn.MaxPool2d((2, 2)),   
  30.             nn.Conv2d(in_channels=32,out_channels=16,kernel_size=3,stride=1),
  31.             nn.ReLU(inplace=True),
  32.             nn.Conv2d(in_channels=16,out_channels=16,kernel_size=3,stride=1),
  33.             nn.ReLU(inplace=True),
  34.             nn.MaxPool2d((2, 2))      
  35.         )
  36.         self.dnn = nn.Sequential(
  37.             nn.Linear(7056, 64),
  38.             nn.ReLU(inplace=True),
  39.             nn.Dropout(0.5),
  40.             nn.Linear(64, 32),
  41.             nn.ReLU(inplace=True),
  42.             nn.Dropout(0.5),
  43.             nn.Linear(32, 3)
  44.         )
  45.     def forward(self, x):
  46.         # x.size(0)即为batch_size
  47.         in_size = x.size(0)
  48.         out = self.cov(x)
  49.         out = out.view(in_size, -1)
  50.         out = self.dnn(out)
  51.         return out
  52.     # 3. 定义优化器
  53.     def configure_optimizers(self):
  54.         optimizer = torch.optim.Adam(self.parameters(), lr=0.001,weight_decay=0.05)
  55.         return optimizer
  56.     # 4. 训练loop
  57.     def training_step(self, train_batch, batch_ix):
  58.         vector_sample, sample_label = train_batch
  59.         probs = self.forward(vector_sample)
  60.         loss  = self.myloss(probs, sample_label)
  61.         acc = pl.metrics.functional.accuracy(probs, sample_label)
  62.         mylogdict = {'loss': loss, 'log': {'train_loss': loss, 'train_acc': acc}}
  63.         return mylogdict
  64.     def validation_step(self, validation_batch, batch_ix):
  65.         vector_sample, sample_label = validation_batch
  66.         probs= self.forward(vector_sample)
  67.         loss  = self.myloss(probs, sample_label)
  68.         val_acc = pl.metrics.functional.accuracy(probs, sample_label)
  69.         self.log_dict({'val_loss': loss, 'val_acc': val_acc})
  70.         mylogdict = {'log': {'val_loss': loss, 'val_acc': val_acc}}
  71.         return mylogdict
  72.     def test_step(self, test_batch, batch_ix):
  73.         vector_sample, sample_label = test_batch
  74.         probs= self.forward(vector_sample)
  75.         self.test_predict.append(probs.cpu())
  76.         self.test_sample_label.append(sample_label.cpu())
  77.         return {'test': 0}
  78.     def training_epoch_end(self, output):
  79.         train_loss = sum([out['log']['train_loss'].item() for out in output]) / len(output)
  80.         self.train_epoch_loss.append(train_loss)
  81.         train_acc = sum([out['log']['train_acc'].item() for out in output]) / len(output)
  82.         self.train_epoch_acc.append(train_acc)
  83.         print(train_acc)
  84.         return train_loss
  85.     def validation_epoch_end(self, output):
  86.         val_loss = sum([out['log']['val_loss'].item() for out in output]) / len(output)
  87.         self.val_epoch_loss.append(val_loss)
  88.         val_acc = sum([out['log']['val_acc'].item() for out in output]) / len(output)
  89.         self.val_epoch_acc.append(val_acc)
  90.         print('mean_val_loss: ', val_loss, '\t', 'mean_val_acc: ', val_acc)
  91.         return val_loss
复制代码
然后开始训练模型

  1. # training
  2. traning = True
  3. # set seed
  4. pl.utilities.seed.seed_everything(seed=42)
  5. # Defined function
  6. myloss = nn.CrossEntropyLoss()
  7. model = VGG_Simple(myloss=myloss)
  8. if traning:
  9.     # output file
  10.     OUTPUT_DIR = './lightning_logs'
  11.     tb_logger = pl.loggers.TensorBoardLogger(save_dir='./',
  12.                                              name=f'check_point')
  13.     # set check point to choose best model
  14.     checkpoint_callback = pl.callbacks.ModelCheckpoint(
  15.         dirpath=tb_logger.log_dir,
  16.         filename='{epoch}-{val_acc:.5f}-{val_loss:.5f}',
  17.         save_top_k = 15, #保留最优的15
  18.         monitor='val_acc', # check acc
  19.         mode='auto'
  20.     )
  21.     # train loop
  22.     trainer = pl.Trainer(gpus=-1,
  23.                          callbacks=[checkpoint_callback],
  24.                          max_epochs = 100,
  25.                          gradient_clip_val=0,
  26.                          auto_lr_find=False)# 开始训练
  27.     trainer.fit(model, train_iter, val_iter)
  28.     from pathlib import Path
  29.     out = Path(tb_logger.log_dir)
  30.     print(out)
  31.     [ckpt.stem for ckpt in out.iterdir()]
复制代码
但是,如果你这样训练那么会出现如下的问题:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p21vXGnK-161********46)(https://tva1.sinaimg.cn/large/008eGmZEly1gnylnhvxatj30u00t642p.jpg)]
不管如何训练,loss不会改变,accuracy也不会改变,那么这是什么原因呢?常见的有如下两种原因:

  • 模型复杂度不够
  • 模型一开始就过拟合
当然,在这次训练中,这两种情况可以完全排除,理由很简单,我用的模型和文章一模一样
那么,我们就要从loss函数或者数据本生出发,这里我们要注意一下小细节,在之前我们进行class counter的时候的输出
  1. # class count
  2. Counter(YRawData.numpy())
  3. # [out]:
  4. # Counter({2: 4001, 1: 9979, 0: 5705})
复制代码
发现,样本量的差距着实有点大,但是说大把也不是很大,可能在运行三五百次的epoch应该是完全没有什么太大的问题
但是为了节约时间我们要对loss进行加权
  1. class_weight = torch.FloatTensor([2,5,2]).cuda()
  2. myloss = nn.CrossEntropyLoss(weight=class_weight)
复制代码
最终修改代码如下
  1. # training traning = True# set seed pl.utilities.seed.seed_everything(seed=42)# Defined function class_weight = torch.FloatTensor([2,5,2]).cuda()
  2. myloss = nn.CrossEntropyLoss(weight=class_weight)
  3. model = VGG_Simple(myloss=myloss)if traning:    # output file    OUTPUT_DIR = './lightning_logs'    tb_logger = pl.loggers.TensorBoardLogger(save_dir='./',                                             name=f'check_point')    # set check point to choose best model     checkpoint_callback = pl.callbacks.ModelCheckpoint(        dirpath=tb_logger.log_dir,        filename='{epoch}-{val_acc:.5f}-{val_loss:.5f}',        save_top_k = 15, #保留最优的15        monitor='val_acc', # check acc         mode='auto'    )    # train loop     trainer = pl.Trainer(gpus=-1,                          callbacks=[checkpoint_callback],                         max_epochs = 100,                         gradient_clip_val=0,                         auto_lr_find=False)# 开始训练    trainer.fit(model, train_iter, val_iter)    from pathlib import Path    out = Path(tb_logger.log_dir)    print(out)    [ckpt.stem for ckpt in out.iterdir()]
复制代码
输出结果如下
031502ojzoe3jjnvxnebiu.jpeg

我们可以看到,在第8次epoch就已经达到92%的准确率了,但是,还有一个问题,这个准确率还是太低了。
考虑到,CT数据的图像是灰度图,且图像的大致形状也差不多,感觉各个医院拍出来的图也差不了多少(自己认为的),所以我认为,模型中的dropout层是没有用的
   这里是我觉得文章不妥的第一个点
  1. model.train_epoch_acc[40:44]
  2. # [out]
  3. # [ 0.8633503256557143,
  4. # 0.8609685636012354,
  5. # 0.862724444576513,
  6. # 0.8625120386899074]
复制代码
我们发现train的过程准确率也才86%
在删除dropout后,准确率得到显著的提升
031503ou45liuawxiw6nz6.jpeg

  1. model.train_epoch_acc[40:44]
  2. # [out]
  3. #[0.9764486033225728,
  4. # 0.9797252973663473,
  5. # 0.979190********92,
  6. # 0.9742679161446117]
复制代码
验证模型

  1. # load model parameter
  2. version = 'version_14' # choose best model path
  3. path = os.path.join(os.getcwd(),'check_point',version)
  4. ckpt = str(os.listdir(path)[2]) # choose best model
  5. print(ckpt)
  6. # get best model parameter
  7. ckpt = os.path.join(path,ckpt)
  8. model = VGG_Simple(myloss=myloss)
  9. trainer = pl.Trainer(resume_from_checkpoint=ckpt, gpus=-1)
  10. # create test_iter
  11. test_iter = makeDataiter(x_test, y_test, batch_size=batch_size,shuffle=False)
  12. trainer.test(model, test_iter)
  13. # get test accuracy
  14. predict = torch.cat(model.test_predict)
  15. y = torch.cat(model.test_sample_label)
  16. predict = toLabel(predict).cpu().numpy()
  17. y = y.cpu().numpy()
  18. sum( y == predict) / len(y)
  19. # [out]:
  20. # 0.9862874555611986
复制代码
测试集准确率98%,然后这里要注意一下,文章用的是AUC去评估,对于像这种的3分类任务,使用AUC评估很不直观,应该使用混淆矩阵或者accuarcy进行评估
   这里是我觉得文章不妥的第二个点
  1. Best Regards,  
  2. Yuan.SH;
  3. School of Basic Medical Sciences,  
  4. Fujian Medical University,  
  5. Fuzhou, Fujian, China.  
  6. please contact with me via the following ways:  
  7. (a) e-mail :yuansh3354@163.com  
复制代码
来源:https://blog.csdn.net/qq_40966210/article/details/114021421
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

提醒:禁止复制他人回复等『恶意灌水』行为,违者重罚!
您需要登录后才可以回帖 登录 | 注册[Register] 手机动态码快速登录 微信登录

本版积分规则

发布主题 快速回复 收藏帖子 返回列表 客服中心 搜索
简体中文 繁體中文 English 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french

QQ|RSS订阅|小黑屋|处罚记录|手机版|联系我们|Archiver|医工互联 |粤ICP备2021178090号 |网站地图

GMT+8, 2024-11-23 00:59 , Processed in 0.303884 second(s), 66 queries .

Powered by Discuz!

Copyright © 2001-2023, Discuz! Team.

快速回复 返回顶部 返回列表