分割模型TransNetR的pytorch代码学习笔记

news2024/10/2 14:21:50

这个模型在U-net的基础上融合了Transformer模块和残差网络的原理。

论文地址:https://arxiv.org/pdf/2303.07428.pdf

具体的网络结构如下:

网络的原理还是比较简单的,

编码分支用的是预训练的resnet模块,解码分支则重新设计了。

解码器分支的模块结构示意图如下:

可以看出来,就是Transformer模块和残差连接相加,然后再经过一个residual模块处理。

1,用pytorch实现时,首先要把这个解码器模块实现出来:

class residual_transformer_block(nn.Module):
    def __init__(self, in_c, out_c, patch_size=4, num_heads=4, num_layers=2, dim=None):
        super().__init__()

        self.ps = patch_size
        self.c1 = Conv2D(in_c, out_c)

        encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
        self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.c2 = Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)
        self.c3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.r1 = residual_block(out_c, out_c)

    def forward(self, inputs):
        x = self.c1(inputs)

        b, c, h, w = x.shape
        num_patches = (h*w)//(self.ps**2)
        x = torch.reshape(x, (b, (self.ps**2)*c, num_patches))
        x = self.te(x)
        x = torch.reshape(x, (b, c, h, w))

        x = self.c2(x)
        s = self.c3(inputs)
        x = self.relu(x + s)
        x = self.r1(x)
        return x

其中我们需要注意的是这里构建Transformer块的方法,也就是下面两句:

encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=num_heads)
self.te = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

首先,第一句是用nn.TransformerEncoderLayer定义了一个Transformer层,并存储在encoder_layer变量中。

nn.TransformerEncoderLayer的参数包括:d_model(输入特征的维度大小),nhead(自注意力机制中注意力头数量),dim_feedforward(前馈网络的隐藏层维度大小),dropout(dropout比例),apply(用于在编码器层及其子层上应用函数,例如初始化或者权重共享等功能)。

第二句则是将多个Transformer层堆叠在一起,构建一个Transformer编码器块。

nn.TransformerEncoder的参数包括:encoder_layer(用于构建模块的每个Transformer层),num_layer(堆叠的层数),norm(执行的标准化方法),apply(同上)。

2,在上面的解码器模块中,还有一个residual block需要额外实现,如下:

class residual_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.LeakyReLU(negative_slope=0.1, inplace=True),
            nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_c)
        )
        self.shortcut = nn.Sequential(
            nn.Conv2d(in_c, out_c, kernel_size=1, padding=0),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)

    def forward(self, inputs):
        x = self.conv(inputs)
        s = self.shortcut(inputs)
        return self.relu(x + s)

这个代码就是简单的残差卷积模块,不赘述。

3,重要的模块实现完了,接下来就是按照UNet的形状拼装网络,代码如下:

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        backbone = resnet50()
        self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3
        self.layer4 = backbone.layer4

        self.e1 = Conv2D(64, 64, kernel_size=1, padding=0)
        self.e2 = Conv2D(256, 64, kernel_size=1, padding=0)
        self.e3 = Conv2D(512, 64, kernel_size=1, padding=0)
        self.e4 = Conv2D(1024, 64, kernel_size=1, padding=0)


        """ Decoder """
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r1 = residual_transformer_block(64+64, 64, dim=64)
        self.r2 = residual_transformer_block(64+64, 64, dim=256)
        self.r3 = residual_block(64+64, 64)

        """ Classifier """
        self.outputs = nn.Conv2d(64, 1, kernel_size=1, padding=0)

    def forward(self, inputs):
        """ Encoder """
        x0 = inputs
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2]
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4]
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8]
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16]

        e1 = self.e1(x1)
        e2 = self.e2(x2)
        e3 = self.e3(x3)
        e4 = self.e4(x4)

        """ Decoder """
        x = self.up(e4)
        x = torch.cat([x, e3], axis=1)
        x = self.r1(x)

        x = self.up(x)
        x = torch.cat([x, e2], axis=1)
        x = self.r2(x)

        x = self.up(x)
        x = torch.cat([x, e1], axis=1)
        x = self.r3(x)

        x = self.up(x)

        """ Classifier """
        outputs = self.outputs(x)
        return outputs

其中,x1,x2,x3,x4就是编码器模块,用的都是resnet50的预训练模块。

其中r1,r2,r3,r4则是解码器的模块,就是上面实现的模块。

而e1,e2,e3,e4则是在skip connection前给编码器的输出做1x1卷积,作用大体上就是减少计算量。

完整代码:https://github.com/DebeshJha/TransNetR/blob/main/model.py#L45

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1505281.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据结构奇妙旅程之二叉平衡树

꒰˃͈꒵˂͈꒱ write in front ꒰˃͈꒵˂͈꒱ ʕ̯•͡˔•̯᷅ʔ大家好,我是xiaoxie.希望你看完之后,有不足之处请多多谅解,让我们一起共同进步૮₍❀ᴗ͈ . ᴗ͈ აxiaoxieʕ̯•͡˔•̯᷅ʔ—CSDN博客 本文由xiaoxieʕ̯•͡˔•̯᷅ʔ 原创 CSDN …

【C++庖丁解牛】实现string容器的增删查改 | string容器的基本接口使用

📙 作者简介 :RO-BERRY 📗 学习方向:致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 📒 日后方向 : 偏向于CPP开发以及大数据方向,欢迎各位关注,谢谢各位的支持 目录 前言📖pu…

双链表()

双链表 实现一个双链表,双链表初始为空,支持 55 种操作: 在最左侧插入一个数;在最右侧插入一个数;将第 k 个插入的数删除;在第 k 个插入的数左侧插入一个数;在第 k 个插入的数右侧插入一个数 …

【LeetCode: 299. 猜数字游戏 - 模拟 + 计数】

🚀 算法题 🚀 🌲 算法刷题专栏 | 面试必备算法 | 面试高频算法 🍀 🌲 越难的东西,越要努力坚持,因为它具有很高的价值,算法就是这样✨ 🌲 作者简介:硕风和炜,…

解决阿里云服务器开启frp服务端,内网服务器开启frp客户端却连接不上的问题

解决方法: 把阿里云自带的Alibabxxxxxxxlinux系统 换成centos 7系统!!!! 说一下我的过程和问题:由于我们内网的服务器在校外是不能连接的,因此我弄了个阿里云服务器做内网穿透,所谓…

nRF52832——GPIO端口的应用

nRF52832——GPIO端口的应用 nRF52832 GPIO 端口资源描述nRF52832 GPIO 寄存器介绍GPIO 端口状态的设置GPIO 输出设置 nRF52832 GPIO 输出应用点亮第一个 LED 灯硬件部分Keil 工程搭建 蜂鸣器驱动硬件设计程序编写测试验证 nRF52832 GPIO 输入应用GPIO 输入扫描流程机械按键输入…

基于GAN对抗网进行图像修复

一、简介 使用PyTorch实现的生成对抗网络(GAN)模型,包括编码器(Encoder)、解码器(Decoder)、生成器(ResnetGenerator)和判别器(Discriminator)。…

vue 自定义组件绑定model+弹出选择支持上下按键选择

参考地址v-modelhttps://v2.cn.vuejs.org/v2/guide/components-custom-events.html#%E8%87%AA%E5%AE%9A%E4%B9%89%E7%BB%84%E4%BB%B6%E7%9A%84-v-model 原文代码 Vue.component(base-checkbox, {model: {prop: checked,event: change},props: {checked: Boolean},template: `…

阅读最新的论文,研究趋势

我们需要时刻了解技术的发展趋势,阅读最新的论文研究。那么,怎么阅读论文最高效?最近我们使用了全新的阅读方法: 第一步,阅读最新分类好的列表 第二步,挑选感兴趣的论文,阅读其一页纸总结 第三步…

spring-cloud-openfeign 3.0.0之前版本(对应spring boot 2.4.x之前版本)feign配置加载顺序

在之前写的文章配置基础上 https://blog.csdn.net/zlpzlpzyd/article/details/136060312 下图为自己整理的

rk3399使用阿里推理引擎MNN使用cpu和gpu进行benchmark,OpenCL效果不佳?

视频讲解 rk3399使用阿里推理引擎MNN使用cpu和gpu进行benchmark,OpenCL效果不佳? 背景 MNN是阿里开源的推理引擎,今天测试一下在rk3399平台上的benchmark怎么样? alibaba/MNN: MNN is a blazing fast, lightweight deep learning…

百家争鸣!AI艺术生成器的进化: 深入AI生成艺术世界

人工智能(AI)已经彻底改变了艺术界,AI艺术生成器现在能够创作出独特而迷人的作品。然而,关于AI生成艺术与人类创作艺术的艺术价值的争论仍然在引起争议。 社区对AI生成图像的原创性和所有权提出了关注,导致了法律纠纷和…

第十六章垃圾回收相关概念

第十六章垃圾回收相关概念 文章目录 第十六章垃圾回收相关概念1. System.gc()的理解2. 内存溢出与内存泄漏2.1 内存溢出(OOM)2.2 内存泄漏(Memory Leak) 3. Stop The World4. 垃圾回收的并行与并发4.1 并发(Concurrent…

基于SpringBoot的招聘网站

基于jspmysqlSpring的SpringBoot招聘网站项目(完整源码sql) 博主介绍:多年java开发经验,专注Java开发、定制、远程、文档编写指导等,csdn特邀作者、专注于Java技术领域 作者主页 央顺技术团队 Java毕设项目精品实战案例《1000套》…

ModuleNotFoundError: No module named ‘aitodpycocotools‘

具体不清楚,反正pip下载也下载不了,改为pycocotools后没问题了 解决 分析 是承接之前错误,为了解决keyerror问题,pip install -v -e .重新安装mmdet,导致的

PostgreSQL 安装部署

文章目录 一、PostgreSQL部署方式1.Yum方式部署2.RPM方式部署3.源码方式部署4.二进制方式部署5.Docker方式部署 二、PostgreSQL部署1.Yum方式部署1.1.部署数据库1.2.连接数据库 2.RPM方式部署2.1.部署数据库2.2.连接数据库 3.源码方式部署3.1.准备工作3.2.编译安装3.3.配置数据…

[LeetCode][LCR151]彩灯装饰记录 III——队列

题目 LCR 151. 彩灯装饰记录 III 一棵圣诞树记作根节点为 root 的二叉树,节点值为该位置装饰彩灯的颜色编号。请按照如下规则记录彩灯装饰结果: 第一层按照从左到右的顺序记录除第一层外每一层的记录顺序均与上一层相反。即第一层为从左到右&#xff0c…

Spring AOP底层原理

目录 代理模式 静态代理 动态代理 1. JDK动态代理 创建⼀个代理对象并使用 2. CGLIB动态代理 SpringAOP底层原理面试 代理模式 Spring AOP是基于动态代理模式来实现的 代理模式:静态代理模式动态代理模式 代理模式, 也叫委托模式。 定义:为其…

Mysql - is marked as crashed and should be repaired

概述 上周发生了一个Mysql报错的问题,今天有时间整理一下产生的原因和来龙去脉,Mysql的版本是5.5,发生错误的表存储引擎都是MyISAM,产生的报错信息是Table xxxxxx is marked as crashed and should be repaired。 定位问题 产生的后果是Nginx服务没有…

Util工具类功能设计与类设计(http模块一)

目录 类功能 类定义 类实现 编译测试 Split分割字符串测试 ReadFile读取测试 WriteFile写入测试 UrlEncode编码测试 UrlDecode编码测试 StatuDesc状态码信息获取测试 ExtMime后缀名获取文件mime测试 IsDirectory&IsRegular测试 VaildPath请求路径有效性判断测…