参考
导师!博主的复现太细了。做个记录。
层神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解
计算机视觉中的transformer模型创新思路总结_Tom Hardy的博客-CSDN博
Vision Transformer详解
ViT
前处理
网络结构
整体思想
目标检测DETR(2020.5)-->分类ViT(2020.10)-->分割SETR(2020.12)-->Swin Transformer(2021.3)-->
transformer用于计算机视觉领域的难点在于序列太长,前面的工作有使用CNN提取特征后再transformer的操作,有在小窗口里操作自注意力,也有对图像长宽分别使用自注意力的机制,都差强人意。自注意力完全取代卷积在CV领域之前也有应用,但直接将transformer用在视觉领域的网络没有出现。
transformer缺少归纳偏置(卷积核的局部性和卷积操作的互不干扰性,滑动平移,这些归纳偏置使CNN有一定的先验信息),需要更大的训练量。
ViT只有分割图像块和位置编码使用了一些图像特有的归纳偏置,这都是尽可能证明NLP领域标准的transformer可以胜任视觉任务。
具体结构
特征提取
(224, 224, 3)-->(14, 14, 768)/(196, 768)/(197, 768)-->(197, 768)
1.Patch
步长16的16*16卷积/高宽维度的平铺+Cls Token(1, 768)
Cls Token会一起进行特征提取。
2.Position Embedding
为所有特征添加上位置信息,这样网络才有区分不同区域的能力。
nn.Parameter()生成可学习的张量(196, 768)与上面的Cls Token cat后得到(197, 768)。再与1得到的张量相加。
3.Transformer Encoder
(1)说明
1)上面的L表示要将Transformer块叠加多少个。
2)attention内部qk相乘后,全连接后会设置dropout。atention外部,mlp外部也会设置dropout,这里的dropout是将输入的特征图像素值全部置于0,并且随着层数的叠加,置0的概率越低(推测这里的操作是破坏掉网络的拟合效果,防止过拟合,这个破坏概率是很低的,可以研究研究源码)。进入encoder之前也设置了dropout。
3)序列长度仅为3,每个单位序列的特征长度仅为3,在VIT的Transformer Encoder中,序列长度为197,每个单位序列的特征长度为768 // num_heads。
(2)内部执行细节
(3)具体模块
Norm
nn.LayerNorm
Multi-Head Attention
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.scale = (dim // num_heads) ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
# batchsize, 197, 768
B, N, C = x.shape
# 通过全连接层扩充维度为3倍,再将维度拆分为num_head份:3(qkv), batchsize, 12(nums_head), 197(patch), 64(768//12)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# 分配:batchsize, 12(nums_head), 197(patch), 64(768//12)
q, k, v = qkv[0], qkv[1], qkv[2]
# q,k矩阵相乘
attn = (q @ k.transpose(-2, -1)) * self.scale
# softmax求每个元素在每个行上的占比是多少
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
# 得到的attn与v矩阵相乘
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
# 进入线性层
x = self.proj(x)
x = self.proj_drop(x)
return x
MLP
2个nn.Linear(),中间的激活函数用GELU
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
drop_probs = (drop, drop)
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
add
注意到中间两条连接线了没有,残差连接。
(一个)patch与所有patch点乘,计算重要程度,再将这个patch反馈的重要程度与所有patch点乘,得到(一个)patch。将(一个)换成所有就是整个自注意力过程。
总结:每个patch处拥有了其他patch相对于该patch的加权和。
分类
(1)说明
(197, 768)-->(, 768)
到这Cls Token要被搞出来了,前面提到的,Cls Token拥有与其他所有patch交互的信息。做全连接就行了啊。完事。
(2)内部执行细节
(3)具体模块