【代码分析】Unet-Pytorch

news2025/1/2 22:40:06

1:unet_parts.py

主要包含:

【1】double conv,双层卷积

【2】down,下采样

【3】up,上采样

【4】out conv,输出卷积

""" Parts of the U-Net model """

import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            # // 是整除运算
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

【1】double conv

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

=》卷积。卷积核是3*3,填充是1

=》批归一化。

=》ReLU。激活函数

【2】down

=》最大池化。池化核是2*2

=》double conv。

【3】up

=》上采样。可选择upsample + double conv 和 transpose + double conv

=》计算尺寸差异。

=》填充x1。使得x1和x2对齐

=》拼接x2和x1。按照dim=1,也就是channel通道拼接

=》double conv。

【4】out conv

=》卷积。卷积核是1*1

2:unet_model.py

主要包含:UNet完整架构

""" Full assembly of the parts to form the complete network """

from .unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)

其中,use_checkpointing的作用是丢弃中间计算结果,加快训练速度。

上面的代码可以结合下图分析

前向传播过程:

        x1 = self.inc(x)

通过double conv双层卷积,输入通道为图像自身的,输出通道为64

        x2 = self.down1(x1)

通过down下采样,输入通道为64,输出通道为128

        x3 = self.down2(x2)

通过down下采样,输入通道为128,输出通道为256

        x4 = self.down3(x3)

通过down下采样,输入通道为256,输出通道为512

        x5 = self.down4(x4)

通过down下采样,输入通道为512,输出通道为1024(非bilinear,后续上采样也是如此)

        x = self.up1(x5, x4)

通过up上采样,输入通道为1024,输出通道为512

这个地方concat的对象是x4,也就是下采样输出通道为512的时候的特征

        x = self.up2(x, x3)

通过up上采样,输入通道为512,输出通道为256

这个地方concat的对象是x,也就是原图(后续也是原图)

其实这里和原作者的跳跃连接有点不太一样,代码库的作者直接省事用了原图进行拼接

        x = self.up3(x, x2)

通过up上采样,输入通道为256,输出通道为128

        x = self.up4(x, x1)

通过up上采样,输入通道为128,输出通道为64

        logits = self.outc(x)

通过out conv输出卷积,输入通道为64,输出通道为2,也就是分割为背景和物体2个类别的像素

3:完整代码

可以在github上通过git clone下载

milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images (github.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2268147.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

小程序基础 —— 08 文件和目录结构

文件和目录结构 一个完整的小程序项目由两部分组成:主体文件、页面文件: 主体文件:全局文件,能够作用于整个小程序,影响小程序的每个页面,主体文件必须放到项目的根目录下; 主体文件由三部分组…

Maven 测试和单元测试介绍

一、测试介绍 二、单元测试 1&#xff09;介绍 2&#xff09;快速入门 添加依赖 <dependencies><!-- junit依赖 --><dependency><groupId>org.junit.jupiter</groupId><artifactId>junit-jupiter</artifactId><version>5.9…

大数据技术-Hadoop(一)Hadoop集群的安装与配置

目录 一、准备工作 1、安装jdk&#xff08;每个节点都执行&#xff09; 2、修改主机配置 &#xff08;每个节点都执行&#xff09; 3、配置ssh无密登录 &#xff08;每个节点都执行&#xff09; 二、安装Hadoop&#xff08;每个节点都执行&#xff09; 三、集群启动配置&a…

android sqlite 数据库简单封装示例(java)

sqlite 数据库简单封装示例&#xff0c;使用记事本数据库表进行示例。 首先继承SQLiteOpenHelper 使用sql语句进行创建一张表。 public class noteDBHelper extends SQLiteOpenHelper {public noteDBHelper(Context context, String name, SQLiteDatabase.CursorFactory fact…

【CSS in Depth 2 精译_095】16.3:深入理解 CSS 动画(animation)的性能

当前内容所在位置&#xff08;可进入专栏查看其他译好的章节内容&#xff09; 第五部分 添加动效 ✔️【第 16 章 变换】 ✔️ 16.1 旋转、平移、缩放与倾斜 16.1.1 变换原点的更改16.1.2 多重变换的设置16.1.3 单个变换属性的设置 16.2 变换在动效中的应用 16.2.1 放大图标&am…

yii2 手动添加 phpoffice\phpexcel

1.下载地址&#xff1a;https://github.com/PHPOffice/PHPExcel 2.解压并修改文件名为phpexcel 在yii项目的vendor目录下创建一个文件夹命名为phpoffice 把phpexcel目录放到phpoffic文件夹下 查看vendor\phpoffice\phpexcel目录下会看到这些文件 3.到vendor\composer目录下…

如何通过采购管理系统提升供应链协同效率?

供应链是企业运营的命脉&#xff0c;任何环节的延迟或失误都会对企业造成严重影响。在采购环节中&#xff0c;如何保证与供应商的协同效率&#xff0c;避免因信息不对称而导致的决策失误&#xff0c;是企业面临的一大挑战。采购管理系统作为数字化供应链管理的重要工具&#xf…

yolov5 yolov6 yolov7 yolov8 yolov9目标检测、目标分类 目标切割 性能对比

文章目录 YOLOv1-YOLOv8之间的对比如下表所示&#xff1a;一、YOLO算法的核心思想1. YOLO系列算法的步骤2. Backbone、Neck和Head 二、YOLO系列的算法1.1 模型介绍1.2 网络结构1.3 实现细节1.4 性能表现 2. YOLOv2&#xff08;2016&#xff09;2.1 改进部分2.2 网络结构 3. YOL…

Vue项目如何设置多个静态文件;如何自定义静态文件目录

Vite实现方案 安装插件 npm i vite-plugin-static-copy在vite.config.ts引入 import { viteStaticCopy } from vite-plugin-static-copy配置 plugins: [viteStaticCopy({targets: [{src: "要设置的静态文件目录的相对路径 相对于vite.config.ts的", dest: ./, // …

学习threejs,THREE.RingGeometry 二维平面圆环几何体

&#x1f468;‍⚕️ 主页&#xff1a; gis分享者 &#x1f468;‍⚕️ 感谢各位大佬 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍⚕️ 收录于专栏&#xff1a;threejs gis工程师 文章目录 一、&#x1f340;前言1.1 ☘️THREE.RingGeometry 圆环几…

【畅购商城】详情页模块之评论

目录 接口 分析 后端实现&#xff1a;JavaBean 后端实现 前端实现 接口 GET http://localhost:10010/web-service/comments/spu/2?current1&size2 { "code": 20000, "message": "查询成功", "data": { "impressions&q…

04.HTTPS的实现原理-HTTPS的混合加密流程

04.HTTPS的实现原理-HTTPS的混合加密流程 简介1. 非对称加密与对称加密2. 非对称加密的工作流程3. 对称加密的工作流程4. HTTPS的加密流程总结 简介 主要讲述了HTTPS的加密流程&#xff0c;包括非对称加密和对称加密两个阶段。首先&#xff0c;客户端向服务器发送请求&#xf…

【每日学点鸿蒙知识】数据迁移、大量图片存放、原生自定义键盘调用、APP包安装到测试机、photoPicker顶部高度

1、迁移&#xff08;克隆&#xff09;手机中经过 ArkData &#xff08;方舟数据管理&#xff09;服务持久化后的数据&#xff1f; 在用户手动迁移&#xff08;克隆&#xff09;手机数据至另一台设备后&#xff0c;使用 ArkData &#xff08;方舟数据管理&#xff09;服务持久化…

Centos 7.6 安装mysql 5.7

卸载mysql 之前服务器上一直是mysql8&#xff0c;因为不经常使用&#xff0c;而且8的内存占用还挺高的&#xff0c;所以想降低到5.7&#xff0c;腾出点运行内存 停止服务 # 查询服务的状态 systemctl status mysqld # 停止服务 systemctl stop mysqld随后再次查询状态 查询…

大数据学习之Redis 缓存数据库二,Scala分布式语言一

一.Redis 缓存数据库二 26.Redis数据安全_AOF持久化机制 27.Redis数据安全_企业中该如何选择持久化机制 28.Redis集群_主从复制概念 29.Redis集群_主从复制搭建 30.Redis集群_主从复制原理剖析 31.Redis集群_哨兵监控概述 32.Redis集群_配置哨兵监控 33.Redis集群_哨兵监控原理…

Elasticsearch:使用 Ollama 和 Go 开发 RAG 应用程序

作者&#xff1a;来自 Elastic Gustavo Llermaly 使用 Ollama 通过 Go 创建 RAG 应用程序来利用本地模型。 关于各种开放模型&#xff0c;有很多话要说。其中一些被称为 Mixtral 系列&#xff0c;各种规模都有&#xff0c;而一种可能不太为人所知的是 openbiollm&#xff0c;这…

【日常开发】Git Stash使用技巧

文章目录 引言一、git stash 基础命令&#xff08;一&#xff09;存储当前工作区的修改&#xff08;二&#xff09;查看存储列表 二、查看存储的内容&#xff08;一&#xff09;查看特定存储的详细内容&#xff08;二&#xff09;查看特定存储修改的文件列表 三、恢复存储的修改…

GXUOJ-算法-第三次作业

1.基础练习 Huffman树 问题描述 GXUOJ | 基础练习 Huffman树 代码解析 #include<bits/stdc.h> using namespace std; int main(){int n;cin>>n;priority_queue<int,vector <int>,greater<int> >pq;for(int i0;i<n;i){int value;cin>>…

04-微服务02

我们将黑马商城拆分为5个微服务&#xff1a; 用户服务 商品服务 购物车服务 交易服务 支付服务 由于每个微服务都有不同的地址或端口&#xff0c;相信大家在与前端联调的时候发现了一些问题&#xff1a; 请求不同数据时要访问不同的入口&#xff0c;需要维护多个入口地址…

vue导入导出excel、设置单元格文字颜色、背景色、合并单元格(使用xlsx-js-style库)

npm i xlsx-js-style <template><button click"download">下载 Excel 表格</button><el-table :data"tableData" style"width: 100%"><el-table-column prop"date" label"日期" width"180…