文章目录
文章目录
- 00 写在前面
- 01 基于Pytorch版本的3D UNet代码
- 02 论文下载
00 写在前面
通过3D U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂3D网络模型,有很大的帮助作用。
在01中,可以根据3D U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、下采样模块定义、输出卷积层定义、网络模型定义等。
在模型调试过程中,可以先通过简单测试代码,进行代码调试。
01 基于Pytorch版本的3D UNet代码
# 库函数调用
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
# from measure import Four_three
# 三维卷积块定义
class DoubleConv(nn.Module):
"""(Conv3D -> IN -> ReLU) * 2"""
def __init__(self, in_channels, out_channels, num_groups = 8):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
nn.ReLU(inplace=True),
nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
# 下采样模块定义
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.encoder = nn.Sequential(
nn.MaxPool3d(2,2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.encoder(x)
# 上采样模块定义
class Up(nn.Module):
def __init__(self, in_channels, out_channels, trilinear = True):
super().__init__()
if trilinear:
self.up = nn.Upsample(scale_factor = 2)
else:
self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size = 2, stride = 2)
self.conv = DoubleConv(in_channels, out_channels)
self.downc = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1, bias=True)
self.downr = nn.ReLU(inplace=True)
def forward(self, x1, x2):
x1 = self.up(x1)
diffZ = x2.size()[2] - x1.size()[2]
diffY = x2.size()[3] - x1.size()[3]
diffX = x2.size()[4] - x1.size()[4]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])
x1 = self.downr(self.downc(x1))
x = torch.cat([x2, x1], dim = 1)
return self.conv(x)
# 输出卷积层定义
class Out(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)
def forward(self, x):
return self.conv(x)
# 3D-UNet模型定义
class 3DUNET(nn.Module):
def __init__(self, in_channels=3, out_channels=1,n_channels=64):
super().__init__()
self.in_channels = in_channels
self.n_channels = n_channels
self.conv = DoubleConv(in_channels, n_channels)
self.enc1 = Down(n_channels, 2 * n_channels)
self.enc2 = Down(2 * n_channels, 4 * n_channels)
self.enc3 = Down(4 * n_channels, 8 * n_channels)
self.enc4 = Down(8 * n_channels, 16 * n_channels)
self.dec1 = Up(16 * n_channels, 8 * n_channels)
self.dec2 = Up(8 * n_channels, 4 * n_channels)
self.dec3 = Up(4 * n_channels, 2*n_channels)
self.dec4 = Up(2 * n_channels, n_channels)
self.out = Out(n_channels, out_channels) #(1,4,128,128,n)
def forward(self, x):
# print('size of x:', x.shape)
x1 = self.conv(x)
# print('size of x1:', x1.shape)
x2 = self.enc1(x1)
# print('size of x2:', x2.shape)
x3 = self.enc2(x2)
# print('size of x3:', x3.shape)
x4 = self.enc3(x3)
# print('size of x4:', x4.shape)
x5 = self.enc4(x4)
# print('size of x5:', x5.shape)
mask = self.dec1(x5, x4)
# print('size of mask:', mask.shape)
mask = self.dec2(mask, x3)
# print('size of mask:', mask.shape)
mask = self.dec3(mask, x2)
# print('size of mask:', mask.shape)
mask = self.dec4(mask, x1)
# print('size of mask:', mask.shape)
mask = self.out(mask)
# print('size of mask:', mask.shape)
return mask
# 测试代码
if __name__ == '__main__':
input_channels = 4
output_channels = 1
x = torch.ones([16, 4, 16, 16,16])
model = 3DUNET(input_channels, output_channels)
print('model initialization finished!')
f = model(x)
print(f)
02 论文下载
3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation
arXiv: 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation