MKDCNet分割模型搭建

news2024/9/25 13:19:47

原论文:https://arxiv.org/abs/2206.06264v1
源码:https://github.com/nikhilroxtomar/MKDCNet

直接步入正题~~~

一、基础模块

class Conv2D(nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=1, dilation=1, bias=False, act=True):
        super().__init__()
        self.act = act

        self.conv = nn.Sequential(
            nn.Conv2d(
                in_c, out_c,
                kernel_size=kernel_size,
                padding=padding,
                dilation=dilation,
                bias=bias
            ),
            nn.BatchNorm2d(out_c)
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        if self.act == True:
            x = self.relu(x)
        return x

class residual_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.network = nn.Sequential(
            Conv2D(in_c, out_c),
            Conv2D(out_c, out_c, kernel_size=1, padding=0, act=False)

        )
        self.shortcut = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x_init):
        x = self.network(x_init)
        s = self.shortcut(x_init)
        x = self.relu(x+s)
        return x

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16): #in_planes=96
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x): #2,96,128,128
        # 2,96,128,128 -> 2,96,1,1 -> 2,6,1,1 -> 2,96,1,1
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        # 2,96,128,128 -> 2,96,1,1 -> 2,6,1,1 -> 2,96,1,1
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out #2,96,1,1
        return self.sigmoid(out)

二、ChannelAttention和SpatialAttention

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16): #in_planes=96
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc1   = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2   = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x): #2,96,128,128
        # 2,96,128,128 -> 2,96,1,1 -> 2,6,1,1 -> 2,96,1,1
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        # 2,96,128,128 -> 2,96,1,1 -> 2,6,1,1 -> 2,96,1,1
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out #2,96,1,1
        return self.sigmoid(out)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
        padding = 3 if kernel_size == 7 else 1

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x): #2,96,128,128
        avg_out = torch.mean(x, dim=1, keepdim=True) #2,1,128,128
        max_out, _ = torch.max(x, dim=1, keepdim=True) #2,1,128,128
        x = torch.cat([avg_out, max_out], dim=1) #2,2,128,128
        x = self.conv1(x) #2,1,128,128
        return self.sigmoid(x)

三、encoder模块

class encoder(nn.Module):
    def __init__(self, ch):
        super().__init__()

        """ ResNet50 """
        backbone = resnet50()
        self.layer0 = nn.Sequential(backbone.conv1, backbone.bn1, backbone.relu)
        self.layer1 = nn.Sequential(backbone.maxpool, backbone.layer1)
        self.layer2 = backbone.layer2
        self.layer3 = backbone.layer3

        """ Reduce feature channels """
        self.c1 = Conv2D(64, ch)
        self.c2 = Conv2D(256, ch)
        self.c3 = Conv2D(512, ch)
        self.c4 = Conv2D(1024, ch)

    def forward(self, x):
        """ Backbone: ResNet50 """
        x0 = x
        x1 = self.layer0(x0)    ## [-1, 64, h/2, w/2] 2,64,128,128
        x2 = self.layer1(x1)    ## [-1, 256, h/4, w/4] 2,256,64,64
        x3 = self.layer2(x2)    ## [-1, 512, h/8, w/8] 2,512,32,32
        x4 = self.layer3(x3)    ## [-1, 1024, h/16, w/16] 2,1024,16,16

        c1 = self.c1(x1) #2,96,128,128
        c2 = self.c2(x2) #2,96,64,64
        c3 = self.c3(x3) #2,96,32,32
        c4 = self.c4(x4) #2,96,16,16

        return c1, c2, c3, c4

四、MKDC模块

class multikernel_dilated_conv(nn.Module):
    def __init__(self, in_c, out_c):  #in_c=96, out_c=96
        super().__init__()
        self.relu = nn.ReLU(inplace=True)

        self.c1 = Conv2D(in_c, out_c, kernel_size=1, padding=0)
        self.c2 = Conv2D(in_c, out_c, kernel_size=3, padding=1)
        self.c3 = Conv2D(in_c, out_c, kernel_size=7, padding=3)
        self.c4 = Conv2D(in_c, out_c, kernel_size=11, padding=5)
        self.s1 = Conv2D(out_c*4, out_c, kernel_size=1, padding=0)

        self.d1 = Conv2D(out_c, out_c, kernel_size=3, padding=1, dilation=1)
        self.d2 = Conv2D(out_c, out_c, kernel_size=3, padding=3, dilation=3)
        self.d3 = Conv2D(out_c, out_c, kernel_size=3, padding=7, dilation=7)
        self.d4 = Conv2D(out_c, out_c, kernel_size=3, padding=11, dilation=11)
        self.s2 = Conv2D(out_c*4, out_c, kernel_size=1, padding=0, act=False)
        self.s3 = Conv2D(in_c, out_c, kernel_size=1, padding=0, act=False)

        self.ca = ChannelAttention(out_c)
        self.sa = SpatialAttention()

    def forward(self, x):  #假设x.shape [2,96,128,128]
        x0 = x
        x1 = self.c1(x) #2,96,128,128
        x2 = self.c2(x) #2,96,128,128
        x3 = self.c3(x) #2,96,128,128
        x4 = self.c4(x) #2,96,128,128
        x = torch.cat([x1, x2, x3, x4], axis=1) #2,96*4,128,128
        x = self.s1(x) #2,96,128,128

        x1 = self.d1(x) #2,96,128,128
        x2 = self.d2(x) #2,96,128,128
        x3 = self.d3(x) #2,96,128,128
        x4 = self.d4(x) #2,96,128,128
        x = torch.cat([x1, x2, x3, x4], axis=1) #2,96*4,128,128
        x = self.s2(x) #2,96,128,128
        s = self.s3(x0) #2,96,128,128

        x = self.relu(x+s) #2,96,128,128
        # 2,96,1,1 -> 2,96,128,128
        x = x * self.ca(x)
        # 2,1,128,128 -> 2,96,128,128
        x = x * self.sa(x)

        return x #2,96,128,128

五、MFF模块

class multiscale_feature_fusion(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up_2 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)

        self.c1 = Conv2D(in_c, out_c)
        self.c2 = Conv2D(out_c+in_c, out_c)
        self.c3 = Conv2D(in_c, out_c)
        self.c4 = Conv2D(out_c+in_c, out_c)

        self.ca = ChannelAttention(out_c)
        self.sa = SpatialAttention()

    def forward(self, f1, f2, f3): #f1:2,96,32,32, f2:2,96,64,64, f3:2,96,128,128
        x1 = self.up_2(f1) #2,96,64,64
        x1 = self.c1(x1) #2,96,64,64
        x1 = torch.cat([x1, f2], axis=1) #2,192,64,64
        x1 = self.up_2(x1) #2,192,128,128
        x1 = self.c2(x1) #2,96,128,128
        x1 = torch.cat([x1, f3], axis=1) #2,192,128,128
        x1 = self.up_2(x1) #2,192,256,256
        x1 = self.c4(x1) #2,96,256,256

        x1 = x1 * self.ca(x1) #2,96,256,256
        x1 = x1 * self.sa(x1) #2,96,256,256

        return x1

六、decoder模块

class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r1 = residual_block(in_c[0]+in_c[1], out_c)
        self.r2 = residual_block(out_c, out_c)

    def forward(self, x, s): #假设 x:2,96,16,16  s:2,96,32,32
        x = self.up(x) #2,96,32,32
        x = torch.cat([x, s], axis=1) #2,192,32,32
        x = self.r1(x) #2,96,32,32
        x = self.r2(x) #2,96,32,32
        return x

七、整体网络结构

class DeepSegNet(nn.Module):
    def __init__(self):
        super().__init__()

        """ Encoder """
        self.encoder = encoder(96)

        """ MultiKernel Conv + Dilation """
        self.c1 = multikernel_dilated_conv(96, 96)
        self.c2 = multikernel_dilated_conv(96, 96)
        self.c3 = multikernel_dilated_conv(96, 96)
        self.c4 = multikernel_dilated_conv(96, 96)

        """ Decoder """
        self.d1 = decoder_block([96, 96], 96)
        self.d2 = decoder_block([96, 96], 96)
        self.d3 = decoder_block([96, 96], 96)

        """ Multiscale Feature Fusion """
        self.msf = multiscale_feature_fusion(96, 96)

        """ Output """
        self.y = nn.Conv2d(96, 1, kernel_size=1, padding=0)

    def forward(self, image): #image:2, 3, 256, 256
        s0 = image
        s1, s2, s3, s4 = self.encoder(image)

        x1 = self.c1(s1)
        x2 = self.c2(s2)
        x3 = self.c3(s3)
        x4 = self.c4(s4)

        d1 = self.d1(x4, x3)
        d2 = self.d2(d1, x2)
        d3 = self.d3(d2, x1)

        x = self.msf(d1, d2, d3)
        y = self.y(x)

        return y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/696336.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2.8C++继承和组合

C 继承和组合 C中的继承和组合都是面向对象的重要概念,它们可以帮助我们构建更加灵活和可扩展的程序。 继承是一种机制,它允许我们创建一个新的类,该类从现有的类中继承属性和方法。 在C中,我们使用关键字 class 或 struc 来定…

华为云obs桶授权

登录华为云控制台 https://auth.huaweicloud.com/authui/login.html?servicehttps://console.huaweicloud.com/console/#/login 进入并行文件系统页面 进入桶 添加ACL授权,填写被授权账号ID,点击确定即可

IMX6ULL系统移植篇-镜像烧写说明

一. 镜像烧写简介 之前一篇文章学习了 阿尔法开发板烧写镜像的方法。 即将 镜像烧写到 Nand-Flash内部,设备最终从 Nand-Flash启动。说明博文如下: IMX6ULL系统移植篇-镜像烧写方法_凌肖战的博客-CSDN博客 二. 镜像烧写说明 之前文章说明了 使用 mfg…

基于vue+Element Table 表格的封装

项目场景&#xff1a; 项目场景&#xff1a;需要频繁使用列表进行呈现数据&#xff0c;不可能每次都写一个表格&#xff0c;可以将表格封装为一个组件&#xff0c;在需要使用时可以直接调用。 效果展示&#xff1a; 项目结构&#xff1a; 具体实现&#xff1a; Table.vue <…

总结linux防火墙firewall端口开通步骤

之前开通过服务器端口&#xff0c;在这里也记录和分享一下。 Step1:检查白名单&#xff1a; sudo firewall-cmd --list-port step2:添加8080端口到白名单 [user zhangsan ~]$ sudo firewall-cmd --zonepublic --add-port8080/tcp --permanent Success Step3&#xff1a;r…

UR5机器人示教器使用——可视化控制部分(非编程)

感谢董青云师兄教我使用示教器 1.UR5机器人示教器 问师兄 3楼 UR-robotic 的控制 示教器相关内容&#xff08;UR5机器人的控制&#xff0c;有UR机器人的仿真环境&#xff0c;需要在虚拟器上运行&#xff09; 1.示教器上的控制有正逆控制&#xff1a;逆向运动学通常用于计算机…

同态加密的类型,同态加密示例

目录 什么是同态加密 同态加密的类型 同态加密示例 什么是同态加密 同态加密&#xff08;Homomorphic Encryption&#xff09;是指将原始数据经过同态加密后&#xff0c;对得到的密文进行特定的运算&#xff0c;然后将计算结果再进行同态解密后得到的明文等价于原始明文数据…

DJI AIR 2S

一、注意事项 注意&#xff1a; 1、侧飞时需要注意&#xff0c;没有侧向避障 2、返航高度设置&#xff0c;应高于飞行区域高楼高度&#xff08;如269m&#xff09; 3、遥控与飞机之间不能有建筑物遮挡&#xff0c;如果出现信号弱&#xff08;上升高度会改善信息&#xff09; 4、…

doris on k8s 的安装部署

官方文档 1. 按照官网提供地址下载部署文件 2. 修改内核配置 sysctl -w vm.max_map_count20000003. 根据服务器环境&#xff0c;修改doris_be.yml文件。 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the…

深度学习编译器对比:The Deep Learning Compiler A Comprehensive Survey

参考&#xff1a;The Deep Learning Compiler: A Comprehensive Survey 记录几种深度学习编译器的功能和性能的对比&#xff1b; TVM在CPU和GPU的表现最好&#xff1b; MobileNet:TVM在conv、linear、expand表现最好&#xff1b;XLA在dewise的表现最好&#xff1b;

赛效:如何在线做图表

1&#xff1a;打开并登录图表秀&#xff0c;点击“我的模板”菜单里的“新建图表”。 2&#xff1a;根据自己的需要&#xff0c;在右侧的模板里选择一个。图表编辑区域里&#xff0c;会自动出现刚才点击的图表。 3&#xff1a;我们可以直接在右侧区域里编辑图标属性&#xff0c…

用户实操 | GBase 8a MPP Cluster慢SQL分析排查和优化方法

本期供稿 | 中国农业银行研发中心 蔡鹍鹏 01 排查和优化方法 SQL任务历史性能对比分析&#xff1a; 通过开启GBase 8a的audit_log审计日志&#xff0c;可以连续收集周期性任务的执行时间&#xff0c;通过对比相同SQL任务历史执行时长可以判定相同任务SQL长周期内的执行耗时趋…

【Java】如何在 Java 中连接字符串

本文仅供学习参考&#xff01; 字符串连接可以定义为将两个或多个字符串连接在一起以形成新字符串的过程。大多数编程语言至少提供一种连接字符串的方法。Java 为您提供了多种选择&#xff0c;包括&#xff1a; ****运算符**String.concat()**方法StringBuilder类StringBuffer…

LeetCode·每日一题·1186. 删除一次得到子数组最大和·动态规划

作者&#xff1a;小迅 链接&#xff1a;https://leetcode.cn/problems/maximum-subarray-sum-with-one-deletion/solutions/2321919/dong-tai-gui-hua-zhu-shi-chao-ji-xiang-x-cwvs/ 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 著作权归作者所有。商业转载请联系作…

XILINX 7系列FPGA Dedicated Configuration Bank功能详解

&#x1f3e1;《Xilinx FPGA开发指南》 目录 1&#xff0c;概述2&#xff0c;功能详解2.1&#xff0c;DXP_0与DXN_02.2&#xff0c;VCCBATT_02.3&#xff0c;INIT_B_02.4&#xff0c;M0_0&#xff0c;M1_0&#xff0c;M2_02.5&#xff0c;TDI,TDO,TMS,TCK2.6&#xff0c; VCCAD…

【Unity开发小技巧】UnityWebGL打包本地浏览器运行查看

目录 一.前言&#xff1a; 二.WebGL打包 三.配置web.config&#xff08;重要&#xff09; 四.部署IIS 五.测试 一.前言&#xff1a; 正常打包WebGL后在浏览器直接运行会报以下这个错&#xff1a; It seems your browser does not support running Unity WebGL content fr…

【效率工具】Windows 10 终端自动补全、智能提示

1. 安装PSReadLine 2.1.0 Install-Module PSReadLine -RequiredVersion 2.1.02. 检查是否存在配置文件 Test-path $profile创建配置文件&#xff08;不存在的话&#xff09; New-item –type file –force $profile3. 编辑配置文件 notepad $profile4. 运行该指令后退出终端…

rancher-import-k8s集群

一、 二、 三、 四、 到k8s 节点服务器上执行&#xff1a; 其实在&#xff1a;https://192.168.31.105:8443/v3/import/fgmt2r88wn4xvkm9n88gnshhhb8l976n6rpdvgz79r6rsfhlljnsxn_c-m-kq6c2fvn.yaml 里面下载了镜像 我们可以先下载镜像&#xff1a; docker pull rancher/ranc…

CSDN-AI小组2023-半年-研发总结

目录 1.丐版「大模型」&#xff0c;Proof of concept2. LLM和AIGC的各种综述3. 基于Embedding的应用&#xff0c;问答&#xff0c;AI编程4. 评论区的AI助手5. 结合AIGC的各种数据自动计算6. 个性化推荐的系统重构7. 基于AIGC的个性化博客创作鼓励8. 博客质量分V5: 可解释性计算…

java基础之super

当父类拥有一个带参的构造方法时&#xff0c;子类要有一个带有相同类型参数的构造方法&#xff0c;并且第一行使用super&#xff08;参数&#xff09;来接受&#xff0c;否则会报错 上图是一个类 Two&#xff0c;拥有一个带String类型参数的构造方法。 上图是一个类One&#x…