Unet网络架构讲解(从零到一,逐行编写并重点讲解数据维度变化)

news2024/9/24 17:14:01

📚博客主页:knighthood2001
公众号:认知up吧 (目前正在带领大家一起提升认知,感兴趣可以来围观一下)
🎃知识星球:【认知up吧|成长|副业】介绍
❤️感谢大家点赞👍🏻收藏⭐评论✍🏻,您的三连就是我持续更新的动力❤️
🙏笔者水平有限,欢迎各位大佬指点,相互学习进步!

今天开始讲解一下Unet网络架构以及Pytorch代码构建。

整体架构图

在这里插入图片描述
这张图片应该是Unet网络最出名的图片,网络形状像“U”,故被称为U-net。

网络讲解

我觉得很多人对这个网络架构可能还是一知半解的,包括我最初也是这样的。

首先就是这几个箭头表示的是什么,这几个箭头相比于VGG16网络架构,难了不少。
因为VGG16网络架构中只有卷积层、全连接层,不涉及到特别复杂的操作。

  1. conv 3x3,ReLu就是卷积层,其中卷积核大小是3x3,然后经过ReLu激活。
  2. copy and crop的意思是复制和裁剪。这块内容我觉得很多人最初和我一样,不明白是什么意思,这里的意思就是对于你输出的尺寸,你需要进行复制并进行中心剪裁。方便和后面上采样生成的尺寸进行拼接。
  3. max pool 2x2,就是最大池化层,卷积核为2x2。
  4. up-conv 2x2:这里对于初学者来说,是最难领悟的地方,因为看不懂这个符号是啥意思。我最初以为是upsample+conv2d,试了一下,好像生成不了符合要求的尺寸,后来想了一下,这个是不是就是反卷积,用来上采样的,然后试了一下,可以实现,并且卷积核也是2x2。本文中使用的就是ConvTranspose2d()函数进行该操作。
  5. conv 1x1 这里就是卷积层,卷积核大小是1x1。

网络架构图,大家需要好好理解一下。

以上内容,我会在接下来的文章进行讲解。

左边部分代码讲解

导入包

import torch
import torch.nn as nn

Unet网络的左边部分和VGG16网络结构类似,都是卷积+最大池化,因此这部分讲解,可以看我之前写的这篇文章,里面着重讲了参数如何设置,也希望基础不好的,先去这篇文章补一补。

深度学习VGG16网络构建(Pytorch代码从零到一精讲,帮助理解网络的参数定义)

第一块内容

        # 由572*572*1变成了570*570*64
        self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.relu1_1 = nn.ReLU(inplace=True)
        # 由570*570*64变成了568*568*64
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)  
        self.relu1_2 = nn.ReLU(inplace=True)

由Unet网络架构图,可以看出输入图像是1x572x572大小,其中的1代表的是通道数(后续可以自己更改成自己想要的,比如3通道),输出通道是64,并且通过conv3x3,得知卷积核为3x3尺寸,并且由图片中的尺寸变成570x570,因此可以得出相关的参数值in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0,整个图片的蓝色箭头的卷积操作都是这样,因此kernel_size=3, stride=1, padding=0可以固定了。只需要更改输入和输出通道数的大小即可。
数据维度变化:1x572x572->64x570x570->64x568x568

最大池化层1

# 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64
self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  

最大池化的卷积核和步长都设置为2,使得输出尺寸减半,通道数不变。
数据维度变化:64x568x568->64x284x284

第二块内容

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)  # 284*284*64->282*282*128
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0)  # 282*282*128->280*280*128
        self.relu2_2 = nn.ReLU(inplace=True)

数据维度变化:64x284x284->128x282x282->128x280x280

最大池化层2

# 采用最大池化进行下采样  280*280*128->140*140*128
self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  

最大池化的卷积核和步长都设置为2,使得输出尺寸减半,通道数不变。
数据维度变化:128x280x280->128x140x140

Unet左边部分剩下内容,等等等等(有空补)

Unet左边部分汇总

由Unet网络架构图,可以看出,每经过一次卷积+relu操作,图像尺寸-2,可以得出padding=0(VGG16中padding=1,因此使得图像尺寸不变);每经过一次最大池化,图像尺寸减半。

左边部分代码如下:

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)  # 由572*572*1变成了570*570*64
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)  # 由570*570*64变成了568*568*64
        self.relu1_2 = nn.ReLU(inplace=True)

        self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)  # 284*284*64->282*282*128
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0)  # 282*282*128->280*280*128
        self.relu2_2 = nn.ReLU(inplace=True)

        self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  280*280*128->140*140*128

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=0)  # 140*140*128->138*138*256
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)  # 138*138*256->136*136*256
        self.relu3_2 = nn.ReLU(inplace=True)

        self.maxpool_3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  136*136*256->68*68*256

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=0)  # 68*68*256->66*66*512
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0)  # 66*66*512->64*64*512
        self.relu4_2 = nn.ReLU(inplace=True)

        self.maxpool_4 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  64*64*512->32*32*512

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=0)  # 32*32*512->30*30*1024
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0)  # 30*30*1024->28*28*1024
        self.relu5_2 = nn.ReLU(inplace=True)

前向传播函数中,你不能将所有对象都写成x,因为这个网络涉及到copy and crop,如果你全部都当作x,那么就无法复制和裁剪了,因为每次都是对最终结果进行复制,而不是中间步骤进行复制。

注意一下,下面的写法是错误的。
在这里插入图片描述
原因是因为Unet网络需要copy,因此你不能将所有层的输出都定义为x。

**正确做法应该是,在最大池化(下采样)之前,你需要有个新变量保存输出的内容,方便后续进行复制和裁剪。**这里不知道你能否听懂。代码如下:

    def forward(self, x):
        x1 = self.conv1_1(x)
        x1 = self.relu1_1(x1)
        x2 = self.conv1_2(x1)
        x2 = self.relu1_2(x2)  # 这个后续需要使用
        down1 = self.maxpool_1(x2)

        x3 = self.conv2_1(down1)
        x3 = self.relu2_1(x3)
        x4 = self.conv2_2(x3)
        x4 = self.relu2_2(x4)  # 这个后续需要使用
        down2 = self.maxpool_2(x4)

        x5 = self.conv3_1(down2)
        x5 = self.relu3_1(x5)
        x6 = self.conv3_2(x5)
        x6 = self.relu3_2(x6)  # 这个后续需要使用
        down3 = self.maxpool_3(x6)

        x7 = self.conv4_1(down3)
        x7 = self.relu4_1(x7)
        x8 = self.conv4_2(x7)
        x8 = self.relu4_2(x8)  # 这个后续需要使用
        down4 = self.maxpool_4(x8)

        x9 = self.conv5_1(down4)
        x9 = self.relu5_1(x9)
        x10 = self.conv5_2(x9)
        x10 = self.relu5_2(x10)

右边部分代码讲解

右边部分的架构如下,当然,由于Unet网络的特殊性,不能只看右半边。
在这里插入图片描述
右半部分每一层最开始的数据,由两部分组成,一部分由up-conv 2x2的上采样组成,另外一部风是由左边部分进行复制并进行中心裁剪后得到的,然后对这两部分进行拼接。
在这里插入图片描述
以最下面的绿色箭头这部分,举个例子。

最下面的是1024x28x28的图像,经过上采样(绿色箭头),得到512x56x56的图像,尺寸扩大一倍,通道数减半。

然后看最下面的灰色的横向箭头。灰色箭头左边的图像是512x64x64,然后对其进行复制并中心裁剪(中心裁剪是看图得出的),最后得到512x56x56,然后和刚刚说的上采样得到的图像进行拼接,最后得出1024x56x56,我这应该讲的很清楚了。我最开始的时候,这地方没有仔细看图,一直在想到底是如何得出的。

有了上面这个例子,大家应该就能理解右半部分了。
接下来就实现这个上面说的。
注意:我在init中只是定义了上采样的函数,没有涉及到copy and crop,这个我放到forward函数中实现。
下面这四个上采样,就是图片中绿色箭头部分,大家可以关注一下数据维度的变化。

上采样部分代码

# 接下来实现上采样中的up-conv2*2
self.up_conv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0) # 28*28*1024->56*56*512

数据维度变化:1024x28x28->512x56x56

self.up_conv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0) # 52*52*512->104*104*256

数据维度变化:512x52x52->256x104x104

self.up_conv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0) # 100*100*256->200*200*128

数据维度变化:256x100x100->128x200x200

self.up_conv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0) # 196*196*128->392*392*64

数据维度变化:128x196x196->64x392x392


右半部分的卷积

右边部分的卷积层也有四个大层,每个大层经过两个卷积层。

        self.conv6_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=0)  # 56*56*1024->54*54*512
        self.relu6_1 = nn.ReLU(inplace=True)
        self.conv6_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0)  # 54*54*512->52*52*512
        self.relu6_2 = nn.ReLU(inplace=True)

数据维度变化:1024x56x56->512x54x54->512x52x52

        self.conv7_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0)  # 104*104*512->102*102*256
        self.relu7_1 = nn.ReLU(inplace=True)
        self.conv7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)  # 102*102*256->100*100*256
        self.relu7_2 = nn.ReLU(inplace=True)

数据维度变化:512x104x104->256x102x102->256x100x100

        self.conv8_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0)  # 200*200*256->198*198*128
        self.relu8_1 = nn.ReLU(inplace=True)
        self.conv8_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0)  # 198*198*128->196*196*128
        self.relu8_2 = nn.ReLU(inplace=True)

数据维度变化:256x200x200->128x198x198->128x196x196

        self.conv9_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0)  # 392*392*128->390*390*64
        self.relu9_1 = nn.ReLU(inplace=True)
        self.conv9_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)  # 390*390*64->388*388*64
        self.relu9_2 = nn.ReLU(inplace=True)

数据维度变化:128x392x392->64x390x390->64x388x388

最后的conv 1x1

        # 最后的conv1*1
        self.conv_10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0)

这个代码就是最后的conv1x1操作,输入通道数为64,输出通道数为2,卷积核大小为1,步长为1,padding=0,使得

数据维度变化:64x388x388->2x388x388

其中的输出通道数可以根据自己的需要进行更改。

copy and crop的实现,并且实现拼接操作

上面的代码是init中定义好的层,copy and crop操作没有能够直接实现的函数,因此我放到forward函数中。

    # 中心裁剪,
    def crop_tensor(self, tensor, target_tensor):
        target_size = target_tensor.size()[2]
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        # 如果原始张量的尺寸为10,而delta为2,那么"delta:tensor_size - delta"将截取从索引2到索引8的部分,长度为6,以使得截取后的张量尺寸变为6。
        return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

首先我实现了一个这样的函数。这个函数可以帮助我将tensor中心裁剪成target_tensor的尺寸,符合Unet网络的需求。

        # 第一次上采样,需要"Copy and crop"(复制并裁剪)
        up1 = self.up_conv_1(x10)  # 得到56*56*512
        # 需要对x8进行裁剪,从中心往外裁剪
        crop1 = self.crop_tensor(x8, up1)
        # 拼接操作
        up_1 = torch.cat([crop1, up1], dim=1)

这是第一次实现上采样并且进行拼接。

首先up1 = self.up_conv_1(x10)这段代码实现上采样,得到512x56x56的数据,x8就是经过conv4_2和relu操作后,处在左下角灰色箭头左边的数据,其维度是512x64x64,我们需要将其裁剪成up1的形状,因此可以调用self.crop_tensor函数,得到crop1,其维度和up1一样,都是512x56x56。

然后就可以进行拼接,使用torch.cat()函数对张量列表在指定维度上进行拼接,这里就是将crop1和up1进行在通道数维度上的拼接,最后拼接成1024x56x56大小的数据(由unet架构图中可以看出,crop1在前面,up1在后面)。


然后经过两次卷积后,继续上采样,copy and crop,然后进行拼接。

这是第二次的这个过程:上采样+裁剪+拼接

		# 第二次上采样,需要"Copy and crop"(复制并裁剪)
        up2 = self.up_conv_2(y2)
        # 需要对x6进行裁剪,从中心往外裁剪
        crop2 = self.crop_tensor(x6, up2)
        # 拼接
        up_2 = torch.cat([crop2, up2], dim=1)

同理:经过两次卷积后,继续上采样,copy and crop,然后进行拼接。

这是第三次的这个过程:上采样+裁剪+拼接

        # 第三次上采样,需要"Copy and crop"(复制并裁剪)
        up3 = self.up_conv_3(y4)
        # 需要对x4进行裁剪,从中心往外裁剪
        crop3 = self.crop_tensor(x4, up3)
        up_3 = torch.cat([crop3, up3], dim=1)

同理:经过两次卷积后,继续上采样,copy and crop,然后进行拼接。

这是第四次的这个过程:上采样+裁剪+拼接

        # 第四次上采样,需要"Copy and crop"(复制并裁剪)
        up4 = self.up_conv_4(y6)
        # 需要对x2进行裁剪,从中心往外裁剪
        crop4 = self.crop_tensor(x2, up4)
        up_4 = torch.cat([crop4, up4], dim=1)

最终代码展示

这个代码我敢肯定,这是全网最基础、最简单的代码,但是是最适合小白的代码。

并且我配上了本代码的具体名称对应的层,比如x1就是第一个conv 3x3,大家可以自己对应着看。
在这里插入图片描述

import torch
import torch.nn as nn

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.conv1_1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=0)  # 由572*572*1变成了570*570*64
        self.relu1_1 = nn.ReLU(inplace=True)
        self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)  # 由570*570*64变成了568*568*64
        self.relu1_2 = nn.ReLU(inplace=True)

        self.maxpool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样,图片大小减半,通道数不变,由568*568*64变成284*284*64

        self.conv2_1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=0)  # 284*284*64->282*282*128
        self.relu2_1 = nn.ReLU(inplace=True)
        self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0)  # 282*282*128->280*280*128
        self.relu2_2 = nn.ReLU(inplace=True)

        self.maxpool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  280*280*128->140*140*128

        self.conv3_1 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=0)  # 140*140*128->138*138*256
        self.relu3_1 = nn.ReLU(inplace=True)
        self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)  # 138*138*256->136*136*256
        self.relu3_2 = nn.ReLU(inplace=True)

        self.maxpool_3 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  136*136*256->68*68*256

        self.conv4_1 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=0)  # 68*68*256->66*66*512
        self.relu4_1 = nn.ReLU(inplace=True)
        self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0)  # 66*66*512->64*64*512
        self.relu4_2 = nn.ReLU(inplace=True)

        self.maxpool_4 = nn.MaxPool2d(kernel_size=2, stride=2)  # 采用最大池化进行下采样  64*64*512->32*32*512

        self.conv5_1 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=0)  # 32*32*512->30*30*1024
        self.relu5_1 = nn.ReLU(inplace=True)
        self.conv5_2 = nn.Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0)  # 30*30*1024->28*28*1024
        self.relu5_2 = nn.ReLU(inplace=True)

        # 接下来实现上采样中的up-conv2*2
        self.up_conv_1 = nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=2, stride=2, padding=0) # 28*28*1024->56*56*512


        self.conv6_1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=0)  # 56*56*1024->54*54*512
        self.relu6_1 = nn.ReLU(inplace=True)
        self.conv6_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=0)  # 54*54*512->52*52*512
        self.relu6_2 = nn.ReLU(inplace=True)

        self.up_conv_2 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=2, stride=2, padding=0) # 52*52*512->104*104*256

        self.conv7_1 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=0)  # 104*104*512->102*102*256
        self.relu7_1 = nn.ReLU(inplace=True)
        self.conv7_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=0)  # 102*102*256->100*100*256
        self.relu7_2 = nn.ReLU(inplace=True)

        self.up_conv_3 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=2, stride=2, padding=0) # 100*100*256->200*200*128


        self.conv8_1 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=0)  # 200*200*256->198*198*128
        self.relu8_1 = nn.ReLU(inplace=True)
        self.conv8_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=0)  # 198*198*128->196*196*128
        self.relu8_2 = nn.ReLU(inplace=True)

        self.up_conv_4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2, padding=0) # 196*196*128->392*392*64


        self.conv9_1 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=0)  # 392*392*128->390*390*64
        self.relu9_1 = nn.ReLU(inplace=True)
        self.conv9_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)  # 390*390*64->388*388*64
        self.relu9_2 = nn.ReLU(inplace=True)

        # 最后的conv1*1
        self.conv_10 = nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1, stride=1, padding=0)

    # 中心裁剪,
    def crop_tensor(self, tensor, target_tensor):
        target_size = target_tensor.size()[2]
        tensor_size = tensor.size()[2]
        delta = tensor_size - target_size
        delta = delta // 2
        # 如果原始张量的尺寸为10,而delta为2,那么"delta:tensor_size - delta"将截取从索引2到索引8的部分,长度为6,以使得截取后的张量尺寸变为6。
        return tensor[:, :, delta:tensor_size - delta, delta:tensor_size - delta]

    def forward(self, x):
        x1 = self.conv1_1(x)
        x1 = self.relu1_1(x1)
        x2 = self.conv1_2(x1)
        x2 = self.relu1_2(x2)  # 这个后续需要使用
        down1 = self.maxpool_1(x2)

        x3 = self.conv2_1(down1)
        x3 = self.relu2_1(x3)
        x4 = self.conv2_2(x3)
        x4 = self.relu2_2(x4)  # 这个后续需要使用
        down2 = self.maxpool_2(x4)

        x5 = self.conv3_1(down2)
        x5 = self.relu3_1(x5)
        x6 = self.conv3_2(x5)
        x6 = self.relu3_2(x6)  # 这个后续需要使用
        down3 = self.maxpool_3(x6)

        x7 = self.conv4_1(down3)
        x7 = self.relu4_1(x7)
        x8 = self.conv4_2(x7)
        x8 = self.relu4_2(x8)  # 这个后续需要使用
        down4 = self.maxpool_4(x8)

        x9 = self.conv5_1(down4)
        x9 = self.relu5_1(x9)
        x10 = self.conv5_2(x9)
        x10 = self.relu5_2(x10)

        # 第一次上采样,需要"Copy and crop"(复制并裁剪)
        up1 = self.up_conv_1(x10)  # 得到56*56*512
        # 需要对x8进行裁剪,从中心往外裁剪
        crop1 = self.crop_tensor(x8, up1)
        up_1 = torch.cat([crop1, up1], dim=1)

        y1 = self.conv6_1(up_1)
        y1 = self.relu6_1(y1)
        y2 = self.conv6_2(y1)
        y2 = self.relu6_2(y2)

        # 第二次上采样,需要"Copy and crop"(复制并裁剪)
        up2 = self.up_conv_2(y2)
        # 需要对x6进行裁剪,从中心往外裁剪
        crop2 = self.crop_tensor(x6, up2)
        up_2 = torch.cat([crop2, up2], dim=1)

        y3 = self.conv7_1(up_2)
        y3 = self.relu7_1(y3)
        y4 = self.conv7_2(y3)
        y4 = self.relu7_2(y4)

        # 第三次上采样,需要"Copy and crop"(复制并裁剪)
        up3 = self.up_conv_3(y4)
        # 需要对x4进行裁剪,从中心往外裁剪
        crop3 = self.crop_tensor(x4, up3)
        up_3 = torch.cat([crop3, up3], dim=1)

        y5 = self.conv8_1(up_3)
        y5 = self.relu8_1(y5)
        y6 = self.conv8_2(y5)
        y6 = self.relu8_2(y6)

        # 第四次上采样,需要"Copy and crop"(复制并裁剪)
        up4 = self.up_conv_4(y6)
        # 需要对x2进行裁剪,从中心往外裁剪
        crop4 = self.crop_tensor(x2, up4)
        up_4 = torch.cat([crop4, up4], dim=1)

        y7 = self.conv9_1(up_4)
        y7 = self.relu9_1(y7)
        y8 = self.conv9_2(y7)
        y8 = self.relu9_2(y8)

        # 最后的conv1*1
        out = self.conv_10(y8)
        return out
if __name__ == '__main__':
    input_data = torch.randn([1, 1, 572, 572])
    unet = Unet()
    output = unet(input_data)
    print(output.shape)
    # torch.Size([1, 2, 388, 388])

这段代码包括空行在内,一共写了160行++,因为我是初学者,我懂初学者的痛。

网上的代码是经过封装的,因为Unet网络中涉及到很多重复的操作,大家为了简便代码,都通过定义相同操作的类,通过调用,从而减少代码量,使得代码看起来简短一些。但是这就不利于初学者去学习了,因为一般,大家都不喜欢嵌套,跳来跳去会容易晕(除非你一步一步debug)。

正是由于这一点,我才花了很多时间,来写这个博客,想着从初学者的角度,如何逐行编写网络结构。

后续也会打算从这篇文章开始,把重复操作的步骤,通过定义成类,进行调用,从而使得代码简短一些,也更加符合大佬们的写法。

我相信,经过这两天,VGG16+Unet网络架构从零到一的编写,大家的能力会得到很大的提升。

本文如有错误或者不合理的地方,请指出(笔者没有去看过原作者的paper,是根据Unet网络架构图和网上的资料进行编写的),谢谢!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1616026.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【bug】使用mmsegmentaion遇到的问题

利用mmsegmentaion跑自定义数据集时的bug处理(使用bisenetV2) 1. ValueError: val_dataloader, val_cfg, and val_evaluator should be either all None or not None, but got val_dataloader{batch_size: 1, num_workers: 4}, val_cfg{type: ValLoop}, …

Mysql 在Windows Server系统下修改数据文件存储路径遇到的坑

因项目需要搭建一个Mysql数据库,为了方便日常运维操作开始选择了Windows Server 2012R2(已有的虚拟机),考滤到要300G空间,原来的盘空间不够了,就是给虚拟机加了磁盘,Mysql 8.0.26社区版安装路径没得选择,默认就装在C&a…

OpenHarmony开发实例:【 待办事项TodoList】

简介 TodoList应用是基于OpenHarmony SDK开发的安装在润和HiSpark Taurus AI Camera(Hi3516d)开发板标准系统上的应用;应用主要功能是以列表的形式,展示需要完成的日程;通过本demo可以学习到 JS UI 框架List使用; 运行效果 样例…

记一次 Vscode + Latex 正向/反向搜索忽然失效

遥望大半个月前,完成论文撰写后,这些天虽然多次打开项目,但我真的一个字都没动过,今天想着开始着手修改一下,打开项目发现正向/反向搜索忽然失效了,感觉浑身有蚂蚁在爬,思索再三后找到问题&…

端口被占用的解决方案汇总

端口被占用的解决方案汇总 【一】windows系统端口被占用【二】Linux系统端口被占用【三】Linux的ps命令查找(1)ps命令常用的方式有三种(2)ps -ef |grep 8080 【一】windows系统端口被占用 (1)键盘上按住Wi…

javaEE初阶——多线程(八)——常见的锁策略 以及 CAS机制

T04BF 👋专栏: 算法|JAVA|MySQL|C语言 🫵 小比特 大梦想 此篇文章与大家分享分治算法关于多线程进阶的章节——关于常见的锁策略以及CAS机制 如果有不足的或者错误的请您指出! 多线程进阶 1.常见的锁策略 我们需要了解的是,我们使用是锁,在加锁 / 解锁…

【软考】UML中的关系

目录 1. 说明2. 依赖3. 关联4. 泛化5. 实现 1. 说明 1.UML中有4种关系:依赖、关联、泛化和实现2.这 4种关系是 UML,模型中可以包含的基本关系事物。它们也有变体,例如,依赖的变体有精化、跟踪、包含和延伸 2. 依赖 1.依赖(Dependency)。2.…

文心一言 VS 讯飞星火 VS chatgpt (242)-- 算法导论17.4 1题

一、假定我们希望实现一个动态的开地址散列表。为什么我们需要当装载因子达到一个严格小于 1 的值 a 时就认为表满?简要描述如何为动态开地址散列表设计一个插入算法,使得每个插入操作的摊还代价的期望值为 O(1) 。为什么每个插入操作的实际代价的期望值…

Excel如何计算时间差

HOUR(B1-A1)&"小时 "&MINUTE(B1-A1)&"分钟 "&SECOND(B1-A1)&"秒"

下级平台级联EasyCVR视频汇聚安防监控平台后,设备显示层级并存在重复的原因排查和解决

视频汇聚平台/视频监控系统/国标GB28181协议EasyCVR安防平台可以提供实时远程视频监控、视频录像、录像回放与存储、告警、语音对讲、云台控制、平台级联、磁盘阵列存储、视频集中存储、云存储等丰富的视频能力,平台支持7*24小时实时高清视频监控,能同时…

机器学习-10-基于paddle实现神经网络

文章目录 总结参考本门课程的目标机器学习定义第一步:数据准备第二步:定义网络第三步:训练网络第四步:测试训练好的网络 总结 本系列是机器学习课程的系列课程,主要介绍基于paddle实现神经网络。 参考 MNIST 训练_副…

深入剖析机器学习领域的璀璨明珠——支持向量机算法

在机器学习的广袤星空中,支持向量机(Support Vector Machine,简称SVM)无疑是一颗璀璨的明珠。它以其独特的分类能力和强大的泛化性能,在数据分类、模式识别、回归分析等领域大放异彩。本文将详细剖析SVM算法的原理、特…

MLLM | InternLM-XComposer2-4KHD: 支持336 像素到 4K 高清的分辨率的大视觉语言模型

上海AI Lab,香港中文大学等 论文标题:InternLM-XComposer2-4KHD: A Pioneering Large Vision-Language Model Handling Resolutions from 336 Pixels to 4K HD 论文地址:https://arxiv.org/abs/2404.06512 Code and models are publicly available at https://gi…

互联网扭蛋机小程序:打破传统扭蛋机的局限,提高销量

扭蛋机作为一种适合全年龄层的娱乐消费方式,深受人们的喜欢,通过一个具有神秘性的商品给大家带来欢乐。近几年,扭蛋机在我国的发展非常迅速,市场规模在不断上升。 经过市场的发展,淘宝线上扭蛋机小程序开始流行起来。…

一文讲透彻Redis 持久化

文章目录 ⛄1.RDB持久化🪂🪂1.1.执行时机🪂🪂1.2.RDB原理🪂🪂1.3.小结 ⛄2.AOF持久化🪂🪂2.1.AOF原理🪂🪂2.2.AOF配置🪂🪂2.3.AOF文件…

40+ Node.js 常见面试问题 [2024]

今天就开始你的Node.js生涯。在这里,我们探讨了最佳Node.js面试问题和答案,以帮助应届生和经验丰富的候选人获得理想的工作。 Node.js 是许多大公司技术堆栈的重要组成部分,例如 PayPal、Trello、沃尔玛和 NASA。 根据 ZipRecruiter 的数据&…

了解边缘计算,在制造行业使用边缘计算。

边缘计算是一种工业元宇宙技术,可以帮助组织实现其数据的全部潜力。 处理公司的所有数据可能具有挑战性,而边缘计算可以帮助公司更快地处理数据。在制造业中,边缘计算可以帮助进行预测性维护和自动驾驶汽车操作等工作。 什么是边缘计算? …

ruoyi-cloud-plus添加一个不要认证的公开新页面

文章目录 一、前端1. 组件创建2. src/router/index.ts3. src/permission.ts 二、后端1. 设计思想2. ruoyi-gateway.yml3. 开发Controller 版本RuoYiCloudPlusv2.1.2plus-uiVue3 ts 以新增一个公开的课程搜索页面为例。 一、前端 1. 组件创建 在view目录下创建一个页面的vue…

python--使用pika库操作rabbitmq实现需求

Author: wencoo Blog:https://wencoo.blog.csdn.net/ Date: 22/04/2024 Email: jianwen056aliyun.com Wechat:wencoo824 QQ:1419440391 Details:文章目录 目录正文 或 背景pika链接mqpika指定消费数量pika自动消费实现pika获取队列任务数量pi…

去哪儿网开源的一个对应用透明,无侵入的Java应用诊断工具

今天 V 哥给大家带来一款开源工具Bistoury,Bistoury 是去哪儿网开源的一个对应用透明,无侵入的java应用诊断工具,用于提升开发人员的诊断效率和能力。 Bistoury 的目标是一站式java应用诊断解决方案,让开发人员无需登录机器或修改…