UNet入门总结

news2024/12/23 23:38:04

作者:AI浩 来源:投稿
编辑:学姐

Unet已经是非常老的分割模型了,是2015年《U-Net: Convolutional Networks for Biomedical Image Segmentation》提出的模型。

论文连接:https://arxiv.org/abs/1505.04597

在Unet之前,则是更老的FCN网络,FCN是Fully Convolutional Netowkrs的缩写,确立分割网络的基础框架 ,不过FCN网络的准确度较低,没有Unet好用。

Unet网络非常的简单,前半部分就是特征提取,后半部分是上采样。在一些文献中把这种结构叫做编码器-解码器结构,由于网络的整体结构是一个大些的英文字母U,所以叫做U-net。

网络结构如下图:

  • Encoder:左半部分,由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块(后面代码可以看出);

  • Decoder:有半部分,由一个上采样的卷积层(去卷积层)+特征拼接concat+两个3x3的卷积层(ReLU)反复构成(代码中可以看出来);这种通过通道数的拼接,可以得到更多的特征,同时也会消耗更多的显存。

UNet的结构设计,使其能够结合底层特征和高层特征的信息。

底层(深层)特征:

经过多次下采样后的低分辨率信息。能够提供分割目标在整个图像中上下文语义信息,可理解为反应目标和它的环境之间关系的特征。这个特征有助于物体的类别判断(所以分类问题通常只需要低分辨率/深层信息,不涉及多尺度融合)

高层(浅层)特征:

经过concatenate操作从encoder直接传递到同高度decoder上的高分辨率信息。能够为分割提供更加精细的特征,如梯度等。

关于size对不上的原因:

从图上可以看到左侧和右侧的尺寸是对不上的,所以如果要对上,就要经历了裁切,所有灰色箭头的解释是copy and crop,不过复现的模型都没有采用这样的思路,都是将左侧和右侧的尺寸设置成一样的,而且每次卷积都加入了padding,这样经过卷积后尺寸不会改变。

不足之处:

  1. 该网络运行效率很慢。对于每个邻域,网络都要运行一次,且对于邻域重叠部分,网络会进行重复运算。

  2. 网络需要在精确的定位和获取上下文信息之间进行权衡。越大的patch需要越多的最大[池化层],其会降低定位的精确度,而小的邻域使得网络获取较少的上下文信息。

UNet 代码(pytorch)

import torch.nn as nn
import torch
from torch import autograd
 
#把常用的2个卷积操作简单封装下
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch), #添加了BN层
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
 
    def forward(self, input):
        return self.conv(input)
 
class Unet(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Unet, self).__init__()
        self.conv1 = DoubleConv(in_ch, 64)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(2)
        self.conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(2)
        self.conv5 = DoubleConv(512, 1024)
        # 逆卷积,也可以使用上采样(保证k=stride,stride即上采样倍数)
        self.up6 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
        self.conv6 = DoubleConv(1024, 512)
        self.up7 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.conv7 = DoubleConv(512, 256)
        self.up8 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv8 = DoubleConv(256, 128)
        self.up9 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv9 = DoubleConv(128, 64)
        self.conv10 = nn.Conv2d(64, out_ch, 1)
 
    def forward(self, x):
        c1 = self.conv1(x)
        p1 = self.pool1(c1)
        c2 = self.conv2(p1)
        p2 = self.pool2(c2)
        c3 = self.conv3(p2)
        p3 = self.pool3(c3)
        c4 = self.conv4(p3)
        p4 = self.pool4(c4)
        c5 = self.conv5(p4)
        up_6 = self.up6(c5)
        merge6 = torch.cat([up_6, c4], dim=1)
        c6 = self.conv6(merge6)
        up_7 = self.up7(c6)
        merge7 = torch.cat([up_7, c3], dim=1)
        c7 = self.conv7(merge7)
        up_8 = self.up8(c7)
        merge8 = torch.cat([up_8, c2], dim=1)
        c8 = self.conv8(merge8)
        up_9 = self.up9(c8)
        merge9 = torch.cat([up_9, c1], dim=1)
        c9 = self.conv9(merge9)
        c10 = self.conv10(c9)
        out = nn.Sigmoid()(c10)
        return out

Unet代码(Keras)

def unet(pretrained_weights=None, input_size=(256, 256, 3)):
    inputs = Input(input_size)

    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(inputs)
    conv1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool2)
    conv3 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool3)
    conv4 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv4)
    drop4 = Dropout(0.5)(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(pool4)
    conv5 = Conv2D(1024, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv5)
    drop5 = Dropout(0.5)(conv5)

    up6 = Conv2D(512, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(drop5))
    merge6 = concatenate([drop4, up6], axis=3)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge6)
    conv6 = Conv2D(512, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv6)

    up7 = Conv2D(256, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv6))
    merge7 = concatenate([conv3, up7], axis=3)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge7)
    conv7 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv7)

    up8 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv7))
    merge8 = concatenate([conv2, up8], axis=3)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge8)
    conv8 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv8)

    up9 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer='he_normal')(
        UpSampling2D(size=(2, 2))(conv8))
    merge9 = concatenate([conv1, up9], axis=3)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(merge9)
    conv9 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv9 = Conv2D(2, 3, activation='relu', padding='same', kernel_initializer='he_normal')(conv9)
    conv10 = Conv2D(1, 1, activation='sigmoid')(conv9)

    model = Model(inputs=inputs, outputs=conv10)
    model.summary()

    if (pretrained_weights):
        model.load_weights(pretrained_weights)

    return model

图像分割论文资料🚀🚀🚀

关注下方《学姐带你玩AI》免费领取

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/148601.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android 深入系统完全讲解(4)

4 SystemServer 创建过程 SystemServer 进程非常关键了,我们上层的服务都是在这里以线程的形式存在,比如 AMS,PMS,WindowManagerService,壁纸服务,而关于调试这个服务进程,我们随后就会讲到。 …

虚拟人-面部表情-Audio2Face语音驱动表情

任务: 输入自己的音频,导入maya模型,让maya模型通过音频驱动说话 教程: https://www.bilibili.com/video/BV1rZ4y1R7H4/?p2&spm_id_frompageDriver&vd_sourceef114f70c3fd4d5394f12dbd3d022bbe 一.下载和安装 1.首先…

Java面试常见问题-SE篇

JavaSE面试问题汇总①int和Integer的区别为什么设计封装类型?JDK、JRE、JVM的区别和equals方法的区别hashCode()与equals()之间的关系泛型中extends和super的区别String、StringBuffer、StringBuilder的区别重载和重写的区别接口和抽象类的区别List与Set的区别Array…

2023/1/8总结

今天学了了强连通算法 Tarjan算法 Tarjan算法是一种求解有向图强连通分量的线性时间的算法,他运用到了DFS算法以及DFS的特性和数据结构——栈。 算法介绍:如果两个顶点可以相互通达,则称两个顶点强连通(strongly connected)。如果有向图G…

LeetCode题解 二叉树(十三):701 二叉搜索树的插入操作;450 删除二叉搜索树中的结点

701 二叉搜索树的插入操作 medium 给定二叉搜索树(BST)的根节点和要插入树中的值,将值插入二叉搜索树。 返回插入后二叉搜索树的根节点。 输入数据保证,新值和原始二叉搜索树中的任意节点值都不同。 如果要按照题目中所说改变二叉…

渗透测试中的常用编码

渗透测试中的常用编码 页面编码 在网页设置网页编码 在中加入设置特定html标签 这样页面的编码就会变成utf-8 ,如果没有设置编码就会使用默认 的编码,而浏览器默认编码与之不同就会出现乱码。 常用的有三种格式分别是 utf-8、gbk、gbk2312 ascii编码…

_Linux 进程信号-信号处理篇

文章目录前言捕捉信号1. 内核如何实现信号的捕捉2. sigaction代码验证可重入函数volatile(关键字)SIGCHLD信号实验一实验二前言 信号发送 信号处理 已经讲过,本章讲解信号处理最后一部分。 捕捉信号 信号捕捉过程图 经过信号捕捉过程图:我们知道信号…

语音文件分析

语音文件格式的重要参数 语音波形,它的这个文件,主要的格式就是采样率,那么电话或者嵌入式设备,采样率一般是8000Hz,就一秒钟8000个点,如果是PC机麦克风,那是16K,就一秒钟是16000个点,像这个CD是高保真的,音乐唱片的是用这个44.1K。 量化位数,又叫采样精度,…

绿通科技在创业板通过注册:收入依赖美国市场,张志江为实控人

2023年1月6日,证监会发布关于同意广东绿通新能源电动车科技股份有限公司(下称“绿通科技”或“绿通电动车”)、杭州国泰环保科技股份有限公司首次公开发行股票注册的批复。换句话说,证监会同意上述两家公司的创业板IPO注册。 同日…

【手写 Vue2.x 源码】第十一篇 - Vue 的数据渲染流程

一,前言 上篇,主要介绍了数组数据变化的观测情况,涉及以下几个点: 实现了数组数据变化被劫持后,已重写原型方法的具体逻辑;数组各种数据变化时的观测情况分析; 目前为止,数据劫持…

webpack 如何编写 loader

三种本地开发测试 loader 的方法 1. 匹配(test)单个 loader 你可以通过在 rule 对象使用 path.resolve 指定一个本地文件: webpack.config.js const path require(path);module.exports {//...module: {rules: [{test: /\.js$/,use: [{…

Ansible Automation Platform - 在 RHEL 安装 Ansible Automation Platform 2.3 环境

《OpenShift / RHEL / DevSecOps 汇总目录》 文本已在 RHEL 9 AAP 2.3 环境中进行验证。 说明: 本文介绍如何在一个节点上部署一套 all-in-one 的 Ansible Automation Platform 2.3 的运行环境。红帽 Ansible Automation Platform 2.3 需要至少 RHEL 8.4 以上的环…

【EHub_tx1_tx2_E100】Ubuntu18.04 + ROS_ Melodic + Velodyne VLP-16雷达 测试使用

简介:介绍 Velodyne VLP-16 16线激光雷达 在EHub_tx1_tx2_E100载板,TX1核心模块环境(Ubuntu18.04)下测试ROS驱动,打开使用RVIZ 查看点云数据,本文的前提条件是你的TX1里已经安装了ROS版本:Melod…

HashMap、HashTable、ConcurrentHashMap之间的区别及常见面试题

Java集合类有的集合类是存在线程安全的问题,但是由于之前对于集合类的使用都是在单线程的情况下使用的,不没有在多线程环境下使用,所以不涉及线程安全的问题;这篇博客着重讲解一下多线程环境下使用哈希表。HashMapHashMap本身不是…

一些开发时常用的网站或命令

目录关于gitgit下载网址git安装教程Gortoisegit下载地址关于PythonAnyWhere关于Linux压缩与解压命令关于python的相对与绝对路径使用_\_file_\_实现跨平台关于宝塔面板关于浏览器驱动下载本博客首次编辑于2023.01.04 ,后续将持续进行更新 关于git git下载网址 gi…

Linux - 系统文件目录说明

目录/ - 根目录/bin - 用户基础二进制文件目录/boot - 静态启动文件/dev - 设备文件/etc - 配置文件/home - 主目录/lib - 基础共享库/lib64 - 64位基础共享库/lostfound - 可恢复的文件/media - 可移动媒体/mnt - 临时挂载点目录/opt - 自选软件包/proc - 内核 & 进程文件…

【Node】事件循环机制

Node 中的异步 API 定时器:setTimeout、setIntervalI/O 操作:文件读写、数据库操作、网络请求…Node 独有的 API:process.nextTick、setImmediate 事件循环的流程 Node 的事件循环分为 6 个阶段,这 6 个阶段会按顺序反复运行运行…

高并发内存池项目

文章目录一、项目介绍二、什么是内存池2.1 池化技术2.2 内存池2.3 内存池的作用2.4 malloc三、设计定长内存池四、高并发内存池整体框架设计六、threadcache6.1 threadcache整体设计6.2 threadcache哈希桶映射对齐规则6.3 编写对齐和映射的相关函数6.4 编写ThreadCache类6.5 th…

电网头条知识竞赛题库答案(自动答题)

今天教你们自动完成2023年电网头条的知识竞赛,小编也为大家安排好了教程,首先呢需要知道电网助手,打开电网助手网页https://wwwl.lanzouw.com/b01w803yj 为了帮到大家,我特地分享出来,希望能给大家带来一丝丝便利&…

1.3第二周 星期二Samba、FTP

目录 01 Samba文件共享服务 Samba服务基础 2.主配置文件 02 linux文件传输服务 1.用户访问的Samba 03 FTP服务概述 1.vsftp知识预备 04 操作流程: 1.使用时记得装ftp的包:yum install ftp -y 2.装完之后启动服务&am…