pytorch复现_UNet

news2025/1/12 11:24:35

什么是UNet
U-Net由收缩路径和扩张路径组成。收缩路径是一系列卷积层和汇集层,其中要素地图的分辨率逐渐降低。扩展路径是一系列上采样层和卷积层,其中特征地图的分辨率逐渐增加。
在扩展路径中的每一步,来自收缩路径的对应特征地图与当前特征地图级联。
在这里插入图片描述
主干结构解析
左边为特征提取网络(编码器),右边为特征融合网络(解码器)

高分辨率—编码—低分辨率—解码—高分辨率

特征提取网络
高分辨率—编码—低分辨率

前半部分是编码, 它的作用是特征提取(获取局部特征,并做图片级分类),得到抽象语义特征

由两个3x3的卷积层(RELU)再加上一个2x2的maxpooling层组成一个下采样的模块,一共经过4次这样的操作

特征融合网络
低分辨率—解码—高分辨率

利用前面编码的抽象特征来恢复到原图尺寸的过程, 最终得到分割结果(掩码图片)

代码:

import torch.nn as nn
import torch

# 编码器(论文中称之为收缩路径)的基本单元
def contracting_block(in_channels, out_channels):
    block = torch.nn.Sequential(
        # 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels, out_channels=out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )
    return block


# 解码器(论文中称之为扩张路径)的基本单元
class expansive_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super(expansive_block, self).__init__()

        # 每进行一次反卷积,通道数减半,尺寸扩大2倍
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=(3, 3), stride=2, padding=1,
                                     output_padding=1)
        self.block = nn.Sequential(
            # 这里的卷积操作没有使用padding,所以每次卷积后图像的尺寸都会减少2个像素大小
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=mid_channels),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(),
            nn.Conv2d(kernel_size=(3, 3), in_channels=mid_channels, out_channels=out_channels),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, e, d):
        d = self.up(d)
        # concat
        # e是来自编码器部分的特征图,d是来自解码器部分的特征图,它们的形状都是[B,C,H,W]
        diffY = e.size()[2] - d.size()[2]
        diffX = e.size()[3] - d.size()[3]
        # 裁剪时,先计算e与d在高和宽方向的差距diffY和diffX,然后对e高方向进行裁剪,具体方法是两边分别裁剪diffY的一半,
        # 最后对e宽方向进行裁剪,具体方法是两边分别裁剪diffX的一半,
        # 具体的裁剪过程见下图一
        e = e[:, :, diffY // 2:e.size()[2] - diffY // 2, diffX // 2:e.size()[3] - diffX // 2]
        cat = torch.cat([e, d], dim=1)  # 在特征通道上进行拼接
        out = self.block(cat)
        return out


# 最后的输出卷积层
def final_block(in_channels, out_channels):
    block = nn.Conv2d(kernel_size=(1, 1), in_channels=in_channels, out_channels=out_channels)
    return block


class UNet(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(UNet, self).__init__()

        # 编码器 (Encode)
        self.conv_encode1 = contracting_block(in_channels=in_channel, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = contracting_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = contracting_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = contracting_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 编码器与解码器之间的过渡部分(Bottleneck)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=1024),
            nn.BatchNorm2d(1024),
            nn.ReLU(),
            nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        # 解码器(Decode)
        self.conv_decode4 = expansive_block(1024, 512, 512)
        self.conv_decode3 = expansive_block(512, 256, 256)
        self.conv_decode2 = expansive_block(256, 128, 128)
        self.conv_decode1 = expansive_block(128, 64, 64)

        self.final_layer = final_block(64, out_channel)

    def forward(self, x):
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_pool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_pool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_pool3(encode_block3)
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_pool4(encode_block4)

        # Bottleneck
        bottleneck = self.bottleneck(encode_pool4)

        # Decode
        decode_block4 = self.conv_decode4(encode_block4, bottleneck)
        decode_block3 = self.conv_decode3(encode_block3, decode_block4)
        decode_block2 = self.conv_decode2(encode_block2, decode_block3)
        decode_block1 = self.conv_decode1(encode_block1, decode_block2)

        final_layer = self.final_layer(decode_block1)
        return final_layer


if __name__ == '__main__':
    image = torch.rand((1, 3, 572, 572))
    unet = UNet(in_channel=3, out_channel=2)
    mask = unet(image)
    print(mask.shape)
    
    #输出结果:
    torch.Size([1, 2, 388, 388])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1178839.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

什么是分治算法?

分治算法(divide and conquer algorithm)是指把大问题分割成多个小问 题,然后把每个小问题分割成多个更小的问题,直到问题的规模小到能够 轻易解决。这种算法很适合用递归实现,因为把问题分割成多个与自身相 似的小问题正对应递归情况&#x…

Java —— 类和对象(一)

目录 1. 面向对象的初步认知 1.1 什么是面向对象 1.2 面向对象与面向过程 2. 类定义和使用 2.1 认识类 2.2 类的定义格式 3. 类的实例化(如何产生对象) 3.1 什么是实例化 3.2 访问对象的成员 3.3 类和对象的说明 4. this引用 4.1 为什么要有this引用 4.2 什么是this引用 4.3 th…

无线发射芯片解决方案在智能家居中的应用

随着物联网的发展,智能家居已经成为一个热门话题。智能家居利用无线技术来实现设备之间的互联互通,提供更智能、更便利的生活体验。无线发射芯片解决方案在智能家居中扮演着关键的角色,它们为智能家居设备之间的通信提供了稳定、高效的连接&a…

stm32f103+HC-SR04+ssd1306实现超声波测距

🙌秋名山码民的主页 😂oi退役选手,Java、大数据、单片机、IoT均有所涉猎,热爱技术,技术无罪 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 获取源码,添加WX 目录 前言HC…

【江协科技-用0.96寸OLED播放知名艺人打篮球视频】

Python进行视频图像处理,通过串口发送给stm32,stm32接收数据,刷新OLED进行显示。 步骤: 1.按照接线图连接好硬件 2.把Keil工程的代码下载到STM32中 3.运行Python代码,通过串口把处理后的数据发送给STM32进行显示 …

Spark 新特性+核心回顾

Spark 新特性核心 本文来自 B站 黑马程序员 - Spark教程 :原地址 1. 掌握Spark的Shuffle流程 1.1 Spark Shuffle Map和Reduce 在Shuffle过程中,提供数据的称之为Map端(Shuffle Write)接收数据的称之为Reduce端(Sh…

Leetcode刷题详解——组合

1. 题目链接:77. 组合 2. 题目描述: 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以按 任何顺序 返回答案。 示例 1: 输入:n 4, k 2 输出: [[2,4],[3,4],[2,3],[1,2],[1,3],[…

vue3拖拽排序——vuedraggable

文章目录 安装代码效果拖拽前拖拽时拖拽后 vue3 的拖拽排序博主用的是 vuedraggable 安装 安装 npm i vuedraggable4.1.0 --save 引用 import Draggable from vuedraggable;代码 html <van-checkbox-group v-model"dataMap.newsActionChecked"><van-cell…

LazyVim: 将 Neovim 升级为完整 IDE | 开源日报 No.67

curl/curl Stars: 31.5k License: NOASSERTION Curl 是一个命令行工具&#xff0c;用于通过 URL 语法传输数据。 核心优势和关键特点包括&#xff1a; 可在命令行中方便地进行数据传输支持多种协议 (HTTP、FTP 等)提供丰富的选项和参数来满足不同需求 kubernetes/ingress-n…

项目中登录验证码怎么做才合理

唠嗑部分 今天我们来聊聊项目实战中登录验证码如何做比较合理&#xff0c;首先我们聊以下几个问题 1、登录时验证码校验是否必要&#xff1f; 答案当然是很有必要的&#xff0c;因为用户登录行为会直接影响数据库&#xff0c;如果没有某些防范措施&#xff0c;有恶意用户暴力…

NOIP2023模拟12联测33 A. 构造

NOIP2023模拟12联测33 A. 构造 文章目录 NOIP2023模拟12联测33 A. 构造题目大意思路code 题目大意 构造题 思路 想一种构造方法&#xff0c;使得 y y y 能够凑成尽可能多的答案 第一行 x y r y ⋯ r xyry \cdots r xyry⋯r 第二行 r y x y ⋯ x ryxy \cdots x ryxy⋯x …

基于SSM的出租车管理系统

基于SSM的出租车管理系统的设计与实现~ 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringSpringMVCMyBatis工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 登录界面 管理员界面 驾驶员界面 摘要 基于SSM&#xff08;Spring、Spring MVC、My…

软考 -- 计算机学习(3)

文章目录 一、软件测试基础1.1 基本概念1.2 软件测试模型1.3 软件测试的分类 二、基于规格说明的测试技术(黑盒)2.1 重要的测试方法1. 等价类划分法2. 边界值法3. 判定表法4. 因果图法 2.2 其他测试方法 三、基于结构的测试技术(白盒)3.1 静态测试3.2 动态测试 一、软件测试基础…

Vue Vuex模块化编码

正常写vuex的index的时候如果数据太多很麻烦&#xff0c;如有的模块是管理用户信息或修改课程等这两个是不同一个种类的&#xff0c;如果代码太多会造成混乱&#xff0c;这时候可以使用模块化管理 原始写法 如果功能模块太多很乱 import Vue from vue import Vuex from vuex …

nodejs卸载和安装教程

一、卸载 1、Win菜单中找到Node.js的卸载程序&#xff0c;运行卸载程序。 3.选择 OK&#xff0c;等待卸载。 4. 删除C:\Users\用户名\AppData\Roaming目录下的npm和npm-cache&#xff1b;删除C:\Users\123\AppData\Local\目录下的npm-cache。 二、安装 傻瓜式安装&#xf…

socket开发步骤及相关API介绍

socket服务器和客户端的开发步骤 TCP服务端&#xff1a; 创建套接字socket为套接字添加信息&#xff08;IP地址和端口号&#xff09;bind监听网络连接listen监听到由客户端接入&#xff0c;接受一个连接accept数据交互read、write关闭套接字&#xff0c;断开连接close TCP客户…

JAVA二叉搜索树(专门用来查找)

目录 二叉搜索树又叫二叉排序树&#xff0c;它具有以下特征 二次搜索树的效率 模拟最简二叉搜索树代码 代码片段分析 查找二叉搜索树数据&#xff1a; 如果我们用递归的方法查找数据有什么不一样? 插入数据 删除数据(难点) 二叉搜索树又叫二叉排序树&#xff0c;它具有以下特征…

python之pyQt5实例:几何绘图界面

使用PyQt5设计一个界面&#xff0c;其中点击不同的按钮可以在画布上画出点、直线、圆和样条曲线 from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton,QHBoxLayout,QVBoxLayout,QWidget,QLabel from PyQt5.QtGui import QPainter, QPen, QColor from PyQt5.Q…

nssm将exe应用封装成windows服务

一、简介 NSSM&#xff08;Non-Sucking Service Manager&#xff09;是一个用于在Windows操作系统上管理和运行应用程序作为服务的工具。它提供了一种简单的方法来将任意可执行文件转换为Windows服务&#xff0c;并提供了一些额外的功能和配置选项。 优点&#xff1a; 简单易…

【遍历二叉树算法描述】

文章目录 遍历二叉树算法描述先序遍历二叉树的操作定义中序遍历二叉树的操作定义后序遍历二叉树的操作定义 遍历二叉树算法描述 1.遍历定义&#xff1a;顺着某一条搜索路径寻访二叉树中的结点&#xff0c;使得每一个结点均被访问一次&#xff0c;而且仅访问一次&#xff08;又…