UNet - unet网络

news2024/10/3 8:16:44

目录

1. u-net介绍

2. u-net网络结构

3. u-net 网络搭建

3.1 DoubleConv

3.2 Down 下采样

3.3 Up 上采样

3.4 网络输出

3.5 UNet 网络

UNet 网络

forward  前向传播

3.6 网络的参数

4. 完整代码


1. u-net介绍

Unet网络是医学图像分割领域常用的分割网络,因为网络的结构很像个U,所以称为Unet

Unet 网络是针对像素点的分类,之前介绍的LeNet、ResNet等等都是图像分类,最后分的是整幅图像的类别,而Unet是对像素点输出的是前景还是背景的分类

注:因为Unet 具体的网络框架均有所不同,例如有的连续卷积后会改变图像的size,有的上采样用的是线性插值的方法。这只介绍same卷积和上采样用的转置卷积

Unet网络是个U型结构,左边是Encoder,右边为Decoder

左边是下采样的过程,通过减少图像size,增加图像channel来提取特征。

右边是还原图像的过程,上采样将逐步还原图像的size,这里上采样的输入特征图不仅仅是上一步的输出,还包含了左边对应特征信息。

2. u-net网络结构

本章采用的unet网络如图,为了后面数据的训练和预测。这里实现的方式和下图有些细小的区别,具体的会在下面讲解

首先,网络输入图像的size设定为(480,480)的灰度图像(注意:这里输入是单通道的灰度图)

然后经过成对的3*3卷积,将图像的深度加深,变成维度为(64,480,480),这里因为图像的size没有变,又因为kernel_size = 3,stride = 1,因此需要保证padding = 1

接下来是下采样层,先经过一个最大池化层,stride = 2,kernel_size = 2 将图像的size变为原来的一半。然后接两个3*3 的卷积,输出的特征图维度是(128,240,240)

下采样层总共有四次,根据每次下采样都会将图像的size减半,图像的channel翻倍来计算的话。最后一次图像的size = 480 / (2^4) = 30 ,channel = 64 * (2^4) = 1024 ,所以最后一次下采样图像的维度为(1024,30,30)------> 这里和图上不一样,因为后面用的是转置卷积

左边的下采样部分实现后,就是右边的上采样部分

上采样会使图像的channel减半,size变为两倍,正好和下采样的部分反过来。这里利用的操作是转置卷积,转置卷积具体的实现这里不做介绍,主要看它的维度变换。转置卷积变换的公式为:

out = (in - 1) * stride - 2 * padding + ksize

这里为了保证图像的size变为两倍,所以要保证 out = 2 * in ,而in的系数2只能从stride来,所以公式变为out = 2 * in - 2 - 2 * padding + ksize ,这里我们让ksize = 2,因此padding = 0 就可以满足要求。而channel的减半只需要把卷积核的个数减半即可

之前介绍过,最后一层的维度是(1024,30,30),这样通过转置卷积的操作图像的维度就变成了(512,60,60),刚好等于左边下采样的维度!! 所以将它们加在一块,然后进行成对的3*3卷积

之后就是和下采样的次数一样,重复四次上采样,直到将图像还原成(64,480,480)

最后一步,如果是图像分类的话,这里应该是全连接层找最大的预测值了。但是Unet是像素点的分类,所以最后产生的也是一副图像,因为这时候图像的size已经是480不需要变了,只需要将图像的channel改变,所以这里只需要一个kernel_size = 1的卷积核就可以了。

注:最后输出图像的维度是(480,480)的灰度图像,准确的说是二值图像

3. u-net 网络搭建

3.1 DoubleConv

观察unet 网络可以发现,3*3的卷积核都是成对出现的,所以这里将成对卷积核的操作封装成一个类

1. 因为采用的是两个连续的3*3  卷积,不改变图像的size,所以这里卷积的参数要设置padding=1

2. ResNet 介绍过,BN代替Dropout 的时候,不需要Bias 

3. 最后经过ReLU 激活函数

3.2 Down 下采样

然后定义下采样的操作

 

1. 这里下采样采用的就是最大池化层,kernel_size = 2,padding =2 会让图像的size减半

2. 然后经过两个连续3*3 的卷积

3. 将 下采样+两个3*3 的卷积 封装成一个新的类Down

3.3 Up 上采样

然后是定义上采样

 

1. 上采样用的是转置卷积,会将图像的size扩大两倍

2.  注意这里不是定义成 Sequential ,因为 Sequential 会从上到下顺序传播。这里还需要一步尺度融合,就是拼接的操作

3. 前向传播的时候,图像首先上采样,会将channel减小一半,size扩大两倍。这样就和左边对应的下采样的位置维度一致,将它们通过torch.cat 拼接,dim = 1是因为batch的维度是0 。然后经过两个3*3 的卷积就行了

3.4 网络输出

最后网络的输出很简单,经过一个1*1 的卷积核,不改变size的情况下。通过卷积核的个数调整图像的channel就行了

3.5 UNet 网络

UNet 网络

网络的框架很简单,因为每个小的模块已经搭好了,将它们拼接起来就行了

因为搭建小的模块的时候,我们对于模块的输入都是in和out channel,所以在定义网络的时候,每个模块只要传入对应的channel就行了。

这里按照UNet 网络的框架设置

 

forward  前向传播

前向传播的过程如下:

在下采样的时候,每个输出都要用变量保存,为了和后面上采样拼接使用

 

3.6 网络的参数

# 计算 UNet 的网络参数个数
model = UNet(in_channels=1,num_classes=1)
print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in model.parameters()))) 

UNet 网络参数个数为:

 

4. 完整代码

代码:

import torch.nn as nn
import torch


# 搭建unet 网络
class DoubleConv(nn.Module):    # 连续两次卷积
    def __init__(self,in_channels,out_channels):
        super(DoubleConv,self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,kernel_size=3,padding=1,bias=False),     # 3*3 卷积核
            nn.BatchNorm2d(out_channels),                                               # 用 BN 代替 Dropout
            nn.ReLU(inplace=True),                                                      # ReLU 激活函数

            nn.Conv2d(out_channels,out_channels,kernel_size=3,padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):    # 前向传播
        x = self.double_conv(x)
        return x


class Down(nn.Module):   # 下采样
    def __init__(self,in_channels,out_channels):
        super(Down, self).__init__()
        self.downsampling = nn.Sequential(
            nn.MaxPool2d(kernel_size=2,stride=2),
            DoubleConv(in_channels,out_channels)
        )

    def forward(self,x):
        x = self.downsampling(x)
        return x


class Up(nn.Module):    # 上采样
    def __init__(self, in_channels, out_channels):
        super(Up,self).__init__()

        self.upsampling = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) # 转置卷积
        self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.upsampling(x1)
        x = torch.cat([x2, x1], dim=1)  # 从channel 通道拼接
        x = self.conv(x)
        return x


class OutConv(nn.Module):   # 最后一个网络的输出
    def __init__(self, in_channels, num_classes):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class UNet(nn.Module):   # unet 网络
    def __init__(self, in_channels = 1, num_classes = 1):
        super(UNet, self).__init__()
        self.in_channels = in_channels                  # 输入图像的channel
        self.num_classes = num_classes                  # 网络最后的输出

        self.in_conv = DoubleConv(in_channels, 64)      # 第一层

        self.down1 = Down(64, 128)                      # 下采样过程
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)

        self.up1 = Up(1024, 512)                        # 上采样过程
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)

        self.out_conv = OutConv(64, num_classes)        # 网络输出

    def forward(self, x):           # 前向传播    输入size为 (10,1,480,480),这里设置batch = 10

        x1 = self.in_conv(x)        # torch.Size([10, 64, 480, 480])
        x2 = self.down1(x1)         # torch.Size([10, 128, 240, 240])
        x3 = self.down2(x2)         # torch.Size([10, 256, 120, 120])
        x4 = self.down3(x3)         # torch.Size([10, 512, 60, 60])
        x5 = self.down4(x4)         # torch.Size([10, 1024, 30, 30])

        x = self.up1(x5, x4)        # torch.Size([10, 512, 60, 60])
        x = self.up2(x, x3)         # torch.Size([10, 256, 120, 120])
        x = self.up3(x, x2)         # torch.Size([10, 128, 240, 240])
        x = self.up4(x, x1)         # torch.Size([10, 64, 480, 480])
        x = self.out_conv(x)        # torch.Size([10, 1, 480, 480])

        return x


# 计算 UNet 的网络参数个数
model = UNet(in_channels=1,num_classes=1)
print("Total number of paramerters in networks is {}  ".format(sum(x.numel() for x in model.parameters())))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/7233.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

软件测试找bug小技巧总结,从初级跨入中级测试......

目录:导读前言一、必备知识二、定位技巧三、常用小技巧方法1、分析问题场景进行预判2、关注请求体的状态码3、关注请求的入参与响应数据4、查看日志5、经验法则四、总结前言 测试工作 测试的过程就是一个寻找影响产品功能和用户体验bug的过程,测试人员发…

C++之多态详解

文章目录前言一、多态的概念概念二、多态的定义及实现1.多态的构成条件2.虚函数3.虚函数的重写3.1多态条件探究(1)多态条件探究一:不符合重写 -- virtual函数(2)多态条件探究二:不符合重写 -- 不是父类的指针或者引用调用(3)多态条件探究三:不符合重写 -- 参数不同(4)多态条件探…

Vue3 - ref 基本类型(通俗易懂,详细教程)

简介 它是 Vue3 提供的一个用于创建基本数据类型的函数,能将普通的数据包装为响应式。 大白话说,就是咱们要创建一个响应式变量,需要通过这玩意才行! 回忆 Vue2 不理解没关系啊,我们先来回顾一下 Vue2 中是如何做到的…

MaxViT: Multi-Axis Vision Transformer

论文:https://arxiv.org/abs/2204.01697 代码地址:https://github.com/google-research/maxvit 在本文中,介绍了一种高效且可扩展的注意力模型,称之为多轴注意力,该模型由两个方面组成:分块的局部注意力和…

开源作品:引流宝!集活码、短网址等功能为一体的工具!致力于提高引流效率,减少资源流失!

前言 开发这款工具的初衷是为了辅助自己的工作,提供自己日常工作的效率,自己使用了一段时间下来觉得很有用,于是完善之后开源。如今已经开源近2年,第一个版本是在2020年9月份开源,收获了390个star,后来持续…

用ue4怎么制作一个物体故障闪烁的特效

这是一篇关于电子标牌出现故障时可以使用的毛刺效应的文章。本文将其分解为两个简单的效果,将使用 GIF 来解释它们。 噪音效果 第一个组合是噪音效果。 在 Component Mask 中指定 G 矢量并在 Sine 和 Ceil 中调整 G 值。要将线条更改为水平移动,请在 …

操作系统4小时速成:操作系统的基本概念,它是系统软件,管理处理机、存储器、io设备、文件,并发和共享是最基本特征,还有虚拟和异步

操作系统4小时速成:操作系统的基本概念,它是系统软件,管理处理机、存储器、io设备、文件,并发和共享是最基本特征,还有虚拟和异步 2022找工作是学历、能力和运气的超强结合体,遇到寒冬,大厂不招…

二叉树的存储结构

引言: 对于二叉树的存储,我们可以采取顺序存储和链式存储结构 顺序存储结构 ● 按编号次序存储节点 • 对树中每个节点进行编号 • 其编号从小到大的顺序就是节点在连续存储单元的先后次序。 我们是从编号为1开始,为了保持数组位序和编号保持…

EMS Advanced Data Import高级数据导入选项Crack版

EMS Advanced Data Import高级数据导入选项Crack版 EMS Advanced Data Import是Dolphi和CBuilder应用。允许您一次以著名的MS Excel、MS Access、DBF、XML、TXT、CSV、ODF和HTML格式输入数据文件。 EMS高级数据导入选项: 导入流行格式信息:S Excel 97-20…

什么是shuffle?shuffle的原理及过程

目录 一、什么是shuffle 二、为什么要引入shuffle,有哪些影响 三、shuffle的工作原理 1、shuffle的阶段 2、shuffle的中间文件 3、Shuffle Write 4、Shuffle Read 四、总结回顾 一、什么是shuffle 类比分公司的人与物和 Spark 的相关概念是这样对应的&#…

同时看过 unreal4 和 Unity 源代码的人觉得哪个引擎架构更好?

同时看过 unreal4 和 Unity 源代码的人觉得哪个引擎架构更好? UE VS U3D 技术策略上 U3D技术策略是很保守的,发出来的featurelist测试覆盖率无可非议,开发者无需多少新功能的熟悉测试成本。 UE4的技术策略是激进的,发出来的fea…

嵌入式开发学习之--点亮LED灯(上)

在嵌入式学习里,点亮LED灯的地位就如同编程语言学习里的“hello world”,是每个初学者都必须经历的一关,因为点亮了LED灯,至少可以说明几件事: 1.开发环境没问题,包括软件,硬件都没问题。 2.能…

电脑重装系统前怎么备份,重装系统怎么备份数据

有很多用户想把操作系统升级成为Win11的,但是又不知道怎么备份原来的数据,担心升级操作系统后,电脑中的重要数据全部丢失了。电脑重装系统前怎么备份?”这里小编就给我们详细介绍一下吧! 一、升级Windows 11系统要不要…

梯度多云管理技术架构的优势

随着云计算模式的日渐成熟,作为私有云和公有云的混合形态,混合云迎来了爆发期。在混合云的建设过程中,多云管理能力成为关键,梯度多云管理平台是多云时代下的服务管理利器。梯度多云管理平台是对多个公有云、私有云及各种异构资源…

视频声音怎么翻译?这几个办法教你实现视频声音翻译成中文

现如今刷视频已经成为我们的日常生活中不可缺少的一部分了,例如有时我们看到一些有用的教学视频,可能会想要把这些视频保存下来,但有些视频却都是英文的,有些小伙伴可能英语基础不好,查看起来不方便,这个时…

MySQL运算符

算术运算符 算术运算符主要用于数学运算,其可以连接运算符前后的两个数值或表达式,对数值或表达式进行加()、减(-)、乘(*)、除(/)和取模(%&#…

【math】Hiden Markov Model 隐马尔可夫模型了解

文章目录Introduction to Hidden Markov ModelIntroductionMarkov chainHidden Markov Model(HMM)Three QuestionsQ1: evaluate problem -- Forward algorithmQ2: decode problem -- Viterbi algorithmQ3: learn problem -- Baum-Welch algorithmApplicationIntroduction to Hi…

重装Windows系统教程(U盘制作+重装系统)

一、U盘制作 找一个不用的U盘,大小建议在15G以上,因为后面要存储下载好的电脑系统。U盘在被制作成系统盘的时候会被格式化,注意使用前将有用的信息提前保存以免丢失。 第一步:用能够正常联网的电脑打开U盘制作网站,打开…

MySQL解决group by分组后未排序问题

MySQL解决group by分组后未排序问题一、遇见问题1、错误SQL2、正确SQL一、遇见问题 当我们要实现SQL分组后取第一条数据则需要进行排序结果作为子查询后分组 CREATE TABLE op_joke (id int(11) NOT NULL AUTO_INCREMENT,name1 varchar(255) DEFAULT NULL,name2 varchar(255) D…

Spring

Spring[TOC](Spring)1、概述1.1、优点1.2、组成2. IOC概述2.1 什么是IOC2.1.1 推导过程2.1.2 IOC本质2.2 HelloSpring2.2.1 导入Jar包2.2.2 编写代码2.2.2 思考2.3 IOC过程2.4 IOC 接口3. Bean 管理3.1 基于xml方式——set方法注入3.2 FactoryBean3.3 bean 作用域3.4 bean 生命…