【Python/Pytorch - 网络模型】-- 手把手搭建U-Net模型

news2024/11/17 1:55:55

在这里插入图片描述
文章目录

文章目录

  • 00 写在前面
  • 01 基于Pytorch版本的UNet代码
  • 02 论文下载

00 写在前面

通过U-Net代码学习,可以学习基于Pytorch的网络结构模块化编程,对于后续学习其他更复杂网络模型,有很大的帮助作用。

在01中,可以根据U-Net的网络结构(开头图片),进行模块化编程。包括卷积模块定义、上采样模块定义、输出卷积层定义、损失函数定义、网络模型定义等。

在模型调试过程中,可以先通过简单测试代码,进行代码调试。

01 基于Pytorch版本的UNet代码

# 库函数调用
import torch
import torch.nn as nn
from network.ops import TotalVariation
from torchvision.models import vgg19

# 卷积块定义
class conv_block(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            #nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True),
            nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
            #nn.BatchNorm2d(ch_out),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

# 上采样部分定义
class up_conv(nn.Module):
    def __init__(self,ch_in,ch_out):
        super(up_conv,self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
		    #nn.BatchNorm2d(ch_out),
			nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.up(x)
        return x

# 输出卷积层定义
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            #nn.ReLU(inplace=True),
        )
    def forward(self, x):
        x = self.conv(x)
        return x


class UNET_MODEL(nn.Module):
    def __init__(self, img_ch=3, output_ch=1,filter_dim=64):
        super().__init__()

        self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.Conv1 = conv_block(ch_in=img_ch, ch_out=filter_dim)
        self.Conv2 = conv_block(ch_in=64, ch_out=128)
        self.Conv3 = conv_block(ch_in=128, ch_out=256)
        self.Conv4 = conv_block(ch_in=256, ch_out=512)
        self.Conv5 = conv_block(ch_in=512, ch_out=1024)

        self.Up5 = up_conv(ch_in=1024, ch_out=512)
        self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)

        self.Up4 = up_conv(ch_in=512, ch_out=256)
        self.Up_conv4 = conv_block(ch_in=512, ch_out=256)

        self.Up3 = up_conv(ch_in=256, ch_out=128)
        self.Up_conv3 = conv_block(ch_in=256, ch_out=128)

        self.Up2 = up_conv(ch_in=128, ch_out=64)
        self.Up_conv2 = conv_block(ch_in=128, ch_out=64)

        self.Conv11 = outconv(64, output_ch)

    def forward(self, x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x5 = self.Conv5(x5)

        # decoding + concat path
        d5 = self.Up5(x5)
        d5 = torch.cat((x4, d5), dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3, d4), dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2, d3), dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1, d2), dim=1)
        d2 = self.Up_conv2(d2)

        T2 = self.Conv11(d2)

        return T2

# 损失函数定义
class loss_fun(nn.Module):
    def __init__(self, regular):
        super().__init__()
        self.tv = TotalVariation()
        self.regular = regular

    def forward(self, x, y):
        ychange = y[:, 0:1, :, :]
        mask = y[:, 1:2, :, :]
        return torch.add(torch.mean(torch.pow((x[:,:,:,:] - y[:,2:3,:,:])*ychange, 2)), self.regular* torch.mean(self.tv(x[:, :, :, :]*mask)))

class loss_fun_total(nn.Module):
    def __init__(self, regular):
        super().__init__()
        self.tv = TotalVariation()
        self.regular = regular

    def forward(self, x, y):
        loss1 = torch.mean(torch.pow((x[:,0:1,:,:] - y[:,0:1,:,:]*10), 2))
        return loss1

# 测试代码
if __name__ == '__main__':
	input_channels = 4
	output_channels = 1
	x = torch.ones([32, 4, 256, 256])
	model = UNET_MODEL(input_channels, output_channels)
	print('model initialization finished!')
	f = model(x)
	print(f)

02 论文下载

U-Net: deep learning for cell counting, detection, and morphometry
U-Net: Convolutional Networks for Biomedical Image Segmentation

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1812500.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

瓦片边界可视化工具

本文涉及的核心内容 瓦片边界可视化-VisibleTileBoundariesmeethigher/visible-tile-boundaries: visible tiles boundaries demo 一、瓦片边界可视化 1.1 背景 日常GIS开发中,需要了解瓦片是什么,瓦片展示的效果是什么样的。这种口头上抽象的东西&a…

计算机哈佛架构、冯·诺依曼架构对比

哈佛架构和冯诺依曼架构是两种不同的计算机系统架构,它们在存储器组织方式上有着显著的区别。下面是它们的原理、优缺点的对比以及一些常见的 MCU 采用的架构: 哈佛架构: 原理:哈佛架构将指令存储器(程序存储器&#x…

Androd adb命令汇总,app专项测试命令。

1.普通命令 1.1 devices命令 # 语法格式 :adb devices [-l] # 作用 :返回已连接设备的信息 # 示例 :adb devices : 返回设备的信息adb devices -l : 返回设备的详细信息1.2 help命令 # 语法格式 :adb --help # 作用 &…

攻防世界--杂项misc-2017_Dating_in_Singapore

题目信息 题目描述和附件分别是一串数字和新加坡日历,数字中间有短线-连接,刚好分成了12个字段。猜想对应了12个月 01081522291516170310172431-050607132027262728-0102030209162330-02091623020310090910172423-02010814222930-0605041118252627-020…

集合进阶(接口Collection(迭代器、增强for、Lambda表达式)、List中常见的方法和五种遍历方式、数据结构(栈、队列、数组、链表)

一、单列集合顶层Collection List系列集合:添加的元素是有序、可重复、有索引Set系列集合:添加的元素是无序、不重复、无索引 Collection是单列集合的祖宗接口,它的功能是全部单列集合都可以继承使用的。 Collection的遍历方式 1、迭代器——…

catia零件装配中通过指南针移动零件

1 将零件导入进来后 2 把指南针移动到零件上 具体移动哪个可以通过模型树点击选中,选中那个就可以移动那个。 这种情况需要注意的是 需要双击选择要移动零件的父节点 如下图,Product2蓝色表示是激活的,这样才可以单击选中下面的零件后通过…

STM32F103RCT6换STM32F103C8T6后delay函数延时了10倍

更换单片机步骤: 1、型号选择 2、启动文件,将HD改为MD。 3、引入对应的启动文件。 4、后面发现delay比之前延时了差不多10倍,解决办法:在初始化后加入SystemInit();即可。

Frontiers旗下期刊,23年分区表整理出炉!它还值得投吗?

本周投稿推荐 SSCI • 中科院2区,6.0-7.0(录用友好) EI • 各领域沾边均可(2天录用) CNKI • 7天录用-检索(急录友好) SCI&EI • 4区生物医学类,0.5-1.0(录用…

第十五届蓝桥杯pb组国赛E题[马与象] (15分)BFS算法 详解

博客主页:誓则盟约 系列专栏:IT竞赛 专栏 关注博主,后期持续更新系列文章 如果有错误感谢请大家批评指出,及时修改 感谢大家点赞👍收藏⭐评论✍ 问题描述: 小蓝有一个大小为 N N 的棋盘(棋…

110.网络游戏逆向分析与漏洞攻防-装备系统数据分析-装备与技能描述信息的处理

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 如果看不懂、不知道现在做的什么,那就跟着做完看效果,代码看不懂是正常的,只要会抄就行,抄着抄着就能懂了 内容…

javaWeb项目-ssm+vue医院住院信息管理系统功能介绍

项目关键技术 开发工具:IDEA 、Eclipse 编程语言: Java 数据库: MySQL5.7 框架:ssm、Springboot 前端:Vue、ElementUI 关键技术:springboot、SSM、vue、MYSQL、MAVEN 数据库工具:Navicat、SQLyog 1、Java简介 现代社…

Nvidia/算能 +FPGA+AI大算力边缘计算盒子:AI智能监控 用于沙滩救援

以色列的一个团队在人工智能领域取得的成果引起了轰动。 今天他们取得的成果源于多年前的一个想法。Netanel Eliav 和 Adam Bismut 是校园时代的旧伙伴,当时他们想要解决一个可以改变世界的问题,由此引出这样一个想法:溺水的 Bismut 漂流到死…

RV32M指令集

RV32M指令集 1、乘法运算2、除法运算1、乘法运算 MUL 指令(得到整数32位乘积(64位中的低32位)) MUL 指令用于执行两个带符号或无符号整数之间的乘法运算。其语法如下: mul rd, rs1, rs2 它将寄存器 rs1 和 rs2 中的值相乘,并将结果写入寄存器 rd 中。如果 rs1 和 rs2 都是有…

catia零件装配时预览零件的形状

这样的显示方式看不到 选择大或中图标就可预览零件形状

基于STM32的智能水产养殖系统(二)

TPS5433IDR TPS5433IDR 是一款由德州仪器 (Texas Instruments) 生产的高效降压转换器(Buck Converter)。它能够将较高的输入电压转换为较低的输出电压,适用于各种电源管理应用。 主要特性 输入电压范围: 5.5V 至 36V输出电压范围: 0.9V 至 …

惊艳的短视频:成都科成博通文化传媒公司

惊艳的短视频:瞬间之美,震撼心灵 在数字化时代,短视频以其短小精悍、内容丰富的特点,迅速占领了我们的屏幕和时间。而在这个浩如烟海的视频海洋中,总有一些短视频能够脱颖而出,以其惊艳的视觉效果、深刻的…

设计模式-代理模式(结构型)

代理模式 代理模式是一种结构型模式,它可以通过一个类代理另一个类的功能。代理类持有被代理类的引用地址,负责将请求转发给代理类,并且可以在转发前后做一些处理 图解 角色 抽象主题(Subject): 定义代理对象和被代理…

足球实况分析系统YOLO

① 足球运动员、裁判和球检测; ② 球员球队预测; ③ 足球地图上球员和球位置的估计; ④ 足球跟踪; 当你启动应用程序时,会自动加载两个演示视频以及推荐的设置和超参数. 1. 使用侧栏菜单“浏览文件”按钮上传视频…

什么是OCR转写服务?

OCR(Optical Character Recognition,光学字符识别)转写服务是一种技术,用于将图像或扫描文档中的文字转换为可编辑的文本格式。这项服务通过识别图像中的文字,并将其转换成计算机可读的文本形式,从而使得用…

力扣39. 组合总和

Problem: 39. 组合总和 文章目录 题目描述思路及解题方法复杂度Code 题目描述 思路及解题方法 1.创建一个 res 变量存储所有满足条件的组合结果,使用 track 变量记录当前的组合路径,使用 trackSum 变量记录当前路径中元素的和。 2.回溯方法 backtrack: 2.1.基本情况…