self.cls_token
在 Vision Transformer (ViT) 模型中,在训练阶段和推理阶段的行为和作用是不同的,而且它的值在训练过程中会发生变化。
1. self.cls_token
的作用
在 ViT 中,self.cls_token
是一个特殊的、可学习的嵌入向量(embedding vector),它被添加到输入序列(图像patch的embedding序列)的开头。这个 cls_token
的主要目的是在经过 Transformer Encoder 的多层自注意力计算后,其对应的输出向量能够聚合整个输入序列的信息,用于最终的分类任务。
可以把 cls_token
理解为一个“班长”的角色。每个图像块(patch)是一个“学生”。一开始,“班长”(cls_token
)和“学生”(patches)互相不认识(都是随机初始化的)。在 Transformer 的每一层,“班长”都会和每个“学生”交流(自注意力机制),同时“学生”之间也互相交流。经过多层交流后,“班长”就逐渐了解了整个班级的情况(图像的全局信息)。最后,我们只用“班长”的输出来做分类。
2. 训练阶段
-
随机初始化:在模型初始化时,
self.cls_token
是一个形状为(1, 1, embed_dim)
的张量,其中的值通常是从某个分布(如正态分布)中随机采样的。这意味着在训练开始时,cls_token
没有任何关于图像的先验信息。 -
可学习参数:
self.cls_token
被定义为nn.Parameter
,这意味着它是一个模型的可学习参数。在训练过程中,它会随着其他模型参数一起通过反向传播和梯度下降进行更新。 -
与输入交互:在每个训练批次中,
self.cls_token
会被复制并与每个输入图像的patch embeddings进行拼接(concatenate),形成 Transformer Encoder 的输入序列。# 假设 x 是图像patch embeddings, 形状为 (batch_size, num_patches, embed_dim) cls_token = self.cls_token.expand(x.shape[0], -1, -1) # 扩展到与 batch_size 匹配 x = torch.cat((cls_token, x), dim=1) # 拼接
-
信息聚合:在 Transformer Encoder 的每一层,
cls_token
对应的embedding都会与其他patch embeddings进行自注意力计算。通过这种方式,cls_token
逐渐“学习”到如何聚合来自所有patch的信息。 -
参数更新:在反向传播过程中,
cls_token
的梯度会根据分类损失进行计算,并通过优化器进行更新。这意味着cls_token
的值会不断调整,以更好地捕捉图像的全局特征。
3. 推理阶段
-
固定值:在推理阶段,模型的所有参数(包括
self.cls_token
)都是固定的,不再进行更新。cls_token
使用的是训练结束时学习到的值。 -
相同操作:与训练阶段类似,
self.cls_token
仍然会被复制并与输入图像的patch embeddings进行拼接,作为 Transformer Encoder 的输入。 -
信息提取:经过 Transformer Encoder 的处理后,
cls_token
对应的输出向量被用作分类器的输入,进行最终的类别预测。
4. 总结
特性 | 训练阶段 | 推理阶段 |
---|---|---|
值 | 随机初始化,通过反向传播更新 | 固定(使用训练结束时学习到的值) |
是否可学习 | 是 (nn.Parameter ) | 否 |
作用 | 与patch embeddings交互,聚合全局信息,参与梯度更新 | 与patch embeddings交互,提取全局信息,用于分类 |
5. 代码示例 (简化)
import torch
import torch.nn as nn
class VisionTransformer(nn.Module):
def __init__(self, embed_dim=768, ...):
super().__init__()
# ... 其他层 ...
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 可学习参数
# ... 其他层 ...
def forward(self, x):
# ... patch embedding ...
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # 复制cls_token
x = torch.cat((cls_token, x), dim=1) # 拼接
# ... Transformer Encoder ...
x = x[:, 0] # 取cls_token对应的输出
# ... 分类器 ...
return x
因此,self.cls_token
在训练阶段是随机初始化的可学习参数,通过与图像patch embeddings的交互和反向传播不断更新;在推理阶段,self.cls_token
的值是固定的,它利用训练中学到的知识来提取图像的全局特征,用于分类。
这种设计使得 ViT 能够有效地处理图像数据,并在各种视觉任务中取得了出色的性能。