self.cls_token在 Vision Transformer (ViT) 模型中的训练阶段和推理阶段的行为和作用的异同

news2025/3/14 10:05:00

self.cls_token 在 Vision Transformer (ViT) 模型中,在训练阶段和推理阶段的行为和作用是不同的,而且它的值在训练过程中会发生变化。

1. self.cls_token 的作用

在 ViT 中,self.cls_token 是一个特殊的、可学习的嵌入向量(embedding vector),它被添加到输入序列(图像patch的embedding序列)的开头。这个 cls_token 的主要目的是在经过 Transformer Encoder 的多层自注意力计算后,其对应的输出向量能够聚合整个输入序列的信息,用于最终的分类任务。

可以把 cls_token 理解为一个“班长”的角色。每个图像块(patch)是一个“学生”。一开始,“班长”(cls_token)和“学生”(patches)互相不认识(都是随机初始化的)。在 Transformer 的每一层,“班长”都会和每个“学生”交流(自注意力机制),同时“学生”之间也互相交流。经过多层交流后,“班长”就逐渐了解了整个班级的情况(图像的全局信息)。最后,我们只用“班长”的输出来做分类。

2. 训练阶段

  1. 随机初始化:在模型初始化时,self.cls_token 是一个形状为 (1, 1, embed_dim) 的张量,其中的值通常是从某个分布(如正态分布)中随机采样的。这意味着在训练开始时,cls_token 没有任何关于图像的先验信息。

  2. 可学习参数self.cls_token 被定义为 nn.Parameter,这意味着它是一个模型的可学习参数。在训练过程中,它会随着其他模型参数一起通过反向传播和梯度下降进行更新。

  3. 与输入交互:在每个训练批次中,self.cls_token 会被复制并与每个输入图像的patch embeddings进行拼接(concatenate),形成 Transformer Encoder 的输入序列。

    # 假设 x 是图像patch embeddings, 形状为 (batch_size, num_patches, embed_dim)
    cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 扩展到与 batch_size 匹配
    x = torch.cat((cls_token, x), dim=1)  # 拼接
    
  4. 信息聚合:在 Transformer Encoder 的每一层,cls_token 对应的embedding都会与其他patch embeddings进行自注意力计算。通过这种方式,cls_token 逐渐“学习”到如何聚合来自所有patch的信息。

  5. 参数更新:在反向传播过程中,cls_token 的梯度会根据分类损失进行计算,并通过优化器进行更新。这意味着 cls_token 的值会不断调整,以更好地捕捉图像的全局特征。

3. 推理阶段

  1. 固定值:在推理阶段,模型的所有参数(包括 self.cls_token)都是固定的,不再进行更新。cls_token 使用的是训练结束时学习到的值。

  2. 相同操作:与训练阶段类似,self.cls_token 仍然会被复制并与输入图像的patch embeddings进行拼接,作为 Transformer Encoder 的输入。

  3. 信息提取:经过 Transformer Encoder 的处理后,cls_token 对应的输出向量被用作分类器的输入,进行最终的类别预测。

4. 总结

特性训练阶段推理阶段
随机初始化,通过反向传播更新固定(使用训练结束时学习到的值)
是否可学习是 (nn.Parameter)
作用与patch embeddings交互,聚合全局信息,参与梯度更新与patch embeddings交互,提取全局信息,用于分类

5. 代码示例 (简化)

import torch
import torch.nn as nn

class VisionTransformer(nn.Module):
    def __init__(self, embed_dim=768, ...):
        super().__init__()
        # ... 其他层 ...
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # 可学习参数
        # ... 其他层 ...

    def forward(self, x):
        # ... patch embedding ...
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)  # 复制cls_token
        x = torch.cat((cls_token, x), dim=1)  # 拼接
        # ... Transformer Encoder ...
        x = x[:, 0]  # 取cls_token对应的输出
        # ... 分类器 ...
        return x

因此,self.cls_token 在训练阶段是随机初始化的可学习参数,通过与图像patch embeddings的交互和反向传播不断更新;在推理阶段,self.cls_token 的值是固定的,它利用训练中学到的知识来提取图像的全局特征,用于分类。
这种设计使得 ViT 能够有效地处理图像数据,并在各种视觉任务中取得了出色的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2314799.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

247g 的工业级电调,如何让无人机飞得更 “聪明“?——STONE 200A-M 深度测评

一、轻量化设计背后的技术取舍 当拿到 STONE 200A-M 时,247g 的重量让人意外 —— 这个接近传统 200A 电调 70% 的重量,源自 1205624.5mm 的紧凑结构(0.1mm 公差控制)。实测装机显示,相比同规格产品,其体积…

Node.js:快速启动你的第一个Web服务器

Node.js 全面入门指南 文章目录 Node.js 全面入门指南一 安装Node.js1. Windows2. MacOS/Linux 二 配置开发环境1. VSCode集成 三 第一个Node.js程序1. 创建你的第一个Node.js程序 四 使用Express框架1. 快速搭建服务器 一 安装Node.js 1. Windows 以下是Windows环境下Node.j…

自定义日志回调函数实现第三方库日志集成:从理论到实战

一、应用场景与痛点分析 在开发过程中,我们经常会遇到以下场景: 日志格式统一:第三方库使用自己的日志格式,导致系统日志混杂,难以统一管理和分析。日志分级过滤:需要动态调整第三方库的日志输出级别&…

Linux练级宝典->任务管理和守护进程

任务管理 进程组概念 每个进程除了进程ID以外,还有一个进程组,进程组就是一个或多个进程的集合 同一个进程组,代表着他们是共同作业的,可以接收同一个终端的各种信号,进程组也有其唯一的进程组号。还有一个组长进程&a…

C语言:计算并输出三个整数的最大值 并对三个数排序

这是《C语言程序设计》73页的思考题。下面分享自己的思路和代码 思路&#xff1a; 代码&#xff1a; #include <stdio.h> int main() {int a,b,c,max,min,mid ; //设置大中小的数分别为max&#xff0c;mid&#xff0c;min&#xff0c;abc为输入的三个数printf("ple…

工具(十二):Java导出MySQL数据库表结构信息到excel

一、背景 遇到需求&#xff1a;将指定数据库表设计&#xff0c;统一导出到一个Excel中&#xff0c;存档查看。 如果一个一个弄&#xff0c;很复杂&#xff0c;耗时长。 二、写一个工具导出下 废话少絮&#xff0c;上码&#xff1a; 2.1 pom导入 <dependency><grou…

ACL初级总结

ACL–访问控制列表 1.访问控制 在路由器流量流入或者流出的接口上,匹配流量,然后执行相应动作 permit允许 deny拒绝 2.抓取感兴趣流 3.ACL匹配规则 自上而下逐一匹配,若匹配到了则按照对应规则执行动作,而不再向下继续匹配 思科:ACL列表末尾隐含一条拒绝所有的规则 华为:AC…

调优案例一:堆空间扩容提升吞吐量实战记录

&#x1f4dd; 调优案例一&#xff1a;堆空间扩容提升吞吐量实战记录 &#x1f527; 调优策略&#xff1a;堆空间扩容三部曲 # 原配置&#xff08;30MB堆空间&#xff09; export CATALINA_OPTS"$CATALINA_OPTS -Xms30m -Xmx30m"# 新配置&#xff08;扩容至120MB&am…

C语言 —— 此去经年梦浪荡魂音 - 深入理解指针(卷一)

目录 1. 内存和地址 2. 指针变量和地址 2.1 取地址操作符&#xff08;&&#xff09; 2.2 指针变量 2.3 解引用操作符 &#xff08;*&#xff09; 3. 指针的解引用 3.1 指针 - 整数 3.2 void* 指针 4. const修饰指针 4.1 const修饰变量 4.2 const修饰指针变量 5…

计算机毕业设计:留守儿童的可视化界面

留守儿童的可视化界面mysql数据库创建语句留守儿童的可视化界面oracle数据库创建语句留守儿童的可视化界面sqlserver数据库创建语句留守儿童的可视化界面springspringMVChibernate框架对象(javaBean,pojo)设计留守儿童的可视化界面springspringMVCmybatis框架对象(javaBean,poj…

golang算法二叉树对称平衡右视图

100. 相同的树 给你两棵二叉树的根节点 p 和 q &#xff0c;编写一个函数来检验这两棵树是否相同。 如果两个树在结构上相同&#xff0c;并且节点具有相同的值&#xff0c;则认为它们是相同的。 示例 1&#xff1a; 输入&#xff1a;p [1,2,3], q [1,2,3] 输出&#xff1a…

Chatbox通过百炼调用DeepSeek

解决方案链接&#xff1a;评测&#xff5c;零门槛&#xff0c;即刻拥有DeepSeek-R1满血版 方案概览 本方案以 DeepSeek-R1 满血版为例进行演示&#xff0c;通过百炼模型服务进行 DeepSeek 开源模型调用&#xff0c;可以根据实际需求选择其他参数规模的 DeepSeek 模型。百炼平台…

【数据结构】6栈

0 章节 3&#xff0e;1到3&#xff0e;3小节。 认知与理解栈结构&#xff1b; 列举栈的操作特点。 理解并列举栈的应用案例。 重点 栈的特点与实现&#xff1b; 难点 栈的灵活实现与应用 作业或思考题 完成学习测试&#xff12;&#xff0c;&#xff1f; 内容达成以下标准(考核…

PyTorch 入门学习

目录 PyTorch 定义 核心作用 应用场景 Pytorch 基本语法 1. 张量的创建 2. 张量的类型转换 3. 张量数值计算 4. 张量运算函数 5. 张量索引操作 6. 张量形状操作 7. 张量拼接操作 8. 自动微分模块 9. 案例-线性回归案例 PyTorch 定义 PyTorch 是一个基于 Python 深…

mov格式视频如何转换mp4?

mov格式视频如何转换mp4&#xff1f;在日常的视频处理中&#xff0c;经常需要将MOV格式的视频转换为MP4格式&#xff0c;以兼容更多的播放设备和平台。下面给大家分享如何将MOV视频转换为MP4&#xff0c;4款视频格式转换工具分享。 一、牛学长转码大师 牛学长转码大师是一款功…

二进制求和(js实现,LeetCode:67)

这道题我的解决思路是先将a和b的长度保持一致以方便后续按位加减 let lena a.length let lenb b.length if (lena ! lenb) {if (lena > lenb) {for (let i 0; i <lena-lenb; i) {b 0 b}} else {for (let i 0; i < lenb-lena; i) {a 0 a}} } 下一步直接进行按…

【C#】使用DeepSeek帮助评估数据库性能问题,C# 使用定时任务,每隔一分钟移除一次表,再重新创建表,和往新创建的表追加5万多条记录

&#x1f339;欢迎来到《小5讲堂》&#x1f339; &#x1f339;这是《C#》系列文章&#xff0c;每篇文章将以博主理解的角度展开讲解。&#x1f339; &#x1f339;温馨提示&#xff1a;博主能力有限&#xff0c;理解水平有限&#xff0c;若有不对之处望指正&#xff01;&#…

【openGauss】物理备份恢复

文章目录 1. gs_backup&#xff08;1&#xff09;备份&#xff08;2&#xff09;恢复&#xff08;3&#xff09;手动恢复的办法 2. gs_basebackup&#xff08;1&#xff09;备份&#xff08;2&#xff09;恢复① 伪造数据目录丢失② 恢复 3. gs_probackup&#xff08;1&#xf…

蓝桥杯备赛-基础练习 day1

1、闰年判断 问题描述 给定一个年份&#xff0c;判断这一年是不是闰年。 当以下情况之一满足时&#xff0c;这一年是闰年:1.年份是4的倍数而不是100的倍数 2&#xff0e;年份是400的倍数。 其他的年份都不是闰年。 输入格式 输入包含一个…

实验四 Python聚类决策树训练与预测 基于神经网络的MNIST手写体识别

一、实验目的 Python聚类决策树训练与预测&#xff1a; 1、掌握决策树的基本原理并理解监督学习的基本思想。 2、掌握Python实现决策树的方法。 基于神经网络的MNIST手写体识别&#xff1a; 1、学习导入和使用Tensorflow。 2、理解学习神经网络的基本原理。 3、学习使用…