每日Attention学习12——Exterior Contextual-Relation Module

news2024/11/14 21:58:21
模块出处

[ISBI 22] [link] [code] Duplex Contextual Relation Network for Polyp Segmentation


模块名称

Exterior Contextual-Relation Module (ECRM)


模块作用

内存型特征增强模块


模块结构

在这里插入图片描述


模块思想

原文表述:在临床环境中,不同样本之间存在息肉的同步视觉模式。基于这一关键观察,属于所有训练数据的同一语义类的区域特征应该具有上下文关系。因此,我们提出了一种新颖的跨不同样本的上下文关系探索模块。
具体做法则是,对于编码器最后一层得到的全局特征(图中红色方块),进行两次增强:
第一次是直接将全局特征送入一个 1 × 1 1 \times 1 1×1卷积(图中浅紫色部分)以获取一个粗糙分割mask,该mask与全局特征相乘后便能得到过滤掉背景特征的增强特征(图中enqueue左边的部分)。
第二次增强则是基于网络存储的源自其他训练样本的历史上下文信息(图中的Cross-Batch Memory)。即,当前特征与Memory内特征进行Cross Attention操作,从而利用历史经验对当前状态进行补全。


模块代码

代码实现有几个额外要注意的地方:

  • 模块返回的aux_out要进行side supervision监督,以保证准确性;
  • Memory负责维护网络的历史信息,为防止被破坏,这部分信息并不参与梯度更新过程;
  • 在测试阶段,Memory不再更新,直接使用训练所存储的历史信息,这一思想与BatchNorm类似。
import torch
from torch import nn

def conv2d(in_channel, out_channel, kernel_size):
    layers = [
        nn.Conv2d(
            in_channel, out_channel, kernel_size, padding=kernel_size // 2, bias=False
        ),
        nn.BatchNorm2d(out_channel),
        nn.ReLU(),
    ]
    return nn.Sequential(*layers)


def conv1d(in_channel, out_channel):
    layers = [
        nn.Conv1d(in_channel, out_channel, 1, bias=False),
        nn.BatchNorm1d(out_channel),
        nn.ReLU(),
    ]
    return nn.Sequential(*layers)


class ECRM(nn.Module):
    def __init__(self, bank_size=20, feat_channels=512, num_classes=1):
        super(ECRM, self).__init__()  
        # BANK CONFIG
        self.bank_size = bank_size
        self.register_buffer("bank_ptr", torch.zeros(1, dtype=torch.long))  # memory bank pointer
        self.register_buffer("bank", torch.zeros(self.bank_size, feat_channels, num_classes))  # memory bank
        self.bank_full = False

        # ATTENTION CONFIG
        self.feat_channels = feat_channels
        self.L = nn.Conv2d(feat_channels, num_classes, 1)
        self.X = conv2d(feat_channels, 512, 3)
        self.phi = conv1d(512, 256)
        self.psi = conv1d(512, 256)
        self.delta = conv1d(512, 256)
        self.rho = conv1d(256, 512)
        self.g = conv2d(512 + 512, 512, 1)

    def init(self):
        self.bank_ptr[0] = 0
        self.bank_full = False

    @torch.no_grad()
    def update_bank(self, x):
        ptr = int(self.bank_ptr)
        batch_size = x.shape[0]
        vacancy = self.bank_size - ptr
        if batch_size >= vacancy:
            self.bank_full = True
        pos = min(batch_size, vacancy)
        self.bank[ptr:ptr+pos] = x[0:pos].clone()
        # update pointer
        ptr = (ptr + pos) % self.bank_size
        self.bank_ptr[0] = ptr

    def enhance_by_memory(self, bank, X_flat, X):
        batch, n_class, height, width = X.shape
        # query = S * C
        query = self.phi(bank).squeeze(dim=2)
        # key: = B * C * HW
        key = self.psi(X_flat)
        # logit = HW * S * B (cross image relation)
        logit = torch.matmul(query, key).transpose(0,2)
        # attn = HW * S * B
        attn = torch.softmax(logit, 2)
        # delta = S * C
        delta = self.delta(bank).squeeze(dim=2)
        # attn_sum = B * C * HW
        attn_sum = torch.matmul(attn.transpose(1,2), delta).transpose(1,2)
        # x_obj = B * C * H * W
        X_obj = self.rho(attn_sum).view(batch, -1, height, width)
        concat = torch.cat([X, X_obj], 1)
        out = self.g(concat)
        return out
    
    def get_prototype(self, input):
        L = self.L(input)
        aux_out = L
        batch, n_class, _, _ = L.shape
        l_flat = L.view(batch, n_class, -1)
        M = torch.softmax(l_flat, -1)
        X = self.X(input)
        channel = X.shape[1]
        X_flat = X.view(batch, channel, -1)
        f_k = (M @ X_flat.transpose(1, 2)).transpose(1, 2)
        return aux_out, f_k, X_flat, X

    def forward(self, x, flag='train'):
        # x [3, 512, 11, 11]
        # patch [3, 512, 1]
        aux_out, patch, feats_flat, feats = self.get_prototype(x)
        if flag == 'train':
            self.update_bank(patch)
            ptr = int(self.bank_ptr)
            if self.bank_full == True:
                out = self.enhance_by_memory(self.bank, feats_flat, feats)
            else:
                out = self.enhance_by_memory(self.bank[0:ptr], feats_flat, feats)
        elif flag == 'test':
            out = self.enhance_by_memory(patch, feats_flat, feats)
        return out, aux_out
    
if __name__ == '__main__':
    x = torch.randn([3, 512, 11, 11])
    ecrm = ECRM()
    out = ecrm(x)
    print(out[0].shape)  # 3, 512, 11, 11
    print(out[1].shape)  # 3, 1, 11, 11

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1946000.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python算法基础:解锁冒泡排序与选择排序的奥秘

在数据处理和算法设计中,排序是一项基础且重要的操作。本文将介绍两种经典的排序算法:冒泡排序(Bubble Sort)和选择排序(Selection Sort)。我们将通过示例代码来演示这两种算法如何对列表进行升序排列。 一…

BGP选路之Local Preference

原理概述 当一台BGP路由器中存在多条去往同一目标网络的BGP路由时,BGP协议会对这些BGP路由的属性进行比较,以确定去往该目标网络的最优BGP路由。BGP首先比较的是路由信息的首选值(PrefVal),如果 PrefVal相同,就会比较本…

python-NLP:2词性标注与命名实体识别

文章目录 词性标注命名实体识别时间命名实体(规则方法)CRF 命名实体识别方法 词性标注 词性是词汇基本的语法属性,通常也称为词类。词性标注是在给定句子中判定每个词的语法范畴,确定其词性并加以标注的过程。例如,表示…

爱回收严选买的二手iPad Air 4已经使用一周啦!

有多少人是跟我一样,手里一旦有点小钱就留不住,只想花出去的? 本24届应届生目前刚开始实习工作,虽然工资低的可怜,但是比起大学时期还是宽裕了不少。 于是发完工资的我就非常想消费!而我最近最想要的就是…

Fedora40安装telnet-server启用telnet服务

Fedora40安装telnet-server启用telnet服务 安装 telnet-server sudo yum install telnet-server或 sudo dnf install telnet-server启用服务 fedora40 或 CentosStream9 不能用 yum或dnf安装xinetd, telnet-server 的服务名为: telnet.socket 启用 telnet.socket.service …

Kithara和Halcon (二)

Kithara使用Halcon QT 进行二维码实时识别 目录 Kithara使用Halcon QT 进行二维码实时识别Halcon 简介以及二维码检测的简要说明Halcon 简介Halcon的二维码检测功能 Qt应用框架简介项目说明关键代码抖动测试测试平台:测试结果: 开源源码 Halcon 简介以…

C++与C中,由函数形参test(int *a)引出的问题

文章参考来源: 1.c函数中形参为引用的情况;C中a和&a的区别 描述: 最近在看循环单链表时,看到有篇文章中,链表初始化函数为图下,我在想,这个函数形参(类似 "int * & a"一样)到…

数据结构(二叉树-1)

文章目录 一、树 1.1 树的概念与结构 1.2 树的相关术语 1.3 树的表示 二、二叉树 2.1 二叉树的概念与结构 2.2特殊的二叉树 满二叉树 完全二叉树 2.3 二叉树的存储结构 三、实现顺序结构二叉树 3.1 堆的概念与结构 3.2 堆的实现 Heap.h Heap.c 默认初始化堆 堆的销毁 堆的插入 …

关于使用宝兰德bes中间件进行windows部署遇到的问题——license不存在

报错信息 日志文件中是这么报错的 遇到的具体情况: 实例按照**的文档手册正常步骤下去节点部署的时候没有报错,成功启动,但是日志里会有报错信息,也是license不存在实例创建的时候失败了,报错信息如下所示 解决方法…

Gitops-Argo-Cli安装与使用

一、安装Argo-Cli工具 Release v2.9.21 argoproj/argo-cd GitHub **选择合适的符合你操作系统以及CPU架构的二进制文件 #依v2.9.21-X86-64-Linux操作系统为例 wget https://github.com/argoproj/argo-cd/releases/download/v2.9.21/argocd-linux-amd64 #添加执行权限并且移…

昇思25天学习打卡营第19天|生成式-DCGAN生成漫画头像

打卡 目录 打卡 GAN基础原理 DCGAN原理 案例说明 数据集操作 数据准备 数据处理和增强 部分训练数据的展示 构造网络 生成器 生成器代码 ​编辑 判别器 判别器代码 模型训练 训练代码 结果展示(3 epoch) 模型推理 GAN基础原理 原理介…

AV1技术学习:Loop Restoration Filter

环路恢复滤波器(restoration filter)适用于64 64、128 128 或 256 256 像素块单元,称为 loop restoration units (LRUs)。每个单元可以独立选择是否跳过滤波、使用维纳滤波器(Wiener filter)或使用自导滤波器&#…

AM62x和rk3568的异同点

AM62x 和 RK3568 是两款不同的处理器,分别来自 Texas Instruments(TI)和 Rockchip。它们在设计目标、架构、性能和应用领域等方面存在一些异同。以下是这两款处理器的对比: 1. 基本架构 AM62x: 架构:基于…

【云原生】Kubernetes 中的 PV 和 PVC 介绍、原理、用法及实战案例分析

✨✨ 欢迎大家来到景天科技苑✨✨ 🎈🎈 养成好习惯,先赞后看哦~🎈🎈 🏆 作者简介:景天科技苑 🏆《头衔》:大厂架构师,华为云开发者社区专家博主,…

CRM客户管理系统是什么?如何利用CRM盘活老客户?

人与人之间的差距,不仅在于业务能力的高低,更在于如何高效地管理客户、建立深厚的客户关系。在这个“内卷化”严重的时代,借助工具来管理客户成为必不可少的流程。如果你保持怀疑态度,那我们先来聊聊。 客户管理是什么&#xff1…

HormonyOs之 路由简单跳转

Navigation路由相关的操作都是基于页面栈NavPathStack提供的方法进行,每个Navigation都需要创建并传入一个NavPathStack对象,用于管理页面。主要涉及页面跳转、页面返回、页面替换、页面删除、参数获取、路由拦截等功能。 Entry Component struct Index …

探索 Electron:快捷键与剪切板操作

Electron是一个开源的桌面应用程序开发框架,它允许开发者使用Web技术(如 HTML、CSS 和 JavaScript)构建跨平台的桌面应用程序,它的出现极大地简化了桌面应用程序的开发流程,让更多的开发者能够利用已有的 Web 开发技能…

Mysql9安装

目录 一、下载mysql 二、安装 三、配置mysql环境变量 四、mysql初始化和启动 1.以管理员身份运行cmd 2.cd到mysql的安装目录 3.初始化mysql的数据库 4.为Windows系统安装MySQL服务 5.查看一下名为mysql的服务: 6.启动MySQL服务 五、附录 1.系统变量还在&…

grafana对接zabbix数据展示

目录 1、初始化、安装grafana 2、浏览器访问 3、安装zabbix 4、zabbix数据对接grafana 5、如何导入模板? ① 设置键值 ② 在zabbix web端完成自定义监控项 ③ garafana里添加nginx上面的的三个监控项 6、如何自定义监控项? 以下实验沿用上一篇z…

arm环境下构建Flink的Docker镜像

准备工作 资源准备 按需下载 flink,我的是1.17.2版本。官方说1.13版本之后的安装包兼容了arm架构,所以直接下载就行。 如需要cdc组件,提前下载好。 服务器准备 可在某云上购买arm服务器,2c/4g即可,按量付费。 带宽…