【代码】Swan-Transformer 代码详解(待完成)

news2024/11/10 0:56:01

1. 局部注意力  Window Attention (W-MSA Module)

class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.

    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # [Mh, Mw]
        print(self.window_size)
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))  # [2*Mh-1 * 2*Mw-1, nH]

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(self.window_size[0])
        coords_w = torch.arange(self.window_size[1])
        coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))  # [2, Mh, Mw]
        coords_flatten = torch.flatten(coords, 1)  # [2, Mh*Mw]
        # [2, Mh*Mw, 1] - [2, 1, Mh*Mw]
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # [2, Mh*Mw, Mh*Mw]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # [Mh*Mw, Mh*Mw, 2]
        relative_coords[:, :, 0] += self.window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += self.window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
        relative_position_index = relative_coords.sum(-1)  # [Mh*Mw, Mh*Mw]
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, mask: Optional[torch.Tensor] = None):
        """
        Args:
            x: input features with shape of (num_windows*B, Mh*Mw, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        # [batch_size*num_windows, Mh*Mw, total_embed_dim]
        B_, N, C = x.shape
        # qkv(): -> [batch_size*num_windows, Mh*Mw, 3 * total_embed_dim]
        # reshape: -> [batch_size*num_windows, Mh*Mw, 3, num_heads, embed_dim_per_head]
        # permute: -> [3, batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        # [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        q, k, v = qkv.unbind(0)  # make torchscript happy (cannot use tensor as tuple)

        # transpose: -> [batch_size*num_windows, num_heads, embed_dim_per_head, Mh*Mw]
        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, Mh*Mw]
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        # relative_position_bias_table.view: [Mh*Mw*Mh*Mw,nH] -> [Mh*Mw,Mh*Mw,nH]
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # [nH, Mh*Mw, Mh*Mw]
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            # mask: [nW, Mh*Mw, Mh*Mw]
            nW = mask.shape[0]  # num_windows
            # attn.view: [batch_size, num_windows, num_heads, Mh*Mw, Mh*Mw]
            # mask.unsqueeze: [1, nW, 1, Mh*Mw, Mh*Mw]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        # @: multiply -> [batch_size*num_windows, num_heads, Mh*Mw, embed_dim_per_head]
        # transpose: -> [batch_size*num_windows, Mh*Mw, num_heads, embed_dim_per_head]
        # reshape: -> [batch_size*num_windows, Mh*Mw, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2058493.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

汽车的UDS诊断01

UDS(Unified Diagnostic Services):ISO14229中定义了汽车通用诊断协议;ISO15765规定了帧的格式; 1)UDS中的四种帧 UDS中的四种帧:单帧、首帧、流空帧、连续帧 图1 …

美团面试题:new Integer(“127“)和Integer.valueOf(“128“)有什么

🍅 作者简介:哪吒,CSDN2021博客之星亚军🏆、新星计划导师✌、博客专家💪 🍅 哪吒多年工作总结:Java学习路线总结,搬砖工逆袭Java架构师 🍅 技术交流:定期更新…

Windosw下Visual Studio2022编译安装VTK(支持QT),ITK

VTK(Visualization Toolkit)是一个开源的、跨平台的三维可视化开发库,用于处理和可视化三维数据。它提供了一系列算法和工具,用于创建、操作和渲染复杂的三维图形,并支持多种数据表示方式,包括点、线、面、…

桔子哥/基于云快充协议1.5版本的充电桩系统软件-充电桩系统 -新能源车充电平台源码

基于云快充协议1.5版本的充电桩系统软件 介绍 SpringBoot 框架,充电桩平台充电桩系统充电平台充电桩互联互通协议云快充协议1.5-1.6协议新能源汽车二轮车公交车二轮车充电-四轮车充电充电源代码充电平台源码Java源码 软件功能 小程序端:城市切换、附…

植物神经紊乱也不怕!吃出好心情,饮食调整秘籍大公开

Hey小伙伴们~👋 今天我们来聊聊一个可能听起来有点陌生但又挺常见的健康问题——植物神经紊乱。是不是有时候感觉心跳加速、呼吸不畅、还容易失眠多梦?别怕,除了专业治疗,饮食调整也是超级重要的一环哦!🍽️…

想要不得痉挛性斜颈?做这六件事!

一、保持良好的坐姿和站姿 长期不正确的姿势会给颈部肌肉带来过大的压力,增加痉挛性斜颈的发病风险。无论是工作还是休息,都要时刻提醒自己保持挺胸抬头、肩膀放松、颈椎正直的姿势。比如,在办公时,调整电脑屏幕的高度和角度&…

2024东湖高新区下半年水测报名开始啦

东湖高新区下半年职称评审水测报名开始啦,报名时间8月3--8月16号,马上报名截止了!! 请想明年拿证的需要先准备论文和软著 中级工程师职称基本评审条件:1、专科及以上学历2、大学理工类专业3、专科工作满七年&#xf…

【FreeRTOS】队列实验-多设备玩游戏(旋转编码器)

目录 0 前言1 任务1.1 本节源码1.2实验目的1.3实现方案 2 code2.1 创建队列2.2 写队列2.3 创建任务 3 勘误 0 前言 学习视频: 【FreeRTOS入门与工程实践 --由浅入深带你学习FreeRTOS(FreeRTOS教程 基于STM32,以实际项目为导向)】…

基于SpringBoot+Vu e.js校园疫情防控系统的设计与实现

文章目录 前言具体实现截图详细视频演示技术栈系统测试为什么选择我官方认证玩家,服务很多代码文档,百分百好评,战绩可查!!入职于互联网大厂,可以交流,共同进步。有保障的售后 代码参考数据库参…

轮换IP与固定IP,如何抉择?

IP地址相信大家都知道,它是标识我们网络身份的重要凭证。从访问网站到数据抓取,都能看到IP地址的身影。那么,轮换IP和固定IP该怎么理解呢?本文将详细介绍这两种IP类型,旨在帮助你根据需求做出合适的选择。 什么是固定I…

VM Ubuntu22.04 ROS2 从头安装

目录 前言安装步骤1 设置编码2 添加ROS2软件源(从哪去下载ros2相关软件)报错解决方法 3 安装报错解决方法1解决方法2 报错 4 设置环境变量5 Ros2 测试Hello World 发送和监听小海龟键盘控制 成功 Hello World 发送和监听界面成功控制小海龟界面 前言 本…

OpenAI发布微调功能 允许企业客户定制AI模型

当地时间8月20日,OpenAI发布了一项新功能,允许企业客户使用他们自己的公司数据来定制其最强大的模型GPT-4o,这将大大提高应用程序的性能和准确性。此举出台之际,初创企业在人工智能(AI)产品上面临着日益激烈…

MySQL 高阶三 (索引性能分析)

执行过程 Explain explain select * from student s, course c , student_coure sc where s.id sc.studentid and c.id sc.courseid;EXPLAIN执行计划各字段含义: 【ld】 id相同,执行顺序从上到下; id不同,值越大,越先执行)。 【select_type…

【论文学习与撰写】快捷搜索指令filetype:pdf,搜索引擎关键词搜索pdf格式文件或者word格式文件。文献搜索方法大全。

1、使用快捷搜索指令 在搜索框中输入:关键词空格filetype英文冒号文件格式 (如:关键词 filetype:pdf)。 通过这种方式,搜索引擎会限定搜索结果只显示 PDF 格式的文件。 比如搜索“2018 年考研英语真题 filetype:pd…

使用 Docker 安装 Ollama 部署本地大模型并接入 One-API

Ollama是一款开源工具,它允许用户在本地便捷地运行多种大型开源模型,包括清华大学的ChatGLM、阿里的千问以及Meta的llama等。目前,Ollama兼容macOS、Linux和Windows三大主流操作系统。本文将介绍如何通过Docker安装Ollama,并将其部…

苹果手机如何备份通讯录?4个方法手把手教你备份

苹果手机通讯录是我们联系亲朋好友的重要工具。然而,如果苹果手机出现损坏或者是丢失的情况,那手机通讯录存储的联系方式也会随之消失。为了避免这种情况的发生,定期备份通讯录变得至关重要,那么,苹果手机如何备份通讯…

基于Nginx进行服务器隐私保护:隐藏真实的服务器IP地址或主机名( 转发代理、服务器的别名)

文章目录 引言I 隐藏站点请求API的真实服务器IP和端口查看主文件配置服务的端口和站点目录的映射配置proxy_pass代理转发代理转发的其他配置【可选】II 服务器主机名处理隐藏真实的服务器主机名判断API请求是哪个服务器处理的III GLC日志中心新增用户信息扩展:在Linux中配置主…

趋势分享|Gartner解读中国企业容器管理新挑战:混合环境、容器安全、AI支持

不少企业都使用容器管理类软件/平台,方便容器环境的部署和运维。而随着应用系统的运行环境逐渐多元化,IT 运维人员仅依靠容器管理产品,已难以同时兼顾多种 IT 基础设施上的多个应用运行环境。同时,AI 等高性能应用场景的兴起&…

探索802.1X:构筑安全网络的认证之盾

在现代网络安全的世界里,有一个极其重要但又常常被忽视的角色,它就是802.1x认证协议。这个协议可以被称作网络安全的守护者,为我们提供了强有力的防护。今天,我们就来深入探讨一下802.1x的原理、应用和测试,看看它是如…

做谷歌seo如何创建良好的用户体验?

Google 希望排名靠前的页面能够为用户提供良好的体验,所以网站提升用户体验很重要。以下是一些实用的小建议,让你的网站更受用户欢迎,并且有助于提升你的 SEO 排名。 现代化设计:确保你的网站设计符合当前的审美和功能趋势。使用高…