文章目录
文章目录
- 00 写在前面
- 01 基于Pytorch版本的UNet代码
- 02 论文下载
00 写在前面
通过U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂网络模型,有很大的帮助作用。
在01中,可以根据U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、输出卷积层定义、损失函数定义、网络模型定义等。
在模型调试过程中,可以先通过简单测试代码,进行代码调试。
01 基于Pytorch版本的UNet代码
# 库函数调用
import torch
import torch.nn as nn
from network.ops import TotalVariation
from torchvision.models import vgg19
# 卷积块定义
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
#nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
#nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
# 上采样部分定义
class up_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
#nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
# 输出卷积层定义
class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
#nn.ReLU(inplace=True),
)
def forward(self, x):
x = self.conv(x)
return x
class UNET_MODEL(nn.Module):
def __init__(self, img_ch=3, output_ch=1,filter_dim=64):
super().__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Conv1 = conv_block(ch_in=img_ch, ch_out=filter_dim)
self.Conv2 = conv_block(ch_in=64, ch_out=128)
self.Conv3 = conv_block(ch_in=128, ch_out=256)
self.Conv4 = conv_block(ch_in=256, ch_out=512)
self.Conv5 = conv_block(ch_in=512, ch_out=1024)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Up_conv4 = conv_block(ch_in=512, ch_out=256)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Up_conv3 = conv_block(ch_in=256, ch_out=128)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Up_conv2 = conv_block(ch_in=128, ch_out=64)
self.Conv11 = outconv(64, output_ch)
def forward(self, x):
# encoding path
x1 = self.Conv1(x)
x2 = self.Maxpool(x1)
x2 = self.Conv2(x2)
x3 = self.Maxpool(x2)
x3 = self.Conv3(x3)
x4 = self.Maxpool(x3)
x4 = self.Conv4(x4)
x5 = self.Maxpool(x4)
x5 = self.Conv5(x5)
# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_conv5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_conv4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_conv3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_conv2(d2)
T2 = self.Conv11(d2)
return T2
# 损失函数定义
class loss_fun(nn.Module):
def __init__(self, regular):
super().__init__()
self.tv = TotalVariation()
self.regular = regular
def forward(self, x, y):
ychange = y[:, 0:1, :, :]
mask = y[:, 1:2, :, :]
return torch.add(torch.mean(torch.pow((x[:,:,:,:] - y[:,2:3,:,:])*ychange, 2)), self.regular* torch.mean(self.tv(x[:, :, :, :]*mask)))
class loss_fun_total(nn.Module):
def __init__(self, regular):
super().__init__()
self.tv = TotalVariation()
self.regular = regular
def forward(self, x, y):
loss1 = torch.mean(torch.pow((x[:,0:1,:,:] - y[:,0:1,:,:]*10), 2))
return loss1
# 测试代码
if __name__ == '__main__':
input_channels = 4
output_channels = 1
x = torch.ones([32, 4, 256, 256])
model = UNET_MODEL(input_channels, output_channels)
print('model initialization finished!')
f = model(x)
print(f)
02 论文下载
U-Net: deep learning for cell counting, detection, and morphometry
U-Net: Convolutional Networks for Biomedical Image Segmentation