机器学习课程学习周报八

news2024/12/29 1:58:53

机器学习课程学习周报八

文章目录

  • 机器学习课程学习周报八
    • 摘要
    • Abstract
    • 一、机器学习部分
      • 1.1 self-attention的计算量
      • 1.2 人类理解代替自注意力计算
        • 1.2.1 Local Attention/Truncated Attention
        • 1.2.2 Stride Attention
        • 1.2.3 Global Attention
        • 1.2.4 聚类Query和Key
      • 1.3 自动选择自注意力计算
      • 1.4 Attention Matrix中的线性组合
      • 1.5 通过矩阵乘法推导自注意力计算
      • 1.6 Batch Normalization
    • 总结

摘要

本周的学习重点是自注意力机制的计算优化。我探讨了如何通过Local Attention、Stride Attention、Global Attention等方法减少计算量。此外,还介绍了自动选择注意力计算和Attention Matrix的线性组合方法。最后,补充了Batch Normalization的知识,为模型训练提供了更好的稳定性。

Abstract

This week’s focus is on optimizing the computation of the self-attention mechanism. I explored methods like Local Attention, Stride Attention, and Global Attention to reduce computational load. Additionally, we discussed automatic selection of attention computation and linear combinations in the Attention Matrix. Lastly, we supplemented our understanding with Batch Normalization, enhancing model training stability.

一、机器学习部分

1.1 self-attention的计算量

请添加图片描述
如果现在自注意力模型输入的序列长度为 N N N,则对应的Query为 N N N个,对应的Key也为 N N N个。它们之间相互计算关联性(即注意力分数),可以得到上图中的Attention Matrix,这个矩阵的复杂度是 N 2 {N^2} N2,当 N N N的数值很大时,该矩阵的计算量就会变得很大。因此,这一节介绍多种方法以加速计算Attention Matrix的计算。

Notice:当 N N N很大时,self-attention的计算才会主导整个模型中计算量。例如:在Transformer模型中,除了self-attention还有其他模块的计算量,self-attention模块的计算量占模型整体计算量是与 N N N有关的,当 N N N过小时,对self-attention的改进计算并不会明显提高Transformer模型的运算速度。

1.2 人类理解代替自注意力计算

根据人类对问题的理解,对Attention Matrix某些位置的值直接赋值,跳过计算步骤,从而减少计算量。

1.2.1 Local Attention/Truncated Attention

计算self-attention时,并非计算整个序列间的self-attention分数,而是只看自己和左右的邻居,其他的关联性都设定为0。下图在Attention Matrix中,表示为灰色的部分都人工设定为0,只计算蓝色部分的self-attention分数。这种方法叫做Local Attention或Truncated Attention。
请添加图片描述

Local Attention与CNN较为相似,主要体现在它们的局部关注机制上。这种机制使得模型在处理输入数据时,只关注输入数据的局部区域,而不是整体。卷积神经网络(CNN)中,卷积层通过滑动窗口的方式在输入数据上提取特征。这种操作也可以看作是一种局部关注机制,通过卷积核仅关注输入数据的局部区域来提取特征。Local attention相比于之前介绍的包含全序列的注意力,更加注重输入数据的局部关系,与卷积核的滑动也很类似。

1.2.2 Stride Attention

根据自己对问题的理解,计算局部的self-attention并不一定是左右邻居,如下图,可以是分别计算序列中两步前或两步后的关联性,也可以是分别计算序列中一步前或一步后的关联性,灰色的地方设定为0。这种方法叫做Stride Attention。

在这里插入图片描述

1.2.3 Global Attention

前面介绍的方法都是以某一个位置为中心,分别计算左右的关联性。Global Attention注重于整个序列,其会添加特殊的token到原始的序列中,特殊的token分别与整个序列计算self-attention,具体做法有两种:

  • 从原来的token序列中,选择一部分作为特殊的token。
  • 外加一部分额外的token。

在这里插入图片描述

从上图的Attention Matrix观察得到,在原始的序列中,第一和第二个位置被选择为特殊的token。从横轴的角度看,第一和第二个位置的Query与整个序列的Key分别做了self-attention。从纵轴的角度看,序列每一个位置的Query都与第一和第二位置的Key做了self-attention。灰色的位置设定为0。

在这里插入图片描述

在Big Bird中提出了Random attention并且将其与前面的Local Attention和Global Attention一并融合。

1.2.4 聚类Query和Key

在这里插入图片描述

第一步,根据相似度聚类Query和Key,上图中根据不同颜色聚类为了4类。

在这里插入图片描述

第二步,相同类之间的Query和Key才做self-attention。

1.3 自动选择自注意力计算

在这里插入图片描述

通过神经网络学习出一个0-1矩阵,深色位置代表1,浅色位置代表0。只有深色位置计算self-attention,浅色位置不计算。

在这里插入图片描述

输入序列中的每一个位置都通过一个神经网络产生一个长度为 N N N的向量,然后将这些向量拼起来得到大小为 N × N N \times N N×N的矩阵。然而现在这个由向量拼成得到的矩阵中的值,是连续值,要转换为0-1矩阵,这一部分是可以微分的,所以可以通过学习得到,具体需要看Sinkhorn Sorting Network的论文。

1.4 Attention Matrix中的线性组合

计算Attention Matrix的Rank(秩),得到Low Rank,说明该矩阵的很多列是其它列的线性组合。由此可得,实际上并不需要 N × N N \times N N×N的矩阵,目前 N × N N \times N N×N的矩阵中包含很多重复的信息,也许可以通过减少Attention Matrix的大小(主要是列数量)实现减少运算量。

在这里插入图片描述

选择具有代表性的Key,得到K个Key,即得到大小为 N × K N \times K N×K的Attention Matrix。接下来考虑self-attention这一层的输出,同样地要从N个Value中挑出具有代表性的K个Value,一个Key对应一个Value向量。然后用Value矩阵乘上Attention Matrix可以得到self-attention层的输出。

为什么我们不能挑出K个代表的Query呢?

输出序列的长度与Query的数量是一致的,如果减少Query的数量,输出序列的长度就会变短。

挑选具有代表性的Key的方法为:

卷积降维和线性组合(K个向量是N个向量的K种线性组合,下图右)

在这里插入图片描述

1.5 通过矩阵乘法推导自注意力计算

在这里插入图片描述

简要复习一下自注意力机制的矩阵计算过程:第一步,输入序列分别做三种不同的变换,得到 d × N d \times N d×N大小的Query和 d × N d \times N d×N大小的Key,其中 d d d是Query和Key的维度, N N N代表序列的长度。并得到 d ′ × N d' \times N d×N大小的Value,其中特别用 d ′ d' d表示Value的维度,是因为Value的维度可以与Query、Key不一样。第二步, K T {K^{\rm T}} KT乘上 Q Q Q得到Attention Matrix,然后通过softmax做归一化。第三步,用 V V V乘上归一化后的Attention Matrix( A ′ A' A)得到自注意力层的输出 O O O

在这里插入图片描述

如果我们先忽略softmax的操作,self-attention的计算方法就是上图中第一行的计算过程,现在考虑第二行运算,先算 V V V乘上 K T {K^{\rm T}} KT的结果,再乘上 Q Q Q,这样的计算顺序与第一行有何不同?得到的结果是一样的,运算量是不一样的。

请添加图片描述

尽管 A ( C P ) = ( A C ) P A\left( {CP} \right) = \left( {AC} \right)P A(CP)=(AC)P,但是第一种计算方式的计算量是 1 0 6 {10^6} 106,第二种计算方式的计算量的 1 0 3 {10^3} 103,两者计算量之间的差异很大。因此我们这里先忽略softmax操作,考虑self-attention中矩阵计算的改进。

请添加图片描述

根据上图证明, V ( K T Q ) V({K^{\rm T}}Q) V(KTQ)的计算量通常大于 ( V K T ) Q (V{K^{\rm T}})Q (VKT)Q的计算量。

接下来加入softmax,写出计算self-attention的数学表达式:

请添加图片描述

下面通过数学证明的角度说明更换矩阵乘法顺序,计算self-attention的过程:

请添加图片描述

还有一个问题是, exp ⁡ ( q ⋅ k ) ≈ Φ ( q ) ⋅ Φ ( k ) \exp (q \cdot k) \approx \Phi (q) \cdot \Phi (k) exp(qk)Φ(q)Φ(k)是如何实现的,具体需要参考下面的论文。

请添加图片描述

1.6 Batch Normalization

在Transformer的编码器中使用到了Layer Normalization,在上一周的周报中并将其与Batch Normalization做了比较,这里特别补充Batch Normalization的知识。

请添加图片描述

做标准化的原因是,希望能把不同维度的特征值规范到同样的数值范围,从而使得error surface比较平滑,更好训练。

请添加图片描述

Batch Normalization是对不同特征向量的同一维度,计算平均值和标准差,然后将特征值减去平均值再除以标准差,实现标准化。标准化后,同一维度上的数值的平均值是0,方差是1,接近高斯分布。

请添加图片描述

在神经网络中,输入特征 x ~ 1 {\tilde x^1} x~1 x ~ 2 {\tilde x^2} x~2 x ~ 3 {\tilde x^3} x~3已经做过了标准化,在经过 W 1 {W^1} W1层后,且输入 W 2 {W^2} W2层之前仍需要做标准化。至于是对激活函数前的 z 1 {z^1} z1 z 2 {z^2} z2 z 3 {z^3} z3还是之后的 a 1 {a^1} a1 a 2 {a^2} a2 a 3 {a^3} a3做标准化,差别不是很大。以 z 1 {z^1} z1 z 2 {z^2} z2 z 3 {z^3} z3为例, z 1 {z^1} z1 z 2 {z^2} z2 z 3 {z^3} z3都是向量,做标准化的方法如下:

请添加图片描述

μ = 1 3 ∑ i = 1 3 z i \mu = \frac{1}{3}\sum\limits_{i = 1}^3 {{z^i}} μ=31i=13zi是对向量 z i {z^i} zi中对应元素进行相加,然后取平均。 σ = 1 3 ∑ i = 1 3 ( z i − μ ) 2 \sigma = \sqrt {\frac{1}{3}\sum\limits_{i = 1}^3 {{{\left( {{z^i} - \mu } \right)}^2}} } σ=31i=13(ziμ)2 是向量 z i {z^i} zi μ \mu μ相减,然后逐元素平方,求和平均后,再对向量的逐元素开根号。如果直接看公式会有一些歧义,因为 z i {z^i} zi μ \mu μ σ \sigma σ都是向量,其中的求和,平方,开根号都是对向量中逐元素操作。最后标准化公式为:

z ~ i = z i − μ σ {{\tilde z}^i} = \frac{{{z^i} - \mu }}{\sigma } z~i=σziμ

实际上,GPU的内存不足以把整个dataset的数据一次性加载。因此,只考虑一个batch中的样本,对一个batch中的样本做Batch Normalization。在inference中,不可能等到整个batch数量的输入才做推理,具体方法为:在训练时计算 μ \mu μ σ \sigma σ的moving average,训练时的第一个batch为 μ 1 {\mu^1} μ1,第二个batch为 μ 1 {\mu^1} μ1,直到第t个batch为 μ t {\mu^t} μt,且不断地计算moving average:

μ ˉ ← p μ ˉ + ( 1 − p ) μ t \bar \mu \leftarrow p\bar \mu + \left( {1 - p} \right){\mu ^t} μˉpμˉ+(1p)μt

inference中标准化的公式变为:

z ~ i = z i − μ ˉ σ ˉ {{\tilde z}^i} = \frac{{{z^i} - \bar \mu }}{{\bar \sigma }} z~i=σˉziμˉ

总结

通过本周的学习,我对自注意力机制的优化策略有了更深入的了解,不同的注意力方法提供了多样化的计算选择,有助于提高模型的效率。下周还会围绕自注意力机制进行拓展学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2050340.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用哪种方式可以将 MATLAB 算法转换到FPGA中运行?

FPGA在进行相关算法计算时,一般都会使用高级语言进行算法验证,目前比较常见的就是 MATLAB ,那么使用哪种方式可以将MATLAB中实现的算哒转换到FPGA中? 目前可以通过多种方式在 FPGA 中实现算法。 Simulink HDL Coder MathWorks 提供…

Keepalived学习

环境准备:两台服务器,两台客户机,关闭火墙和selinux 在两台主机上安装ka yum install keepalived -y 开启软件 keepalived配置 进入文件 vim /etc/keepalived/keepalived.conf 修改配置 配置slave 效果 在另一台路由配置 抢占模式和非…

UE基础 —— 项目设置

目录 访问项目设置 类别和分段 Project Game Engine Editor Platforms Plugins 通过 项目设置(Project Settings),可以配置影响以下内容: 虚幻引擎项目;引擎在运行项目时的行为;项目如何在特定平台…

JavaEE 第13节 synchronized关键字基本实现原理

目录 synchronized的基本特点: synchronized关键字的底层实现: 1)锁升级 2)锁消除 3)锁粗化 synchronized的基本特点: 以下特点只考虑(jdk1.8): 1)刚开始…

高可用集群keep-alive

keepalive简介 keepalive为LVS应用延伸的高可用服务。lvs的调度器无法做高可用。但keepalive不是为lvs专门集群服务的,也可以为其他的的代理服务器做高可用。 keepalive在lvs的高可用集群,主调度器和备调度器(可以有多个) 一主两备或一主一备。 VRRP: k…

Windows下枚举USB设备信息Demo

目录 1 简介 1.1 设备接口类 1.2 枚举设备信息原理 2 SetupDi系列函数介绍 2.1 SetupDiGetClassDevs 2.2 SetupDiEnumDeviceInfo 2.3 SetupDiGetDeviceRegistryProperty 2.4 SetupDiGetDeviceRegistryProperty 3 演示Demo 3.1 开发环境 3.2 功能介绍 3.3 下载地址 …

70 爬楼梯

解题思路一:(动态规划) \qquad 假设F(n)返回的是爬n阶的所有方法个数,由题可知,每次可以爬1-2级台阶,那么可以得到: \qquad \qquad \qquad \qquad \qquad F(n) F(n - 1) F(n - 2) \qquad 我…

WeTab AI桌面端的下载安装

wetab AI的使用很方便,收费也不高,专业版的最新版本的AI核心配置如下: 现在推出了桌面端,下载链接:桌面端下载链接 在下载页面点击windows(Beta版): 下载并安装,桌面上就…

DRF组件讲解

DRF组件 1. Web应用模式 在开发Web应用中,有两种应用模式: 前后端不分离[客户端看到的内容和所有界面效果都是由服务端提供出来的。 前后端分离【把前端的界面效果(html,css,js分离到另一个服务端,python服务端只需…

LLM agentic模式之工具使用: Toolformer、CoA、MM-React思路

Toolformer Toolformer出自2023年2月Meta上传的论文《Toolformer: Language Models Can Teach Themselves to Use Tools》,它提出了一种通过自监督训练的方式来让模型决定调哪个API什么时候调用。 API调用的表示:为了让模型去能够调用API,将…

实现随机地牢与摄像机追随与拖拽

//author bilibili 民用级脑的研发记录 // 开发环境 小熊猫c 2.25.1 raylib 版本 4.5 // 2024-7-14 // AABB 碰撞检测 在拖拽,绘制,放大缩小中 // 2024-7-20 // 直线改每帧打印一个点,生长的直线,直线炮弹 // 2024-8-4 // 实现敌…

JavaScript高级程序设计 -- -- 观后记录

一、什么是 JavaScript 1、JavaScript 实现 完整的 JavaScript 实现包含以下几个部分: -- --  核心(ECMAScript)  文档对象模型(DOM)  浏览器对象模型(BOM) 2、DOM 文档对象模型&#…

橙色简洁大气体育直播自适应模板赛事直播门户自适应网站源码

源码名称:酷黑简洁大气体育直播自适应模板赛事直播门户网站 源码开发环境:帝国cms 7.5 安装环境:phpmysql 带采集,可以挂着电脑上自动采集发布,无需人工操作! 橙色简洁大气体育直播自适应模板赛事直播门户…

广州必看自闭症康复机构十大排名名单出炉

在众多为自闭症儿童提供帮助的机构中,星贝育园以其卓越的服务和显著的成效脱颖而出,成功跻身广州必看自闭症康复机构十大排名。 星贝育园在广州、浙江拥有三个校区,为更多的自闭症儿童和家庭带来了希望。这里的特教老师和生活老师不辞辛劳&a…

一次现网redis CPU使用率异常定位

背景 618大促前,运维对系统做巡检时发现redis cpu利用率白天基本保持在72%左右,夜里也在60%以上。担心618流量比平时大,导致redis超负荷,因此找开发进行优化,降低redis的负载。 定位思路 其实资源使用率过高定位都…

大数据技术—— Clickhouse安装

目录 第一章 ClickHouse入门 1.1 ClickHouse的特点 1.1.1 列式存储 1.1.2 DBMS的功能 1.1.3 多样化引擎 1.1.4 高吞吐写入能力 1.1.5 数据分区与线程级并行 1.1.6 性能对比 第二章 ClickHouse的安装 2.1 准备工作 2.1.1 确定防火墙处于关闭状态 2.1.2 CentOS取消…

Vue UI - 可视化的Vue项目管理器

概述 Vue CLI 3.0 更新后,提供了一套全新的可视化Vue项目管理器 —— Vue UI。所以要想使用它,你的 Vue CL I版本必须要在v3.0以上。 一、启动Vue UI 1.1 环境准备 1.1.1 安装node.js 访问官网(外网下载速度较慢)或 http://nod…

民航管理局无人机运营合格证技术详解

1. 证书定义与意义 民航管理局无人机运营合格证(以下简称“合格证”)是对符合民航法规、规章及标准要求的无人机运营单位或个人进行资质认证的重要证明。该证书旨在确保无人机运营活动的安全、有序进行,保护国家空域安全,维护公众…

电子电气架构 --- 软件定义汽车需要怎么样的EE架构

我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和事,多看一眼都是你的不…

反射型XSS的几种payload

目录 第一种&#xff1a;采用的是urlcode编码 第二种&#xff1a;前面用html实体编码&#xff0c;后面用urlcode编码 第三种&#xff1a;只对&#xff1a;使用urlcode编码 第四种&#xff1a;对<>进行html实体编码 第五种&#xff1a;textarea 第六种&#xff1a;和…