背景:transformer在CV领域的应用
论文下载链接:https://arxiv.org/abs/2010.11929
Pytorch实现代码: pytorch_classification/vision_transformer(太阳花的小绿豆博主实现的代码)
有一些大神在研究关于CNN+transformer或者纯用transformer实现。
原文的摘要说"We show that this reliance on CNNs is not necessary and a pure transformer applied directly to sequences of image patches can perform very well on image classification tasks."(我们展示,这种对 CNN 的依赖是不必要的,直接应用于图像块序列的纯变换器可以很好地执行图像分类任务)
比较具体的内容请看太阳花的小绿豆博主的《Vision Transformer详解》,相关的图片是这个博主的,我这里直接用ONNX的模型结构进行说明,可能更加直观一点(不喜勿碰哈)
VIT整体结构图
VIT形状变化
pytorch的api:summary(model, (3, 224, 224))
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
(1) 前处理
Conv2d-1 [-1, 768, 14, 14] 590,592
Identity-2 [-1, 196, 768] 0
PatchEmbed-3 [-1, 196, 768] 0
Dropout-4 [-1, 197, 768] 0
(2) transformer encoder
block 1
LayerNorm-5 [-1, 197, 768] 1,536
Linear-6 [-1, 197, 2304] 1,771,776
Dropout-7 [-1, 12, 197, 197] 0
Linear-8 [-1, 197, 768] 590,592
Dropout-9 [-1, 197, 768] 0
Attention-10 [-1, 197, 768] 0
Identity-11 [-1, 197, 768] 0
LayerNorm-12 [-1, 197, 768] 1,536
Linear-13 [-1, 197, 3072] 2,362,368
GELU-14 [-1, 197, 3072] 0
Dropout-15 [-1, 197, 3072] 0
Linear-16 [-1, 197, 768] 2,360,064
Dropout-17 [-1, 197, 768] 0
Mlp-18 [-1, 197, 768] 0
Identity-19 [-1, 197, 768] 0
Block-20 [-1, 197, 768] 0
block 2
LayerNorm-21 [-1, 197, 768] 1,536
Linear-22 [-1, 197, 2304] 1,771,776
Dropout-23 [-1, 12, 197, 197] 0
Linear-24 [-1, 197, 768] 590,592
Dropout-25 [-1, 197, 768] 0
Attention-26 [-1, 197, 768] 0
Identity-27 [-1, 197, 768] 0
LayerNorm-28 [-1, 197, 768] 1,536
Linear-29 [-1, 197, 3072] 2,362,368
GELU-30 [-1, 197, 3072] 0
Dropout-31 [-1, 197, 3072] 0
Linear-32 [-1, 197, 768] 2,360,064
Dropout-33 [-1, 197, 768] 0
Mlp-34 [-1, 197, 768] 0
Identity-35 [-1, 197, 768] 0
Block-36 [-1, 197, 768] 0
block 3
LayerNorm-37 [-1, 197, 768] 1,536
Linear-38 [-1, 197, 2304] 1,771,776
Dropout-39 [-1, 12, 197, 197] 0
Linear-40 [-1, 197, 768] 590,592
Dropout-41 [-1, 197, 768] 0
Attention-42 [-1, 197, 768] 0
Identity-43 [-1, 197, 768] 0
LayerNorm-44 [-1, 197, 768] 1,536
Linear-45 [-1, 197, 3072] 2,362,368
GELU-46 [-1, 197, 3072] 0
Dropout-47 [-1, 197, 3072] 0
Linear-48 [-1, 197, 768] 2,360,064
Dropout-49 [-1, 197, 768] 0
Mlp-50 [-1, 197, 768] 0
Identity-51 [-1, 197, 768] 0
Block-52 [-1, 197, 768] 0
block 4
LayerNorm-53 [-1, 197, 768] 1,536
Linear-54 [-1, 197, 2304] 1,771,776
Dropout-55 [-1, 12, 197, 197] 0
Linear-56 [-1, 197, 768] 590,592
Dropout-57 [-1, 197, 768] 0
Attention-58 [-1, 197, 768] 0
Identity-59 [-1, 197, 768] 0
LayerNorm-60 [-1, 197, 768] 1,536
Linear-61 [-1, 197, 3072] 2,362,368
GELU-62 [-1, 197, 3072] 0
Dropout-63 [-1, 197, 3072] 0
Linear-64 [-1, 197, 768] 2,360,064
Dropout-65 [-1, 197, 768] 0
Mlp-66 [-1, 197, 768] 0
Identity-67 [-1, 197, 768] 0
Block-68 [-1, 197, 768] 0
block 5
LayerNorm-69 [-1, 197, 768] 1,536
Linear-70 [-1, 197, 2304] 1,771,776
Dropout-71 [-1, 12, 197, 197] 0
Linear-72 [-1, 197, 768] 590,592
Dropout-73 [-1, 197, 768] 0
Attention-74 [-1, 197, 768] 0
Identity-75 [-1, 197, 768] 0
LayerNorm-76 [-1, 197, 768] 1,536
Linear-77 [-1, 197, 3072] 2,362,368
GELU-78 [-1, 197, 3072] 0
Dropout-79 [-1, 197, 3072] 0
Linear-80 [-1, 197, 768] 2,360,064
Dropout-81 [-1, 197, 768] 0
Mlp-82 [-1, 197, 768] 0
Identity-83 [-1, 197, 768] 0
Block-84 [-1, 197, 768] 0
block 6
LayerNorm-85 [-1, 197, 768] 1,536
Linear-86 [-1, 197, 2304] 1,771,776
Dropout-87 [-1, 12, 197, 197] 0
Linear-88 [-1, 197, 768] 590,592
Dropout-89 [-1, 197, 768] 0
Attention-90 [-1, 197, 768] 0
Identity-91 [-1, 197, 768] 0
LayerNorm-92 [-1, 197, 768] 1,536
Linear-93 [-1, 197, 3072] 2,362,368
GELU-94 [-1, 197, 3072] 0
Dropout-95 [-1, 197, 3072] 0
Linear-96 [-1, 197, 768] 2,360,064
Dropout-97 [-1, 197, 768] 0
Mlp-98 [-1, 197, 768] 0
Identity-99 [-1, 197, 768] 0
Block-100 [-1, 197, 768] 0
block 7
LayerNorm-101 [-1, 197, 768] 1,536
Linear-102 [-1, 197, 2304] 1,771,776
Dropout-103 [-1, 12, 197, 197] 0
Linear-104 [-1, 197, 768] 590,592
Dropout-105 [-1, 197, 768] 0
Attention-106 [-1, 197, 768] 0
Identity-107 [-1, 197, 768] 0
LayerNorm-108 [-1, 197, 768] 1,536
Linear-109 [-1, 197, 3072] 2,362,368
GELU-110 [-1, 197, 3072] 0
Dropout-111 [-1, 197, 3072] 0
Linear-112 [-1, 197, 768] 2,360,064
Dropout-113 [-1, 197, 768] 0
Mlp-114 [-1, 197, 768] 0
Identity-115 [-1, 197, 768] 0
Block-116 [-1, 197, 768] 0
block 8
LayerNorm-117 [-1, 197, 768] 1,536
Linear-118 [-1, 197, 2304] 1,771,776
Dropout-119 [-1, 12, 197, 197] 0
Linear-120 [-1, 197, 768] 590,592
Dropout-121 [-1, 197, 768] 0
Attention-122 [-1, 197, 768] 0
Identity-123 [-1, 197, 768] 0
LayerNorm-124 [-1, 197, 768] 1,536
Linear-125 [-1, 197, 3072] 2,362,368
GELU-126 [-1, 197, 3072] 0
Dropout-127 [-1, 197, 3072] 0
Linear-128 [-1, 197, 768] 2,360,064
Dropout-129 [-1, 197, 768] 0
Mlp-130 [-1, 197, 768] 0
Identity-131 [-1, 197, 768] 0
Block-132 [-1, 197, 768] 0
block 9
LayerNorm-133 [-1, 197, 768] 1,536
Linear-134 [-1, 197, 2304] 1,771,776
Dropout-135 [-1, 12, 197, 197] 0
Linear-136 [-1, 197, 768] 590,592
Dropout-137 [-1, 197, 768] 0
Attention-138 [-1, 197, 768] 0
Identity-139 [-1, 197, 768] 0
LayerNorm-140 [-1, 197, 768] 1,536
Linear-141 [-1, 197, 3072] 2,362,368
GELU-142 [-1, 197, 3072] 0
Dropout-143 [-1, 197, 3072] 0
Linear-144 [-1, 197, 768] 2,360,064
Dropout-145 [-1, 197, 768] 0
Mlp-146 [-1, 197, 768] 0
Identity-147 [-1, 197, 768] 0
Block-148 [-1, 197, 768] 0
block 10
LayerNorm-149 [-1, 197, 768] 1,536
Linear-150 [-1, 197, 2304] 1,771,776
Dropout-151 [-1, 12, 197, 197] 0
Linear-152 [-1, 197, 768] 590,592
Dropout-153 [-1, 197, 768] 0
Attention-154 [-1, 197, 768] 0
Identity-155 [-1, 197, 768] 0
LayerNorm-156 [-1, 197, 768] 1,536
Linear-157 [-1, 197, 3072] 2,362,368
GELU-158 [-1, 197, 3072] 0
Dropout-159 [-1, 197, 3072] 0
Linear-160 [-1, 197, 768] 2,360,064
Dropout-161 [-1, 197, 768] 0
Mlp-162 [-1, 197, 768] 0
Identity-163 [-1, 197, 768] 0
Block-164 [-1, 197, 768] 0
block 11
LayerNorm-165 [-1, 197, 768] 1,536
Linear-166 [-1, 197, 2304] 1,771,776
Dropout-167 [-1, 12, 197, 197] 0
Linear-168 [-1, 197, 768] 590,592
Dropout-169 [-1, 197, 768] 0
Attention-170 [-1, 197, 768] 0
Identity-171 [-1, 197, 768] 0
LayerNorm-172 [-1, 197, 768] 1,536
Linear-173 [-1, 197, 3072] 2,362,368
GELU-174 [-1, 197, 3072] 0
Dropout-175 [-1, 197, 3072] 0
Linear-176 [-1, 197, 768] 2,360,064
Dropout-177 [-1, 197, 768] 0
Mlp-178 [-1, 197, 768] 0
Identity-179 [-1, 197, 768] 0
Block-180 [-1, 197, 768] 0
block 12
LayerNorm-181 [-1, 197, 768] 1,536
Linear-182 [-1, 197, 2304] 1,771,776
Dropout-183 [-1, 12, 197, 197] 0
Linear-184 [-1, 197, 768] 590,592
Dropout-185 [-1, 197, 768] 0
Attention-186 [-1, 197, 768] 0
Identity-187 [-1, 197, 768] 0
LayerNorm-188 [-1, 197, 768] 1,536
Linear-189 [-1, 197, 3072] 2,362,368
GELU-190 [-1, 197, 3072] 0
Dropout-191 [-1, 197, 3072] 0
Linear-192 [-1, 197, 768] 2,360,064
Dropout-193 [-1, 197, 768] 0
Mlp-194 [-1, 197, 768] 0
Identity-195 [-1, 197, 768] 0
Block-196 [-1, 197, 768] 0
(3)后处理
LayerNorm-197 [-1, 197, 768] 1,536
Identity-198 [-1, 768] 0
Linear-199 [-1, 5] 3,845
================================================================
Total params: 85,650,437
Trainable params: 85,650,437
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 408.54
Params size (MB): 326.73
Estimated Total Size (MB): 735.84
----------------------------------------------------------------
3. 数据前处理
3*224*224经过768个16*16的卷积,输出768*14*14
将输出flatten,768*196(14*14)
调整通道196*768
添加class_num(分类信息)1*768,拼接196*768成197*768
添加位置信息pos,add(shape还是197*768)
4.数据输入到transformer encoder的onnx结构图
关于ONNX里面的op,说实话,有点hold不住,layernorm层搞得很复杂,融合暂时还没有看(后续会研究的,暂时没有时间),反正这个就是transformer encoder(我不管,这个就是)
LayerNorm-5 [-1, 197, 768]
Linear-6 [-1, 197, 2304]
Dropout-7 [-1, 12, 197, 197]
Linear-8 [-1, 197, 768]
Dropout-9 [-1, 197, 768]
Attention-10 [-1, 197, 768]
Identity-11 [-1, 197, 768]
LayerNorm-12 [-1, 197, 768]
Linear-13 [-1, 197, 3072]
GELU-14 [-1, 197, 3072]
Dropout-15 [-1, 197, 3072]
Linear-16 [-1, 197, 768]
Dropout-17 [-1, 197, 768]
Mlp-18 [-1, 197, 768]
Identity-19 [-1, 197, 768]
Block-20 [-1, 197, 768]
5.后处理
LayerNorm-197 [-1, 197, 768]
Identity-198 [-1, 768]
Linear-199 [-1, 5]
那 ,你看,这就是layernorm的op操作(不忍吐槽)
最后接上全连接层,输出结果
总结
其实从OP来看,VIT并没有添加新的算子,只是一些层的拼接,但是效果却是很好,真的,朴实无华的结构,做着深奥的内容,哎,继续学习吧,学无止境!!!相关的ONNX代码,感兴趣的读者多的话,后续可以上传,供大家试用,请关注或者评论(⊙o⊙)哦!!!
class: daisy prob: 0.995
class: dandelion prob: 0.00298
class: roses prob: 0.000599
class: sunflowers prob: 0.000633
class: tulips prob: 0.000771