UNet介绍
- 参考
- UNet网络介绍
- 整体架构
- UNet过程
- 输入
- 编码器(下采样)
- 中间特征表示
- 解码器(上采样)
- 输出
- 代码详解
- unetUP和Unet关系
- 上采样模块——unetUp
- 用于图像分割的卷积神经网络(CNN)架构模块——Unet
- 类的定义
- 初始化方法
- 上采样模块
- 额外的上采样卷积层(仅用于ResNet50)
- 最终卷积层
- 前向传播方法
- 冻结和解冻骨干网络
- 完整代码
参考
U-Net: Convolutional Networks for Biomedical Image Segmentation 输入文章名自行查询
参考博客
UNet网络介绍
整体架构
U-net 架构(以最低分辨率为 32x32 像素为例),每个蓝色框对应一个多通道特征图。通道数显示在框的顶部。x-y 大小位于框的左下边缘。白框表示复制的特征图。箭头表示不同的操作。
深蓝色箭头:利用3×3的卷积核对图片进行卷积后,通过ReLU激活函数输出特征通道;
灰色箭头:对左边下采样过程中的图片进行裁剪复制;
红色箭头:通过最大池化对图片进行下采样,池化核大小为2×2;
绿色箭头:反卷积,对图像进行上采样,卷积核大小为2×2;
青色箭头:使用1×1的卷积核对图片进行卷积。
因为网络形状像U,故被称为U-net
参考博客(网络结构很清晰,推荐)
UNet过程
输入
U-Net 的输入是一幅单通道的图像,通常大小为 572x572 像素,由于在不断valid卷积过程中,会使得图片越来越小,为了避免数据丢失,在图像输入前都需要进行镜像扩大。
编码器(下采样)
- U-Net 的编码器部分输入图像通过卷积层进行特征提取,这些卷积层通常使用 3x3 的卷积核,逐步提取图像特征并缩小空间维度。
- 然后,通过池化层(通常是最大池化)将图像的空间维度减小,例如从 572x572 缩小到 286x286。
- 这个过程会重复多次,每次都会减小图像的空间维度和增加特征通道数。
中间特征表示
- 在编码器的最后一层,我们获得了一个中间特征表示,通常是一个高维的特征张量。
- 这个特征表示包含了图像的抽象特征,可以用于后续的分割任务。
解码器(上采样)
- U-Net 的解码器部分将中间特征表示还原到原始的空间维度,并逐步增加分辨率。
- 首先,通过上采样操作将特征张量的空间维度扩大,例如从 286x286 扩大到 572x572。
- 然后,通过卷积层进行特征融合,将低级和高级特征结合起来。
- 最后,输出通道数为 64 的卷积层将特征映射到最终的分割结果。
输出
- U-Net 的输出是一个分割图像,大小与输入图像相同(通常为 572x572 像素),这幅分割图像被分成不同的区域,其中不同区域被分配不同的标签或类别。
- 分割图像中的每个像素都被分类到不同的类别中,即可以准确地知道图像中的每个像素属于哪个结构或区域。这个分割图像可以用于识别生物医学图像中的不同结构,例如肿瘤、器官等。
代码详解
unetUP和Unet关系
-
unetUp
:unetUp
是一个自定义的 PyTorch 模块(nn.Module
),用于实现 U-Net 模型中的上采样部分。- 它接受两个输入特征张量
inputs1
和inputs2
,并将它们进行上采样、特征融合和卷积操作,最终输出一个特征张量。 - 在 U-Net 中,
unetUp
负责将低分辨率的特征图上采样到与高分辨率特征图相同的尺寸,以便进行特征融合。
-
Unet
:Unet
是整个 U-Net 模型的主体部分,它由多个unetUp
模块组成。- 根据选择的
backbone
(可以是 VGG 或 ResNet-50),Unet
使用不同的主干网络提取特征。 Unet
的前向传播过程包括多次特征融合,上采样和卷积操作,最终生成语义分割结果。
上采样模块——unetUp
这段代码定义了一个名为 unetUp 的类,它是一个用于UNet架构中的上采样模块。这个模块的作用是将低分辨率特征图上采样并与高分辨率特征图结合,以生成更高分辨率的输出。
# 定义unetUp类,unetUp类继承自nn.Module,是PyTorch中所有神经网络模块的基类。
class unetUp(nn.Module):
# 初始化方法
in_size 和 out_size 是输入和输出通道的数量。
self.conv1 和 self.conv2 是两个二维卷积层,卷积核大小为3,填充为1。
self.up 是一个最近邻插值的上采样层,放大倍数为2。
self.relu 是一个ReLU激活函数。
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
self.up = nn.UpsamplingNearest2d(scale_factor=2)
self.relu = nn.ReLU(inplace=True)
# 前向传播方法
# inputs1 和 inputs2 是前向传播时的输入张量。
# torch.cat([inputs1, self.up(inputs2)], 1) 将 inputs1 和上采样后的 inputs2 在通道维度上拼接。
# outputs = self.conv1(outputs) 对拼接后的张量进行第一次卷积。
# outputs = self.relu(outputs) 对卷积结果应用ReLU激活函数。
# outputs = self.conv2(outputs) 对激活后的张量进行第二次卷积。
# outputs = self.relu(outputs) 再次应用ReLU激活函数。最终返回处理后的 outputs。
# 这段代码的主要功能是将低分辨率特征图上采样并与高分辨率特征图结合,经过两次卷积和激活函数处理后,生成更高分辨率的输出特征图。
def forward(self, inputs1, inputs2):
outputs = torch.cat([inputs1,self.up(inputs2)],1)
outputs = self.conv1(outputs)
outputs = self.relu(outputs)
outputs = self.conv2(outputs)
outputs = self.relu(outputs)
return outputs
用于图像分割的卷积神经网络(CNN)架构模块——Unet
下面这段代码定义了一个名为 Unet
的类,它是一个用于图像分割的卷积神经网络(CNN)架构。这个类可以使用不同的骨干网络(backbone),如VGG16或ResNet50,并包含上采样模块以生成高分辨率的输出。以下是对代码的详细解释:
类的定义
class Unet(nn.Module):
Unet
类继承自 nn.Module
,这是PyTorch中所有神经网络模块的基类。
初始化方法
def __init__(self, num_classes=21, pretrained=False, backbone='vgg'):
super(Unet, self).__init__()
if backbone == 'vgg':
self.vgg = VGG16(pretrained=pretrained)
in_filters = [192, 384, 768, 1024]
elif backbone == "resnet50":
self.resnet = resnet50(pretrained=pretrained)
in_filters = [192, 512, 1024, 3072]
else:
raise ValueError('Unsupported backbone - `{}`, Use vgg, resnet50.'.format(backbone))
out_filters = [64, 128, 256, 512]
num_classes
是输出类别的数量。pretrained
指示是否使用预训练的权重。backbone
指定使用的骨干网络,可以是VGG16或ResNet50。- 根据选择的骨干网络,初始化相应的网络并设置输入过滤器的数量。
in_filters
(输入通道数):在卷积神经网络(CNN)中,in_filters 表示输入图像的通道数或特征图的数量。在输入层,如果是灰度图片,那就只有一个 feature map;如果是彩色图片,一般就是 3 个 feature map(对应红、绿、蓝通道)。在其他层,每个卷积核(也称为过滤器)与上一层的每个 feature map 做卷积,产生下一层的一个 feature map。因此,如果有 N 个卷积核,下一层就会产生 N 个 feature map。out_filters
(输出通道数):在卷积神经网络中,out_filters 表示卷积核的数量或输出的特征图数量。卷积核的个数决定了下一层的 feature map 数量。每个卷积核可以提取一种特征,并生成一个新的特征图。在多层卷积网络中,下一层的卷积核的通道数等于上一层的 feature map 数量。如果通道数不相等,就无法继续进行卷积操作。
上采样模块
# upsampling
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
- 定义四个上采样模块,每个模块将低分辨率特征图上采样到更高分辨率并与高分辨率特征图结合。
self.up_concat4, self.up_concat3, self.up_concat2 , self.up_concat1
是上采样操作的一部分。它们分别将不同层的特征图级联在一起,以获得更丰富的特征表示
额外的上采样卷积层(仅用于ResNet50)
if backbone == 'resnet50':
self.up_conv = nn.Sequential(
# 使用双线性插值(nn.UpsamplingBilinear2d)将特征图的大小放大两倍
nn.UpsamplingBilinear2d(scale_factor=2),
# 通过两个卷积层对特征图进行处理,以获得更好的特征表示
nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
# 使用 nn.ReLU() 激活函数来确保非线性变换
nn.ReLU(),
nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
nn.ReLU(),
)
else:
self.up_conv = None
- 如果使用ResNet50作为骨干网络,定义一个额外的上采样卷积层。
最终卷积层
# self.final 是一个卷积层,用于生成最终的输出。它将高分辨率的特征图映射到类别数(num_classes)
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
- 定义一个最终的卷积层,将输出通道数转换为类别数。
前向传播方法
定义模型的前向传播过程,将输入数据通过网络的各层进行计算,最终生成输出。
def forward(self, inputs):
# 根据 self.backbone 的值,选择不同的模型(VGG 或 ResNet-50)进行前向传播。
# 通过卷积层和池化层对输入数据进行处理,得到特征图 feat1、feat2、feat3、feat4 和 feat5
if self.backbone == "vgg":
[feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
elif self.backbone == "resnet50":
[feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
# 通过上采样操作将这些特征图进行级联,得到更高分辨率的特征图 up4、up3、up2 和 up1
up4 = self.up_concat4(feat4, feat5)
up3 = self.up_concat3(feat3, up4)
up2 = self.up_concat2(feat2, up3)
up1 = self.up_concat1(feat1, up2)
# 如果存在上采样卷积层 self.up_conv,则对 up1 进行进一步处理
if self.up_conv != None:
up1 = self.up_conv(up1)
# 通过 self.final 层获得最终的输出
final = self.final(up1)
return final
- 根据选择的骨干网络,获取不同层的特征图。
- 使用上采样模块逐层上采样并结合特征图。
- 如果定义了额外的上采样卷积层,则应用该层。
- 最终通过一个卷积层生成输出。
冻结和解冻骨干网络
def freeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = False
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = True
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = True
freeze_backbone
方法用于冻结骨干网络的参数,使其在训练过程中不更新。unfreeze_backbone
方法用于解冻骨干网络的参数,使其在训练过程中可以更新。
冻结或解冻神经网络模型的特定层–更好地进行迁移学习或微调
-
迁移学习:
- 在迁移学习中,使用一个预训练的神经网络模型(通常在大规模数据集上进行训练)来解决新的任务。
- 通过冻结模型的底层层(例如卷积层),可以保留其在原始任务上学到的特征表示,然后在新任务上进行微调。
- 这样做有助于避免在新任务上过拟合,并且可以加快训练速度。
-
微调:
- 微调是指在预训练模型的基础上继续训练,以适应新任务的特定数据。
- 解冻底层层,允许其权重在新任务上进行调整,以更好地适应新数据。
- 通常,我们只微调模型的一部分,而不是整个模型,以避免丢失预训练模型的有用特征。
完整代码
import torch
import torch.nn as nn
from nets.resnet import resnet50
from nets.vgg import VGG16
class unetUp(nn.Module):
def __init__(self, in_size, out_size):
super(unetUp, self).__init__()
self.conv1 = nn.Conv2d(in_size, out_size, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(out_size, out_size, kernel_size=3, padding=1)
self.up = nn.UpsamplingNearest2d(scale_factor=2)
self.relu = nn.ReLU(inplace=True)
def forward(self, inputs1, inputs2):
outputs = torch.cat([inputs1,self.up(inputs2)],1)
outputs = self.conv1(outputs)
outputs = self.relu(outputs)
outputs = self.conv2(outputs)
outputs = self.relu(outputs)
return outputs
class Unet(nn.Module):
def __init__(self, num_classes = 2, pretrained = False, backbone = 'vgg'):
super(Unet, self).__init__()
if backbone == 'vgg':
self.vgg = VGG16(pretrained=pretrained)
in_filters = [192, 384, 768, 1024]
elif backbone == 'resnet50'
self.resnet = resnet50(pretrained=pretrained)
in_filters = [192, 512, 1024, 3072]
else:
raise ValueError('Unsupported backbone -`{}`, Use vgg, resnet50.'.format(backbone))
out_filters = [64, 128, 256, 512]
#???
self.up_concat4 = unetUp(in_filters[3], out_filters[3])
self.up_concat3 = unetUp(in_filters[2], out_filters[2])
self.up_concat2 = unetUp(in_filters[1], out_filters[1])
self.up_concat1 = unetUp(in_filters[0], out_filters[0])
if backbone == 'resnet50':
self.up_conv = nn.Sequential(
nn.UpsamplingNearest2d(scale_factor=2),
nn.Conv2d(out_filters[0],out_filters[0], kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(out_filters[0], out_filters[0], kernel_size=3, padding=1),
nn.ReLU(),
)
else:
self.up_conv = None
self.final = nn.Conv2d(out_filters[0], num_classes, 1)
self.backbone = backbone
def forward(self, inputs):
if self.backbone == "vgg":
[feat1, feat2, feat3, feat4, feat5] = self.vgg.forward(inputs)
elif self.backbone == "resnet50":
[feat1, feat2, feat3, feat4, feat5] = self.resnet.forward(inputs)
up4 = self.up_concat4(feat4, feat5)
up3 = self.up_concat3(feat3, up4)
up2 = self.up_concat2(feat2, up3)
up1 = self.up_concat1(feat1, up2)
if self.up_conv != None:
up1 = self.up_conv(up1)
final = self.final(up1)
return final
def freeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = False
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
if self.backbone == "vgg":
for param in self.vgg.parameters():
param.requires_grad = True
elif self.backbone == "resnet50":
for param in self.resnet.parameters():
param.requires_grad = True