ACSNet分割模型搭建

news2024/12/25 9:06:27

原论文:Adaptive Context Selection for Polyp Segmentation
源码:https://github.com/ReaFly/ACSNet.

直接步入正题~~~

一、基础模块

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=kernel_size,
                              stride=stride,
                              padding=padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, stride=1, padding=1):
        super(DecoderBlock, self).__init__()

        self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.conv2 = ConvBlock(in_channels // 4, out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.upsample(x)
        return x


class SideoutBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super(SideoutBlock, self).__init__()

        self.conv1 = ConvBlock(in_channels, in_channels // 4, kernel_size=kernel_size,
                               stride=stride, padding=padding)

        self.dropout = nn.Dropout2d(0.1)

        self.conv2 = nn.Conv2d(in_channels // 4, out_channels, 1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.conv2(x)

        return x

二、LCA模块 

class LCA(nn.Module):
    def __init__(self):
        super(LCA, self).__init__()

    def forward(self, x, pred): #x:256,16,16  pre:1,16,16
        residual = x
        score = torch.sigmoid(pred)
        dist = torch.abs(score - 0.5)
        att = 1 - (dist / 0.5)
        att_x = x * att #256,16,16
        out = att_x + residual #256,16,16

        return out

三、GCM模块

class GCM(nn.Module):
    def __init__(self, in_channels, out_channels): #in_channels=512, out_channels=64
        super(GCM, self).__init__()
        pool_size = [1, 3, 5]
        out_channel_list = [256, 128, 64, 64]
        upsampe_scale = [2, 4, 8, 16]
        GClist = []
        GCoutlist = []
        for ps in pool_size:
            GClist.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(ps),
                nn.Conv2d(in_channels, out_channels, 1, 1),
                nn.ReLU(inplace=True)))
        GClist.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, 1),
            nn.ReLU(inplace=True),
            NonLocalBlock(out_channels)))
        self.GCmodule = nn.ModuleList(GClist)
        for i in range(4):
            GCoutlist.append(nn.Sequential(nn.Conv2d(out_channels * 4, out_channel_list[i], 3, 1, 1),
                                           nn.ReLU(inplace=True),
                                           nn.Upsample(scale_factor=upsampe_scale[i], mode='bilinear')))
        self.GCoutmodel = nn.ModuleList(GCoutlist)

    def forward(self, x): # 输入x: 512,8,8
        xsize = x.size()[2:]
        global_context = []
        for i in range(len(self.GCmodule) - 1): #range(3)
            global_context.append(F.interpolate(self.GCmodule[i](x), xsize, mode='bilinear', align_corners=True))
        global_context.append(self.GCmodule[-1](x))
        global_context = torch.cat(global_context, dim=1)

        output = []
        for i in range(len(self.GCoutmodel)): #range(4)
            output.append(self.GCoutmodel[i](global_context))

        return output

四、NonLocalBlock模块

class NonLocalBlock(nn.Module):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): #in_channels=64
        super(NonLocalBlock, self).__init__()

        self.sub_sample = sub_sample

        self.in_channels = in_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1

        self.g = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0)

        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                          kernel_size=1, stride=1, padding=0),
                nn.BatchNorm2d(self.in_channels)
            )
            # nn.init.constant_(tensor, val):基于输入参数(val)初始化输入张量tensor,即tensor的值均初始化为val。
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv2d(in_channels=self.inter_channels, out_channels=self.in_channels,
                               kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)

        self.theta = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                               kernel_size=1, stride=1, padding=0)
        self.phi = nn.Conv2d(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=1, stride=1, padding=0)

        if sub_sample:
            self.g = nn.Sequential(self.g, nn.MaxPool2d(kernel_size=(2, 2)))
            self.phi = nn.Sequential(self.phi, nn.MaxPool2d(kernel_size=(2, 2)))

    def forward(self, x): #bs,64,8,8

        batch_size = x.size(0)

        # bs,64,8,8->bs,32,4,4->bs,32,16
        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1) #bs,16,32

        # bs,64,8,8->bs,32,8,8->bs,32,64
        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1) #bs,64,32

        # bs,64,8,8->bs,32,4,4->bs,32,16
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x) #bs,64,16
        f_div_C = F.softmax(f, dim=-1) #bs,64,16

        y = torch.matmul(f_div_C, g_x) #bs,64,32
        y = y.permute(0, 2, 1).contiguous() #bs,32,64
        y = y.view(batch_size, self.inter_channels, *x.size()[2:]) #bs,32,8,8
        W_y = self.W(y) #bs,64,8,8
        z = W_y + x #bs,64,8,8

        return z

五、SE模块

class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

六、ASM模块

class ASM(nn.Module):
    def __init__(self, in_channels, all_channels):
        super(ASM, self).__init__()
        self.non_local = NonLocalBlock(in_channels)
        self.selayer = SELayer(all_channels)

    def forward(self, lc, fuse, gc):
        fuse = self.non_local(fuse)
        fuse = torch.cat([lc, fuse, gc], dim=1)
        fuse = self.selayer(fuse)

        return fuse

七、ACSNet网络结构

class ACSNet(nn.Module):
    def __init__(self, num_classes):
        super(ACSNet, self).__init__()

        self.resnet = resnet34(pretrained=False)
        
        # Encoder
        self.encoder1_conv = self.resnet.conv1
        self.encoder1_bn = self.resnet.bn1
        self.encoder1_relu = self.resnet.relu
        self.maxpool = self.resnet.maxpool
        self.encoder2 = self.resnet.layer1
        self.encoder3 = self.resnet.layer2
        self.encoder4 = self.resnet.layer3
        self.encoder5 = self.resnet.layer4

        # Decoder
        self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
        self.decoder4 = DecoderBlock(in_channels=1024, out_channels=256)
        self.decoder3 = DecoderBlock(in_channels=512, out_channels=128)
        self.decoder2 = DecoderBlock(in_channels=256, out_channels=64)
        self.decoder1 = DecoderBlock(in_channels=192, out_channels=64)

        self.outconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1),
                                      nn.Dropout2d(0.1),
                                      nn.Conv2d(32, num_classes, 1))

        # Sideout
        self.sideout2 = SideoutBlock(64, 1)
        self.sideout3 = SideoutBlock(128, 1)
        self.sideout4 = SideoutBlock(256, 1)
        self.sideout5 = SideoutBlock(512, 1)

        # local context attention module
        self.lca1 = LCA()
        self.lca2 = LCA()
        self.lca3 = LCA()
        self.lca4 = LCA()

        # global context module
        self.gcm = GCM(512, 64)

        # adaptive selection module
        self.asm4 = ASM(512, 1024)
        self.asm3 = ASM(256, 512)
        self.asm2 = ASM(128, 256)
        self.asm1 = ASM(64, 192)

    def forward(self, x):
        # x: 3,256,256
        e1 = self.encoder1_conv(x)  # 64,128,128
        e1 = self.encoder1_bn(e1)
        e1 = self.encoder1_relu(e1)
        e1_pool = self.maxpool(e1) # 64,64,64
        e2 = self.encoder2(e1_pool) # 64,64,64
        e3 = self.encoder3(e2)  # 128,32,32
        e4 = self.encoder4(e3)  # 256,16,16
        e5 = self.encoder5(e4)  # 512,8,8

        global_contexts = self.gcm(e5)
        # print(global_contexts[0].shape) [1, 256, 16, 16]
        # print(global_contexts[1].shape) [1, 128, 32, 32]
        # print(global_contexts[2].shape) [1, 64, 64, 64]
        # print(global_contexts[3].shape) [1, 64, 128, 128]
        
        d5 = self.decoder5(e5) # 512,8,8->512,16,16
        out5 = self.sideout5(d5) # 1,16,16
        lc4  = self.lca4(e4, out5) # 256,16,16
        gc4 = global_contexts[0]
        comb4 = self.asm4(lc4, d5, gc4) # 1024, 16, 16

        d4 = self.decoder4(comb4) # 256, 32, 32
        out4 = self.sideout4(d4) # 1, 32, 32
        lc3 = self.lca3(e3, out4) # 128, 32, 32
        gc3 = global_contexts[1]
        comb3 = self.asm3(lc3, d4, gc3) # 512,32,32


        d3 = self.decoder3(comb3)  # 128,64,64
        out3 = self.sideout3(d3) # 1,64,64
        lc2 = self.lca2(e2, out3) # 64,64,64
        gc2 = global_contexts[2]
        comb2 = self.asm2(lc2, d3, gc2)  # 256, 64, 64

        d2 = self.decoder2(comb2)  # 64,128,128
        out2 = self.sideout2(d2) # 1,128,128
        lc1 = self.lca1(e1, out2) # 64,128,128
        gc1 = global_contexts[3]
        comb1 = self.asm1(lc1, d2, gc1) # 192,128,128


        d1 = self.decoder1(comb1)  # 64,256,256
        out1 = self.outconv(d1)  # num_classes,256,256

        # return out1
        return torch.sigmoid(out1), torch.sigmoid(out2), torch.sigmoid(out3), \
            torch.sigmoid(out4), torch.sigmoid(out5)


if __name__ == '__main__':
    input_tensor = torch.randn((1, 3, 256, 256))
    model = ACSNet(num_classes=4)
    # out1 = model(input_tensor)
    # print(out1.shape)
    o1,o2,o3,o4,o5 = model(input_tensor)
    print(o1.shape,o2.shape,o3.shape,o4.shape,o5.shape)

 八、损失函数(Deep Supervision Loss)

def DeepSupervisionLoss(pred, gt):
    d0, d1, d2, d3, d4 = pred[0:]

    criterion = BceDiceLoss()

    loss0 = criterion(d0, gt) #256,256
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss1 = criterion(d1, gt) #128,128
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss2 = criterion(d2, gt) #64,64
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss3 = criterion(d3, gt) #32,32
    gt = F.interpolate(gt, scale_factor=0.5, mode='bilinear', align_corners=True)
    loss4 = criterion(d4, gt) #16,16

    return loss0 + loss1 + loss2 + loss3 + loss4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/695118.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

linux下卸载vmware

linux下卸载vmware 一、卸载二、vmware安装及使用 &emps;如果linux下安装了vmware系列产品,需要卸载或重装的,本文介绍了详细的卸载方法: 一、卸载 一般的发行版都不会带有vmware,所以通常是下载安装包来安装。这里主要说的就…

谷歌chrome浏览器所有历史版本下载及selenium自动化控制插件资源分享

使用python selenium做网页自动化开发的小伙伴经常需要用到google chrome浏览器以及chromedriver插件。 谷歌浏览器所有历史版本下载链接: chrome历史版本,点击下载 chromedriver插件下载地址:下载链接1:点击下载下载链接2…

一文了解python中pip的使用

目录 🍒pip的作用 🍒pip的使用 🍒运行python 🍓1.终端运行 🍓2.运行python文件 🍓3.pycharm 🍒IDE的概念 🦐博客主页:大虾好吃吗的博客 🦐专栏地址&#xff1…

scGen perturbation response prediction

如何利用scGen进行扰动响应预测 scGen是什么 scGen is a generative model to predict single-cell perturbation response across cell types, studies and species (Nature Methods, 2019). scGen is implemented using the scvi-tools framework. 文章传送门 Original tut…

Axure教程——100内的随机加减

本文介绍用Axure制作随机加减数效果 效果 预览地址https://ik31ey.axshare.com 制作 一、需要的元件: 2个动态面板,2个文本框,4个矩形,1个图片 二、制作过程 1、随机加 拖入一个矩形元件,命名为“被加数”&#xff…

MQ选手终极对决:比较几个主流MQ实现分布式事务的方案

1、MQ实现分布式事务 在分布式事务的实现中,ACK消息和半消息是两种实现分布式事务的两种不同机制,它们的共同点是都使用了两阶段提交的机制,但是它们的实现细节和适用场景有所不同。具有一些区别和特点。下面是ACK消息和半消息在实现分布式事…

暑期旅游市场火热,旅游行业如何做好邮件营销

随着2023年全国高考结束,以及各大高校毕业生也将陆续告别校园,许多学生都会用一场具有庆贺、欢聚意义的毕业旅游来为自己的高中生涯或大学生涯画上完美句号。携程数据显示,6月9日至8月底,高中和大学毕业生的旅游订单逐渐增多&…

QT 5.9.9 配置使用 MYSQL5.7 数据库

目录 Mysql下载安装 QT 下载安装 编译MYSQL mingw构建方式 msvc构建方式 QT中MYSQL测试使用 因为版权问题,Qt本身不自带Mysql数据库的驱动,因此如果想要借用Qt操作Mysql数据库,需要手动进行编译。 Mysql下载安装 【Qt】 Mysql服务端安装…

抖音seo账号矩阵系统源码搭建服务器内耗量?

一、短视频账号矩阵系统视频库存量技术云端如何处理? 1.抖音SEO账号矩阵系统源码搭建服务器内耗量怎么后端处理 非常低,可以轻松应对大量用户使用。通过服务器的高效运行,系统可以提供稳定的服务,并且快速响应用户的操作。源码搭…

2023年软件测试这个行业怎么样?

今天我要从一个新的角度来论述软件测试行业怎么样。 最近热搜新闻是张雪峰最近抨击的“新闻专业不要报,建议把孩子打晕后,随便选个专业都比新闻好”,重庆大学新闻学教授发文抨击张雪峰偏激言论,眼睛雪亮的人民群众却纷纷站队张雪…

C#复制构造函数学习

通过从另一个对象复制变量或将一个对象的数据复制到另一个对象来创建对象的构造函数称为复制构造函数。 复制构造函数是一个参数化构造函数,包含相同类类型的参数。它的主要用途是将新实例初始化为现有实例的值。 using System;namespace Mytest{class User {publi…

非凸科技金牌赞助“第三届中国Rust开发者大会”

6月17-18日,由Rust中文社区主办的“第三届中国Rust开发者大会”在上海圆满举行。非凸科技作为金牌赞助商,全力协助大会顺利开展,共同为中国 Rustaceans带来一场技术交流盛宴。 本次大会演讲主题内容广泛,涉及编程语言、量化金融、…

springboot mybatis-plus慢sql输出日志,log4jdbc使用

前言 无论使用原生JDBC、mybatis还是hibernate,使用log4j等日志框架可以看到生成的SQL,但是占位符和参数总是分开打印的,不便于分析,显示如下的效果: Log4jdbc 是一个开源 SQL 日志框架,它使用代理模式实现对常用的…

C++之operator重载运算符应用(一百四十一)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

采集发布到WordPress指定文章作者

采集的数据发布到wordpress系统网站,指定发布文章的作者设置方法教程。 目录 1. 获取用户名 2. 对接采集器指定作者 3. 随机作者 1. 获取用户名 进入Wordpress系统后台,点击控制台左侧菜单的【用户】,再点击展开列表的【所有用户】&…

word2vec self-attention transformer diffusion的技术演变

这一段时间大模型的相关进展如火如荼,吸引了很多人的目光;本文从nlp领域入门的角度来总结相关的技术路线演变路线。 1、introduction 自然语言处理(Natural Language Processing),简称NLP。这个领域是通过统计学、数…

使用homebrew安装RabbitMQ3.12.XX版本无法启动的解决方案

使用brew安装RabbitMQ3.12.XX版本遇到无法启动的天坑 首先来看RabbitMQ 3.12.0的新版说明,这也是我为什么无法启动的原因 所需的功能标志 RabbitMQ 3.12.0 将要求在升级前启用 3.11.x 系列版本的所有功能标志、 类似于 3.11.0 要求在 3.9.0 之前引入的所有功能标…

Android SDK file not found: F:\androidSDK\build-tools\34.0.0\aapt

问题表现 执行flutter doctor 的时候,报错Android SDK file not found,很明确的说没有配置 30.0.3 问题解决 首先去报错的SDK路径中排查是否有这个版本。发现有,但是是个空文件夹,所以删除掉该文件夹重新运行 flutter doctor &a…

【大语言模型】5分钟了解预训练、微调和上下文学习

5分钟了解预训练、微调和上下文学习 什么是预训练?什么是微调?什么是上下文学习?相关资料 近年来大语言模型在自然语言理解和生成方面、多模态学习等方面取得了显著进展。这些模型通过 预训练、 微调和 上下文学习的组合来学习。本文将快速…

JWT数字签名与token实现

JWT介绍 官方介绍 JSON Web Token (JWT)是一个开放标准(RFC 7519),它定义了一种紧凑的、自包含的方式,用于作为JSON对象在各方之间安全地传输信息。该信息可以被验证和信任,因为它是数字签名的。 什么时候你应该用JSON Web Token &#xf…