目录

1、前言

2、实现思路

3、实验代码

3.1 环境配置

3.2 数据集

 3.3 训练

3.4 指标

3.5 推理

4、其他


1、前言

本章尝试将TransUnet和SAM结合,以期望达到更换的模型

TransUnet作为医学图像分割的基准,在许多数据集上均取得了很好的效果,然而最近SAM大模型的兴起,图像分割似乎有了新的方向

关于图像分割项目、sam模型复现参考本人其他专栏,这里之作简单介绍

TransUnet是一个专门为医学图像分割任务设计的深度学习模型。它是一种卷积神经网络(CNN),采用基于变压器的架构。TransUnet在具有相应分割掩模的大型医学图像数据集上进行训练,以学习如何从输入图像中准确分割器官、病变或其他结构。

TransUnet的一个关键优势是它能够处理医学图像中的大小物体。它通过结合CNN(擅长捕获局部空间信息)和变换器(擅长捕获全局上下文信息)的优点来实现这一点。这使得TransUnet能够有效地分割器官或其他结构,无论其大小如何。

TransUnet在多个医学图像分割基准测试中取得了最先进的性能。其高分割精度和多功能性使其成为各种医学应用的有前景的工具,如肿瘤检测、器官分割和疾病诊断。

总体而言,TransUnet代表了一种创新的医学图像分割方法,它结合了卷积神经网络和变换器的强大功能,在一系列医学成像任务中取得了卓越的结果。

而SAM大模型就是在推理的时候,加入人为的提示信息,例如boxes、point等等,这样sam在进行推理的时候,就会着重于指定部分的推理

2、实现思路

实现的思路很简单,在训练的时候,为每一个类别指定box,这样原来GRB 3通道的数据就会变成4个维度的数据(RGB+box)。而对于多类别的分割来说,标签前景是很多的,只需要随机取出来一个,然后根据前景自动获取box即可

每次取出一个类别

        label_ids = np.unique(mask)[1:]
        label_id = random.choice(label_ids.tolist())
        mask = np.uint8(mask == label_id)  # only one label

根据分割区域自动获取box提示信息

        y_indices, x_indices = np.where(mask > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)

        H, W = mask.shape
        x_min = max(0, x_min - random.randint(0, self.bbox_shift))
        x_max = min(W, x_max + random.randint(0, self.bbox_shift))
        y_min = max(0, y_min - random.randint(0, self.bbox_shift))
        y_max = min(H, y_max + random.randint(0, self.bbox_shift))
        bboxes = np.array([x_min, y_min, x_max, y_max])

主要注意的是,训练的时候,box是自动生成的

代码处理完,经过处理完,输送给网络的图像如下:

TIPS:数据里面的颜色都是白色(box是单通道的),为了可视化,这里才显示成红色的掩膜形式

为了消融试验,代码里还加了对比部分,然后增加了很多训练的tricks

3、实验代码

下载连接:https://download.csdn.net/download/qq_44886601/89878907

有偿下载

下载完目录如下:

readme 有详细的运行步骤,这里简单介绍

3.1 环境配置

建议用conda配置虚拟环境,参考:https://blog.csdn.net/qq_44886601/category_12573095.html

配置好环境后,一键安装库文件即可:

pip install -r requirements.txt

einops==0.8.0
matplotlib==3.7.5
monai==1.3.2
numpy==1.24.4
opencv_python==4.10.0.84
Pillow==10.4.0
torch==2.4.1
tqdm==4.66.5

3.2 数据集

数据集摆放如下:

这里测试的数据集是 MICCAI FLARE 腹部13器官分割,标签为:

{
        "0": "background",
        "1": "spleen",
        "2": "right kidney",
        "3": "left kidney",
        "4": "gallbladder",
        "5": "esophagus",
        "6": "liver",
        "7": "stomach",
        "8": "aorta",
        "9": "IVC",
        "10": "veins",
        "11": "pancreas",
        "12": "rad",
        "13": "lad"
    }

可视化结果如下:

 3.3 训练

训练参数如下,建议epoch尽量大点

网络的损失采用更好的分割损失:DiceCELoss

    parser.add_argument("--batch-size", default=8, type=int)
    parser.add_argument("--epochs", default=100, type=int)
    parser.add_argument("--optim", default='SGD',type=str, help='SGD、Adam、RMSProp')

    parser.add_argument('--lr', default=0.01, type=float)
    parser.add_argument('--lrf',default=0.001,type=float)                  # 最终学习率 = lr * lrf

    parser.add_argument("--img_f", default='.png', type=str)               # 数据图像的后缀
    parser.add_argument("--mask_f", default='.png', type=str)              # mask图像的后缀

    parser.add_argument("--imgSize", default=[224,224],help='image size')              # img size

训练过程:

训练完效果还行:

3.4 指标

训练结果全部在runs目录下

其中可视化的数据:

loss:

dice:

iou:

训练日志:

[train hyper-parameters: Namespace(batch_size=8, epochs=100, imgSize=[224, 224], img_f='.png', lr=0.01, lrf=0.001, mask_f='.png', optim='SGD')]

epoch  train_loss train_mdice    train_miou val_loss   val_mdice  val_miou   
1     0.3523    0.5368    0.4148    0.2399    0.7367    0.5924
2     0.2271    0.6468    0.5066    0.2148    0.7816    0.6471
3     0.2063    0.6895    0.5529    0.2038    0.7751    0.6413
4     0.191     0.7144    0.5866    0.1907    0.8239    0.7099
5     0.1798    0.7704    0.6461    0.1666    0.8635    0.7636
6     0.1648    0.8044    0.6869    0.1547    0.8561    0.7523
7     0.1573    0.8198    0.7065    0.1518    0.8644    0.7657
8     0.1525    0.8184    0.7129    0.1404    0.8924    0.8083
9     0.1453    0.8457    0.7444    0.1464    0.8815    0.7914
10    0.1495    0.833     0.7279    0.1372    0.8914    0.8065
11    0.1376    0.8517    0.7554    0.1449    0.8879    0.8027
12    0.1296    0.8727    0.7805    0.1442    0.8941    0.8125
13    0.136     0.8619    0.7656    0.1307    0.896     0.8147
14    0.1308    0.8782    0.7892    0.1295    0.9074    0.8328
15    0.1289    0.8913    0.8074    0.1198    0.891     0.8083
16    0.1264    0.8809    0.7932    0.1287    0.9018    0.8253
17    0.1231    0.885     0.7994    0.1239    0.9103    0.8377
18    0.1276    0.8789    0.7911    0.1213    0.9018    0.8236
19    0.1222    0.9011    0.823     0.1295    0.9101    0.8376
20    0.122     0.898     0.8189    0.12      0.9088    0.8365
21    0.1195    0.9083    0.8344    0.1132    0.9198    0.8536
22    0.1159    0.8992    0.8228    0.1208    0.9063    0.8314
23    0.1128    0.9107    0.8391    0.1059    0.9238    0.861
24    0.1104    0.9128    0.8423    0.1146    0.9178    0.8505
25    0.1072    0.9186    0.8516    0.1062    0.9188    0.8524
26    0.1118    0.9092    0.8366    0.1143    0.9139    0.8433
27    0.1089    0.9135    0.844     0.1146    0.913     0.8432
28    0.1033    0.916     0.8473    0.1056    0.9243    0.8608
29    0.107     0.9216    0.8561    0.1059    0.9179    0.8505
30    0.104     0.9119    0.8412    0.1043    0.919     0.8522
31    0.1038    0.9171    0.849     0.1055    0.9215    0.8562
32    0.1009    0.9275    0.866     0.1022    0.9275    0.8661
33    0.0963    0.9301    0.8708    0.1036    0.9202    0.8551
34    0.0983    0.928     0.8672    0.0999    0.9272    0.8661
35    0.1038    0.9145    0.8457    0.0986    0.9277    0.8667
36    0.0995    0.9275    0.8663    0.0978    0.9252    0.863
37    0.0984    0.925     0.8622    0.1023    0.9197    0.8543
38    0.1007    0.9251    0.8635    0.0953    0.9297    0.8702
39    0.0949    0.9297    0.8705    0.0981    0.9283    0.8684
40    0.0954    0.9247    0.8626    0.1026    0.9203    0.8541
41    0.0943    0.9296    0.8698    0.096     0.9278    0.867
42    0.0923    0.9358    0.8804    0.0913    0.9304    0.8711
43    0.0906    0.9318    0.8739    0.093     0.9366    0.8817
44    0.0945    0.9324    0.875     0.0932    0.9308    0.8724
45    0.0891    0.9358    0.8808    0.0881    0.9395    0.887
46    0.0884    0.9358    0.8805    0.0904    0.9353    0.8798
47    0.0875    0.9393    0.8869    0.0855    0.9371    0.8832
48    0.0908    0.9337    0.8774    0.0881    0.9351    0.8798
49    0.089     0.9349    0.8794    0.0924    0.929     0.8698
50    0.0845    0.9381    0.8852    0.0824    0.9413    0.8899
51    0.0846    0.9357    0.8807    0.0837    0.9399    0.888
52    0.0858    0.9407    0.8892    0.0847    0.9365    0.882
53    0.0837    0.9413    0.8901    0.0864    0.9384    0.8858
54    0.0856    0.9415    0.8904    0.088     0.9335    0.8776
55    0.0842    0.9413    0.8907    0.086     0.9395    0.8876
56    0.0849    0.9398    0.8878    0.086     0.9318    0.8739
57    0.0833    0.9446    0.896     0.0817    0.9436    0.8942
58    0.0831    0.945     0.8968    0.0809    0.9396    0.8869
59    0.0803    0.9444    0.8957    0.0798    0.9388    0.8867
60    0.0783    0.9461    0.8987    0.0824    0.9446    0.8961
61    0.0769    0.9466    0.8995    0.0836    0.9389    0.8869
62    0.0794    0.9486    0.9031    0.0768    0.9453    0.8974
63    0.0797    0.9484    0.9028    0.0818    0.9424    0.8922
64    0.0784    0.9448    0.8963    0.0808    0.9441    0.8955
65    0.075     0.9485    0.9031    0.0756    0.9455    0.8975
66    0.077     0.9488    0.9035    0.0799    0.9372    0.8848
67    0.0777    0.9447    0.8965    0.0789    0.9404    0.8895
68    0.0761    0.9486    0.9033    0.0787    0.9468    0.9001
69    0.0776    0.9501    0.9058    0.0827    0.942     0.8917
70    0.0745    0.9506    0.9068    0.0731    0.9494    0.9046
71    0.072     0.9512    0.9078    0.0777    0.9446    0.8962
72    0.0727    0.9497    0.9052    0.0785    0.9437    0.8953
73    0.073     0.9492    0.9045    0.0799    0.9474    0.9011
74    0.0731    0.9515    0.9085    0.0746    0.9439    0.8956
75    0.0714    0.9523    0.9097    0.0785    0.9455    0.898
76    0.0724    0.9505    0.9065    0.0719    0.9501    0.9056
77    0.0739    0.9506    0.9068    0.0742    0.9479    0.9019
78    0.0715    0.95      0.9057    0.0736    0.9482    0.903
79    0.0681    0.9529    0.9109    0.0722    0.9467    0.8997
80    0.0698    0.9537    0.9124    0.0729    0.9509    0.9072
81    0.0719    0.9506    0.9067    0.0735    0.9513    0.9081
82    0.0712    0.9505    0.9066    0.0708    0.9502    0.9062
83    0.0719    0.9508    0.9073    0.0753    0.9495    0.9049
84    0.0683    0.9539    0.9127    0.0718    0.9484    0.9028
85    0.067     0.9543    0.9132    0.0705    0.9513    0.908
86    0.0681    0.9519    0.9092    0.0722    0.9466    0.8996
87    0.0676    0.9531    0.9113    0.072     0.9469    0.9002
88    0.0698    0.953     0.9111    0.07      0.9495    0.9051
89    0.0697    0.9546    0.9139    0.0688    0.9504    0.9064
90    0.0673    0.9561    0.9166    0.0738    0.9488    0.9037
91    0.0682    0.9547    0.914     0.0745    0.9492    0.9046
92    0.0682    0.953     0.9109    0.0729    0.952     0.9094
93    0.0691    0.9552    0.915     0.0719    0.9481    0.9024
94    0.0649    0.9536    0.9122    0.0713    0.9455    0.8986
95    0.0661    0.9531    0.9112    0.0713    0.9461    0.8991
96    0.0668    0.953     0.9112    0.0728    0.9494    0.9048
97    0.0661    0.9557    0.9157    0.0714    0.946     0.8992
98    0.0657    0.9568    0.9178    0.0712    0.9492    0.9048
99    0.0665    0.9537    0.9123    0.0711    0.9498    0.9053
100       0.0668    0.9558    0.9161    0.071     0.9538    0.9125

之前在TransUnet也训练了这个数据集,指标如下:

3.5 推理

推理的脚本是infer.py ,在生成的UI界面绘制box推理即可

在目录下会生成gt区域的图像:

4、其他

CT 图像数据的对比度很低,想要训练结果更好的话,可以使用对比度拉伸来增强数据。其实在数字图像处理中方法很多(灰度变换啊、空间滤波啊之类的)。

直方图均衡化增强:

sam在cv上的推理还有point推理,实现也很简单,其实在自动获取box的时候,通过数字图像处理的腐蚀操作就可以获得point,这样提示信息就换成了point

当然,更直接的改进可以增加attention机制,或者添加有多有效的module等

Logo

有“AI”的1024 = 2048,欢迎大家加入2048 AI社区

更多推荐