论文及其创新点学习cvpr2022 On the Integration of Self-Attention and Convolution

news2024/10/25 13:21:26

代码地址

https://github.com/LeapLabTHU/ACmix

 https://gitee.com/mindspore/models

论文创新点,将注意力机制 和卷积 相结合

# encoding: utf-8
'''
@author: duhanyue
@start time: 2024/10/13 10:04
'''

import torch
import torch.nn as nn
def position(H, W, is_cuda=True):
    if is_cuda:
        loc_w = torch.linspace(-1.0, 1.0, W).cuda().unsqueeze(0).repeat(H, 1)
        loc_h = torch.linspace(-1.0, 1.0, H).cuda().unsqueeze(1).repeat(1, W)
    else:
        loc_w = torch.linspace(-1.0, 1.0, W).unsqueeze(0).repeat(H, 1)
        loc_h = torch.linspace(-1.0, 1.0, H).unsqueeze(1).repeat(1, W)
    loc = torch.cat([loc_w.unsqueeze(0), loc_h.unsqueeze(0)], 0).unsqueeze(0)
    return loc


def stride(x, stride):
    b, c, h, w = x.shape
    return x[:, :, ::stride, ::stride]

def init_rate_half(tensor):
    if tensor is not None:
        tensor.data.fill_(0.5)

def init_rate_0(tensor):
    if tensor is not None:
        tensor.data.fill_(0.)
class ACmix(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1):
        super(ACmix, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.head = head
        self.kernel_att = kernel_att
        self.kernel_conv = kernel_conv
        self.stride = stride
        self.dilation = dilation
        self.rate1 = torch.nn.Parameter(torch.Tensor(1))
        self.rate2 = torch.nn.Parameter(torch.Tensor(1))
        self.head_dim = self.out_planes // self.head

        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv2 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv_p = nn.Conv2d(2, self.head_dim, kernel_size=1)

        self.padding_att = (self.dilation * (self.kernel_att - 1) + 1) // 2
        self.pad_att = torch.nn.ReflectionPad2d(self.padding_att)
        self.unfold = nn.Unfold(kernel_size=self.kernel_att, padding=0, stride=self.stride)
        self.softmax = torch.nn.Softmax(dim=1)

        self.fc = nn.Conv2d(3 * self.head, self.kernel_conv * self.kernel_conv, kernel_size=1, bias=False)
        self.dep_conv = nn.Conv2d(self.kernel_conv * self.kernel_conv * self.head_dim, out_planes,
                                  kernel_size=self.kernel_conv, bias=True, groups=self.head_dim, padding=1,
                               stride=stride)

        self.reset_parameters()
    def reset_parameters(self):
        init_rate_half(self.rate1)
        init_rate_half(self.rate2)
        kernel = torch.zeros(self.kernel_conv * self.kernel_conv, self.kernel_conv, self.kernel_conv)
        for i in range(self.kernel_conv * self.kernel_conv):
            kernel[i, i // self.kernel_conv, i % self.kernel_conv] = 1.
        kernel = kernel.squeeze(0).repeat(self.out_planes, 1, 1, 1)
        self.dep_conv.weight = nn.Parameter(data=kernel, requires_grad=True)
        self.dep_conv.bias = init_rate_0(self.dep_conv.bias)

    def forward(self, x):
        q, k, v = self.conv1(x), self.conv2(x), self.conv3(x)
        scaling = float(self.head_dim) ** -0.5
        b, c, h, w = q.shape
        h_out, w_out = h // self.stride, w // self.stride

        # ### att
        # ## positional encoding
        pe = self.conv_p(position(h, w, x.is_cuda))

        q_att = q.view(b * self.head, self.head_dim, h, w) * scaling
        k_att = k.view(b * self.head, self.head_dim, h, w)
        v_att = v.view(b * self.head, self.head_dim, h, w)

        if self.stride > 1:
            q_att = stride(q_att, self.stride)
            q_pe = stride(pe, self.stride)
        else:
            q_pe = pe

        unfold_k = self.unfold(self.pad_att(k_att)).view(b * self.head, self.head_dim,
                                                         self.kernel_att * self.kernel_att, h_out,
                                                         w_out)  # b*head, head_dim, k_att^2, h_out, w_out
        unfold_rpe = self.unfold(self.pad_att(pe)).view(1, self.head_dim, self.kernel_att * self.kernel_att, h_out,
                                                        w_out)  # 1, head_dim, k_att^2, h_out, w_out

        att = (q_att.unsqueeze(2) * (unfold_k + q_pe.unsqueeze(2) - unfold_rpe)).sum(
            1)  # (b*head, head_dim, 1, h_out, w_out) * (b*head, head_dim, k_att^2, h_out, w_out) -> (b*head, k_att^2, h_out, w_out)
        att = self.softmax(att)

        out_att = self.unfold(self.pad_att(v_att)).view(b * self.head, self.head_dim, self.kernel_att * self.kernel_att,
                                                        h_out, w_out)
        out_att = (att.unsqueeze(1) * out_att).sum(2).view(b, self.out_planes, h_out, w_out)

        ## conv
        f_all = self.fc(torch.cat(
            [q.view(b, self.head, self.head_dim, h * w), k.view(b, self.head, self.head_dim, h * w),
             v.view(b, self.head, self.head_dim, h * w)], 1))
        f_conv = f_all.permute(0, 2, 1, 3).reshape(x.shape[0], -1, x.shape[-2], x.shape[-1])

        out_conv = self.dep_conv(f_conv)

        return self.rate1 * out_att + self.rate2 * out_conv
acmix_model = ACmix(in_planes=64,out_planes=64, kernel_att=7, head=4, kernel_conv=3, stride=1, dilation=1)
x=torch.randn(16,64,64,44)
out=acmix_model(x)
x=x+out
print(out.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2210772.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

邮票鉴赏系统| 邮票鉴赏系统平台|基于java和vue的邮票鉴赏系统设计与实现(源码+数据库+文档)

邮票鉴赏系统\ 目录 基于java和vue的邮票鉴赏系统设计与实现 一、前言 二、系统功能设计 三、系统实现 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获取: 博主介绍:✌️大厂码农|毕设布道师,阿里…

用 Gemini Google 生成图片的魔法

用 Gemini Google 生成图片的魔法指南 你是否曾经想过,用一些简单的文本描述来生成一张图片?这听起来像是科幻小说中的魔法,但实际上,这就是 Gemini Google 的魔力!在这篇文章中,我将向你详细介绍如何使用…

【HarmonyOS NEXT】实现页面水印功能

关键词:鸿蒙、水印、Watermark、页面、触摸问题 注:本期文章同样适用 OpenHarmony 的开发 在app开发过程中时常会出现敏感信息页面,为保护信息安全和及时的数据追踪,通常会采用给页面加水印的形式,那么本期文章会介绍…

自回归视觉生成里程碑!比ControlNet 和 T2I-Adapter 快五倍!北大腾讯提出CAR:灵活、高效且即插即用的可控框架

文章链接:https://arxiv.org/pdf/2410.04671 项目链接:https://github.com/MiracleDance/CAR 亮点直击 CAR是首个为自回归模型家族设计的灵活、高效且即插即用的可控框架。CAR基于预训练的自回归模型,不仅保留了原有的生成能力,还…

sherpa-ncnn 语言模型简单对比

在昨天把系统搞崩溃前,对sherpa-ncnn的中文模型做了一个简单的对比。这次使用的分别是sherpa-ncnn-streaming-zipformer-bilingual-zh-en-2023-02-13(以下简称bilingual-zh-en-2023-02-13)和sherpa-ncnn-streaming-zipformer-small-bilingual…

服务器数据恢复—EMC存储RAID5磁盘阵列数据恢复案例

服务器数据恢复环境: 一台EMC某型号存储设备,该存储中有一组由12块(包括2块热备盘)STAT硬盘组建的raid5阵列。 服务器故障: 该存储在运行过程中突然崩溃,raid瘫痪。数据恢复工程师到达现场对故障存储设备进…

GPT联网分析到底有多强?实测效果告诉你答案!

文章目录 零、前言一、gpt-4o操作指导gpt4o 二、感受 零、前言 早上在聊到博主在选择平台时,要选择哪个平台发展。 通过GPT查询并分析了小红书,微信视频号,抖音和B站的用户群体。 由此可举一反三,如何让GPT联网分析,…

部署私有仓库以及docker web ui应用

官方地址:https://hub.docker.com/_/registry/tags 一、拉取registry私有仓库镜像 docker pull registry:latest 二、运⾏容器 docker run -itd -v /home/dockerdata/registry:/var/lib/registry --name "pri_registry1" --restartalways -p 5000:5000 …

如何针对项目中的技术难点准备面试?——黑马点评为例

最核心的,包装和准备 个人项目,怎么包装?一定要写出代码才可以吗? 你可以在系统A中实现就可以,了解其中实现的细节,怎么跟面试官对线等等,这些话术到位了之后,再把它融入到系统B&a…

《CUDA编程》7.全局内存的合理使用

上一章简单的介绍了一下各种内存,本章开始详细讲解各个内存的合理使用,在所有设备中,全局内存的访问速度最慢,是CUDA程序的一个性能瓶颈,所以值得特别关注 1 全局内存的合并与非合并访问 对全局内存的访问将触发内存事…

LabVIEW如何实现高精度定时器

在LabVIEW中实现高精度定时器通常需要考虑以下几个方面:定时器的精度要求、操作系统的调度机制、硬件资源(如计时器、触发器)等。以下是几种常见的实现方式: ​ 1. 使用 Wait(ms) 或 Wait Until Next ms Multiple VI 这两个函数…

【无人机设计与控制】PID_积分滑模_积分反步四旋翼无人机轨迹跟踪控制算法

摘要 本文基于四旋翼无人机设计与控制,提出了一种结合PID控制、积分滑模控制以及积分反步控制的轨迹跟踪算法。该算法通过调节无人机的运动轨迹,提升其在复杂环境下的稳定性与抗扰动能力。实验结果表明,该算法能有效改善无人机的轨迹跟踪精度…

【python实操】python小程序之计算对象个数、游戏更新分数

引言 python小程序之计算对象个数、游戏更新分数 文章目录 引言一、计算对象个数1.1 题目1.2 代码1.3 代码解释1.3.1 代码结构1.3.2 模块解释1.3.3 解释输出 二、游戏更新分数2.1 题目2.2 代码2.3 代码解释2.3.1 定义 Game 类2.3.2 创建 Game 实例并调用方法 三、思考3.1 计算对…

C++之String类模拟实现(下)

片头 哈喽~小伙伴们,在上一篇中,我们讲解了C的string类的相关函数,这一章中,我们将继续深入学习string类函数,准备好了吗?咱们开始咯~ 五、对内容进行修改 ⑤insert函数 在指定位置插入字符或者字符串 …

基于Raspberry Pi人脸识别自动门

人脸识别自动门 简介 在当今数字化时代,智能家居安全变得越来越重要。今天,我要向大家介绍一个结合了安全性与便利性的项目——人脸识别自动门。这个项目通过在门上实施基于面部识别的高级安全系统,使用摄像头验证房主的面部,自…

重学SpringBoot3-集成Spring Boot Actuator

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Spring Boot Actuator 1. 什么是 Spring Boot Actuator?2. Spring Boot Actuator 的核心功能3. Spring Boot 3 中集成 Actuator3.1 添加…

ElasticSearch是什么?

1.概述 Elasticsearch 是一个基于 Apache Lucene 构建的开源分布式搜索引擎和分析引擎。它专为云计算环境设计,提供了一个分布式的、高可用的实时分析和搜索平台。Elasticsearch 可以处理大量数据,并且具备横向扩展能力,能够通过增加更多的硬…

2014年国赛高教杯数学建模C题生猪养殖场的经营管理解题全过程文档及程序

2014年国赛高教杯数学建模 C题 生猪养殖场的经营管理 某养猪场最多能养10000头猪,该养猪场利用自己的种猪进行繁育。养猪的一般过程是:母猪配种后怀孕约114天产下乳猪,经过哺乳期后乳猪成为小猪。小猪的一部分将被选为种猪(其中公…

20240727 影石 笔试

文章目录 1、选择题1.11.21.31.41.51.61.71.81.91.10 2、简答题2.12.22.32.42.52.62.72.8 3、编程题3.1 岗位:云台嵌入式工程师-2025校招 题型:10 道选择题,8 道简答题,1 道编程题 1、选择题 1.1 【多选】以下关于DMA的描述哪些…

Pytest中fixture含返回值时如何隐式自动应用?

在我们使用 Pytest 框架进行自动化测试时,强大的 fixture 夹具为框架的灵活应用提供了极大的便利。比如我们可以利用 fixture 的autouse属性,使它在测试方法的不同范围层级上自动生效。但如果要引用fixture的返回,我们通常还是要明确指定&…