【Python/Pytorch - 网络模型】-- 手把手搭建3D U-Net模型

news2024/11/29 18:31:22

在这里插入图片描述
文章目录

文章目录

  • 00 写在前面
  • 01 基于Pytorch版本的3D UNet代码
  • 02 论文下载

00 写在前面

通过3D U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂3D网络模型,有很大的帮助作用。

在01中,可以根据3D U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、下采样模块定义、输出卷积层定义、网络模型定义等。

在模型调试过程中,可以先通过简单测试代码,进行代码调试。

01 基于Pytorch版本的3D UNet代码

# 库函数调用
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np

# from measure import Four_three


# 三维卷积块定义
class DoubleConv(nn.Module):
    """(Conv3D -> IN -> ReLU) * 2"""

    def __init__(self, in_channels, out_channels, num_groups = 8):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1,bias=True),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

# 下采样模块定义
class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.MaxPool3d(2,2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.encoder(x)

# 上采样模块定义
class Up(nn.Module):

    def __init__(self, in_channels, out_channels, trilinear = True):
        super().__init__()

        if trilinear:
            self.up = nn.Upsample(scale_factor = 2)
        else:
            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2, kernel_size = 2, stride = 2)

        self.conv = DoubleConv(in_channels, out_channels)
        self.downc = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1, bias=True)
        self.downr = nn.ReLU(inplace=True)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        diffZ = x2.size()[2] - x1.size()[2]
        diffY = x2.size()[3] - x1.size()[3]
        diffX = x2.size()[4] - x1.size()[4]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2, diffZ // 2, diffZ - diffZ // 2])

        x1 = self.downr(self.downc(x1))

        x = torch.cat([x2, x1], dim = 1)
        return self.conv(x)

# 输出卷积层定义
class Out(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv3d(in_channels, out_channels, kernel_size = 3, stride=1, padding=1)

    def forward(self, x):
        return self.conv(x)

# 3D-UNet模型定义
class 3DUNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=1,n_channels=64):
        super().__init__()
        self.in_channels = in_channels
        self.n_channels = n_channels

        self.conv = DoubleConv(in_channels, n_channels)
        self.enc1 = Down(n_channels, 2 * n_channels)
        self.enc2 = Down(2 * n_channels, 4 * n_channels)
        self.enc3 = Down(4 * n_channels, 8 * n_channels)
        self.enc4 = Down(8 * n_channels, 16 * n_channels)

        self.dec1 = Up(16 * n_channels, 8 * n_channels)
        self.dec2 = Up(8 * n_channels, 4 * n_channels)
        self.dec3 = Up(4 * n_channels, 2*n_channels)
        self.dec4 = Up(2 * n_channels, n_channels)
        self.out = Out(n_channels, out_channels) #(1,4,128,128,n)


    def forward(self, x):
        # print('size of x:', x.shape)
        x1 = self.conv(x)
        # print('size of x1:', x1.shape)
        x2 = self.enc1(x1)
        # print('size of x2:', x2.shape)
        x3 = self.enc2(x2)
        # print('size of x3:', x3.shape)
        x4 = self.enc3(x3)
        # print('size of x4:', x4.shape)
        x5 = self.enc4(x4)
        # print('size of x5:', x5.shape)

        mask = self.dec1(x5, x4)
        # print('size of mask:', mask.shape)
        mask = self.dec2(mask, x3)
        # print('size of mask:', mask.shape)
        mask = self.dec3(mask, x2)
        # print('size of mask:', mask.shape)
        mask = self.dec4(mask, x1)
        # print('size of mask:', mask.shape)
        mask = self.out(mask)
        # print('size of mask:', mask.shape)

        return mask

# 测试代码
if __name__ == '__main__':
	input_channels = 4
	output_channels = 1
	x = torch.ones([16, 4, 16, 16,16])
	model = 3DUNET(input_channels, output_channels)
	print('model initialization finished!')
	f = model(x)
	print(f)

02 论文下载

3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation
arXiv: 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1815520.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C#——方法函数详情

方法(函数) C#是面向对象的,所以C#中的方法也是相对于对象来说的,是指某个对象的行为,比如,有一个动物的类,兔子是这个动物类里的一个对象,那么跳这个行为就是兔子这个对象的方法了.其实也就是C中的函数(C是面向过程的,叫函数). 方法: 就是把一系列相关的代码组织到一块 用于…

优化Elasticsearch搜索性能:查询调优与索引设计

在构建基于 Elasticsearch 的搜索解决方案时,性能优化是关键。本文将深入探讨如何通过查询调优和索引设计来优化 Elasticsearch 的搜索性能,从而提高用户体验和系统效率。 查询调优 优化查询是提高 Elasticsearch 性能的重要方法。以下是一些有效的查询…

运筹学基础与应用(简洁版总复习)

第一章 线性规划及单纯形法 图解法 单纯形法 大m法 看案例(综合题) 化标准形式 目标函数的转换 min z变为max z 变量的变换 变量取值无约束 约束方程的转换 ≤:加一个松弛变量 ≥:减一个剩余变量 变量符号≤0的变换 保持变量≥…

免密支付存隐患 谨防“便捷”变“踩坑”

免密支付存隐患 谨防“便捷”变“踩坑” 当前,我国网购用户已超9亿人,越来越便捷的支付手段让网络消费体验更加“丝滑”。但免密支付、自动续费等方式在简化付款流程的同时,也成为一些平台“套路”消费者的手段,暗藏诱导消费陷阱。…

L1306——串口的配置

这里需要介绍一下串口的时钟来源,串口的时钟来源一共有三个,分别是: BUSCLK:由内部高频振荡器提供的CPU时钟,通常芯片出厂时设置为了32MHz。 MFCLK:只能使用固定的4MHz时钟(参考用户手册132页)。开启的话…

工程英语【计算机英语】

文章目录 第一专题 Vocabulary, Terms 和 Jargons 区别1.1 知识1.1.1 Vocabulary——词汇1.1.2 Terms——术语1.1.3 Jargons——行话1.1.4 Buzzword——流行语 1.2 阅读文章【5.2 Type of Connection】1.2.1 翻译1.2.2 回答问题 第二专题 Abbreviations2.1 知识2.1.1 Abbreviat…

C语言 指针——字符数组与字符指针:字符串的表示与存储

目录 字符串常量 字符串变量? 字符数组的定义和初始化 字符指针的定义和初始化 将字符指针指向一个字符串 用字符数组保存一个字符串 将字符指针指向一个字符数组 使用字符指针的基本原则 使用指针的基本原则 字符串常量 字符串变量?  C 语言…

海外媒体发稿渠道和方法有哪些?如何选择靠谱的国外媒体发稿服务商?

在选择海外媒体发稿服务商时,以下是一些关键点可以帮助您找到靠谱的服务商: 服务商的经验和口碑:查找该服务商在行业内的声誉和客户评价。拥有丰富经验和良好口碑的服务商通常更可靠。 媒体资源和覆盖范围:了解服务商所能提供的媒…

定个小目标之刷LeetCode热题(10)

这道题属于一道中等题&#xff0c;看来又得背题了&#xff0c;直接看题解吧&#xff0c;有两种解法 第一种动态规划法 状态&#xff1a;dp[i][j] 表示字符串s在[i,j]区间的子串是否是一个回文串 状态转移方程&#xff1a;当s[i] s[j] && (j - i < 2 || dp[i 1]…

【android】安卓入门学习

文档介绍&#xff1a;http://8.136.122.222/book/primary/kotlin/kotlin-intro.html 文档补充说明&#xff1a;https://blog.csdn.net/qq_42059717/category_12047508.html 一、搭建环境及工具安装 见文档 二、工具界面及项目文件介绍 ├── app //工程主模块名称 │ …

男士应该穿三角裤还是平角裤?三角内裤和平角内裤的区别!

在当今市场&#xff0c;男士内裤的材质种类琳琅满目&#xff0c;但令人遗憾的是&#xff0c;众多男士在选择内裤时却常常忽视舒适度与耐用性&#xff0c;导致穿着体验不佳&#xff0c;甚至出现破损起球的问题。作为一位专业的测评博主&#xff0c;我深感有必要为大家深度剖析男…

nsight systems gui报错

问题&#xff1a;使用命令&#xff1a;nsys-ui打开GUI&#xff0c;点击START以后报错&#xff0c;如图 解决&#xff1a; 命令使用&#xff1a;sudo nsys-ui

Javascript时间循环应用—nextTick()详解

简单易懂 关于nextTick()的理解-CSDN博客 【Vue面试专题】56道经典Vue面试题详解&#xff01;说说nextTick使用和原理&#xff1f;_哔哩哔哩_bilibili Vue.nextTick() 是 Vue.js 提供的一个全局 API&#xff0c;用于在 DOM 更新后执行延迟回调。它通常用于在数据更新后立即获取…

电视剧推荐

1、《春色寄情人》 2、《唐朝诡事录》 3、《南来北往》 4、《与凤行》 5、《利剑玫瑰》 6、《承欢记》

Apple ID已成历史,在ios18中正式更名为Apple Account

随着iOS18的首个开发者预览版成功推送&#xff0c;众多热衷于尝鲜的用户已纷纷升级并开启全新体验。在这个版本中&#xff0c;备受瞩目的Apple ID正式迎来了它的进化——更名为Apple Account&#xff0c;并且拥有了中文名称“Apple账户”或简称“苹果账户”。 不过目前官网还称…

数字员工将重塑工作与生产的未来格局?

数字员工&#xff0c;由AI、机器学习和自动化技术驱动&#xff0c;正逐渐取代或协助人类完成从基础到高端的任务&#xff0c;极大提升工作效率&#xff0c;并改变工作认知。它们不仅影响各行业&#xff0c;还重塑人与机器、社会、自然的关系。与二十世纪末的国企下岗变革相比&a…

Nginx与Gateway

Nginx与Gateway Nginx 基本介绍 Nginx 是一款轻量级的高性能 Web 服务器/反向代理服务器及电子邮件&#xff08;IMAP/POP3&#xff09;代理服务器。它由俄罗斯的 Igor Sysoev 所开发&#xff0c;最初供俄罗斯大型的门户网站及搜索引擎 Rambler 使用。 Nginx 的特点在于其占用…

GiantPandaCV | 提升分类模型acc(二):图像分类技巧实战

本文来源公众号“GiantPandaCV”&#xff0c;仅用于学术分享&#xff0c;侵权删&#xff0c;干货满满。 原文链接&#xff1a;提升分类模型acc(二)&#xff1a;图像分类技巧实战 上一篇文章GiantPandaCV | 提升分类模型acc(一)&#xff1a;BatchSize&LARS-CSDN博客探讨了…

开发了一个宝藏云桌面系统,编程小白强烈安利

大家会不会也会有这样的困扰&#xff0c;一个开发小白&#xff0c;在满怀激情的想踏入代码世界时&#xff0c;往往会被一系列复杂的环境配置和软件安装过程绊住了脚步。想象一下&#xff0c;如果你满心期待地想要运行一个简单的“Hello, World!”程序&#xff0c;或是尝试一段刚…

【OpenCV】opencv-4.9.0源码编译

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ&#xff1a;870202403 公众号&#xff1a;VTK忠粉 前言 本文分享OpenCV-4.9.0源码编译流程&#xff0c;包含CUDA模块&#xff0c;包含Python-opencv&#xff0c;希望对各位小伙伴有所帮助&#xff01; 感谢各位小伙伴的点赞…