pytorch实现分割模型TransUNet

news2025/1/10 15:58:47

TransUNet是一个非常经典的图像分割模型。该模型出现在Transformer引入图像领域的早期,所以结构比较简单,但是实际上效果却比很多后续花哨的模型更好。所以有必要捋一遍pytorch实现TransUNet的整体流程。

首先,按照惯例,先看一下TransUNet的结构图:

根据结构图,我们可以看出,整体结构就是基于UNet魔改的。

1,具体结构如下:

1. CNN-Transformer混合编码器:TransUNet使用卷积神经网络(CNN)作为特征提取器,生成特征图。然后,从CNN特征图中提取的1x1 patches通过patch embedding转换为序列,作为Transformer的输入。这种设计允许模型利用CNN的高分辨率特征图。

2. Transformer编码器:Transformer编码器由多头自注意力(Multihead Self-Attention, MSA)和多层感知器(MLP)块组成。这些层处理输入序列,捕获全局上下文信息。

3. 级联上采样器(Cascaded Upsampler, CUP):为了从Transformer编码器的输出中恢复空间分辨率,TransUNet引入了CUP。CUP由多个上采样步骤组成,每个步骤包括一个2x上采样操作、一个3x3卷积层和一个ReLU激活层。这些步骤将特征图从低分辨率逐步上采样到原始图像的分辨率。

4. skip connection:TransUNet采用了U-Net的u形架构设计,通过跳跃连接(skip-connections)将编码器中的高分辨率CNN特征图与Transformer编码的全局上下文特征结合起来,以实现精确的定位。

5. 解码器:解码器部分使用CUP来从Transformer编码器的输出中恢复出最终的分割掩码。这包括将Transformer的输出特征图与CNN特征图结合,并通过上采样步骤恢复到原始图像的分辨率。

我们只需要实现其每个模块,然后安装UNet拼装成整体就可以了。

2,首先实现的是编码器分支的卷积部分:

每个卷积模块可以使用resnet的一个块,或者自己实现一个

class EncoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, base_width=64):
        super().__init__()  # 初始化父类
        self.downsample = nn.Sequential(  # 下采样层,用于降低特征图的维度
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels)
        )
        width = int(out_channels * (base_width / 64))  # 计算中间通道数
        self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1, bias=False)  # 第一个卷积层
        self.norm1 = nn.BatchNorm2d(width)  # 第一个批量归一化层
        self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=2, groups=1, padding=1, dilation=1, bias=False)  # 第二个卷积层
        self.norm2 = nn.BatchNorm2d(width)  # 第二个批量归一化层
        self.conv3 = nn.Conv2d(width, out_channels, kernel_size=1, stride=1, bias=False)  # 第三个卷积层
        self.norm3 = nn.BatchNorm2d(out_channels)  # 第三个批量归一化层
        self.relu = nn.ReLU(inplace=True)  # ReLU激活函数

    def forward(self, x):
        x_down = self.downsample(x)  # 下采样操作
        x = self.conv1(x)  # 第一个卷积操作
        x = self.norm1(x)  # 第一个批量归一化
        x = self.relu(x)  # ReLU激活
        x = self.conv2(x)  # 第二个卷积操作
        x = self.norm2(x)  # 第二个批量归一化
        x = self.relu(x)  # ReLU激活
        x = self.conv3(x)  # 第三个卷积操作
        x = self.norm3(x)  # 第三个批量归一化
        x = x + x_down  # 残差连接
        x = self.relu(x)  # ReLU激活
        return x

3,实现ViT模块

多头注意力实现如下:

class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_dim, head_num):
        super().__init__()  # 调用父类构造函数

        self.head_num = head_num  # 多头的数量
        self.dk = (embedding_dim // head_num) ** (1 / 2)  # 缩放因子,用于缩放点积注意力

        self.qkv_layer = nn.Linear(embedding_dim, embedding_dim * 3, bias=False)  # 线性层,用于生成查询(Q)、键(K)和值(V)
        self.out_attention = nn.Linear(embedding_dim, embedding_dim, bias=False)  # 输出线性层

    def forward(self, x, mask=None):
        qkv = self.qkv_layer(x)  # 通过线性层生成Q、K、V

        query, key, value = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d', k=3, h=self.head_num))  # 将Q、K、V重塑为多头注意力的格式
        energy = torch.einsum("... i d , ... j d -> ... i j", query, key) * self.dk  # 计算点积注意力的能量

        if mask is not None:  # 如果提供了掩码,则在能量上应用掩码
            energy = energy.masked_fill(mask, -np.inf)

        attention = torch.softmax(energy, dim=-1)  # 应用softmax函数,得到注意力权重

        x = torch.einsum("... i j , ... j d -> ... i d", attention, value)  # 应用注意力权重到值上

        x = rearrange(x, "b h t d -> b t (h d)")  # 重塑x以准备输出
        x = self.out_attention(x)  # 通过输出线性层

        return x

MLP实现如下:

# 定义MLP模块
class MLP(nn.Module):
    def __init__(self, embedding_dim, mlp_dim):
        super().__init__()  # 调用父类构造函数

        self.mlp_layers = nn.Sequential(  # 定义MLP的层
            nn.Linear(embedding_dim, mlp_dim),
            nn.GELU(),  # GELU激活函数
            nn.Dropout(0.1),  # Dropout层,用于正则化
            nn.Linear(mlp_dim, embedding_dim),  # 线性层
            nn.Dropout(0.1)  # Dropout层
        )

    def forward(self, x):
        x = self.mlp_layers(x)  # 通过MLP层
        return x

一个Transformer编码器块由归一化层,多头注意力,MLP和残差连接组成,实现如下:

# 定义Transformer编码器块
class TransformerEncoderBlock(nn.Module):
    def __init__(self, embedding_dim, head_num, mlp_dim):
        super().__init__()  # 调用父类构造函数

        self.multi_head_attention = MultiHeadAttention(embedding_dim, head_num)  # 多头注意力模块
        self.mlp = MLP(embedding_dim, mlp_dim)  # MLP模块

        self.layer_norm1 = nn.LayerNorm(embedding_dim)  # 第一层归一化
        self.layer_norm2 = nn.LayerNorm(embedding_dim)  # 第二层归一化

        self.dropout = nn.Dropout(0.1)  # Dropout层

    def forward(self, x):
        _x = self.multi_head_attention(x)  # 通过多头注意力模块
        _x = self.dropout(_x)  # 应用dropout
        x = x + _x  # 残差连接
        x = self.layer_norm1(x)  # 第一层归一化

        _x = self.mlp(x)  # 通过MLP模块
        x = x + _x  # 残差连接
        x = self.layer_norm2(x)  # 第二层归一化

        return x

Transformer 编码器由多层Transformer块堆叠而成,其中block_num代表的就是堆叠的层数

# 定义Transformer编码器
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, head_num, mlp_dim, block_num=12):
        super().__init__()  # 调用父类构造函数

        self.layer_blocks = nn.ModuleList([  # 创建一个模块列表,包含多个编码器块
            TransformerEncoderBlock(embedding_dim, head_num, mlp_dim) for _ in range(block_num)
        ])

    def forward(self, x):
        for layer_block in self.layer_blocks:  # 遍历每个编码器块
            x = layer_block(x)  # 通过每个块
        return x

vit的全部模块已经实现,下面就vit整体结构了。

vit的整体结构就是先将输入图片划分patches,然后将patches做embedding。

vit的分类头是一组额外添加的cl-token,将这个class-token复制batches遍,之后就可以将复制后的class_Token拼接到之前的embedding上了。

之后需要把位置编码加到这个embedding上。

这样,输入的图像特征就被处理好了,转换成了输入给Transformer块的形式。

之后只要输入一个Transformer编码器和一个MLP头,就可以得到vit的输出结果。

如果是分类任务,则class_token就是分类结果,如果不是分类任务,比如分割或者vit作为一个模块,那么输出的就是patches形式的特征图。

# 定义ViT模型
class ViT(nn.Module):
    def __init__(self, img_dim, in_channels, embedding_dim, head_num, mlp_dim, block_num, patch_dim, classification=True, num_classes=1):
        super().__init__()  # 调用父类构造函数

        self.patch_dim = patch_dim  # 定义patch的维度
        self.classification = classification  # 是否进行分类
        self.num_tokens = (img_dim // patch_dim) ** 2  # 计算tokens的数量
        self.token_dim = in_channels * (patch_dim ** 2)  # 计算每个token的维度

        self.projection = nn.Linear(self.token_dim, embedding_dim)  # 线性层,用于将patches投影到embedding空间
        self.embedding = nn.Parameter(torch.rand(self.num_tokens + 1, embedding_dim))  # 可学习的embedding
        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))  # 类别token

        self.dropout = nn.Dropout(0.1)  # Dropout层

        self.transformer = TransformerEncoder(embedding_dim, head_num, mlp_dim, block_num)  # Transformer编码器

        if self.classification:  # 如果是分类任务
            self.mlp_head = nn.Linear(embedding_dim, num_classes)  # 分类头

    def forward(self, x):
        img_patches = rearrange(x,  # 将输入图像重塑为patches序列
                                'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                                patch_x=self.patch_dim, patch_y=self.patch_dim)

        batch_size, tokens, _ = img_patches.shape  # 获取批次大小、tokens数量和通道数

        project = self.projection(img_patches)  # 将patches投影到embedding空间
        token = repeat(self.cls_token, 'b ... -> (b batch_size) ...', batch_size=batch_size)  # 重复cls_token以匹配批次大小

        patches = torch.cat((token, project), dim=1)  # 将cls_token和投影后的patches拼接
        patches += self.embedding[:tokens + 1, :]  # 将可学习的embedding添加到patches

        x = self.dropout(patches)  # 应用dropout
        x = self.transformer(x)  # 通过Transformer编码器
        x = self.mlp_head(x[:, 0, :]) if self.classification else x[:, 1:, :]  # 如果是分类任务,使用cls_token的输出;否则,使用patches的输出

        return x

4,实现解码器的模块

解码器的模块就是卷积模块,接受两个输入:上采样而来的特征图以及skip-connection来的特征图

# 定义解码器中的瓶颈层
class DecoderBottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super().__init__()  # 初始化父类
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=True)  # 上采样层
        self.layer = nn.Sequential(  # 解码器层
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, x_concat=None):
        x = self.upsample(x)  # 上采样操作
        if x_concat is not None:  # 如果有额外的特征图进行拼接
            x = torch.cat([x_concat, x], dim=1)  # 在通道维度上拼接
        x = self.layer(x)  # 通过解码器层
        return x

5,组装成模型

所有模块都已经定义完成,下面拿这些模块来组装成模型。

编码器分支由三个卷积模块和一个vit模块组成,输出的x为解码分支最终的特征图,x1,x2,x3分别是三个卷积模块的输出

# 定义编码器
class Encoder(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim):
        super().__init__()  # 初始化父类
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=7, stride=2, padding=3, bias=False)  # 第一个卷积层
        self.norm1 = nn.BatchNorm2d(out_channels)  # 第一个批量归一化层
        self.relu = nn.ReLU(inplace=True)  # ReLU激活函数
        self.encoder1 = EncoderBottleneck(out_channels, out_channels * 2, stride=2)  # 第一个编码器瓶颈层
        self.encoder2 = EncoderBottleneck(out_channels * 2, out_channels * 4, stride=2)  # 第二个编码器瓶颈层
        self.encoder3 = EncoderBottleneck(out_channels * 4, out_channels * 8, stride=2)  # 第三个编码器瓶颈层
        self.vit_img_dim = img_dim // patch_dim  # ViT的图像维度
        self.vit = ViT(self.vit_img_dim, out_channels * 8, out_channels * 8,  # ViT模型
                       head_num, mlp_dim, block_num, patch_dim=1, classification=False)
        self.conv2 = nn.Conv2d(out_channels * 8, 512, kernel_size=3, stride=1, padding=1)  # 第四个卷积层
        self.norm2 = nn.BatchNorm2d(512)  # 第四个批量归一化层

    def forward(self, x):
        x = self.conv1(x)  # 第一个卷积操作
        x = self.norm1(x)  # 第一个批量归一化
        x1 = self.relu(x)  # ReLU激活
        x2 = self.encoder1(x1)  # 第一个编码器瓶颈层
        x3 = self.encoder2(x2)  # 第二个编码器瓶颈层
        x = self.encoder3(x3)  # 第三个编码器瓶颈层
        x = self.vit(x)  # 通过ViT模型
        x = rearrange(x, "b (x y) c -> b c x y", x=self.vit_img_dim, y=self.vit_img_dim)  # 重塑特征图
        x = self.conv2(x)  # 第四个卷积操作
        x = self.norm2(x)  # 第四个批量归一化
        x = self.relu(x)  # ReLU激活
        return x, x1, x2, x3  # 返回多个特征图

解码分支接受编码分支的输出x,以及三个卷积模块的输出x1,x2,x3,

# 定义解码器
class Decoder(nn.Module):
    def __init__(self, out_channels, class_num):
        super().__init__()  # 初始化父类
        self.decoder1 = DecoderBottleneck(out_channels * 8, out_channels * 2)  # 第一个解码器瓶颈层
        self.decoder2 = DecoderBottleneck(out_channels * 4, out_channels)  # 第二个解码器瓶颈层
        self.decoder3 = DecoderBottleneck(out_channels * 2, int(out_channels * 1 / 2))  # 第三个解码器瓶颈层
        self.decoder4 = DecoderBottleneck(int(out_channels * 1 / 2), int(out_channels * 1 / 8))  # 第四个解码器瓶颈层
        self.conv1 = nn.Conv2d(int(out_channels * 1 / 8), class_num, kernel_size=1)  # 最后一个卷积层,用于输出

    def forward(self, x, x1, x2, x3):
        x = self.decoder1(x, x3)  # 第一个解码器瓶颈层
        x = self.decoder2(x, x2)  # 第二个解码器瓶颈层
        x = self.decoder3(x, x1)  # 第三个解码器瓶颈层
        x = self.decoder4(x)  # 第四个解码器瓶颈层
        x = self.conv1(x)  # 最后一个卷积层
        return x  # 返回解码器的输出

整个模型结构:

# 定义TransUNet模型
class TransUNet(nn.Module):
    def __init__(self, img_dim, in_channels, out_channels, head_num, mlp_dim, block_num, patch_dim, class_num):
        super().__init__()  # 初始化父类
        self.encoder = Encoder(img_dim, in_channels, out_channels,  # 初始化编码器
                               head_num, mlp_dim, block_num, patch_dim)
        self.decoder = Decoder(out_channels, class_num)  # 初始化解码器

    def forward(self, x):
        x, x1, x2, x3 = self.encoder(x)  # 编码分支
        x = self.decoder(x, x1, x2, x3)  # 解码分支
        return x  # 返回最终输出

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1523017.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MySQL 篇- Java 连接 MySQL 数据库并实现数据交互

🔥博客主页: 【小扳_-CSDN博客】 ❤感谢大家点赞👍收藏⭐评论✍ 文章目录 1.0 JDBC 概述 2.0 实现 Java 连接 MySQL 数据库并实现数据交互的完整过程 2.1 安装数据库驱动包 2.2 创建数据源对象 2.3 获取数据库连接对象 2.4 创建 SQL 语句 2.…

mac激活pycharm,python环境安装和包安装问题

1.PyCharm到官网下载就行 地址:Other Versions - PyCharm (jetbrains.com) 2.MacOS 下载python环境,地址: Python Releases for macOS | Python.org 3.PyCharm环境配置: 4. 如果包下载不下来可以换个源试试 pip install py…

如何在Ubuntu中查看编辑lvgl的demo和examples?

如何在Ubuntu中查看编辑lvgl的demo和examples? 如何在 Ubuntu系统中运行查看lvgl 1、拉取代码 在lvgl的github主页面有50多个仓库,找到lv_port_pc_eclipse这个仓库,点进去 拉取仓库代码和子仓库代码 仓库网址:https://github…

Maven: There are test failures.(已解决)

问题解决办法 进行package打包时报错如下: 然后这些并不能看出是测试的哪里的问题,可以点击上一级进行查看更详细的错误,越向上日志越详细,可以看到是52行出了错误, 52对应代码如下: 原因是存在注册的测…

分享一篇Oracle RAC实战安装11G

分享一次很久以前的Oracle rac项目实施。 1、拓扑结构 基础环境是2台H3C的服务器2台3PAR的双活存储,操作系统centos7.2。借用下别人家的拓扑先(这是一套典型的RAC架构)。 2、网卡TEAM操作 以eno51和en052组成Team1组为示例: nm…

校园闲置物品交易网站 |基于springboot框架+ Mysql+Java+Tomcat的校园闲置物品交易网站设计与实现(可运行源码+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 用户功能模块 管理员功能登录前台功能效果图 系统功能设计 数据库E-R图设计 lunwen…

章鱼网络 Community Call #19|​开启与 Eigenlayer 的合作

香港时间2024年3月8日12点,章鱼网络举行第19期 Community Call。 在过去的一个月,章鱼网络在成功完成 $NEAR Restaking 功能的安全审计之后,一直在稳步吸引关注。事实上,在整个行业中,我们是极少数已经推出 Restaking …

JavaWeb笔记 --- 三、MyBatis

三、MyBatis 概述 MyBatis是一个持久层框架,用于简化JDBC Mapper代理开发 在resources配置文件包中创建多级目录用 / MyBatis核心配置文件 enviroments:配置数据库连接环境信息。 可以配置多个enviroment,通过default属性切换不同的envir…

MySQL语法分类 DQL(5)分组查询

为了更好的学习这里给出基本表数据用于查询操作 create table student (id int, name varchar(20), age int, sex varchar(5),address varchar(100),math int,english int );insert into student (id,name,age,sex,address,math,english) values (1,马云,55,男,杭州,66,78),…

2核4g服务器能多少人在线?腾讯云2核4g服务器性能测评

腾讯云轻量应用服务器2核4G5M配置性能测评,腾讯云轻量2核4G5M带宽服务器支持多少人在线访问?并发数10,支持每天5000IP人数访问,腾讯云百科txybk.com整理2核4G服务器支持多少人同时在线?并发数测试、CPU性能、内存性能、…

FFmpeg 常用命令汇总

​​​​​​经常用到ffmpeg做一些视频数据的处理转换等,用来做测试,今天总结了一下,参考了网上部分朋友的经验,一起在这里汇总了一下。 1、ffmpeg使用语法 命令格式: ffmpeg -i [输入文件名] [参数选项] -f [格…

unity内存优化之AB包篇(微信小游戏)

1.搭建资源服务器使用(HFS软件(https://www.pianshen.com/article/54621708008/)) using System.Collections; using System.Collections.Generic; using UnityEngine;using System;public class Singleton<T> where T : class, new() {private static readonly Lazy<…

Unity的AssetBundle资源运行内存管理的再次深入思考

大家好&#xff0c;我是阿赵。   这篇文章我想写了很久&#xff0c;是关于Unity项目使用AssetBundle加载资源时的内存管理的。这篇文章不会分享代码&#xff0c;只是分享思路&#xff0c;思路不一定正确&#xff0c;欢迎讨论。   对于Unity引擎的资源内存管理&#xff0c;我…

【网络原理】TCP 协议中比较重要的一些特性(三)

目录 1、拥塞控制 2、延时应答 3、捎带应答 4、面向字节流 5、异常情况处理 5.1、其中一方出现了进程崩溃 5.2、其中一方出现关机&#xff08;正常流程的关机&#xff09; 5.3、其中一方出现断电&#xff08;直接拔电源&#xff0c;也是关机&#xff0c;更突然的关机&am…

校园博客系统 |基于springboot框架+ Mysql+Java的校园博客系统设计与实现(可运行源码+数据库+设计文档)

推荐阅读100套最新项目 最新ssmjava项目文档视频演示可运行源码分享 最新jspjava项目文档视频演示可运行源码分享 最新Spring Boot项目文档视频演示可运行源码分享 目录 前台功能效果图 管理员功能登录前台功能效果图 系统功能设计 数据库E-R图设计 lunwen参考 摘要 研究…

每日一练:LeeCode-125、验证回文串【字符串+双指针】

如果在将所有大写字符转换为小写字符、并移除所有非字母数字字符之后&#xff0c;短语正着读和反着读都一样。则可以认为该短语是一个 回文串 。 字母和数字都属于字母数字字符。 给你一个字符串 s&#xff0c;如果它是 回文串 &#xff0c;返回 true &#xff1b;否则&#…

mysql中的非空间数据导入sqlserver中空间化

以下操作都在Navicat Premium 15软件中操作 1、mysql导出数据 以导出csv为例 不修改导出路径的话默认就是在桌面 设置编码UTF-8 这边还是默认,最好不要修改,如果文本识别符号为空,导入的时候可能字段会错乱 开始即可 2、导入sqlserver数据库中

App的测试,和传统软件测试有哪些区别?增加哪些方面的测试用例

从上图可知&#xff0c;测试人员所测项目占比中&#xff0c;App测试占比是最高的。 这就意味着学习期间&#xff0c;我们要花最多的精力去学App的各类测试。也意味着我们找工作前&#xff0c;就得知道&#xff0c;App的测试点是什么&#xff0c;App功能我们得会测试&#xff0…

绩效考核设计:拟定工时标准,实现量化考核

该度假村工程维修部的主要工作是修灯泡、换水管、修门、开锁等&#xff0c;部门员工大多是老员工&#xff0c;随着年龄的增加&#xff0c;这些员工的工作积极性越来越差&#xff0c;“老油条”越来越多&#xff0c;其他部门对工程维修部的抱怨声也越来越大。一起来看看人力资源…

RTC的Google拥塞控制算法 rmcat-gcc-02

摘要 本文档描述了使用时的两种拥塞控制方法万维网&#xff08;RTCWEB&#xff09;上的实时通信&#xff1b;一种算法是基于延迟策略&#xff0c;一种算法是基于丢包策略。 1.简介 拥塞控制是所有共享网络的应用程序的要求互联网资源 [RFC2914]。 实时媒体的拥塞控制对于许…