论文学习——U-Net: Convolutional Networks for Biomedical Image Segmentation

news2024/9/21 19:53:38

UNet的特点

  • 采用端到端的结构,通过FCN(最后一层仍然是通过卷积完成),最后输出图像。
  • 通过编码(下采样)-解码(上采样)形成一个“U”型结构。每次下采样时,先进行两次卷积(通道数不变),然后通过一次池化层(也可以通过卷积)处理(长宽减半,通道数加倍);在每次上采样时,同样先进行两次卷积操作,再通过反卷积函数进行上采样(长宽加倍,通道不变),然后与编码过程中对应层进行拼接(通道加倍)。到最后一层时,通过1x1的卷积核修改通道数,最后输出目标图像。编码操作逐层提取图像特征,解码操作则逐层恢复图像信息。
  • 通过跳跃连接,将编码器结构中的底层信息与解码器结构中的高层信息融合,从而提高了分割精度。

网络结构如图所示:
在这里插入图片描述
代码实现(基于pytorch):
相关包的引入:

from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F

定义卷积块:
定义了两个卷积操作,分别使用大小为3x3的卷积核进行卷积,步长为1,并且对卷积后的输出进行批量归一化(批量归一化的作用),激活函数采用ReLU。使用卷积模块时,需要指明输入通道数(in_channel)和输出通道数(out_channel)。

class Conv_Block(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(Conv_Block, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channel, out_channel, 3, 1, 1),
            nn.BatchNorm2d(out_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, input):
        outputs = self.conv1(input)
        outputs = self.conv2(outputs)
        return outputs

编码操作(下采样):将卷积模块的输出进行池化处理。

class UnetDown(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UnetDown, self).__init__()
        self.conv = Conv_Block(in_channel, out_channel)
        self.down = nn.MaxPool2d(2, 2, ceil_mode=True)

    def forward(self, inputs):
        outputs = self.conv(inputs)
        outputs = self.down(outputs)
        return outputs

解码操作(上采样):这里的上采样操作提出了两种——ConvTranspose2d和UpsamplingBilinear2d,两者的区别见这里。另外,由于要进行拼接操作,所以在拼接前对上采样的输出进行填充,避免拼接出错。
(ps:代码里面的解码操作是先进行上采样,然后拼接数据,最后进行卷积的,但是在UnetModel中的最后一个编码操作后,单独进行了一次卷积操作,最后的网络结构还是没有变的。)

class UnetUp(nn.Module):
    def __init__(self, in_channel, out_channel, is_deconv=True):
        super(UnetUp, self).__init__()
        self.conv = Conv_Block(in_channel, out_channel)

        if is_deconv:
            self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)
        else:
            self.up = nn.UpsamplingBilinear2d(scale_factor=2)

    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset1 = (outputs2.size()[2] - inputs1.size()[2])
        offset2 = (outputs2.size()[3] - inputs1.size()[1])
        # pad传入四个元素时,指的是左填充,右填充,上填充,下填充;前两个元素作用在第一四维,后两个元素作用在第三维
        padding = [offset2 // 2, (offset2 + 1) // 2, offset1 // 2, (offset1 + 1) // 2]
        # Skip and concatenate
        outputs1 = F.pad(inputs1, padding)
        return self.conv(torch.cat([outputs1, outputs2], 1))

最后定义整个UNet模块:将代码和网络结构的图结合起来看就很容易理解了。

class UnetModel(nn.Module):
    def __init__(self, n_classes, in_channels, is_deconv):
        super(UnetModel, self).__init__()
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.n_classes = n_classes

        filters = [64, 128, 256, 512, 1024]

        self.down1 = UnetDown(self.in_channels, filters[0])
        self.down2 = UnetDown(filters[0], filters[1])
        self.down3 = UnetDown(filters[1], filters[2])
        self.down4 = UnetDown(filters[2], filters[3])
        self.center = Conv_Block(filters[3], filters[4])
        self.up4 = UnetUp(filters[4], filters[3], self.is_deconv)
        self.up3 = UnetUp(filters[3], filters[2], self.is_deconv)
        self.up2 = UnetUp(filters[2], filters[1], self.is_deconv)
        self.up1 = UnetUp(filters[1], filters[0], self.is_deconv)
        self.final = nn.Conv2d(filters[0], self.n_classes, 1)

    def forward(self, inputs, label_dsp_dim):
        down1 = self.down1(inputs)
        down2 = self.down2(down1)
        down3 = self.down3(down2)
        down4 = self.down4(down3)
        center = self.center(down4)
        up4 = self.up1(down4, center)
        up3 = self.up2(down3, up4)
        up2 = self.up3(down2, up3)
        up1 = self.up4(down1, up2)
        up1 = up1[:, :, 1:1 + label_dsp_dim[0], 1:1 + label_dsp_dim[1]].contiguous()
        return self.final(up1)

    # Initialization of parameters
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.ConvTranspose2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()

总结:UNet 是一种经典的图像分割网络,它通过编码器-解码器结构、跳跃连接和多尺度特征融合等设计,能够在图像分割任务中取得优秀的性能。基于UNet还衍生出了很多网络,例如 U-Net++, ResUNet, Dense U-Net等,接下来就学习它的衍生网络吧,学习大佬是怎么魔改网络的~另外,刚开始写深度学习的代码时,我不知道从何下手,通过学习大佬实现代码的过程,我发现结合两点就能轻松实现代码:1)写代码时结合网络结构的图片,2)百度相关操作的函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/729142.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

3D格式转化工具HOOPS Exchange的功能特点

​领先的CAD导入和导出库 使用用于CAD数据转换的HOOPS Exchange SDK将30多种CAD文件格式导入到您的应用程序中,具有对2D和3D CAD文件格式(包括 CATIA、SOLIDWORKS、Inventor™、Revit™、Creo、 NX™、Solid Edge 等,全部通过单个API实现。 …

aop拦截所有请求并打印请求方法参数

效果图 代码 package com.hxnwm.ny.diancan.common.aspect;import lombok.extern.slf4j.Slf4j; import org.aspectj.lang.JoinPoint; import org.aspectj.lang.annotation.AfterReturning; import org.aspectj.lang.annotation.AfterThrowing; import org.aspectj.lang.annota…

redis的主从复制、哨兵模式和集群模式搭建

概述 redis主从复制搭建 主从复制的作用 主从复制流程 搭建Redis 主从复制 修改 Redis 配置文件(Master节点操作) 修改 Redis 配置文件(Slave节点操作) 验证主从效果 ​编辑 Redis 哨兵模式 哨兵模式的作用 故障转移机制 …

idea切换Git分支时保存未提交的文件

问题描述 现在需要开发一个新功能A时,我们需要从Dev分支上创建一个新的功能分支tenant,然后我们就在这个分支上进行开发。假设有一天,你正在开发,本地已经在tenant上修改了几个文件,但是功能还没有完全开发完成&#…

游戏反调试方案解析与frida/IDA框架分析

近来年,游戏黑灰产攻击趋势呈现出角度多样化的特点。据FairGuard游戏安全数据分析发现,游戏黑灰产攻击以工作室、定制注入挂、内存修改器、模拟点击、破解等形式为主。 游戏安全风险分布占比图 对于一款游戏而言,上述的风险中,被…

Kubernetes(k8s)实战:深入详解Volume,详解k8s文件同步存储

文章目录 一、Volume1、什么是Volume2、Host类型volume实战(不推荐)(1)小总结 二、PersistentVolume持久化volume(推荐)1、什么是PersistentVolume2、什么是PersistentVolumeClaim3、Pod中如何使用PVC 三、…

初中学物理实验室教学装备配置标准

初中物理实验室建设,以现行义务教育物理教科书为基本参照,以学生学科核心素养发展为基本遵循,以加强实验等实践性教学活动,落实立德树人根本任务为目标。因此,初中物理实验室的建设实施过程中,需结合校情、…

PCB封装设计指导(三)如何创建PAD

PCB封装设计指导(三)如何创建PAD 当我们完全看完了Datasheet之后,确定了需要建立的封装类型,以及尺寸之后,这一步就可以开始创建封装的PAD了。 下面介绍如何创建各种类型的PAD和一些技巧和注意点,包括创建flash的注意事项。 1.如何创建PAD 1. 打开pad designer ,根据尺…

STL源码刨析 string实现

目录 一. string 类介绍 二. string 的简单实现 1. 类内成员变量 2. Member functions string ~string operator string(const string& str) 3. Capacity size capacity empty clear reserve resize 4.Modifiers push_back append operator insert era…

【回溯算法part05】| 491.递增子序列、46.全排列、47.全排列||

目录 🎈LeetCode491.递增子序列 🎈LeetCode46.全排列 🎈LeetCode47.全排列|| 🎈LeetCode491.递增子序列 链接:491.递增子序列 给你一个整数数组 nums ,找出并返回所有该数组中不同的递增子序列&#xf…

【玩转Linux操作】详细讲解Linux的 权限 操作

🎊专栏【​​​​​​​玩转Linux操作】 🍔喜欢的诗句:更喜岷山千里雪 三军过后尽开颜。 🎆音乐分享【Love Story】 🥰欢迎并且感谢大家指出小吉的问题🥰 文章目录 🍔权限的基本介绍⭐具体分析&…

【复习1-2天的内容】【我们一起60天准备考研算法面试(大全)-第六天 6/60】

专注 效率 记忆 预习 笔记 复习 做题 欢迎观看我的博客,如有问题交流,欢迎评论区留言,一定尽快回复!(大家可以去看我的专栏,是所有文章的目录)   文章字体风格: 红色文字表示&#…

网络编程---day4

广播发送方: 广播接收方: 组播发射方: 组播接收方:

OpenCV的remap实现图像垂直翻转

以下是完整的代码: #include <opencv2/highgui/highgui.hpp> #include <opencv2/imgproc/imgproc.hpp> #include <iostream>int main() {

从“存算一体”到“存算分离”:金融核心数据库改造的必经之路

科技云报道原创。 近年来&#xff0c;数据库国产化趋势愈发明显&#xff0c;上百家金融业试点单位在数据库国产化的进程中&#xff0c;进一步增强信心&#xff0c;向50%国产化率大步迈进。 但随着数据库国产化的深入&#xff0c;一些金融机构采用国产数据库服务器本地盘的“存…

【微信小程序创作之路】- 小程序中WXML、JS、JSON、WXSS作用

【微信小程序创作之路】- 小程序中WXML、JS、JSON、WXSS作用 第三章 微信小程序WXML、JS、JSON、WXSS作用 文章目录 【微信小程序创作之路】- 小程序中WXML、JS、JSON、WXSS作用前言一、WXML是什么&#xff1f;二、JS是什么&#xff1f;三、JSON是什么&#xff1f;四、WXSS是什…

IPC 进程间通讯 (1)

目录 1.1 为什么要通信 1.2 为什么能通信 2.1 进程间通信机制的结构 2.2 进程间通信机制的类型 2.3 进程间通信机制的接口设计 3.1 SysV共享内存 3.2 POSIX共享内存 3.3 共享内存映射 3.4 Android ION 3.5 dma-buf heaps 3.6 匿名管道 3.7 命名管道 3.8 SysV消息队列…

基于Java实验室考勤管理系统设计实现(源码+lw+部署文档+讲解等)

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

buuctf pwn入门1

目录 1. test_your_nc(简单nc ) pwn做题过程 2. rip(简单栈溢出) 3. warmup_csaw_2016(栈溢出 覆盖Return_Address) 4. ciscn_2019_n_1(栈溢出 浮点数十六进制) (1) 覆盖v2值 (2) 利用system("cat /flag"); 5. pwn1_sctf_2016(字符逃逸栈溢出 32位) 6. jarvis…

实现【Linux--NTP 时间同步服务搭建】

实现【Linux--NTP 时间同步服务搭建】 &#x1f53b; 前言&#x1f53b; 一、NTP 校时&#x1f530; 1.1 NTP 服务校时与 ntpdate 校时的区别&#x1f530; 1.2 NTP 校时服务搭建&#x1f530; 1.2.1 确认 ntp 的安装&#x1f530; 1.2.2 配置 ntp 服务&#x1f530; 1.2.3 启动…