第一步:准备数据
息肉分割数据,总共有1000张
第二步:搭建模型
UNet++,这是一种旨在克服以上限制的新型通用图像分割体系结构。如下图所示,UNet++由不同深度的U-Net组成,其解码器通过重新设计的跳接以相同的分辨率密集连接。 UNet++中引入的体系结构更改具有以下优点。首先,UNet++不易明确地选择网络深度,因为它在其体系结构中嵌入了不同深度的U-Net。所有这些U-Net都部分共享一个编码器,而它们的解码器则交织在一起。通过在深度监督下训练UNet++,可以同时训练所有组成的U-Net,同时受益于共享的图像表示。这种设计不仅可以提高整体分割性能,而且可以在推理期间修剪模型。其次,UNet++不会受到不必要的限制性跳接的限制,在这种情况下,只能融合来自编码器和解码器的相同比例的特征图。UNet++中引入的经过重新设计的跳接在解码器节点处提供了不同比例的特征图,从而使聚合层可以决定如何将跳接中携带的各种特征图与解码器特征图融合在一起。通过以相同的分辨率密集连接组成部分U-Net的解码器,可以在UNet++中实现重新设计的跳接。作者在六个分割数据集和不同深度的多个主干中对UNet++进行了广泛地评估:
五个贡献:
- 在UNet++中引入了一个内置的深度可变的U-Net集合,可为不同大小的对象提供改进的分割性能,这是对固定深度U-Net的改进。
- 重新设计了UNet++中的跳接,从而在解码器中实现了灵活的特征融合,这是对U-Net中仅需要融合相同比例特征图的限制性跳接的一种改进。
- 设计了一种方案来剪枝经过训练的UNet++,在保持其性能的同时加快其推理速度。
- 同时训练嵌入在UNet++体系结构中的多深度U-Net可以激发组成U-Net之间的协作学习,与单独训练具有相同体系结构的隔离U-Net相比,可以带来更好的性能。
- 展示了UNet++对多个主干编码器的可扩展性,并进一步将其应用于包括CT、MRI和电子显微镜在内的各种医学成像模式。
第三步:代码
1)损失函数为:交叉熵损失函数+dice_loss
2)网络代码:
class UNet3Plus(nn.Module):
def __init__(self, n_channels=3, n_classes=1, bilinear=True, feature_scale=4,
is_deconv=True, is_batchnorm=True):
super(UNet3Plus, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.feature_scale = feature_scale
self.is_deconv = is_deconv
self.is_batchnorm = is_batchnorm
filters = [16, 32, 64, 128, 256]
## -------------Encoder--------------
self.conv1 = unetConv2(self.n_channels, filters[0], self.is_batchnorm)
self.maxpool1 = nn.MaxPool2d(kernel_size=2)
self.conv2 = unetConv2(filters[0], filters[1], self.is_batchnorm)
self.maxpool2 = nn.MaxPool2d(kernel_size=2)
self.conv3 = unetConv2(filters[1], filters[2], self.is_batchnorm)
self.maxpool3 = nn.MaxPool2d(kernel_size=2)
self.conv4 = unetConv2(filters[2], filters[3], self.is_batchnorm)
self.maxpool4 = nn.MaxPool2d(kernel_size=2)
self.conv5 = unetConv2(filters[3], filters[4], self.is_batchnorm)
## -------------Decoder--------------
self.CatChannels = filters[0]
self.CatBlocks = 5
self.UpChannels = self.CatChannels * self.CatBlocks
'''stage 4d'''
# h1->320*320, hd4->40*40, Pooling 8 times
self.h1_PT_hd4 = nn.MaxPool2d(8, 8, ceil_mode=True)
self.h1_PT_hd4_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd4_relu = nn.ReLU(inplace=True)
# h2->160*160, hd4->40*40, Pooling 4 times
self.h2_PT_hd4 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h2_PT_hd4_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd4_relu = nn.ReLU(inplace=True)
# h3->80*80, hd4->40*40, Pooling 2 times
self.h3_PT_hd4 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h3_PT_hd4_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_PT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_PT_hd4_relu = nn.ReLU(inplace=True)
# h4->40*40, hd4->40*40, Concatenation
self.h4_Cat_hd4_conv = nn.Conv2d(filters[3], self.CatChannels, 3, padding=1)
self.h4_Cat_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.h4_Cat_hd4_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->40*40, Upsample 2 times
self.hd5_UT_hd4 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd5_UT_hd4_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd4_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd4_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4)
self.conv4d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn4d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu4d_1 = nn.ReLU(inplace=True)
'''stage 3d'''
# h1->320*320, hd3->80*80, Pooling 4 times
self.h1_PT_hd3 = nn.MaxPool2d(4, 4, ceil_mode=True)
self.h1_PT_hd3_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd3_relu = nn.ReLU(inplace=True)
# h2->160*160, hd3->80*80, Pooling 2 times
self.h2_PT_hd3 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h2_PT_hd3_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_PT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_PT_hd3_relu = nn.ReLU(inplace=True)
# h3->80*80, hd3->80*80, Concatenation
self.h3_Cat_hd3_conv = nn.Conv2d(filters[2], self.CatChannels, 3, padding=1)
self.h3_Cat_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.h3_Cat_hd3_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd4->80*80, Upsample 2 times
self.hd4_UT_hd3 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd4_UT_hd3_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd3_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd4->80*80, Upsample 4 times
self.hd5_UT_hd3 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd5_UT_hd3_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd3_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd3_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3)
self.conv3d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn3d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu3d_1 = nn.ReLU(inplace=True)
'''stage 2d '''
# h1->320*320, hd2->160*160, Pooling 2 times
self.h1_PT_hd2 = nn.MaxPool2d(2, 2, ceil_mode=True)
self.h1_PT_hd2_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_PT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_PT_hd2_relu = nn.ReLU(inplace=True)
# h2->160*160, hd2->160*160, Concatenation
self.h2_Cat_hd2_conv = nn.Conv2d(filters[1], self.CatChannels, 3, padding=1)
self.h2_Cat_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.h2_Cat_hd2_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd2->160*160, Upsample 2 times
self.hd3_UT_hd2 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd3_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd2_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd2->160*160, Upsample 4 times
self.hd4_UT_hd2 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd4_UT_hd2_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd2_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd2->160*160, Upsample 8 times
self.hd5_UT_hd2 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
self.hd5_UT_hd2_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd2_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd2_relu = nn.ReLU(inplace=True)
# fusion(h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2)
self.conv2d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn2d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu2d_1 = nn.ReLU(inplace=True)
'''stage 1d'''
# h1->320*320, hd1->320*320, Concatenation
self.h1_Cat_hd1_conv = nn.Conv2d(filters[0], self.CatChannels, 3, padding=1)
self.h1_Cat_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.h1_Cat_hd1_relu = nn.ReLU(inplace=True)
# hd2->160*160, hd1->320*320, Upsample 2 times
self.hd2_UT_hd1 = nn.Upsample(scale_factor=2, mode='bilinear') # 14*14
self.hd2_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd2_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd2_UT_hd1_relu = nn.ReLU(inplace=True)
# hd3->80*80, hd1->320*320, Upsample 4 times
self.hd3_UT_hd1 = nn.Upsample(scale_factor=4, mode='bilinear') # 14*14
self.hd3_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd3_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd3_UT_hd1_relu = nn.ReLU(inplace=True)
# hd4->40*40, hd1->320*320, Upsample 8 times
self.hd4_UT_hd1 = nn.Upsample(scale_factor=8, mode='bilinear') # 14*14
self.hd4_UT_hd1_conv = nn.Conv2d(self.UpChannels, self.CatChannels, 3, padding=1)
self.hd4_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd4_UT_hd1_relu = nn.ReLU(inplace=True)
# hd5->20*20, hd1->320*320, Upsample 16 times
self.hd5_UT_hd1 = nn.Upsample(scale_factor=16, mode='bilinear') # 14*14
self.hd5_UT_hd1_conv = nn.Conv2d(filters[4], self.CatChannels, 3, padding=1)
self.hd5_UT_hd1_bn = nn.BatchNorm2d(self.CatChannels)
self.hd5_UT_hd1_relu = nn.ReLU(inplace=True)
# fusion(h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1)
self.conv1d_1 = nn.Conv2d(self.UpChannels, self.UpChannels, 3, padding=1) # 16
self.bn1d_1 = nn.BatchNorm2d(self.UpChannels)
self.relu1d_1 = nn.ReLU(inplace=True)
# output
self.outconv1 = nn.Conv2d(self.UpChannels, n_classes, 3, padding=1)
# initialise weights
for m in self.modules():
if isinstance(m, nn.Conv2d):
init_weights(m, init_type='kaiming')
elif isinstance(m, nn.BatchNorm2d):
init_weights(m, init_type='kaiming')
def forward(self, inputs):
## -------------Encoder-------------
h1 = self.conv1(inputs) # h1->320*320*64
h2 = self.maxpool1(h1)
h2 = self.conv2(h2) # h2->160*160*128
h3 = self.maxpool2(h2)
h3 = self.conv3(h3) # h3->80*80*256
h4 = self.maxpool3(h3)
h4 = self.conv4(h4) # h4->40*40*512
h5 = self.maxpool4(h4)
hd5 = self.conv5(h5) # h5->20*20*1024
## -------------Decoder-------------
h1_PT_hd4 = self.h1_PT_hd4_relu(self.h1_PT_hd4_bn(self.h1_PT_hd4_conv(self.h1_PT_hd4(h1))))
h2_PT_hd4 = self.h2_PT_hd4_relu(self.h2_PT_hd4_bn(self.h2_PT_hd4_conv(self.h2_PT_hd4(h2))))
h3_PT_hd4 = self.h3_PT_hd4_relu(self.h3_PT_hd4_bn(self.h3_PT_hd4_conv(self.h3_PT_hd4(h3))))
h4_Cat_hd4 = self.h4_Cat_hd4_relu(self.h4_Cat_hd4_bn(self.h4_Cat_hd4_conv(h4)))
hd5_UT_hd4 = self.hd5_UT_hd4_relu(self.hd5_UT_hd4_bn(self.hd5_UT_hd4_conv(self.hd5_UT_hd4(hd5))))
hd4 = self.relu4d_1(self.bn4d_1(self.conv4d_1(
torch.cat((h1_PT_hd4, h2_PT_hd4, h3_PT_hd4, h4_Cat_hd4, hd5_UT_hd4), 1)))) # hd4->40*40*UpChannels
h1_PT_hd3 = self.h1_PT_hd3_relu(self.h1_PT_hd3_bn(self.h1_PT_hd3_conv(self.h1_PT_hd3(h1))))
h2_PT_hd3 = self.h2_PT_hd3_relu(self.h2_PT_hd3_bn(self.h2_PT_hd3_conv(self.h2_PT_hd3(h2))))
h3_Cat_hd3 = self.h3_Cat_hd3_relu(self.h3_Cat_hd3_bn(self.h3_Cat_hd3_conv(h3)))
hd4_UT_hd3 = self.hd4_UT_hd3_relu(self.hd4_UT_hd3_bn(self.hd4_UT_hd3_conv(self.hd4_UT_hd3(hd4))))
hd5_UT_hd3 = self.hd5_UT_hd3_relu(self.hd5_UT_hd3_bn(self.hd5_UT_hd3_conv(self.hd5_UT_hd3(hd5))))
hd3 = self.relu3d_1(self.bn3d_1(self.conv3d_1(
torch.cat((h1_PT_hd3, h2_PT_hd3, h3_Cat_hd3, hd4_UT_hd3, hd5_UT_hd3), 1)))) # hd3->80*80*UpChannels
h1_PT_hd2 = self.h1_PT_hd2_relu(self.h1_PT_hd2_bn(self.h1_PT_hd2_conv(self.h1_PT_hd2(h1))))
h2_Cat_hd2 = self.h2_Cat_hd2_relu(self.h2_Cat_hd2_bn(self.h2_Cat_hd2_conv(h2)))
hd3_UT_hd2 = self.hd3_UT_hd2_relu(self.hd3_UT_hd2_bn(self.hd3_UT_hd2_conv(self.hd3_UT_hd2(hd3))))
hd4_UT_hd2 = self.hd4_UT_hd2_relu(self.hd4_UT_hd2_bn(self.hd4_UT_hd2_conv(self.hd4_UT_hd2(hd4))))
hd5_UT_hd2 = self.hd5_UT_hd2_relu(self.hd5_UT_hd2_bn(self.hd5_UT_hd2_conv(self.hd5_UT_hd2(hd5))))
hd2 = self.relu2d_1(self.bn2d_1(self.conv2d_1(
torch.cat((h1_PT_hd2, h2_Cat_hd2, hd3_UT_hd2, hd4_UT_hd2, hd5_UT_hd2), 1)))) # hd2->160*160*UpChannels
h1_Cat_hd1 = self.h1_Cat_hd1_relu(self.h1_Cat_hd1_bn(self.h1_Cat_hd1_conv(h1)))
hd2_UT_hd1 = self.hd2_UT_hd1_relu(self.hd2_UT_hd1_bn(self.hd2_UT_hd1_conv(self.hd2_UT_hd1(hd2))))
hd3_UT_hd1 = self.hd3_UT_hd1_relu(self.hd3_UT_hd1_bn(self.hd3_UT_hd1_conv(self.hd3_UT_hd1(hd3))))
hd4_UT_hd1 = self.hd4_UT_hd1_relu(self.hd4_UT_hd1_bn(self.hd4_UT_hd1_conv(self.hd4_UT_hd1(hd4))))
hd5_UT_hd1 = self.hd5_UT_hd1_relu(self.hd5_UT_hd1_bn(self.hd5_UT_hd1_conv(self.hd5_UT_hd1(hd5))))
hd1 = self.relu1d_1(self.bn1d_1(self.conv1d_1(
torch.cat((h1_Cat_hd1, hd2_UT_hd1, hd3_UT_hd1, hd4_UT_hd1, hd5_UT_hd1), 1)))) # hd1->320*320*UpChannels
d1 = self.outconv1(hd1) # d1->320*320*n_classes
return F.sigmoid(d1)
第四步:统计一些指标(训练过程中的loss和miou)
第五步:搭建GUI界面
第六步:整个工程的内容
有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码
代码见:
有问题可以私信或者留言,有问必答