目录
1、前言
2、实现思路
3、实验代码
3.1 环境配置
3.2 数据集
3.3 训练
3.4 指标
3.5 推理
4、其他
1、前言
本章尝试将TransUnet和SAM结合,以期望达到更换的模型
TransUnet作为医学图像分割的基准,在许多数据集上均取得了很好的效果,然而最近SAM大模型的兴起,图像分割似乎有了新的方向
关于图像分割项目、sam模型复现参考本人其他专栏,这里之作简单介绍
TransUnet是一个专门为医学图像分割任务设计的深度学习模型。它是一种卷积神经网络(CNN),采用基于变压器的架构。TransUnet在具有相应分割掩模的大型医学图像数据集上进行训练,以学习如何从输入图像中准确分割器官、病变或其他结构。
TransUnet的一个关键优势是它能够处理医学图像中的大小物体。它通过结合CNN(擅长捕获局部空间信息)和变换器(擅长捕获全局上下文信息)的优点来实现这一点。这使得TransUnet能够有效地分割器官或其他结构,无论其大小如何。
TransUnet在多个医学图像分割基准测试中取得了最先进的性能。其高分割精度和多功能性使其成为各种医学应用的有前景的工具,如肿瘤检测、器官分割和疾病诊断。
总体而言,TransUnet代表了一种创新的医学图像分割方法,它结合了卷积神经网络和变换器的强大功能,在一系列医学成像任务中取得了卓越的结果。
而SAM大模型就是在推理的时候,加入人为的提示信息,例如boxes、point等等,这样sam在进行推理的时候,就会着重于指定部分的推理
2、实现思路
实现的思路很简单,在训练的时候,为每一个类别指定box,这样原来GRB 3通道的数据就会变成4个维度的数据(RGB+box)。而对于多类别的分割来说,标签前景是很多的,只需要随机取出来一个,然后根据前景自动获取box即可
每次取出一个类别
label_ids = np.unique(mask)[1:]
label_id = random.choice(label_ids.tolist())
mask = np.uint8(mask == label_id) # only one label
根据分割区域自动获取box提示信息
y_indices, x_indices = np.where(mask > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
H, W = mask.shape
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
y_max = min(H, y_max + random.randint(0, self.bbox_shift))
bboxes = np.array([x_min, y_min, x_max, y_max])
主要注意的是,训练的时候,box是自动生成的
代码处理完,经过处理完,输送给网络的图像如下:
TIPS:数据里面的颜色都是白色(box是单通道的),为了可视化,这里才显示成红色的掩膜形式
为了消融试验,代码里还加了对比部分,然后增加了很多训练的tricks
3、实验代码
下载连接:https://download.csdn.net/download/qq_44886601/89878907
有偿下载
下载完目录如下:
readme 有详细的运行步骤,这里简单介绍
3.1 环境配置
建议用conda配置虚拟环境,参考:https://blog.csdn.net/qq_44886601/category_12573095.html
配置好环境后,一键安装库文件即可:
pip install -r requirements.txt
einops==0.8.0
matplotlib==3.7.5
monai==1.3.2
numpy==1.24.4
opencv_python==4.10.0.84
Pillow==10.4.0
torch==2.4.1
tqdm==4.66.5
3.2 数据集
数据集摆放如下:
这里测试的数据集是 MICCAI FLARE 腹部13器官分割,标签为:
{
"0": "background",
"1": "spleen",
"2": "right kidney",
"3": "left kidney",
"4": "gallbladder",
"5": "esophagus",
"6": "liver",
"7": "stomach",
"8": "aorta",
"9": "IVC",
"10": "veins",
"11": "pancreas",
"12": "rad",
"13": "lad"
}
可视化结果如下:
3.3 训练
训练参数如下,建议epoch尽量大点
网络的损失采用更好的分割损失:DiceCELoss
parser.add_argument("--batch-size", default=8, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--optim", default='SGD',type=str, help='SGD、Adam、RMSProp')
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--lrf',default=0.001,type=float) # 最终学习率 = lr * lrf
parser.add_argument("--img_f", default='.png', type=str) # 数据图像的后缀
parser.add_argument("--mask_f", default='.png', type=str) # mask图像的后缀
parser.add_argument("--imgSize", default=[224,224],help='image size') # img size
训练过程:
训练完效果还行:
3.4 指标
训练结果全部在runs目录下
其中可视化的数据:
loss:
dice:
iou:
训练日志:
[train hyper-parameters: Namespace(batch_size=8, epochs=100, imgSize=[224, 224], img_f='.png', lr=0.01, lrf=0.001, mask_f='.png', optim='SGD')] epoch train_loss train_mdice train_miou val_loss val_mdice val_miou 1 0.3523 0.5368 0.4148 0.2399 0.7367 0.5924 2 0.2271 0.6468 0.5066 0.2148 0.7816 0.6471 3 0.2063 0.6895 0.5529 0.2038 0.7751 0.6413 4 0.191 0.7144 0.5866 0.1907 0.8239 0.7099 5 0.1798 0.7704 0.6461 0.1666 0.8635 0.7636 6 0.1648 0.8044 0.6869 0.1547 0.8561 0.7523 7 0.1573 0.8198 0.7065 0.1518 0.8644 0.7657 8 0.1525 0.8184 0.7129 0.1404 0.8924 0.8083 9 0.1453 0.8457 0.7444 0.1464 0.8815 0.7914 10 0.1495 0.833 0.7279 0.1372 0.8914 0.8065 11 0.1376 0.8517 0.7554 0.1449 0.8879 0.8027 12 0.1296 0.8727 0.7805 0.1442 0.8941 0.8125 13 0.136 0.8619 0.7656 0.1307 0.896 0.8147 14 0.1308 0.8782 0.7892 0.1295 0.9074 0.8328 15 0.1289 0.8913 0.8074 0.1198 0.891 0.8083 16 0.1264 0.8809 0.7932 0.1287 0.9018 0.8253 17 0.1231 0.885 0.7994 0.1239 0.9103 0.8377 18 0.1276 0.8789 0.7911 0.1213 0.9018 0.8236 19 0.1222 0.9011 0.823 0.1295 0.9101 0.8376 20 0.122 0.898 0.8189 0.12 0.9088 0.8365 21 0.1195 0.9083 0.8344 0.1132 0.9198 0.8536 22 0.1159 0.8992 0.8228 0.1208 0.9063 0.8314 23 0.1128 0.9107 0.8391 0.1059 0.9238 0.861 24 0.1104 0.9128 0.8423 0.1146 0.9178 0.8505 25 0.1072 0.9186 0.8516 0.1062 0.9188 0.8524 26 0.1118 0.9092 0.8366 0.1143 0.9139 0.8433 27 0.1089 0.9135 0.844 0.1146 0.913 0.8432 28 0.1033 0.916 0.8473 0.1056 0.9243 0.8608 29 0.107 0.9216 0.8561 0.1059 0.9179 0.8505 30 0.104 0.9119 0.8412 0.1043 0.919 0.8522 31 0.1038 0.9171 0.849 0.1055 0.9215 0.8562 32 0.1009 0.9275 0.866 0.1022 0.9275 0.8661 33 0.0963 0.9301 0.8708 0.1036 0.9202 0.8551 34 0.0983 0.928 0.8672 0.0999 0.9272 0.8661 35 0.1038 0.9145 0.8457 0.0986 0.9277 0.8667 36 0.0995 0.9275 0.8663 0.0978 0.9252 0.863 37 0.0984 0.925 0.8622 0.1023 0.9197 0.8543 38 0.1007 0.9251 0.8635 0.0953 0.9297 0.8702 39 0.0949 0.9297 0.8705 0.0981 0.9283 0.8684 40 0.0954 0.9247 0.8626 0.1026 0.9203 0.8541 41 0.0943 0.9296 0.8698 0.096 0.9278 0.867 42 0.0923 0.9358 0.8804 0.0913 0.9304 0.8711 43 0.0906 0.9318 0.8739 0.093 0.9366 0.8817 44 0.0945 0.9324 0.875 0.0932 0.9308 0.8724 45 0.0891 0.9358 0.8808 0.0881 0.9395 0.887 46 0.0884 0.9358 0.8805 0.0904 0.9353 0.8798 47 0.0875 0.9393 0.8869 0.0855 0.9371 0.8832 48 0.0908 0.9337 0.8774 0.0881 0.9351 0.8798 49 0.089 0.9349 0.8794 0.0924 0.929 0.8698 50 0.0845 0.9381 0.8852 0.0824 0.9413 0.8899 51 0.0846 0.9357 0.8807 0.0837 0.9399 0.888 52 0.0858 0.9407 0.8892 0.0847 0.9365 0.882 53 0.0837 0.9413 0.8901 0.0864 0.9384 0.8858 54 0.0856 0.9415 0.8904 0.088 0.9335 0.8776 55 0.0842 0.9413 0.8907 0.086 0.9395 0.8876 56 0.0849 0.9398 0.8878 0.086 0.9318 0.8739 57 0.0833 0.9446 0.896 0.0817 0.9436 0.8942 58 0.0831 0.945 0.8968 0.0809 0.9396 0.8869 59 0.0803 0.9444 0.8957 0.0798 0.9388 0.8867 60 0.0783 0.9461 0.8987 0.0824 0.9446 0.8961 61 0.0769 0.9466 0.8995 0.0836 0.9389 0.8869 62 0.0794 0.9486 0.9031 0.0768 0.9453 0.8974 63 0.0797 0.9484 0.9028 0.0818 0.9424 0.8922 64 0.0784 0.9448 0.8963 0.0808 0.9441 0.8955 65 0.075 0.9485 0.9031 0.0756 0.9455 0.8975 66 0.077 0.9488 0.9035 0.0799 0.9372 0.8848 67 0.0777 0.9447 0.8965 0.0789 0.9404 0.8895 68 0.0761 0.9486 0.9033 0.0787 0.9468 0.9001 69 0.0776 0.9501 0.9058 0.0827 0.942 0.8917 70 0.0745 0.9506 0.9068 0.0731 0.9494 0.9046 71 0.072 0.9512 0.9078 0.0777 0.9446 0.8962 72 0.0727 0.9497 0.9052 0.0785 0.9437 0.8953 73 0.073 0.9492 0.9045 0.0799 0.9474 0.9011 74 0.0731 0.9515 0.9085 0.0746 0.9439 0.8956 75 0.0714 0.9523 0.9097 0.0785 0.9455 0.898 76 0.0724 0.9505 0.9065 0.0719 0.9501 0.9056 77 0.0739 0.9506 0.9068 0.0742 0.9479 0.9019 78 0.0715 0.95 0.9057 0.0736 0.9482 0.903 79 0.0681 0.9529 0.9109 0.0722 0.9467 0.8997 80 0.0698 0.9537 0.9124 0.0729 0.9509 0.9072 81 0.0719 0.9506 0.9067 0.0735 0.9513 0.9081 82 0.0712 0.9505 0.9066 0.0708 0.9502 0.9062 83 0.0719 0.9508 0.9073 0.0753 0.9495 0.9049 84 0.0683 0.9539 0.9127 0.0718 0.9484 0.9028 85 0.067 0.9543 0.9132 0.0705 0.9513 0.908 86 0.0681 0.9519 0.9092 0.0722 0.9466 0.8996 87 0.0676 0.9531 0.9113 0.072 0.9469 0.9002 88 0.0698 0.953 0.9111 0.07 0.9495 0.9051 89 0.0697 0.9546 0.9139 0.0688 0.9504 0.9064 90 0.0673 0.9561 0.9166 0.0738 0.9488 0.9037 91 0.0682 0.9547 0.914 0.0745 0.9492 0.9046 92 0.0682 0.953 0.9109 0.0729 0.952 0.9094 93 0.0691 0.9552 0.915 0.0719 0.9481 0.9024 94 0.0649 0.9536 0.9122 0.0713 0.9455 0.8986 95 0.0661 0.9531 0.9112 0.0713 0.9461 0.8991 96 0.0668 0.953 0.9112 0.0728 0.9494 0.9048 97 0.0661 0.9557 0.9157 0.0714 0.946 0.8992 98 0.0657 0.9568 0.9178 0.0712 0.9492 0.9048 99 0.0665 0.9537 0.9123 0.0711 0.9498 0.9053 100 0.0668 0.9558 0.9161 0.071 0.9538 0.9125
之前在TransUnet也训练了这个数据集,指标如下:
3.5 推理
推理的脚本是infer.py ,在生成的UI界面绘制box推理即可
在目录下会生成gt区域的图像:
4、其他
CT 图像数据的对比度很低,想要训练结果更好的话,可以使用对比度拉伸来增强数据。其实在数字图像处理中方法很多(灰度变换啊、空间滤波啊之类的)。
直方图均衡化增强:
sam在cv上的推理还有point推理,实现也很简单,其实在自动获取box的时候,通过数字图像处理的腐蚀操作就可以获得point,这样提示信息就换成了point
当然,更直接的改进可以增加attention机制,或者添加有多有效的module等