每日Attention学习20——Group Shuffle Attention

news2025/2/7 0:44:57
模块出处

[MICCAI 24] [link] LB-UNet: A Lightweight Boundary-Assisted UNet for Skin Lesion Segmentation


模块名称

Group Shuffle Attention (GSA)


模块作用

轻量特征学习


模块结构

在这里插入图片描述


模块特点
  • 使用分组(Group)卷积降低计算量
  • 引入External Attention机制更好的学习特征
  • Shuffle操作促进不同Group之间信息的交互

模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x
        

class GSA(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        c_dim = dim_in // 4
        self.share_space1 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space1)
        self.conv1 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space2 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space2)
        self.conv2 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space3 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space3)
        self.conv3 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.share_space4 = nn.Parameter(torch.Tensor(1, c_dim, 8, 8), requires_grad=True)
        nn.init.ones_(self.share_space4)
        self.conv4 = nn.Sequential(
            nn.Conv2d(c_dim, c_dim, kernel_size=3, padding=1, groups=c_dim),
            nn.GELU(),
            nn.Conv2d(c_dim, c_dim, 1)
        )
        self.norm1 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.norm2 = LayerNorm(dim_in, eps=1e-6, data_format='channels_first')
        self.ldw = nn.Sequential(
            nn.Conv2d(dim_in, dim_in, kernel_size=3, padding=1, groups=dim_in),
            nn.GELU(),
            nn.Conv2d(dim_in, dim_out, 1),
        )

    def forward(self, x):
        x = self.norm1(x)
        x1, x2, x3, x4 = torch.chunk(x, 4, dim=1)
        B, C, H, W = x1.size()
        x1 = x1 * self.conv1(F.interpolate(self.share_space1, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x2 = x2 * self.conv2(F.interpolate(self.share_space2, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x3 = x3 * self.conv3(F.interpolate(self.share_space3, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x4 = x4 * self.conv4(F.interpolate(self.share_space4, size=x1.shape[2:4],mode='bilinear', align_corners=True))
        x = torch.cat([x2,x4,x1,x3], dim=1)
        x = self.norm2(x)
        x = self.ldw(x)
        return x
    

if __name__ == '__main__':
    x = torch.randn([1, 64, 44, 44])
    gsa = GSA(dim_in=64, dim_out=64)
    out = gsa(x)
    print(out.shape)  # [1, 64, 44, 44]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2294020.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VUE之组件通信(二)

1、v-model v-model的底层原理:是:value值和input事件的结合 $event到底是啥?啥时候能.target 对于原生事件,$event就是事件对象 ,能.target对应自定义事件,$event就是触发事件时,所传递的数据&#xff…

[x86 ubuntu22.04]进入S4失败

目录 1 问题描述 2 解决过程 2.1 查看内核日志 2.2 新建一个交换分区 2.3 指定交换分区的位置 1 问题描述 CPU:G6900E OS:ubuntu22.04 Kernel:6.8.0-49-generic 使用“echo disk > /sys/power/state”命令进入 S4,但是无法…

idea隐藏无关文件

idea隐藏无关文件 如果你想隐藏某些特定类型的文件(例如 .log 文件或 .tmp 文件),可以通过以下步骤设置: 打开设置 在菜单栏中选择 File > Settings(Windows/Linux)或 IntelliJ IDEA > Preference…

文献阅读 250205-Global patterns and drivers of tropical aboveground carbon changes

Global patterns and drivers of tropical aboveground carbon changes 来自 <Global patterns and drivers of tropical aboveground carbon changes | Nature Climate Change> 热带地上碳变化的全球模式和驱动因素 ## Abstract: Tropical terrestrial ecosystems play …

【数据结构】循环链表

循环链表 单链表局限性单向循环链表判断链表是否有环思路code 找到链表入口思路代码结构与逻辑 code 单链表局限性 单链表作为一种基本的数据结构&#xff0c;虽然在很多场景下都非常有用&#xff0c;但它也存在一些局限性&#xff1a; 单向访问&#xff1a;由于每个节点仅包含…

ImGui 学习笔记(二)—— 多视口

在计算机图形学中&#xff0c;视口&#xff08;Viewport&#xff09;是一个可观察的多边形区域。 将物体渲染至图像的过程中&#xff0c;会用两种区域表示。世界坐标窗口是用户所关注的区域&#xff08;即用户想要可视化的东西&#xff09;&#xff0c;坐标系由应用程序确定。…

安装和卸载RabbitMQ

我的飞书:https://rvg7rs2jk1g.feishu.cn/docx/SUWXdDb0UoCV86xP6b3c7qtMn6b 使用Ubuntu环境进行安装 一、安装Erlang 在安装RabbitMQ之前,我们需要先安装Erlang,RabbitMQ需要Erlang的语言支持 #安装Erlang sudo apt-get install erlang 在安装的过程中,会弹出一段信息,此…

Apache HttpClient

HttpClient是apache组织下面的一个用于处理HTTP请求和响应的来源工具&#xff0c;是一个在JDK基础类库是做了更好的封装的类库。 HttpClient 使用了连接池技术来管理 TCP 连接&#xff0c;这有助于提高性能并减少资源消耗。连接池允许 HttpClient 复用已经建立的连接&#xff0…

Golang 并发机制-6:掌握优雅的错误处理艺术

并发编程可能是提高软件系统效率和响应能力的一种强有力的技术。它允许多个工作负载同时运行&#xff0c;充分利用现代多核cpu。然而&#xff0c;巨大的能力带来巨大的责任&#xff0c;良好的错误管理是并发编程的主要任务之一。 并发代码的复杂性 并发编程增加了顺序程序所不…

react使用DatePicker日期选择器

1、引入&#xff1a;npm i day 2、页面引入&#xff1a; import dayjs from dayjs; 3、使用 <DatePicker onChange{onChange} value{datas ? dayjs(datas) : null} /> 4、事件 const onChange (date, dateString) > {setInput(dateString)setDatas(dateString)…

深度学习 Pytorch 基础网络手动搭建与快速实现

为了方便后续练习的展开&#xff0c;我们尝试自己创建一个数据生成器&#xff0c;用于自主生成一些符合某些条件、具备某些特性的数据集。 导入相关的包 # 随机模块 import random# 绘图模块 import matplotlib as mpl import matplotlib.pyplot as plt# 导入numpy import nu…

保姆级教程Docker部署KRaft模式的Kafka官方镜像

目录 一、安装Docker及可视化工具 二、单节点部署 1、创建挂载目录 2、运行Kafka容器 3、Compose运行Kafka容器 4、查看Kafka运行状态 三、集群部署 四、部署可视化工具 1、创建挂载目录 2、运行Kafka-ui容器 3、Compose运行Kafka-ui容器 4、查看Kafka-ui运行状态 …

51单片机看门狗系统

在 STC89C52 单片机中&#xff0c;看门狗控制寄存器的固定地址为 0xE1。此地址由芯片厂商在硬件设计时确定&#xff0c;但是它在头文件中并未给出&#xff0c;因此在使用看门狗系统时需要声明下这个特殊功能寄存器 sfr WDT_CONTR 0xE1; 本案将用一个小灯的工作状况来展示看门…

RNN/LSTM/GRU 学习笔记

文章目录 RNN/LSTM/GRU一、RNN1、为何引入RNN&#xff1f;2、RNN的基本结构3、各种形式的RNN及其应用4、RNN的缺陷5、如何应对RNN的缺陷&#xff1f;6、BPTT和BP的区别 二、LSTM1、LSTM 简介2、LSTM如何缓解梯度消失与梯度爆炸&#xff1f; 三、GRU四、参考文献 RNN/LSTM/GRU …

Android记事本App设计开发项目实战教程2025最新版Android Studio

平时上课录了个视频&#xff0c;从新建工程到打包Apk&#xff0c;从头做到尾&#xff0c;没有遗漏任何实现细节&#xff0c;欢迎学过Android基础的同学参加&#xff0c;如果你做过其他终端软件开发&#xff0c;也可以学习&#xff0c;快速上手Android基础开发。 Android记事本课…

【R语言】获取数据

R语言自带2种数据存储格式&#xff1a;*.RData和*.rds。 这两者的区别是&#xff1a;前者既可以存储数据&#xff0c;也可以存储当前工作空间中的所有变量&#xff0c;属于非标准化存储&#xff1b;后者仅用于存储单个R对象&#xff0c;且存储时可以创建标准化档案&#xff0c…

为什么在springboot中使用autowired的时候它黄色警告说不建议使用字段注入

byType找到多种实现类导致报错 Autowired: 通过byType 方式进行装配, 找不到或是找到多个&#xff0c;都会抛出异常 我们在单元测试中无法进行字段注入 字段注入通常是 private 修饰的&#xff0c;Spring 容器通过反射为这些字段注入依赖。然而&#xff0c;在单元测试中&…

Unity游戏(Assault空对地打击)开发(6) 鼠标光标的隐藏

前言 鼠标光标在游戏界面太碍眼了&#xff0c;要隐藏掉。 详细操作 新建一个脚本HideCursor&#xff0c;用于隐藏/取消隐藏光标。 写入以下代码。 意义&#xff1a;游戏开始自动隐藏光标&#xff0c;按Esc&#xff08;隐藏<-->显示&#xff09;。 using System.Collectio…

哪些专业跟FPGA有关?

FPGA产业作为近几年新兴的技术领域&#xff0c;薪资高、待遇好&#xff0c;吸引了大量的求职者。特别是对于毕业生&#xff0c;FPGA领域的岗位需求供不应求。那么&#xff0c;哪些专业和FPGA相关呢&#xff1f; 哪些专业跟FPGA有关&#xff1f; 微电子学与固体电子学、微电子科…

UE5 蓝图学习计划 - Day 14:搭建基础游戏场景

在上一节中&#xff0c;我们 确定了游戏类型&#xff0c;并完成了 项目搭建、角色蓝图的基础设置&#xff08;移动&#xff09;。今天&#xff0c;我们将进一步完善 游戏场景&#xff0c;搭建 地形、墙壁、机关、触发器 等基础元素&#xff0c;并添加角色跳跃功能&#xff0c;为…