医工互联

 找回密码
 注册[Register]

手机动态码快速登录

手机号快速登录

微信登录

微信扫一扫,快速登录

QQ登录

只需一步,快速开始

查看: 149|回复: 0
收起左侧

利用tensorFlow api 识别手术器械

[复制链接]

  离线 

发表于 2023-2-18 20:22:12 来自手机 | 显示全部楼层 |阅读模式 <
数据集来源:18岁NIPS Workshop一作,用目标检测评估手术技能点
Tensorflow-model:(选中master,点击tag选择和自己tensorflow适配的版本)
https://github.com/tensorflow/models/
包含手术器械数据集及tensorflow-model所需文件
链接: https://pan.baidu.com/s/1eaUsVQEz0-SK_ADJDhw2Ng
提取码:nj4x
文件目录:
  1. dataset\
  2.         ├─ssd_mobilenet_v1_coco_2018_01_28
  3.         ├─faster_rcnn_resnet101_coco_2018_01_28
  4.                 ├── checkpoint
  5.                 ├── frozen_inference_graph.pb
  6.                 ├── model.ckpt.data-00000-of-00001
  7.                 ├── model.ckpt.index
  8.                 ├── model.ckpt.meta
  9.                 ├── pipeline.config
  10.         ├─output
  11.         ├─tf_text_data
  12.             ├─dataset_test.record
  13.             ├─dataset_train.record
  14.             └─dataset_val.record
  15.         ├─tf_text_graph
  16.             ├─tf_text_graph_common.py
  17.             ├─tf_text_graph_faster_rcnn.py
  18.             └─tf_text_graph_ssd.py
  19.         └─VOCdevkit
  20.                     └─VOC2007
  21.                 ├─Annotations
  22.                 ├─ImageSets
  23.                 │  └─Main
  24.                 ├─JPEGImages
  25.                 └─classfier.py
复制代码
然后将其解压到tensorflow-model文件下即可。
 
前言

识别手术器械效果:
   
1.png
   
2.png
  tensorflow-model的配置可自行百度。
 
一、下载coco-trained models:

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/tf1_detection_zoo.md
3.png

把下载的模型放到dataset文件夹中,例如:
faster_rcnn_resnet101_coco_2018_01_28.tar.gz
然后解压:
  1. tar -vxf faster_rcnn_resnet101_coco_2018_01_28.tar.gz -C .
复制代码
生成faster_rcnn_resnet101_coco_2018_01_28文件夹目录
  1. faster_rcnn_resnet101_coco_2018_01_28/
  2. ├── checkpoint
  3. ├── frozen_inference_graph.pb
  4. ├── model.ckpt.data-00000-of-00001
  5. ├── model.ckpt.index
  6. ├── model.ckpt.meta
  7. ├── pipeline.config
  8. └── saved_model
  9.     ├── saved_model.pb
  10.     └── variables
复制代码
将里面的pipeline.config复制到dataset文件夹中,并按照本文件夹的名字重新命名(这样的目的是为了规范,当你拥有很多个模型时就不会乱)
 
二、数据集的处理

1.将数据集放到JPEGImages文件夹中,xml标签文件放到Annotations文件夹中
2.执行:
  1. python classfier.py
复制代码
生成ImageSets/Main文件夹下的三个文件:test.txt ,train.txt,val.txt以及一个trainval.txt
 
3.修改dataset_label_map.pbtxt,有几类写几类,里面的标签必须和xml文件中的一致
 
4.cd到dataset目录,生成tfrecord,包括验证集,测试集和训练集

  1. python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=val --output_path=./tf_text_data/dataset_val.record
复制代码
  1. python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=val --output_path=./tf_text_data/dataset_test.record
复制代码
  1. python ./create_dataset_tf_record.py --data_dir=./VOCdevkit --year=VOC2007 --label_map_path=./dataset_label_map.pbtxt --set=train --output_path=./tf_text_data/dataset_train.record
复制代码

 
 
三、修改pipeline.config文件

cd到下载模型文件中,可以看到pipeline.config
总共包括五个部分:
1.model:
  1. 主要修改:
  2.         num_classes(分类的类别数)
  3.             image_resizer {
  4.   fixed_shape_resizer {
  5.     height: 300
  6.     width: 300      (输入网络图像的大小尺寸,一般默认)
  7.   }
  8. }
复制代码
2.train_config:
主要修改:
  1. batch_size: 64(按照电脑的配置来定)
  2.        
  3.         ############################################################################
  4.         #########       如果使用faster-rcnn,batch_size只能设为1      ##############
  5.         ############################################################################
  6.        
  7. initial_learning_rate: 0.0023(初始学习率,可以自己调节)
  8. decay_steps: 600(每多少个steps变化一次学习率)
  9. decay_factor: 0.96(每次变化的学习率:current_learning_rate = decay_factor*initial_learning_rate)
  10. num_steps: 50000(训练的次数)
  11. fine_tune_checkpoint: "#####/output/model.ckpt-18508"(迁移学习的模型文件)
复制代码
3.train_input_reader:
  1. tf_record_input_reader {
  2.         input_path: "#####/tf_text_data/dataset_train.record"(训练集的位置)
  3. }
  4. label_map_path: "#####/dataset_label_map.pbtxt"(标签的定义文件)
复制代码
4.eval_config:
  1.         metrics_set:"coco_detection_metrics"(测量的方式,主要是用mAP来度量)
  2.         num_examples: 300(验证集的数量)
  3.         max_evals: 1(验证的循环次数)
  4.         use_moving_averages:false(采用滑动平均)
复制代码
5.eval_input_reader:
  1.        
  2.         tf_record_input_reader {
  3.             input_path: "#####/tf_text_data/dataset_val.record"(验证集的位置)
  4.          }
  5.         label_map_path: "#####/dataset_label_map.pbtxt"(标签的定义文件)
复制代码
四、训练

对于ssd:
  1. python ../research/object_detection/legacy/train.py --logtostderr --train_dir=./output --pipeline_config_path=./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config
复制代码
对于faster-rcnn:
  1. python ../research/object_detection/legacy/train.py  --logtostderr --train_dir=./output --pipeline_config_path=./faster_rcnn_resnet101_coco_2018_01_28/pipeline.config
复制代码

 
 
五、保存节点pb

ssd算法:
  1. python ../research/object_detection/export_inference_graph.py input_type image_tensor --pipeline_config_path ./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config --trained_checkpoint_prefix ./output/model.ckpt-50000 --output_directory ./output
复制代码
faster-rcnn算法:
  1. python ../research/object_detection/export_inference_graph.py input_type image_tensor --pipeline_config_path ./faster_rcnn_resnet101_coco_2018_01_28/pipeline.config --trained_checkpoint_prefix ./output/model.ckpt-0 --output_directory ./output
复制代码
六、生成pbtxt文件,opencv调用需要

生成pbtxt的py文件位于opencv源码文件/samples/dnn中
地址:
https://github.com/opencv/opencv/tree/4.5.2/samples/dnn/tf_text_graph_ssd.py
OpenCV调用示例:https://github.com/opencv/opencv/blob/4.5.2/samples/dnn/object_detection.cpp

  1. python ./tf_text_graph/tf_text_graph_ssd.py --input=./output/frozen_inference_graph.pb --output=./output/frozen_inference_graph.pbtxt  --config=./ssd_mobilenet_v1_coco_2018_01_28/pipeline.config
复制代码

七、测试图片

打开object_detection_test.py文件,修改:
模型位置:
  1. MODEL_NAME = './output'
复制代码
模型名
  1. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'
复制代码
模型标签文件
  1. PATH_TO_LABELS = './dataset_label_map.pbtxt'
复制代码
类别数
  1. NUM_CLASSES = 7
复制代码

  1. python ./object_detection_test.py --image ./VOCdevkit/VOC2007/JPEGImages/v03_062250.jpg
复制代码
或者:
  1. python ./object_detection_test.py  --data ./VOCdevkit/VOC2007/ImageSets/Main/val.txt
复制代码


来源:https://blog.csdn.net/qq_42995327/article/details/117639785
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
回复

使用道具 举报

提醒:禁止复制他人回复等『恶意灌水』行为,违者重罚!
您需要登录后才可以回帖 登录 | 注册[Register] 手机动态码快速登录 微信登录

本版积分规则

发布主题 快速回复 收藏帖子 返回列表 客服中心 搜索
简体中文 繁體中文 English 한국 사람 日本語 Deutsch русский بالعربية TÜRKÇE português คนไทย french

QQ|RSS订阅|小黑屋|处罚记录|手机版|联系我们|Archiver|医工互联 |粤ICP备2021178090号 |网站地图

GMT+8, 2024-11-21 23:33 , Processed in 0.276530 second(s), 66 queries .

Powered by Discuz!

Copyright © 2001-2023, Discuz! Team.

快速回复 返回顶部 返回列表