医工互联

 找回密码
 注册[Register]

手机动态码快速登录

手机号快速登录

微信登录

微信扫一扫,快速登录

QQ登录

只需一步,快速开始

查看: 144|回复: 0
收起左侧

【手把手教你】搭建神经网络(CT扫描3D图像的分类)

[复制链接]

  离线 

发表于 2022-9-23 19:25:56 | 显示全部楼层 |阅读模式 <
203016ca0cc116mar141o4.jpeg

大家好,我是羽峰,今天要和大家分享的是一个基于tensorflow的CT扫描3D图像的分类。文章会把整个代码进行分割讲解,完整看完,相信你一定会有所收获。
欢迎关注“羽峰码字
目录
1. 项目简介
2. API准备
3. 数据集准备
3.1 下载数据
3.2 数据预处理
3.3 建立训练和验证数据集
3.4 数据增强
4 模型构建
4.1 定义3D卷积神经网络
4.2 训练模型
4.3 验证模型
5. 使用模型进行预测
References

1. 项目简介

此示例将显示构建3D卷积神经网络(CNN)以预测计算机断层扫描(CT)扫描中病毒性肺炎的存在所需的步骤。 2D CNN通常用于处理RGB图像(3通道)。 3D CNN只是3D等效项:它以3D体积或2D帧序列(例如CT扫描中的切片)为输入,因此3D CNN是学习体积数据表示的强大模型。
2. API准备

  1. import os
  2. import zipfile
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow import keras
  6. from tensorflow.keras import layers
复制代码
3. 数据集准备

3.1 下载数据

下载MosMedData:具有COVID-19相关发现的胸部CT扫描
在此示例中,我们使用了MosMedData: Chest CT Scans with COVID-19 Related Findings。 该数据集包含具有COVID-19相关发现以及没有发现的肺部CT扫描。
我们将使用CT扫描的相关放射学发现作为标记,以建立分类器来预测病毒性肺炎的存在。 因此,该任务是二进制分类问题。
  1. # Download url of normal CT scans.
  2. url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip"
  3. filename = os.path.join(os.getcwd(), "CT-0.zip")
  4. keras.utils.get_file(filename, url)
  5. # Download url of abnormal CT scans.
  6. url = "https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip"
  7. filename = os.path.join(os.getcwd(), "CT-23.zip")
  8. keras.utils.get_file(filename, url)
  9. # Make a directory to store the data.
  10. os.makedirs("MosMedData")
  11. # Unzip data in the newly created directory.
  12. with zipfile.ZipFile("CT-0.zip", "r") as z_fp:
  13.     z_fp.extractall("./MosMedData/")
  14. with zipfile.ZipFile("CT-23.zip", "r") as z_fp:
  15.     z_fp.extractall("./MosMedData/")
复制代码
3.2 数据预处理

这些文件以Nifti格式提供,扩展名为.nii。 要读取扫描结果,我们使用nibabel软件包。 您可以通过pip install nibabel安装软件包。 CT扫描以Hounsfield单位(HU)存储原始体素强度。 在此数据集中,它们的范围从-1024到2000以上。 高于400的骨骼具有不同的放射强度,因此将其用作更高的界限。 通常将-1000到400之间的阈值用于归一化CT扫描。
要处理数据,我们执行以下操作:


  • 我们首先将体积旋转90度,因此方向是固定的
  • 我们将HU值缩放为介于0和1之间。
  • 我们调整宽度,高度和深度的大小。
在这里,我们定义了几个辅助函数来处理数据。 在构建训练和验证数据集时将使用这些功能。
  1. import nibabel as nib
  2. from scipy import ndimage
  3. def read_nifti_file(filepath):
  4.     """Read and load volume"""
  5.     # Read file
  6.     scan = nib.load(filepath)
  7.     # Get raw data
  8.     scan = scan.get_fdata()
  9.     return scan
  10. def normalize(volume):
  11.     """Normalize the volume"""
  12.     min = -1000
  13.     max = 400
  14.     volume[volume < min] = min
  15.     volume[volume > max] = max
  16.     volume = (volume - min) / (max - min)
  17.     volume = volume.astype("float32")
  18.     return volume
  19. def resize_volume(img):
  20.     """Resize across z-axis"""
  21.     # Set the desired depth
  22.     desired_depth = 64
  23.     desired_width = 128
  24.     desired_height = 128
  25.     # Get current depth
  26.     current_depth = img.shape[-1]
  27.     current_width = img.shape[0]
  28.     current_height = img.shape[1]
  29.     # Compute depth factor
  30.     depth = current_depth / desired_depth
  31.     width = current_width / desired_width
  32.     height = current_height / desired_height
  33.     depth_factor = 1 / depth
  34.     width_factor = 1 / width
  35.     height_factor = 1 / height
  36.     # Rotate
  37.     img = ndimage.rotate(img, 90, reshape=False)
  38.     # Resize across z-axis
  39.     img = ndimage.zoom(img, (width_factor, height_factor, depth_factor), order=1)
  40.     return img
  41. def process_scan(path):
  42.     """Read and resize volume"""
  43.     # Read scan
  44.     volume = read_nifti_file(path)
  45.     # Normalize
  46.     volume = normalize(volume)
  47.     # Resize width, height and depth
  48.     volume = resize_volume(volume)
  49.     return volume
复制代码
让我们从类目录中读取CT扫描的路径。
  1. # Folder "CT-0" consist of CT scans having normal lung tissue,
  2. # no CT-signs of viral pneumonia.
  3. normal_scan_paths = [
  4.     os.path.join(os.getcwd(), "MosMedData/CT-0", x)
  5.     for x in os.listdir("MosMedData/CT-0")
  6. ]
  7. # Folder "CT-23" consist of CT scans having several ground-glass opacifications,
  8. # involvement of lung parenchyma.
  9. abnormal_scan_paths = [
  10.     os.path.join(os.getcwd(), "MosMedData/CT-23", x)
  11.     for x in os.listdir("MosMedData/CT-23")
  12. ]
  13. print("CT scans with normal lung tissue: " + str(len(normal_scan_paths)))
  14. print("CT scans with abnormal lung tissue: " + str(len(abnormal_scan_paths)))
复制代码
3.3 建立训练和验证数据集

从类目录中读取扫描并分配标签。 对扫描进行下采样以具有128x128x64的形状。 将原始HU值重新缩放为0到1的范围。最后,将数据集拆分为训练和验证子集。
  1. # Read and process the scans.
  2. # Each scan is resized across height, width, and depth and rescaled.
  3. abnormal_scans = np.array([process_scan(path) for path in abnormal_scan_paths])
  4. normal_scans = np.array([process_scan(path) for path in normal_scan_paths])
  5. # For the CT scans having presence of viral pneumonia
  6. # assign 1, for the normal ones assign 0.
  7. abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
  8. normal_labels = np.array([0 for _ in range(len(normal_scans))])
  9. # Split data in the ratio 70-30 for training and validation.
  10. x_train = np.concatenate((abnormal_scans[:70], normal_scans[:70]), axis=0)
  11. y_train = np.concatenate((abnormal_labels[:70], normal_labels[:70]), axis=0)
  12. x_val = np.concatenate((abnormal_scans[70:], normal_scans[70:]), axis=0)
  13. y_val = np.concatenate((abnormal_labels[70:], normal_labels[70:]), axis=0)
  14. print(
  15.     "Number of samples in train and validation are %d and %d."
  16.     % (x_train.shape[0], x_val.shape[0])
  17. )
复制代码
3.4 数据增强

在训练过程中,CT扫描也可以通过以任意角度旋转来增强。 由于数据存储在形状(样本,高度,宽度,深度)的3级张量中,因此我们在轴4上添加了尺寸为1的尺寸,以便能够对数据执行3D卷积。 因此,新形状为(样本,高度,宽度,深度,1)。还有有各种各样的预处理和扩充技术请读者们查阅相关资料,此示例显示了一些简单的入门和增强技术。
  1. import random
  2. from scipy import ndimage
  3. @tf.function
  4. def rotate(volume):
  5.     """Rotate the volume by a few degrees"""
  6.     def scipy_rotate(volume):
  7.         # define some rotation angles
  8.         angles = [-20, -10, -5, 5, 10, 20]
  9.         # pick angles at random
  10.         angle = random.choice(angles)
  11.         # rotate volume
  12.         volume = ndimage.rotate(volume, angle, reshape=False)
  13.         volume[volume < 0] = 0
  14.         volume[volume > 1] = 1
  15.         return volume
  16.     augmented_volume = tf.numpy_function(scipy_rotate, [volume], tf.float32)
  17.     return augmented_volume
  18. def train_preprocessing(volume, label):
  19.     """Process training data by rotating and adding a channel."""
  20.     # Rotate volume
  21.     volume = rotate(volume)
  22.     volume = tf.expand_dims(volume, axis=3)
  23.     return volume, label
  24. def validation_preprocessing(volume, label):
  25.     """Process validation data by only adding a channel."""
  26.     volume = tf.expand_dims(volume, axis=3)
  27.     return volume, label
复制代码
在定义训练和验证数据加载器时,训练数据会通过和增强功能传递,该功能会随机旋转不同角度的体积。 请注意,训练和验证数据均已重新缩放为0到1之间的值。
  1. # Define data loaders.
  2. train_loader = tf.data.Dataset.from_tensor_slices((x_train, y_train))
  3. validation_loader = tf.data.Dataset.from_tensor_slices((x_val, y_val))
  4. batch_size = 2
  5. # Augment the on the fly during training.
  6. train_dataset = (
  7.     train_loader.shuffle(len(x_train))
  8.     .map(train_preprocessing)
  9.     .batch(batch_size)
  10.     .prefetch(2)
  11. )
  12. # Only rescale.
  13. validation_dataset = (
  14.     validation_loader.shuffle(len(x_val))
  15.     .map(validation_preprocessing)
  16.     .batch(batch_size)
  17.     .prefetch(2)
  18. )
复制代码
可视化CT增强图像
  1. import matplotlib.pyplot as plt
  2. data = train_dataset.take(1)
  3. images, labels = list(data)[0]
  4. images = images.numpy()
  5. image = images[0]
  6. print("Dimension of the CT scan is:", image.shape)
  7. plt.imshow(np.squeeze(image[:, :, 30]), cmap="gray")
复制代码
203017gbqbh25eu2qoedpp.png

由于CT扫描有很多切片,用序列形式进行排列显示
  1. def plot_slices(num_rows, num_columns, width, height, data):
  2.     """Plot a montage of 20 CT slices"""
  3.     data = np.rot90(np.array(data))
  4.     data = np.transpose(data)
  5.     data = np.reshape(data, (num_rows, num_columns, width, height))
  6.     rows_data, columns_data = data.shape[0], data.shape[1]
  7.     heights = [slc[0].shape[0] for slc in data]
  8.     widths = [slc.shape[1] for slc in data[0]]
  9.     fig_width = 12.0
  10.     fig_height = fig_width * sum(heights) / sum(widths)
  11.     f, axarr = plt.subplots(
  12.         rows_data,
  13.         columns_data,
  14.         figsize=(fig_width, fig_height),
  15.         gridspec_kw={"height_ratios": heights},
  16.     )
  17.     for i in range(rows_data):
  18.         for j in range(columns_data):
  19.             axarr[i, j].imshow(data[i][j], cmap="gray")
  20.             axarr[i, j].axis("off")
  21.     plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
  22.     plt.show()
  23. # Visualize montage of slices.
  24. # 4 rows and 10 columns for 100 slices of the CT scan.
  25. plot_slices(4, 10, 128, 128, image[:, :, :40])
复制代码
203017hnl1nkisnrn9zl37.png

4 模型构建

4.1 定义3D卷积神经网络

为了使模型更易于理解,我们将其构造为块。 本示例中使用的3D CNN的体系结构是基于本文的。
  1. def get_model(width=128, height=128, depth=64):
  2.     """Build a 3D convolutional neural network model."""
  3.     inputs = keras.Input((width, height, depth, 1))
  4.     x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
  5.     x = layers.MaxPool3D(pool_size=2)(x)
  6.     x = layers.BatchNormalization()(x)
  7.     x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
  8.     x = layers.MaxPool3D(pool_size=2)(x)
  9.     x = layers.BatchNormalization()(x)
  10.     x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
  11.     x = layers.MaxPool3D(pool_size=2)(x)
  12.     x = layers.BatchNormalization()(x)
  13.     x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
  14.     x = layers.MaxPool3D(pool_size=2)(x)
  15.     x = layers.BatchNormalization()(x)
  16.     x = layers.GlobalAveragePooling3D()(x)
  17.     x = layers.Dense(units=512, activation="relu")(x)
  18.     x = layers.Dropout(0.3)(x)
  19.     outputs = layers.Dense(units=1, activation="sigmoid")(x)
  20.     # Define the model.
  21.     model = keras.Model(inputs, outputs, name="3dcnn")
  22.     return model
  23. # Build model.
  24. model = get_model(width=128, height=128, depth=64)
  25. model.summary()
复制代码
4.2 训练模型

  1. # Compile model.
  2. initial_learning_rate = 0.0001
  3. lr_schedule = keras.optimizers.schedules.ExponentialDecay(
  4.     initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
  5. )
  6. model.compile(
  7.     loss="binary_crossentropy",
  8.     optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
  9.     metrics=["acc"],
  10. )
  11. # Define callbacks.
  12. checkpoint_cb = keras.callbacks.ModelCheckpoint(
  13.     "3d_image_classification.h5", save_best_only=True
  14. )
  15. early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)
  16. # Train the model, doing validation at the end of each epoch
  17. epochs = 100
  18. model.fit(
  19.     train_dataset,
  20.     validation_data=validation_dataset,
  21.     epochs=epochs,
  22.     shuffle=True,
  23.     verbose=2,
  24.     callbacks=[checkpoint_cb, early_stopping_cb],
  25. )
复制代码
203018f9okuuslblp66pfs.png

4.3 验证模型

在此绘制了训练和验证集的模型准确性和损失。 由于验证集是类平衡的,因此准确性提供了模型性能的公正表示。
  1. fig, ax = plt.subplots(1, 2, figsize=(20, 3))
  2. ax = ax.ravel()
  3. for i, metric in enumerate(["acc", "loss"]):
  4.     ax[i].plot(model.history.history[metric])
  5.     ax[i].plot(model.history.history["val_" + metric])
  6.     ax[i].set_title("Model {}".format(metric))
  7.     ax[i].set_xlabel("epochs")
  8.     ax[i].set_ylabel(metric)
  9.     ax[i].legend(["train", "val"])
复制代码
203019u047mc8z3dc3z74d.png

5. 使用模型进行预测

  1. # Load best weights.
  2. model.load_weights("3d_image_classification.h5")
  3. prediction = model.predict(np.expand_dims(x_val[0], axis=0))[0]
  4. scores = [1 - prediction[0], prediction[0]]
  5. class_names = ["normal", "abnormal"]
  6. for score, name in zip(scores, class_names):
  7.     print(
  8.         "This model is %.2f percent confident that CT scan is %s"
  9.         % ((100 * score), name)
  10.     )
复制代码
203019ygrhluv6hh09ee9h.png

 
References



  • A survey on Deep Learning Advances on Different 3D DataRepresentations(https://arxiv.org/pdf/1808.01462.pdf)
  • VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition(https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf)
  • FusionNet: 3D Object Classification Using MultipleData Representations(http://3ddl.cs.princeton.edu/2016/papers/Hegde_Zadeh.pdf)
  • Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction(https://arxiv.org/abs/2007.13224)
至此,今天的分享结束了,希望通过以上分享,你能学习到语义分割的基本流程,基本过程,与图像分割类似,但更具象化。强烈建议新手能按照上述步骤一步步实践下来,必有收获。
今天代码翻译于:https://keras.io/examples/vision/3D_image_classification/,新入门的小伙伴可以好好看看这个网站,很基础,很适合新手。
当然,这里不得不重点推荐一下这三个网站:
https://tensorflow.google.cn/tutorials/keras/classification
https://keras.io/examples
https://keras.io/zh/
其中keras中文网址中能找到各种API定义,都是中文通俗易懂,如果想看英文直接到https://keras.io/,就可以,这里也有很多案例,也是很基础明白。入门时可以看看。
我是羽峰,还在成长道路上摸爬滚打的小白,希望能够结识一起成长的你,公众号“羽峰码字”,欢迎来撩。

来源:https://blog.csdn.net/m0_37940804/article/details/116910858
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

提醒:禁止复制他人回复等『恶意灌水』行为,违者重罚!
您需要登录后才可以回帖 登录 | 注册[Register] 手机动态码快速登录 微信登录

本版积分规则

发布主题 快速回复 收藏帖子 返回列表 客服中心 搜索
简体中文 繁體中文 English 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french

QQ|RSS订阅|小黑屋|处罚记录|手机版|联系我们|Archiver|医工互联 |粤ICP备2021178090号 |网站地图

GMT+8, 2025-1-23 02:19 , Processed in 0.265360 second(s), 66 queries .

Powered by Discuz!

Copyright © 2001-2023, Discuz! Team.

快速回复 返回顶部 返回列表