动态卷积(轻量级卷积)替代多头自注意力

news2025/1/11 2:27:20

        动态卷积,它比自注意力更简单、更有效。我们仅基于当前时间步长预测单独的卷积核,以确定上下文元素的重要性。这种方法所需的操作数量随输入长度呈线性增长,而自注意力是二次的。在大规模机器翻译、语言建模和抽象摘要上的实验表明,动态卷积比强自注意模型有更好的改进。

1. 引言

        RNN通过在每个时间步更新一个隐藏状态来整合上下文信息,CNN通过多层总结固定大小的上下文,而自关注则直接总结所有上下文。注意力为上下文元素分配注意权重,该权重定义了上下文表示的加权和。源-目标注意从另一个序列(如机器翻译)中总结信息,而自注意力则对当前序列进行操作。自注意力被表述为基于内容的,其中通过比较当前时间步与上下文中的所有元素来计算注意权重。在这种不受限制的背景大小上计算比较的能力被视为自注意力的一个关键特征。

        由于输入长度的二次复杂度,无限上下文大小在计算上非常具有挑战性。此外,在实践中,长序列需要引入层次结构。

        本文中介绍了一种轻量级卷积,它们是深度可分离的,经过softmax归一化,并在通道维度上共享权重。与标准的不可分离卷积相比,轻量级卷积的权重数量减少了几个数量级。不同于自注意力机制,轻量级卷积在处理上下文元素时,无论当前时间步长如何,都重复使用相同的权重。

        动态卷积在轻量级卷积的基础上进行改进,通过在每个时间步长预测不同的卷积核来实现。与自注意力机制需要考虑整个上下文不同,动态卷积的卷积核仅是当前时间步长的函数。动态卷积类似于局部连接层,后者在每个位置的权重会发生变化,但不同之处在于动态卷积的权重是由模型动态生成的,而不是训练后固定的。与基于位置的注意力相似,后者不需要访问上下文来确定注意力权重,但我们并不直接考虑前一个时间步的注意力权重。

        实验表明,轻量级卷积在性能上能够与强大的自注意力机制结果相媲美,而动态卷积的表现甚至更好。

 2. 背景

        序列到序列学习:序列到序列学习是一种将源序列映射到目标序列的方法,通常通过两个独立的网络实现,这在机器翻译等任务中尤为常见。在这个过程中,编码器(Encoder)网络负责计算源序列(如英文句子)的表示,而解码器(Decoder)网络则基于编码器的输出,通过自回归的方式生成目标序列。这种框架允许模型在处理诸如语言翻译、文本摘要等任务时,能够捕捉到源序列和目标序列之间的复杂关系。

        自注意力机制:自注意力机制由Vaswani等人提出,并在Transformer模型中得到了广泛应用。它通过对输入X(X\in R^{n\times d},其中n是时间步的数量,d是输入/输出维度)进行三个投影操作,得到键(K)、查询(Q)和值(V)三种表示。自注意力机制的核心在于它能够同时关注输入序列中的不同位置,通过计算键和查询之间的点积,并对结果进行缩放和softmax归一化,从而得到每个位置的注意力权重。最后,这些权重被用于计算值的加权和,从而得到考虑了序列中所有位置信息的输出表示。自注意力机制通过多个头(Heads)的并行处理,能够捕捉到输入序列中的多种不同特征,并提高了模型处理长序列的能力。

        深度卷积:在每个通道上独立地执行卷积。参数的数量可以从d^2k减少到dk,其中 k 是核宽度。对元素 i 和输出维度 c 进行权值为W\in R^{d\times k}的深度卷积的输出O\in R^{n\times d}定义为:

O_{i,c}=\text{DepthwiseConv}(X,W_{c,:},i,c)=\sum_{j=1}^{k}W_{c,j}\cdot X_{(i+j-\lceil\frac{k+1}{2}\rceil),c}

        对于d个输入通道,我们只需要d * k个参数(每个通道一个卷积核,每个卷积核宽度为k),而不是传统卷积中的d^2 * k个参数 

3. 轻量级卷积(LightConv)

        LightConv,这是一种深度卷积,它共享某些输出通道,其权重在时间维度上使用softmax进行归一化。与自注意力相比,LightConv有一个固定的上下文窗口,它用一组不随时间变化的权重来确定上下文元素的重要性。

LightConv为序列和输出通道 c 中的第 i 个元素计算以下内容:

\text{LightConv}(X,W_{\lceil\frac{cH}{d}\rceil,:},i,c)=\text{DepthwiseConv}(X,\text{softmax}(W_{\lceil\frac{cH}{d}\rceil,:}),i,c) 

W_{\lceil\frac{cH}{d}\rceil,:}        \lceil \frac{cH}{d} \rceil:表示将通道数 c 和特征图高度 H 按比例 d 缩减后向上取整得到的索引。

                        ,:表示选择该索引处的所有列,即取出某个特定位置的整行权重。

i:当前时间步或图像块的索引

c:当前通道索引

3.1 权重共享

        将每个后续的 \frac{d}{H} 通道(即共享的通道数)的参数捆绑在一起,这将参数的数量减少了 \frac{d}{H} 倍数。例如,对于d = 1024和k = 7,一个正则卷积需要7,340,032 (d^2\times k)个权重,一个深度可分离卷积有7,168个权重(d\times k),而对于H = 16的权重共享,只有112个(H\times k)个权重。参数数量的大量减少对于在当前硬件上实现动态卷积至关重要。Wang & Ji提出在所有通道之间共享权重(即H=1)。 

        使用softmax操作将权重W\in R^{H\times K}在卷积核宽度 k 上归一化: 

\mathrm{softmax}(W)_{h,j}=\frac{\exp W_{h,j}}{\sum_{j'=1}^k\exp W_{h,j'}}

        图2b显示了LightConv的模块架构:首先应用从d维到2d维的输入投影映射,然后是门控线性单元(GLU),以及实际的轻量级卷积。GLU通过应用s型单元使用一半的输入作为门,然后计算与其他输入的点积。将大小为W^O\in R^{d\times d}的输出投影应用于LightConv的输出。 

        DropConnect是一种有效的正则化方法(Wan et al, 2013)。具体来说,DropConnect以概率p随机丢弃softmax(W)中每个归一化权重的条目,并在训练过程中将剩余的权重除以1-p以保持权重的总体比例。这种操作实质上是在每个通道内去除部分时间信息,从而防止模型过度拟合训练数据。通过减少模型对特定时间信息的依赖,DropConnect增强了模型的泛化能力,使其能够在未见过的数据上表现更好。

        在尝试实现LightConv时,现有的CUDA卷积原语并不适合用于处理短序列,性能表现不佳。因此,我们采用了一种更快的解决方案来处理短序列。首先,我们将归一化的权重W(形状为H×k)复制并扩展到一个大小为 BH×n×n 的带状矩阵中,其中B是批次大小,n是序列长度(或卷积的输出维度)。接着,将输入数据重新整形并转置为 BH×n×d_H 的形状,其中d_H是输入通道数或特征维度。然后,使用批量矩阵乘法来计算输出。

4. 动态卷积

        动态卷积具有随时间变化的核,作为单个时间步长的学习函数。标准卷积的动态版本对于当前的gpu来说是不切实际的,因为它们需要大量的内存。通过建立LightConv来解决这个问题,大大减少了参数的数量。 

        DynamicConv采用与LightConv相同的形式,但使用时间步长相关的内核,该内核使用函数f:\mathbb{R}^d\to\mathbb{R}^{H\times k}: \mathbf{DynamicConv}(X,i,c)=\mathbf{LightConv}(X,f(X_i)_{h,:},i,c)

        具体来说,DynamicConv 在时间步 i 和通道 c 的情况下等同于使用 f(X_i) 生成的权重的 LightConv。 用一个简单的线性模块来建模函数 f,该模块具有学习到的权重 W^Q\in\mathbb{R}^{H\times k\times d}。具体而言,函数 f(X_i) 的计算方式是 f(X_i)=\sum_{c=1}^{d}W_{h,j,c}^{Q}X_{i,c}​,即通过线性组合输入特征 X_i​ 生成卷积核。

        与自注意力类似,DynamicConv会随时间改变分配给上下文元素的权重。然而,DynamicConv的权重并不依赖于整个上下文,它们只是当前时间步长的函数。自注意力需要在句子长度上进行二次运算来计算注意权值,而DynamicConv的动态核计算在序列长度上呈线性缩放。

5. 模型架构

        在序列到序列学习(sequence to sequence learning)任务中使用的一种编码器-解码器(encoder-decoder)架构。具体来说,这个架构在很大程度上遵循了Transformer模型的设计,同时在一些模块中引入了LightConv和DynamicConv来替代传统的自注意力模块。下面是对这个架构的详细解释:

1. 基本架构

        编码器-解码器架构:该模型使用了编码器-解码器架构,用于序列到序列学习任务。

        Transformer Big:该架构的自注意力基础模型是fairseq重新实现的Transformer Big。

2. 编码器和解码器

        编码器和解码器网络:每个编码器和解码器网络都有N个块(blocks)。

        N的值:通常设定为7(为了大致匹配Transformer Big的参数数量)。

编码器块

        第一个子块:可以是自注意力模块、LightConv模块或DynamicConv模块。

        第二个子块:前馈神经网络模块(Feed-forward module)。

        形式为:ReLU(W^1X+b_1)W^2+b_2\text{ where }W^1\in\mathbb{R}^{d\times d_{ff}},W^2\in\mathbb{R}^{d_{ff}\times d}

        参数:d = 1024,d_{ff}= 4096。

        残差连接和层归一化:每个子块都包裹在残差连接和层归一化中。

解码器块

        结构与编码器块类似,但有一个额外的源-目标注意力子块,位于自注意力和前馈模块之间。

        源-目标注意力:等同于自注意力模块,但值和值键是对每个源词的编码器输出的投影。

3. 嵌入与位置编码

        词嵌入:词嵌入维度为d。

        位置编码:使用正弦位置编码来表示序列中每个词的绝对位置。

4. 词汇分布计算

        词汇分布:通过线性层(权重为W^V\in R^{d\times V})和softmax归一化,将解码器输出转化为词汇V的分布。

5. LightConv和DynamicConv的设置

        替换自注意力模块:LightConv和DynamicConv与Transformer Big相同,只是将自注意力模块替换为固定或动态卷积。

        参数更少:每个块的参数更少,因此增加了编码器块的数量(N = 7)。

        卷积核大小:编码器和解码器的卷积核大小分别设为3, 7, 15, 31;解码器中只有前三层使用31的卷积核大小。

        H的值:一般设置为16。

6. 代码示例

6.1 轻量级卷积 

import torch
import torch.nn as nn
import torch.nn.functional as F

class LightConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(LightConv, self).__init__()
        self.depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, groups=in_channels, padding=kernel_size//2)
        self.pointwise_conv = nn.Conv1d(in_channels, out_channels, 1)
        
    def forward(self, x):
        # 对深度卷积的权重进行softmax归一化
        weight = F.softmax(self.depthwise_conv.weight, dim=-1)
        x = F.conv1d(x, weight, bias=self.depthwise_conv.bias, padding=self.depthwise_conv.padding, groups=self.depthwise_conv.groups)
        
        # 进行逐点卷积
        x = self.pointwise_conv(x)
        
        return x

# 参数设置
in_channels = 3
out_channels = 5
kernel_size = 3

# 创建 LightConv 模块
light_conv = LightConv(in_channels, out_channels, kernel_size)

# 创建一个示例输入
input_tensor = torch.randn(2, in_channels, 10)
output_tensor = light_conv(input_tensor)

print("LightConv output shape:", output_tensor.shape)

6.2 动态卷积

import torch
import torch.nn as nn
import torch.nn.functional as F

class DynamicConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(DynamicConv, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        
        # 深度卷积层
        self.depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, groups=in_channels, padding=kernel_size//2)
        
        # 用于动态生成卷积核权重的线性层
        self.dynamic_weight_generator = nn.Linear(in_channels, in_channels * kernel_size)
        
        # 点卷积层
        self.pointwise_conv = nn.Conv1d(in_channels, out_channels, 1)
        
    def forward(self, x):
        # 动态生成权重
        batch_size, channels, length = x.size()
        dynamic_weight = self.dynamic_weight_generator(x.transpose(1, 2))
        dynamic_weight = dynamic_weight.view(batch_size, channels, self.kernel_size, length)
        dynamic_weight = F.softmax(dynamic_weight, dim=2)
        
        # 使用动态权重进行深度卷积
        x = x.transpose(1, 2)
        output = torch.zeros_like(x)
        for i in range(length):
            output[:, :, i] = F.conv1d(x[:, :, max(0, i - self.kernel_size // 2):i + self.kernel_size // 2 + 1],
                                       dynamic_weight[:, :, :, i].transpose(1, 2), groups=channels)
        x = output.transpose(1, 2)
        
        # 进行逐点卷积
        x = self.pointwise_conv(x)
        
        return x

# 参数设置
in_channels = 3
out_channels = 5
kernel_size = 3

# 创建 DynamicConv 模块
dynamic_conv = DynamicConv(in_channels, out_channels, kernel_size)

# 创建一个示例输入
input_tensor = torch.randn(2, in_channels, 10)
output_tensor = dynamic_conv(input_tensor)

print("DynamicConv output shape:", output_tensor.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1964661.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文阅读笔记 + 思考 + 总结】MoMask: Generative Masked Modeling of 3D Human Motions

创新点: VQ-VAE 👉 Residual VQ-VAE,对每个 motion sequence 输出一组 base motion tokens 和 v 组 residual motion tokensbidirectional 的 Masked transformer 用来生成 base motion tokensResidual Transformer 对 residual motion toke…

机器学习 | 分类算法原理——似然函数

Hi,大家好,我是半亩花海。接着上次的逻辑回归继续更新《白话机器学习的数学》这本书的学习笔记,在此分享似然函数这一分类算法原理。本章的分类算法原理基于《基于图像大小进行分类》项目,欢迎大家交流学习! 目录 一、…

个性化你的生产力工具:待办事项App定制指南

国内外主流的10款待办事项软件对比:PingCode、Worktile、滴答清单、番茄ToDo、Teambition、Todoist、Microsoft To Do、TickTick、Any.do、Trello。 在寻找合适的待办事项软件时,你是否感到选择众多、难以决断?一个好的待办事项工具可以大大提…

stl-algorithm【1】

#include《algorithm》 交换两数swap(x,y) 不只可以交换两个“数”(数据类型) 翻转【借助迭代器】reverse(it1,it2) 仍是左闭右开

国产开源夜莺部署

使用二进制方式部署夜莺 - 快猫星云 (flashcat.cloud) # install mysql yum -y install mariadb* systemctl enable mariadb systemctl restart mariadb mysql -e "SET PASSWORD FOR rootlocalhost PASSWORD(1234);"# install redis yum install -y redis systemctl…

navicat 17 下载安装

百度网盘 通过网盘分享的文件:Navicat17 链接: https://pan.baidu.com/s/1nFFQzWhjxRUM_X6bVlWNGw?pwd8888 提取码: 8888 1.双击运行安装包 2.点击下一步 2.勾选我同意,点击下一步 3.自定义安装路径,点击下一步 4.注意勾选桌面快捷方式&a…

编程新手指南:从入门到精通

编程小白如何成为大神?大学新生的最佳入门攻略 编程已成为当代大学生的必备技能,但面对众多编程语言和学习资源,新生们常常感到迷茫。如何选择适合自己的编程语言?如何制定有效的学习计划?如何避免常见的学习陷阱&…

基于YOLOv8的高压输电线路异物检测系统

基于YOLOv8的高压输电线路异物检测系统 (价格88) 包含 【“鸟窝”,“风筝”,“气球”,“垃圾”】 4个类 通过PYQT构建UI界面,包含图片检测,视频检测,摄像头实时检测。 (该系统可以根据数…

众人帮蚂蚁帮任务平台修复版源码,含搭建教程。

全修复运营版本的任务平台,支持垂直领域细分,定向导流,带有排行榜功能,任务发布上传审核,用户信用等级,充值接口等等均完美可用。支付对接Z支付免签接口,环境配置及安装教程都已经打包。 搭建环…

ARM学习(31)编译器对overlay方式的支持

ARM学习(31)编译器对overlay方式的支持 1、overlay介绍 overlay:重叠得意思,就是可以重复利用得空间,一般在内存上使用这种空间。比如以Windows操作系统为例,其存储空间(ROM/FLASH)…

springboot垂钓服务系统-计算机毕业设计源码17434

摘要 本文旨在针对垂钓爱好者的需求,基于微信小程序平台,设计并实现一套垂钓服务系统。首先,通过对用户需求进行调研和分析,确定了系统的基本功能模块,包括垂钓点信息展示、用户预约和支付、钓具租赁信息等。接着&…

WebView加载数据的几种方式

之前客户端加载H5时遇到了一些问题,我为了方便解决问题,所以将对应场景复刻到了Demo中,从之前的网络加载模拟为了本地加载Html的方式,但是没想到无意被一个基础知识点卡了一些时间,翻看往昔笔记发现未曾记录这种基础场…

【MATLAB源码】机器视觉与图像识别技术(7)续---BP神经网络

系列文章目录在最后面,各位同仁感兴趣可以看看! BP神经网络 第一节、BP网络定义第二节、BP网络结构及其特点第三节、信息传播方式 信息的正向传播:实质是计算网络的输出误差的反向传播:实质是学习过程第四节、 BP网络的算法流程…

python:plotly 网页交互式数据可视化工具

pip install plotly plotly-5.22.0-py3-none-any.whl pip install plotly_express 包含:GDP数据、餐厅的订单流水数据、鸢尾花 Iris数据集 等等 pip show plotly Name: plotly Version: 5.22.0 Summary: An open-source, interactive data visualization librar…

每日OJ_牛客HJ60 查找组成一个偶数最接近的两个素数

目录 牛客HJ60 查找组成一个偶数最接近的两个素数 解析代码 牛客HJ60 查找组成一个偶数最接近的两个素数 查找组成一个偶数最接近的两个素数_牛客题霸_牛客网 解析代码 首先需要判断素数,素数表示除过1和本身,不能被其它数整除。通过循环遍历来判断一…

飞致云开源社区月度动态报告(2024年7月)

自2023年6月起,中国领先的开源软件公司FIT2CLOUD飞致云以月度为单位发布《飞致云开源社区月度动态报告》,旨在向广大社区用户同步飞致云旗下系列开源软件的发展情况,以及当月主要的产品新版本发布、社区运营成果等相关信息。 飞致云开源大屏…

pycharm怎么使用Anaconda和配置

打开Anaconda Prompt 要删除 Conda 环境 yolov5sconda,你可以使用以下命令: conda remove --name yolov5sconda --all这个命令会删除名为 yolov5sconda 的整个环境,包括其中安装的所有包和依赖项。请在命令提示符或终端中运行此命令。执行此…

Java线程池的设计与使用

Java线程池的设计与使用 多线程情景引入 情景分析 请求积压的情况 系统资源受限: 当大量用户请求同时到来时,服务器受限于内存、CPU、和网络带宽等资源,导致用户长时间等待。后端处理能力限制: 如频率限制措施(每秒或每几秒的访问限制&…

嵌入式day15

数组指针 能够指向整个数组 一维数组: &a,考察a的数据类型 int(*p)[10]:表示一个指向长度为10的一维整型数组的指针 二维数组: 指向函数的指针 函数的函数名,即为函数的入口地址&#x…

亲测推荐!PixPin便捷高效,让你的截图工作轻松搞定,还在等什么?

前言 如果你经常使用电脑,是不是也经常遇到这样的烦恼:需要频繁地截图、标注、编辑图片,可是手里的截图工具却总是那么不给力?要么功能单一,要么操作复杂,让人头疼不已;今天咱们的小江湖就要给大…