AttentionFreeTransformer 源码解析(一):AFTFull、AFTSimple、AFTLocal

news2024/12/23 19:26:24

我觉得源码写的很好懂,我就不加注释了,直接上计算流程图。

AFTFull

在这里插入图片描述

class AFTFull(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full

        Number of heads is 1 as done in the paper
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)
        self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))
        nn.init.xavier_uniform_(self.wbias)

    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)
        temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length

        '''
        From the paper
        '''
        Q_sig = torch.sigmoid(Q)
        temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)
        weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))
        Yt = torch.mul(Q_sig, weighted)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

AFTSimple

在这里插入图片描述

class AFTSimple(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full
        
        Number of Heads is 1 as done in the paper.
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)

    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)

        '''
        From the paper
        '''
        weights = torch.mul(torch.softmax(K, 1), V).sum(dim=1, keepdim=True)
        Q_sig = torch.sigmoid(Q)
        Yt = torch.mul(Q_sig, weights)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

AFTLocal

在这里插入图片描述

class AFTLocal(nn.Module):
    def __init__(self, max_seqlen, dim, hidden_dim=64, s=256):
        super().__init__()
        '''
        max_seqlen: the maximum number of timesteps (sequence length) to be fed in
        dim: the embedding dimension of the tokens
        hidden_dim: the hidden dimension used inside AFT Full
        s: the window size used for AFT-Local in the paper

        Number of heads is 1 as done in the paper
        '''
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.to_q = nn.Linear(dim, hidden_dim)
        self.to_k = nn.Linear(dim, hidden_dim)
        self.to_v = nn.Linear(dim, hidden_dim)
        self.project = nn.Linear(hidden_dim, dim)
        self.wbias = nn.Parameter(torch.Tensor(max_seqlen, max_seqlen))
        self.max_seqlen = max_seqlen
        self.s = s
        nn.init.xavier_uniform_(self.wbias)


    def forward(self, x):
        B, T, _ = x.shape
        Q = self.to_q(x).view(B, T, self.hidden_dim)
        K = self.to_k(x).view(B, T, self.hidden_dim)
        V = self.to_v(x).view(B, T, self.hidden_dim)
        self.wbias = nn.Parameter(torch.Tensor([
            [self.wbias[i][j] if math.fabs(i-j) < self.s else 0 for j in range(self.max_seqlen)] 
            for i in range(self.max_seqlen)
            ]))
        temp_wbias = self.wbias[:T, :T].unsqueeze(0) # sequences can still be variable length

        '''
        From the paper
        '''
        Q_sig = torch.sigmoid(Q)
        temp = torch.exp(temp_wbias) @ torch.mul(torch.exp(K), V)
        weighted = temp / (torch.exp(temp_wbias) @ torch.exp(K))
        Yt = torch.mul(Q_sig, weighted)

        Yt = Yt.view(B, T, self.hidden_dim)
        Yt = self.project(Yt)

        return Yt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/857495.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

echarts柱状图X轴增加table列表显示数据,多y轴

效果图 完整配置 data(){return{chart1:null,chartType1:1,data:{years:{date:[2015,2016,2017,2018,2019,2020,2021,2022,2023],business:[10,23,26,33,43,58,50,45,66],profit:[3,4,6,7,8,5,7,8,12],proportion:[12,8,15,20,12,16,13,15,9]},months:{date:[1月, 2月,3月, 4月…

在指定的 DSN 中,驱动程序和应用程序之间的体系结构不匹配

1.Cadence 17.2 配置CIS数据库报&#xff1a;ERROR(ORCIS-6245): Database Operation Failed 安装cadance17.2以上版本时&#xff0c;ERROR(ORCIS-6245): Database Operation Failed_收湾湾的博客-CSDN博客 原因是ODBC数据库没有配置&#xff0c;或者没有驱动&#xff0c; 驱…

侯捷 C++ part2 兼谈对象模型笔记——3 模板

3 模板 3.1 类模板/函数模板 补充&#xff1a;只有模板的尖括号中<>&#xff0c;关键字 typename 和 class 是一样的 3.2 成员模板 它即是模板的一部分&#xff0c;自己又是模板&#xff0c;则称为成员模板 其经常用于构造函数 ctor1 这是默认构造函数的实现&#…

阿里云服务器免费申请使用限制条件及云主机配置

阿里云服务器免费试用申请链接入口&#xff0c;阿里云个人用户和企业用户均可申请免费试用&#xff0c;最高可以免费使用3个小时&#xff0c;阿里云服务器网分享阿里云服务器免费试用申请入口链接及云服务器配置&#xff1a; 目录 阿里云服务器免费试用 企业用户免费服务器试…

解决Qt的列表加载大量数据卡顿的问题

问题概述 本人在使用QListView插入大量数据时&#xff0c;界面卡顿十分严重。数据量大概只有上千左右&#xff0c;但是每个Item的内容比较多。当数据不停地插入一段时间后&#xff0c;卡顿到鼠标的移动都有点困难。 解决思路 QListView是典型的MVC思想的产物。界面呈现出来的数…

CorelDRAW(CDR) 2023中文版64位下载新功能教程

CorelDRAW2023&#xff08;简称CDR2023&#xff09;是一款非常专业的图形设计工具&#xff0c;该产品推出了全新的2023版本&#xff0c;在功能和体验上更进一步&#xff0c;最新的填充和透明设备功能可以完全控制任何类型的纹理&#xff0c;适用于网络摄影、印刷项目、艺术、排…

G. Counting Graphs Codeforces Round 891 (Div. 3) 1857G

Problem - G - Codeforces 题目大意&#xff1a;给出一棵n个点的边权树&#xff0c;问有多少个边权最大不超过s的图的最小生成树是这棵树 2<n<2e5;1<w[i]<1e9 思路&#xff1a;对于最小生成树上的每一条边&#xff0c;如果我们在包含这条边的链上的两点之间加了…

SpringBoot 该如何预防 XSS 攻击

XSS 漏洞到底是什么&#xff0c;说实话我讲不太清楚。但是可以通过遇到的现象了解一下。在前端Form表单的输入框中&#xff0c;用户没有正常输入&#xff0c;而是输入了一段代码&#xff1a;</input><img src1 onerroralert1> 这个正常保存没有问题。问题出在了列表…

建设数字化工厂管理系统的必要性有哪些

随着科技的不断发展&#xff0c;数字化已经深入到各个行业和领域&#xff0c;其中包括制造业。数字化工厂管理系统作为一种先进的生产管理模式&#xff0c;能够提高生产效率&#xff0c;降低生产成本&#xff0c;提升企业的竞争力。下面&#xff0c;我们就来探讨一下建设数字化…

Mysql按小时进行分组统计数据

目录 前言 按1小时分组统计 按2小时分组统计 按X小时分组统计 前言 统计数据时这种是最常见的需求场景&#xff0c;今天写需求时发现按2小时进行分组统计也特别简单&#xff0c;特此记录下。 按1小时分组统计 sql&#xff1a; select hour(pass_time) …

自建hexo博客并将原有的文章发布其上

1、保存粘贴到memo9中的博客文章&#xff0c;并将txt转换成word文档 varPowerShellPath, CommandLine: string; // , ScriptPath begin//save to txtMemo9.Lines.SaveToFile(test.txt);memo10.Lines.SaveToFile(txt2word.ps1);//save as docxPowerShellPath : powershell.exe…

Simulink仿真模块 - Math Function

Math Function:执行数学函数 在仿真库中的位置为:Simulink / Math Operations HDL Coder / Math Operations 模型为: 双击模型打开参数设置界面,如图所示: 说明 Math Function 模块可执行许多常见的数学函数。 提示 要执行平方根计算,请使用Sqrt模块。 …

QGIS二次开发四:实现图层列表

在实际开发中我们通常会遇到同时显示多个图层&#xff0c;并且还要实时显示和隐藏各图层的需求&#xff0c;如同 ArcGIS 的图层列表那样&#xff0c;界面左侧显示图层列表&#xff0c;列出当前已加载的所有图层&#xff0c;同时每个图层前面有复选框可以控制图层的显示/隐藏&am…

虾皮运营每天需要做什么?如何处理后台数据?

#shopee#​有很多朋友想做电商&#xff0c;但是对电商运营比较朦胧&#xff0c;不知道电商运营每天到底该做些什么。今天咱们就来解析下&#xff0c;Shopee电商运营每天该做哪些事情一个合格的电商运营&#xff0c;每天都会做好以下几点&#xff1a; 一、查看数据&#xff1a; …

echarts自动轮播tooltip

1. 将自动轮播的工具函数封装到 utils/echart-tooltip-carousel.js /*** echarts自动轮播tooltip* param {Object} chart echart实例* param {Number} num 轮播数量* param {Number} time 轮播间隔时间*/ export function loopShowTooltip(chart, num20, time2000) {let count…

如何配置一个永久固定的公网TCP地址来SSH远程树莓派?

文章目录 如何配置一个永久固定的公网TCP地址来SSH远程树莓派&#xff1f;前置条件命令行使用举例&#xff1a;修改cpolar配置文件 1. Linux(centos8)安装redis数据库2. 配置redis数据库3. 内网穿透3.1 安装cpolar内网穿透3.2 创建隧道映射本地端口 4. 配置固定TCP端口地址4.1 …

IoT 物联网安全事件的持续检测和监控解决方案

对于IoT物联网安全事件的持续检测和监控&#xff0c;可以采用以下解决方案&#xff1a; 设备管理和漏洞修复&#xff1a;确保设备的固件和软件及时更新&#xff0c;并修补已知的漏洞。建立一个设备清单&#xff0c;并定期审查和更新其中的设备。 流量分析和异常检测&#xff1a…

蓝牙资讯|苹果发布AirPods最新开发者固件,最新功能引人关注

苹果面向 AirPods 耳机发布了开发者固件更新 AirPods 6.0 Beta 3&#xff0c;最新内部版本号为 6A5289c。 本次更新适用于 AirPods 3、AirPods Pro 2 和 AirPods Max&#xff0c;不过并非所有 AirPods 耳机都获得了更新。包含5 种新功能 / 新特性&#xff1a; 自适应音频&am…

virtualBox桥接模式下openEuler镜像修改IP地址、openEule修改IP地址、openEule设置IP地址

安装好openEuler后,设置远程登入前,必不可少的一步,主机与虚拟机之间的通信要解决,下面给出详细步骤: 第一步:检查虚拟机适配器模式:桥接模式 第二步:登入虚拟机修改IP cd /etc/sysconfig/network-scripts vim ifcfg-enpgs3 没有vim的安装或者用vi代替:sudo dnf …

【实操干货】如何开始用Qt Widgets编程?(二)

Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写&#xff0c;所有平台无差别运行&#xff0c;更提供了几乎所有开发过程中需要用到的工具。如今&#xff0c;Qt已被运用于超过70个行业、数千家企业&#xff0c;支持数百万设备及应用。 在本文中&#xff0…