写在最前面之如何只用nn.Linear实现nn.Conv2d的功能
很多人说,Swin-Transformer就是另一种Convolution,但是解释得真就是一坨shit,这里我郑重解释一下,这是为什么?
首先,Convolution是什么?
Convolution是一种矩形区域内参数共享的Linear
这么说可能不好理解,那么我们上代码
import torch
import torch.nn as nn
import torch.nn.functional as F
class Conv2D(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
"""
为了简单且便于理解,我们设定图片的Size是Kernel_size的整数倍,且Kernel_size等于Stride
"""
super(LinearConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
# 计算权重矩阵的维度
weight_size = in_channels * kernel_size * kernel_size
self.linear = nn.Linear(weight_size, out_channels, bias=False)
def forward(self, x):
# 计算输出特征图的尺寸
B, C, H, W = x.size()
output_height = H // self.stride
output_width = W // self.stride
# 展开输入特征,沿着kernel_size的窗口展开
x_flatten = x.view(B, H // self.kernel_size, self.kernel_size, W // self.kernel_size, self.kernel_size, C)
x_flatten = x_flatten.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.kernel_size, self.kernel_size, C)
# 应用线性变换
output_flatten = self.linear(x_flatten)
# 重塑输出形状
output = output_flatten.view(B, self.out_channels, output_height, output_width)
return output
# 使用nn.Linear实现nn.Conv2d(256, 256, k=7, s=7)
conv2d_manual = Conv2D(256, 256, 7, 7)
# 创建一个随机初始化的输入张量,确保尺寸是7的整数倍
input_tensor = torch.randn(1, 256, 56, 56) # 假设输入图像大小为56x56,56是7的倍数
# 应用卷积操作
output = conv2d_manual(input_tensor)
# 输出形状应为[1, 256, 8, 8]
print(output.shape)
上述代码通过了使用输入数据的维度变换,实现了利用nn.Linear来进行nn.Conv2d的过程,当然,nn.Conv1d甚至nn.Conv3d等也是同样操作。这里我们先记住,后面我们详细解释
Swin-Transformer为什么这么叫
首先,需要理解为什么叫Swin!
作者依然使用了Vision Transformer的主题架构,核心区别是对数据处理的区别!
在Vision Transformer中,数据根据spatial维度进行拉伸,并成为[Batch, HW, C]的样子,如图所示,具体参考Transformer之Vision Transformer结构解读
而在Swin-Transformer中,额外增加了一步,就是把维度为
[
B
a
t
c
h
,
H
×
W
,
C
]
[Batch, H\times W, C]
[Batch,H×W,C]的patch_embedding,进行二次分割,变成
[
B
a
t
c
h
×
n
u
m
_
w
i
n
d
o
w
2
,
w
i
n
d
o
w
_
s
i
z
e
,
w
i
n
d
o
w
_
s
i
z
e
,
C
]
[Batch \times num\_window^2, window\_size, window\_size, C]
[Batch×num_window2,window_size,window_size,C],如图所示,
- 第一张图片就是经过patch_embed的patch_embedding
- 第二张图片就是经过window_partrition分割后的图片
- 第三张图片就是处理成
[
B
a
t
c
h
×
n
u
m
_
w
i
n
d
o
w
2
,
w
i
n
d
o
w
_
s
i
z
e
,
w
i
n
d
o
w
_
s
i
z
e
,
C
]
[Batch \times num\_window^2, window\_size, window\_size, C]
[Batch×num_window2,window_size,window_size,C]的图片
这里还有一个操作,就是在第偶数个Attention-Block中,把输入的patch_embedding进行torch.roll操作,这个操作就是循环位移
这时候就可以解释为什么说Swin-Transformer就是另一种形式的CNN了
从上面的图片中可以看到如下过程: - 一张图片,经过nn.Conv2d(k=patch_size, stride=patch_size),将其分割成 N 2 N^2 N2个patch_embedding
- patch_embedding经过维度重整,从 [ B , H × W , C ] [B, H\times W, C] [B,H×W,C]变成 [ B a t c h × n u m _ w i n d o w 2 , w i n d o w _ s i z e , w i n d o w _ s i z e , C ] [Batch \times num\_window^2, window\_size, window\_size, C] [Batch×num_window2,window_size,window_size,C],然后送入nn.Linear()。这里的维度重整加上nn.Linear(),等于nn.Conv2d,可以通过写在最前面的"如何只用nn.Linear()实现nn.Conv2d的功能"看出
- 上一步可以总结为:经过nn.Conv2d的patch_embedding继续经过若干nn.Conv2d
Swin-Transformer的位置编码
绝对位置编码
详情参考Transformer之位置编码的通俗理解
在patch_embedding过程中,依然将Token和PE相加,如上图二所示。
但是既然有了相对位置编码,为什么还要加上绝对位置编码呢?
- 数学解释如下:
Q
E
+
P
E
×
K
E
+
P
E
T
=
X
E
+
P
E
×
W
q
×
[
X
E
+
P
E
×
W
k
]
T
=
X
E
+
P
E
×
W
q
×
W
k
T
×
X
E
+
P
E
T
=
(
X
q
+
P
E
q
)
×
W
q
×
W
k
T
×
(
X
k
+
P
E
k
)
T
=
X
q
×
W
q
⏞
Q
u
e
r
y
×
W
k
T
×
X
k
T
⏞
K
e
y
⏟
第一项
+
P
E
q
×
W
q
⏞
a
×
W
k
T
×
X
k
T
⏞
K
e
y
⏟
第二项
+
X
q
×
W
q
⏞
Q
u
e
r
y
×
W
k
T
×
P
E
k
T
⏞
b
⏟
第三项
+
P
E
q
×
W
q
⏞
a
×
W
k
T
×
P
E
k
T
⏞
b
⏟
第四项
\begin{array}{ccl} Q_{E+PE} \times K_{E+PE}^T &= & X_{E + PE} \times W_q \times \Big[X_{E + PE} \times W_k \Big]^T \\ && \\ &= & X_{E + PE} \times W_q \times W_k^T \times X^T_{E + PE} \\ && \\ & = &(X_q+PE_q) \times W_q \times W_k^T \times (X_k+PE_k)^T \\ &&\\ &= &\underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第一项}+ \underbrace{ \overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times X_k^T}^{Key}}_{第二项} + \underbrace{\overbrace{X_q \times W_q}^{Query} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第三项} + \underbrace{\overbrace{PE_q \times W_q}^{a} \times \overbrace{W_k^T \times PE^T_k}^{b}}_{第四项} \end{array}
QE+PE×KE+PET====XE+PE×Wq×[XE+PE×Wk]TXE+PE×Wq×WkT×XE+PET(Xq+PEq)×Wq×WkT×(Xk+PEk)T第一项
Xq×Wq
Query×WkT×XkT
Key+第二项
PEq×Wq
a×WkT×XkT
Key+第三项
Xq×Wq
Query×WkT×PEkT
b+第四项
PEq×Wq
a×WkT×PEkT
b
绝对位置编码只能消去第三项和第四项中的d项,依然需要第二项中的a项,才能具有完整的偏置
- 直觉解释如下
如果只有相对位置编码,也就是相当于只有相对位置偏置,这个过程和只有绝对位置偏置的意义是相同的,所以只有同时具有相对位置编码和绝对位置编码,才能避免两者是等效的
相对位置编码
详情参考Transformer之位置编码的通俗理解
相对位置编码,实际上是Attention机制的偏置的位置编码:
A
t
t
=
s
o
f
t
m
a
x
(
Q
×
K
T
D
i
m
+
r
e
l
a
t
i
v
e
_
p
o
s
i
t
i
o
n
_
b
i
a
s
)
×
V
Att = softmax\Big( \frac{Q \times K^T}{\sqrt{Dim}} + relative\_position\_bias\Big) \times V
Att=softmax(DimQ×KT+relative_position_bias)×V
这里受到CSDN图片尺寸的限制,只能发这种清晰度的,点击这里下载无损svg