【3D图像分割】基于 Pytorch 的 VNet 3D 图像分割3(3D UNet 模型篇)

news2024/12/1 0:19:53

在本文中,主要是对3D UNet 进行一个学习和梳理。对于3D UNet 网上的资料和GitHub直接获取的代码很多,不需要自己从0开始。那么本文的目的是啥呢?

本文就是想拆解下其中的结构,看看对于一个3DUNet,和2DUNet,究竟有什么不同?如果是你自己构建,有什么样的经验和技巧可以学习。

3DUNet的论文地址:3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation

对于2DUNet感兴趣的小伙伴,可以先跳转去这里:【BraTS】Brain Tumor Segmentation 脑部肿瘤分割2(UNet的复现);相信阅读完,你会对这个模型,心中已经有了结构。

对本系列的其他篇章,点击下面👇链接:

  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割1(综述篇)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割2(基础数据流篇)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割6(数据预处理)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割7(数据预处理)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割8(CT肺实质分割)
  • 【3D 图像分割】基于 Pytorch 的 VNet 3D 图像分割9(patch 的 crop 和 merge 操作)

一、 3D UNet 结构剖析

unet无论是2D,还是3D,从整体结构上进行划分,大体可以分位以下两个阶段:

  1. 下采样的阶段,也就是U的左边(encoder),负责对特征提取;
  2. 上采样的阶段,也就是U的右边(decoder),负责对预测恢复。

如下图展示的这样:
1

其中:

  • 蓝色框表示的是特征图;
  • 绿色长箭头,是concat操作;
  • 橘色三角,是conv+bn+relu的组合;
  • 红色的向下箭头,是max pool
  • 黄色的向上箭头,是up conv
  • 最后的紫色三角,是conv,恢复了最终的输出特征图;

对于模型构建这块,可以在论文中,看看作者是如何描述网络结构的:

  1. Like the standard u-net, it has an analysis and a synthesis path each with four resolution steps.
  2. In the analysis path, each layer contains two 3 × 3 × 3 convolutions each followed by a rectified linear unit (ReLu), and then a 2 × 2 × 2 max pooling with strides of two in each dimension.
  3. In the synthesis path, each layer consists of an upconvolution of 2 × 2 × 2 by strides of two in each dimension, followed by two 3 × 3 × 3 convolutions each followed by a ReLu.
  4. Shortcut connections from layers of equal resolution in the analysis path provide the essential high-resolution features to the synthesis path.
  5. In the last layer a 1×1×1 convolution reduces the number of output channels to the number of labels which is 3 in our case.

从论文中的网络结构示意图也可以发现:

  1. 水平看,每一个小块,基本都是三个特征图,最后一层除外;
  2. 水平看,每个特征图之间,都是橘色三角,是conv+bn+relu的组合,最后一层除外;
  3. encoder阶段,连接各个水平块的,是下采样;
  4. decoder阶段,连接各个水平块的,是反卷积(upconvolution);
  5. 还有就是绿色长箭头的concat,和最后的conv输出特征图。

二、 3D UNet 复现

复线在3D UNet前,可以先参照下相对简单,且很深渊源的2D UNet结构。其中被多次使用的一个水平块中,也是两个conv+bn+relu的组合,2D UNet的构建如下所示:

class ConvBlock2d(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock2d, self).__init__()

        # 第1个3*3的卷积层
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

        # 第2个3*3的卷积层
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    # 定义数据前向流动形式
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

而在3D UNet的一个水平块中,同样是两个conv+bn+relu的组合,如下所示:

is_elu = False
def activateELU(is_elu, nchan):
    if is_elu:
        return nn.ELU(inplace=True)
    else:
        return nn.PReLU(nchan)

def ConvBnActivate(in_channels, middle_channels, out_channels):
    # This is a block with 2 convolutions
    # The first convolution goes from in_channels to middle_channels feature maps
    # The second convolution goes from middle_channels to out_channels feature maps
    conv = nn.Sequential(
        nn.Conv3d(in_channels, middle_channels, stride=1, kernel_size=3, padding=1),
        nn.BatchNorm3d(middle_channels),
        activateELU(is_elu, middle_channels),

        nn.Conv3d(middle_channels, out_channels, stride=1, kernel_size=3, padding=1),
        nn.BatchNorm3d(out_channels),
        activateELU(is_elu, out_channels),
    )
    return conv

可以发现,nn.Conv2d变成了nn.Conv3dnn.BatchNorm2d变成了nn.BatchNorm3d。遵照这个规则,构建下采样MaxPool3d、上采样反卷积ConvTranspose3d,以及最后紫色一层卷积,输出特征层FinalConvolution,如下:

def DownSample():
    # It halves the spatial dimensions on every axes (x,y,z)
    return nn.MaxPool3d(kernel_size=2, stride=2)

def UpSample(in_channels, out_channels):
    # It doubles the spatial dimensions on every axes (x,y,z)
    return nn.ConvTranspose3d(in_channels, out_channels, kernel_size=2, stride=2)

def FinalConvolution(in_channels, out_channels):
    return nn.Conv3d(in_channels, out_channels, kernel_size=1)

除此之外,绿色长箭头,concat操作,是在水平方向上,也就是列上进行组合,如下所示:

def CatBlock(x1, x2):
    return torch.cat((x1, x2), 1)

至此,构建模型所需要的各个组块,都准备完毕了。接下来就是构建模型,将各个组块搭起来。其中有个规律:

  • encoder中第一conv+bn+relu外,每一次前都需要下采样;
  • decoder中,每一个conv+bn+relu前,都需要上采样;
  • 并且,decoder中第一个conv操作,需要进行concat操作;
  • DownSamplechannel不变,特征图尺寸变小;
  • UpSamplechannel不变,特征图尺寸变大;

那就把这些规则,根据图示给加上,组合后的一个类,就如下所示:

import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet3D(nn.Module):
    def __init__(self, num_out_classes=2, input_channels=1, init_feat_channels=32):
        super().__init__()

        # Encoder layers definitions
        self.down_sample = DownSample()

        self.init_conv = ConvBnActivate(input_channels, init_feat_channels, init_feat_channels*2)
        self.down_conv1 = ConvBnActivate(init_feat_channels*2, init_feat_channels*2, init_feat_channels*4)
        self.down_conv2 = ConvBnActivate(init_feat_channels*4, init_feat_channels*4, init_feat_channels*8)
        self.down_conv3 = ConvBnActivate(init_feat_channels*8, init_feat_channels*8, init_feat_channels*16)

        # Decoder layers definitions
        self.up_sample1 = UpSample(init_feat_channels*16, init_feat_channels*16)
        self.up_conv1   = ConvBnActivate(init_feat_channels*(16+8), init_feat_channels*8, init_feat_channels*8)

        self.up_sample2 = UpSample(init_feat_channels*8, init_feat_channels*8)
        self.up_conv2   = ConvBnActivate(init_feat_channels*(8+4), init_feat_channels*4, init_feat_channels*4)

        self.up_sample3 = UpSample(init_feat_channels*4, init_feat_channels*4)
        self.up_conv3   = ConvBnActivate(init_feat_channels*(4+2), init_feat_channels*2, init_feat_channels*2)

        self.final_conv = FinalConvolution(init_feat_channels*2, num_out_classes)

        # Softmax
        self.softmax = F.softmax

    def forward(self, image):
        # Encoder Part #
        # B x  1 x Z x Y x X
        layer_init = self.init_conv(image)

        # B x 64 x Z x Y x X
        max_pool1  = self.down_sample(layer_init)
        # B x 64 x Z//2 x Y//2 x X//2
        layer_down2 = self.down_conv1(max_pool1)

        # B x 128 x Z//2 x Y//2 x X//2
        max_pool2   = self.down_sample(layer_down2)
        # B x 128 x Z//4 x Y//4 x X//4
        layer_down3 = self.down_conv2(max_pool2)

        # B x 256 x Z//4 x Y//4 x X//4
        max_pool_3  = self.down_sample(layer_down3)
        # B x 256 x Z//8 x Y//8 x X//8
        layer_down4 = self.down_conv3(max_pool_3)
        # B x 512 x Z//8 x Y//8 x X//8

        # Decoder part #
        layer_up1 = self.up_sample1(layer_down4)
        # B x 512 x Z//4 x Y//4 x X//4
        cat_block1 = CatBlock(layer_down3, layer_up1)
        # B x (256+512) x Z//4 x Y//4 x X//4
        layer_conv_up1 = self.up_conv1(cat_block1)
        # B x 256 x Z//4 x Y//4 x X//4

        layer_up2 = self.up_sample2(layer_conv_up1)
        # B x 256 x Z//2 x Y//2 x X//2
        cat_block2 = CatBlock(layer_down2, layer_up2)
        # B x (128+256) x Z//2 x Y//2 x X//2
        layer_conv_up2 = self.up_conv2(cat_block2)
        # B x 128 x Z//2 x Y//2 x X//2

        layer_up3 = self.up_sample3(layer_conv_up2)
        # B x 128 x Z x Y x X
        cat_block3 = CatBlock(layer_init, layer_up3)
        # B x (64+128) x Z x Y x X
        layer_conv_up3 = self.up_conv3(cat_block3)

        # B x 64 x Z x Y x X
        final_layer = self.final_conv(layer_conv_up3)
        # B x 2 x Z x Y x X
        return self.softmax(final_layer, dim=1)

定义好了模型还不算完,分阶段测试下构建的网络是不是和我们所预想的一样。我们给他一个输入,测试下是否与我们最初的想法是一致的,是否报错等等问题,如下这样:

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 没gpu就用cpu
print(DEVICE)

# Tensors for 3D Image Processing in PyTorch
# Batch x Channel x Z x Y x X
# Batch size BY x Number of channels x (BY Z dim) x (BY Y dim) x (BY X dim)

if __name__ == '__main__':
    from torchsummary import summary

    model = UNet3D(num_out_classes=3, input_channels=3, init_feat_channels=32)
    # print(model)
    summary(model, input_size=(3, 128, 128, 64), batch_size=-1, device='cpu')

打印的内容如下:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv3d-1     [-1, 32, 128, 128, 64]           2,624
       BatchNorm3d-2     [-1, 32, 128, 128, 64]              64
             PReLU-3     [-1, 32, 128, 128, 64]              32
            Conv3d-4     [-1, 64, 128, 128, 64]          55,360
       BatchNorm3d-5     [-1, 64, 128, 128, 64]             128
             PReLU-6     [-1, 64, 128, 128, 64]              64
         MaxPool3d-7       [-1, 64, 64, 64, 32]               0
            Conv3d-8       [-1, 64, 64, 64, 32]         110,656
       BatchNorm3d-9       [-1, 64, 64, 64, 32]             128
            PReLU-10       [-1, 64, 64, 64, 32]              64
           Conv3d-11      [-1, 128, 64, 64, 32]         221,312
      BatchNorm3d-12      [-1, 128, 64, 64, 32]             256
            PReLU-13      [-1, 128, 64, 64, 32]             128
        MaxPool3d-14      [-1, 128, 32, 32, 16]               0
           Conv3d-15      [-1, 128, 32, 32, 16]         442,496
      BatchNorm3d-16      [-1, 128, 32, 32, 16]             256
            PReLU-17      [-1, 128, 32, 32, 16]             128
           Conv3d-18      [-1, 256, 32, 32, 16]         884,992
      BatchNorm3d-19      [-1, 256, 32, 32, 16]             512
            PReLU-20      [-1, 256, 32, 32, 16]             256
        MaxPool3d-21       [-1, 256, 16, 16, 8]               0
           Conv3d-22       [-1, 256, 16, 16, 8]       1,769,728
      BatchNorm3d-23       [-1, 256, 16, 16, 8]             512
            PReLU-24       [-1, 256, 16, 16, 8]             256
           Conv3d-25       [-1, 512, 16, 16, 8]       3,539,456
      BatchNorm3d-26       [-1, 512, 16, 16, 8]           1,024
            PReLU-27       [-1, 512, 16, 16, 8]             512
  ConvTranspose3d-28      [-1, 512, 32, 32, 16]       2,097,664
           Conv3d-29      [-1, 256, 32, 32, 16]       5,308,672
      BatchNorm3d-30      [-1, 256, 32, 32, 16]             512
            PReLU-31      [-1, 256, 32, 32, 16]             256
           Conv3d-32      [-1, 256, 32, 32, 16]       1,769,728
      BatchNorm3d-33      [-1, 256, 32, 32, 16]             512
            PReLU-34      [-1, 256, 32, 32, 16]             256
  ConvTranspose3d-35      [-1, 256, 64, 64, 32]         524,544
           Conv3d-36      [-1, 128, 64, 64, 32]       1,327,232
      BatchNorm3d-37      [-1, 128, 64, 64, 32]             256
            PReLU-38      [-1, 128, 64, 64, 32]             128
           Conv3d-39      [-1, 128, 64, 64, 32]         442,496
      BatchNorm3d-40      [-1, 128, 64, 64, 32]             256
            PReLU-41      [-1, 128, 64, 64, 32]             128
  ConvTranspose3d-42    [-1, 128, 128, 128, 64]         131,200
           Conv3d-43     [-1, 64, 128, 128, 64]         331,840
      BatchNorm3d-44     [-1, 64, 128, 128, 64]             128
            PReLU-45     [-1, 64, 128, 128, 64]              64
           Conv3d-46     [-1, 64, 128, 128, 64]         110,656
      BatchNorm3d-47     [-1, 64, 128, 128, 64]             128
            PReLU-48     [-1, 64, 128, 128, 64]              64
           Conv3d-49      [-1, 3, 128, 128, 64]             195
================================================================
Total params: 19,077,859
Trainable params: 19,077,859
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 12.00
Forward/backward pass size (MB): 8544.00
Params size (MB): 72.78
Estimated Total Size (MB): 8628.78
----------------------------------------------------------------

其中,我们测试的参数量是19,077,859,论文中说的参数量:The architecture has 19069955 parameters in total. 有略微的差别。

后面再调用模型,进行一次前向传播,loss运算和反向回归。如果这里都通过了,那么后面构建训练代码,就更简单了很多。如下:

if __name__ == '__main__':
    input_channels = 3
    num_out_classes = 2
    init_feat_channels = 32

    batch_size = 4
    model = UNet3D(num_out_classes=num_out_classes, input_channels=input_channels, init_feat_channels=init_feat_channels)

    # B x C x Z x Y x X
    # 4 x 1 x 64 x 64 x 64

    input_batch_size = (batch_size, input_channels, 128, 128, 64)
    input_example = torch.rand(input_batch_size)

    unet = model.to(DEVICE)
    input_example = input_example.to(DEVICE)
    output = unet(input_example)
    # output = output.cpu().detach().numpy()
    # Expected output shape
    # B x N x Z x Y x X
    # 4 x 2 x 64 x 64 x 64
    expected_output_shape = (batch_size, num_out_classes, 128, 128, 64)
    print("Output shape = {}".format(output.shape))
    assert output.shape == expected_output_shape, "Unexpected output shape, check the architecture!"

    expected_gt_shape = (batch_size, 128, 128, 64)
    ground_truth = torch.ones(expected_gt_shape)
    ground_truth = ground_truth.long().to(DEVICE)

    # Defining loss fn
    ce_layer = torch.nn.CrossEntropyLoss()
    # Calculating loss
    ce_loss = ce_layer(output, ground_truth)
    print("CE Loss = {}".format(ce_loss))
    # Back propagation
    ce_loss.backward()

输出内容如下:

Output shape = torch.Size([4, 2, 128, 128, 64])
CE Loss = 0.6823387145996094

一个疑问:什么时候使用softmax?什么时候使用sigmoid
答:

第二个问题:训练阶段是不是不需要softmax/sigmoid?只在推理阶段使用呢?
答:

三、总结

UNet网络的结构,无论是二维的,还是三维的,都是比较容易理解的,这可能也是为什么那么受欢迎的原因之一吧。如果你看过之前那篇关于2D UNet的过程,再看本篇应该就简单的很多。觉得本篇更简单一些呢。

我觉得本篇最大的价值,就是:

  1. 逐模块的分析了结构;
  2. 对后续的模型构建提供了思路;
  3. 构建完模型需要先预测试,两种方式可选;
  4. 对模型的优势和劣势,分析。

如果你阅读的过程中,发现了问题和疑问,欢迎评论区交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1165634.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

QCC TX 音频输入切换+提示声音

QCC TX 音频输入切换提示声音 QCC蓝牙芯片(QCC3040 QCC3056 等等),AUX、I2S、USB输入 蓝牙音频输入,模拟输出是最常见的方式。 也可以再此基础上动态切换输入方式。 针对TX切换EQ,调节音量不能出提示声音问题,可以增…

Spring底层原理(五)

Spring底层原理(五) 本章内容 介绍Aware接口与InitializingBean接口、Bean的初始化与销毁、Scope Aware接口 作用:用于注入一些与容器相关的信息 类名作用BeanNameAware注入Bean的名称BeanFactoryAware注入BeanFactory容器ApplicationContextAware注入ApplicationContext容…

Pycharm 对容器中的 Python 程序断点远程调试

pycharm如何连接远程服务器的docker容器有两种方法: 第一种:pycharm通过ssh连接已在运行中的docker容器 第二种:pycharm连接docker镜像,pycharm运行代码再自动创建容器 本文是第一种方法的教程,第二种请点击以上的链接…

JavaScript中BOM与DOM

BOM window对象 所有的浏览器都支持window对象,他表示浏览器窗口, 所有 JavaScript 全局对象、函数以及变量均自动成为 window 对象的成员。 全局变量是 window 对象的属性。全局函数是 window 对象的方法。 接下来要讲的HTML DOM 的 document 也是…

遥遥领先,免费开源的django4-vue3前后端分离项目

星域后台管理系统前端介绍 🌿项目简介 本项目前端基于当下流行且常用的vue3作为主要技术栈进行开发,融合了typescript和element-plus-ui,提供暗黑模式和白昼模式两种主题以及全屏切换,开发bug少,简单易学&#xff0c…

攻略合集!游戏要领你一定要知道!

大家好!作为游戏玩家,我们都希望在游戏中能够成为顶尖的高手。为了帮助大家更好地掌握游戏的要领,我要分享一些实用的攻略和技巧。 首先,不同游戏有不同的技巧和要领。对于《绝地求生》来说,我们需要注重观察和战略规划…

C++和 C 混合编程处理

原因是因为有很多功能是用 C 语言开发的,而 C是兼容 C 的,C应该能直接使用这些功能,那么我们把 C调用 C 实现的功能的这个做法,称为混合编程 但是用 C 开发的功能,很可能已经用 C 编译器编程成目标文件(或打包成库了)…

Zinx框架-游戏服务器开发001:zinx框架的安装

文章目录 1 zinx下载地址1.1 zinx框架的源码路径:1.2 安装好之后动态库的位置 2 Zinx框架运行的基本概况3 测试Zinx-框架的基本使用3.0 流程预览3.1 初始化框架3.2 标准输入回显标准输出的编写思路3.2.1 回显Echo3.2.2 写标准输入stdin通道类,用通道输入…

协同办公系统:企业提质增效的利器

随着科技的不断发展,企业对于提高工作效率、优化管理流程、降低成本的需求日益迫切。协同办公系统应运而生,成为了许多企业提质增效的利器。那么,协同办公系统究竟是如何帮助企业实现这些目标的呢?本文将从以下几个方面进行详细阐…

【基于MRA:自适应高频融合和注入系数优化:Pansharpening】

Pansharpening Based on Adaptive High-Frequency Fusion and Injection Coefficients Optimization (基于自适应高频融合和注入系数优化的全色锐化) 全色锐化的目的是将多光谱(MS)图像与全色(PAN)图像融…

【实战Flask API项目指南】之六 数据库集成 SQLAlchemy

实战Flask API项目指南之 数据库集成 本系列文章将带你深入探索实战Flask API项目指南,通过跟随小菜的学习之旅,你将逐步掌握 Flask 在实际项目中的应用。让我们一起踏上这个精彩的学习之旅吧! 前言 在上一篇文章中,我们实现了…

私有化部署即时通讯软件WorkPlus,全面适配信创环境

对于企业而言,保护数据的安全至关重要。WorkPlus即时通讯软件允许企业在自己的服务器上部署一套私有化的聊天工具,确保数据完全受控于企业内部。通过私有化部署,企业可以有效地管理和保护敏感信息,防止数据泄露和滥用。 另外&…

React基础知识02

一、通过属性来传值(props) react中可以使用属性(props)可以传递给子组件,子组件可以使用这些属性值来控制其行为和呈现输出。 例子: // 1.1 父组件 import React, { useState } from react // 1.2引入子…

浅谈安科瑞直流电表在荷兰光伏充电桩系统中的应用

摘要:本文介绍了安科瑞直流电表在荷兰光伏充电桩系统中的应用。主要用于充电桩的电流电压电能的计量。 Abstract: This article introduces the application of Acrel DC meters in PV charging pile system in Netherlands.The device is measuring current,volt…

腾讯云域名备案后,如何解析到华为云服务器Linux宝塔面板

一、购买域名并且进行备案和解析,正常情况下,购买完域名,如果找不到去哪备案,可以在腾讯云上搜索“备案”关键词就会出现了,所以这里不做详细介绍,直接进行步骤提示: 二、申请ssl证书&#xff0…

mysql简单备份和恢复

版本:mysql8.0 官方文档 :MySQL :: MySQL 8.0 Reference Manual :: 7 Backup and Recovery 1.物理备份恢复 物理备份是以数据文件形式备份。这种方式效率高点,适合大型数据库备份。物理备份可冷备可热备。 使用mysqlbackup 命令进行物理备…

命名数据网络(NDN)介绍

命名数据网络的由来 IP网络最开始其解决的问题是两个实体间点对点通信需求,实现资源共享。(简单知道即可) 随着互联网的发展,互联网用户对internet的需求现已经发生了巨大变化。目前面临着以下挑战 首先是随着以内容为中心&…

力扣刷题 day63:11-02

1.字符串中的第一个唯一字符 给定一个字符串 s ,找到 它的第一个不重复的字符,并返回它的索引 。如果不存在,则返回 -1 。 方法一:两次遍历哈希表 #方法一:两次遍历哈希表 def firstUniqChar(s):d{}for i in s:if …

Leetcode—707.设计链表【中等】双链表的设计明天再写

2023每日刷题(十七) Leetcode—707.设计链表 设计单链表实现代码 typedef struct Node {int val;struct Node* next; } MyLinkedList;MyLinkedList* myLinkedListCreate() {MyLinkedList* mList (MyLinkedList *)malloc(sizeof(MyLinkedList));mList-&…

知乎盈利来源分析与指标体系构建

知乎用户画像 知乎所属行业:内容社区平台 知乎上的内容涉及的领域: 婚恋情感(300亿总阅读量,截止2022年12月)、法律纠纷(200亿)、教育(200亿)、游戏(150亿&…