【深度学习中的注意力机制9】11种主流注意力机制112个创新研究paper+代码——滑动窗口注意力(Sliding Window Attention)
【深度学习中的注意力机制9】11种主流注意力机制112个创新研究paper+代码——滑动窗口注意力(Sliding Window Attention)
文章目录
- 【深度学习中的注意力机制9】11种主流注意力机制112个创新研究paper+代码——滑动窗口注意力(Sliding Window Attention)
- 1. 滑动窗口注意力的起源与提出
- 2. 滑动窗口注意力的原理
- 3. 滑动窗口注意力的数学表示
- 4. 滑动窗口注意力的发展
- 5. 代码实现
- 6. 代码解释
- 7. 总结
欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz
1. 滑动窗口注意力的起源与提出
滑动窗口注意力(Sliding Window Attention)是针对大型图像和高维特征数据中注意力机制的计算复杂度问题提出的一种解决方案。标准的全局注意力机制需要对每个位置和其他所有位置计算注意力权重,其计算复杂度是 O ( N 2 ) O(N^2) O(N2),这里 N N N 是输入的特征长度。对于大尺寸图像,尤其是在自然语言处理和图像任务中的高分辨率场景,这种复杂度是非常高的。
滑动窗口注意力提出的目标是降低计算复杂度的同时保留局部相关性。它的基本思想是,不再在全局范围内计算每个位置的注意力,而是在局部窗口中进行计算,即将图像或序列分割成多个小的区域(滑动窗口),每个窗口仅计算其内部元素的注意力。
该机制在视觉Transformer(Vision Transformer,ViT)以及类似的变体模型(如Swin Transformer)中被广泛应用。滑动窗口注意力机制在这些模型中能够有效降低全局注意力带来的高计算成本,并保留对局部信息的有效建模。
2. 滑动窗口注意力的原理
滑动窗口注意力的核心思路是将输入特征(如图像)划分为多个局部窗口,每个窗口内的特征计算注意力权重,并在这些窗口间进行滑动,逐步覆盖整个输入的所有特征。这样既能避免全局注意力的高昂计算开销,又能保证局部信息的充分捕获。
滑动窗口注意力主要包括以下几个步骤:
- 窗口划分:输入特征按照固定大小的窗口进行划分,例如对于二维图像,可以划分为多个大小为 M × M M×M M×M 的子块。
- 局部注意力计算:在每个窗口内,应用标准的自注意力机制,即计算窗口内部所有位置的注意力权重。
- 滑动窗口:为避免窗口之间的隔离,可以通过滑动窗口的方法,让不同窗口的边界部分重叠,从而实现跨窗口的信息交流。这一步类似于卷积操作中的滑动窗口机制。
- 结果合并:最后,将所有窗口计算后的结果合并为完整的输出特征。
滑动窗口注意力的优势在于:
- 计算效率高:每个窗口的计算限制在局部范围,避免了全局注意力的二次方复杂度。
- 局部信息捕获:能够充分捕捉局部区域内的重要信息,尤其适合处理图像等局部相关性强的任务。
- 可扩展性强:可以通过调整窗口大小或滑动步长来适应不同的任务需求。
3. 滑动窗口注意力的数学表示
假设输入的特征图为 X ∈ R H × W × C X∈R^{H×W×C} X∈RH×W×C,其中 H H H 为高度, W W W 为宽度, C C C 为通道数。
- 窗口划分: 将输入划分为多个大小为 M × M M×M M×M 的窗口。假设窗口的大小为 M × M M×M M×M,则划分后会得到 H M × W M \frac{H}{M}×\frac{W}{M} MH×MW个窗口,每个窗口的特征为 M × M × C M×M×C M×M×C。
- 局部注意力计算: 对于每个窗口
W
i
W_i
Wi ,计算其内部特征的自注意力。标准的自注意力机制可以表示为:
其中,Q、K、V 分别是查询、键和值矩阵, d k d_k dk是键向量的维度。 - 滑动窗口: 在完成局部窗口的自注意力计算后,通过滑动窗口将不同窗口间的信息交互。滑动步长可以为窗口大小的一部分,例如 M / 2 M/2 M/2,从而使得相邻窗口有重叠部分。
- 合并输出: 将所有窗口的计算结果合并,恢复原始输入的尺寸。
4. 滑动窗口注意力的发展
滑动窗口注意力作为局部注意力机制的一个重要变体,在ViT模型的改进模型中得到了广泛应用。最具代表性的是Swin Transformer,它通过滑动窗口机制和层次化结构,实现了对高分辨率图像的高效处理。Swin Transformer不仅在视觉任务中表现出色,还为后续的图像分割、目标检测等任务提供了有效的解决方案。
5. 代码实现
下面是基于PyTorch的滑动窗口注意力的简单实现代码示例,展示如何在局部窗口内计算注意力并通过滑动窗口机制实现信息交互。
import torch
import torch.nn as nn
class WindowAttention(nn.Module):
def __init__(self, dim, window_size, num_heads):
super(WindowAttention, self).__init__()
self.dim = dim # 输入特征的维度
self.window_size = window_size # 窗口大小
self.num_heads = num_heads # 注意力头的数量
# 定义Query、Key、Value的线性投影层
self.qkv = nn.Linear(dim, dim * 3, bias=False)
self.attn_drop = nn.Dropout(0.1) # 注意力dropout
self.proj = nn.Linear(dim, dim) # 输出的线性投影层
self.proj_drop = nn.Dropout(0.1) # 输出的dropout
self.softmax = nn.Softmax(dim=-1) # softmax函数
def forward(self, x):
# 输入x的形状: [batch_size * num_windows, window_size * window_size, dim]
B_, N, C = x.shape
# 1. 生成Q, K, V矩阵
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] # 分别获取query, key, value
# 2. 计算注意力得分
attn = (q @ k.transpose(-2, -1)) * (1.0 / (C // self.num_heads) ** 0.5)
attn = self.softmax(attn) # 通过softmax获得注意力权重
attn = self.attn_drop(attn) # 加入dropout防止过拟合
# 3. 计算加权后的值
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
# 4. 输出投影
x = self.proj(x)
x = self.proj_drop(x)
return x
# 定义滑动窗口Attention机制
class SwinAttention(nn.Module):
def __init__(self, dim, window_size=7, num_heads=8):
super(SwinAttention, self).__init__()
self.window_size = window_size # 窗口大小
self.attn = WindowAttention(dim, window_size, num_heads) # 局部窗口的注意力
def forward(self, x, H, W):
B, L, C = x.shape
assert L == H * W, "输入特征的大小与给定的高度和宽度不匹配"
# 将输入特征重新reshape为二维图像的形状
x = x.view(B, H, W, C)
# 1. 将特征划分为多个窗口
pad_l = pad_t = 0
pad_r = (self.window_size - W % self.window_size) % self.window_size
pad_b = (self.window_size - H % self.window_size) % self.window_size
x = nn.functional.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
# 2. 划分窗口
x_windows = x.unfold(1, self.window_size, self.window_size).unfold(2, self.window_size, self.window_size)
x_windows = x_windows.contiguous().view(-1, self.window_size * self.window_size, C)
# 3. 对每个窗口应用注意力机制
attn_windows = self.attn(x_windows)
# 4. 恢复窗口后的特征图
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
x = attn_windows.permute(0, 3, 1, 2).contiguous()
return x.view(B, H, W, C)
# 测试滑动窗口注意力机制
if __name__ == "__main__":
B, H, W, C = 2, 32, 32, 96 # batch_size, height, width, channels
x = torch.randn(B, H * W, C)
swin_attn = SwinAttention(dim=C, window_size=8, num_heads=4)
output = swin_attn(x, H, W)
print("输出尺寸:", output.shape)
6. 代码解释
WindowAttention
: 实现了局部窗口内的注意力机制。定义了qkv
线性投影层,生成 Query、Key 和 Value矩阵,并通过自注意力计算局部窗口内的特征交互。SwinAttention
: 实现了滑动窗口的注意力机制。通过unfold
函数将输入特征划分为多个局部窗口,并在每个窗口内应用注意力机制。最后将所有窗口的计算结果合并。- 在测试部分,随机生成了一个大小为 32 × 32 32×32 32×32 的输入特征图,验证了滑动窗口注意力机制的正确性。
7. 总结
滑动窗口注意力有效解决了全局自注意力机制的计算瓶颈问题,特别适用于处理大尺寸图像等高维数据。通过局部窗口的划分和滑动窗口操作,它在保证计算效率的同时,也能有效捕捉局部特征信息。在Swin Transformer等视觉模型中,滑动窗口注意力机制表现出了极高的应用价值。
欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
祝所有的硕博生都能遇到好的导师!好的审稿人!好的同门!顺利毕业!
大多数高校硕博生毕业要求需要参加学术会议,发表EI或者SCI检索的学术论文会议论文:
可访问艾思科蓝官网,浏览即将召开的学术会议列表。会议入口:https://ais.cn/u/mmmiUz