Vision Transformer(vit)的主干

news2025/1/22 19:20:29

图解:

代码:

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,
                 qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,
                 attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,
                 act_layer=None):
        """
        Args:
            img_size (int, tuple): input image size
#输入图像的大小,通常是 224 或其他标准尺寸
            patch_size (int, tuple): patch size
#每个块(patch)的大小,例如 16x16
            in_c (int): number of input channels
#输入图像的通道数,RGB 图像是 3
            num_classes (int): number of classes for classification head
#最终分类的类别数,默认 1000 类
            embed_dim (int): embedding dimension
#嵌入维度,即每个 patch 被映射到的向量的维度,默认是 768
            depth (int): depth of transformer
#Transformer 的深度,即堆叠的块(Block)数量。
            num_heads (int): number of attention heads
#注意力头的数量,默认设为 12
            mlp_ratio (int): ratio of mlp hidden dim to embedding dim
# MLP 隐藏层的维度与嵌入维度的比例。
            qkv_bias (bool): enable bias for qkv if True
#是否为 QKV(查询、键、值)矩阵添加偏置
            qk_scale (float): override default qk scale of head_dim ** -0.5 if set
#如果设定,将会覆盖默认的 qk 缩放因子
            representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
#如果设置了这个值,将会有一个表示层(pre-logits)
            distilled (bool): model includes a distillation token and head as in DeiT models
#vit中可以不管这个参数
            drop_ratio (float): dropout rate
# Dropout 的比例
            attn_drop_ratio (float): attention dropout rate
#注意力层的 Dropout 比例
            drop_path_ratio (float): stochastic depth rate
#droppath比例
            embed_layer (nn.Module): patch embedding layer
#用于嵌入图像的层,默认使用 PatchEmbed
            norm_layer: (nn.Module): normalization layer
#正则化层,通常是 LayerNorm
        """
        super(VisionTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
# 与 embed_dim 保持一致,表示嵌入的维度。
        self.num_tokens = 2 if distilled else 1
#不管distilled所以distilled=1
        norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
#使用 LayerNorm作为默认的规范化层
        act_layer = act_layer or nn.GELU
#默认使用 GELU 作为激活函数

        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
#Embedding层结构
        num_patches = self.patch_embed.num_patches
#patches的个数

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#这是用于分类的分类标记(Class Token),它是一个可学习的参数,初始值为零
        self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
#不管distilled所以self.dist_token=None
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
#位置编码(Position Embedding)
        self.pos_drop = nn.Dropout(p=drop_ratio)
#位置编码后的 Dropout 操作

        dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
#用于控制每个 Block 的 DropPath 比例
        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                  drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],
                  norm_layer=norm_layer, act_layer=act_layer)
            for i in range(depth)
        ])
#使用 Block 类构建了Transformer的主体部分,包括注意力和MLP层,并使用残差连接和 DropPath 
        self.norm = norm_layer(embed_dim)
#最后的归一化层,用于 Transformer 输出的处理

        # Representation layer
        if representation_size and not distilled:
#设置了 representation_size则会增加一个表示层 pre_logits,not distilled=true
            self.has_logits = True
            self.num_features = representation_size
            self.pre_logits = nn.Sequential(OrderedDict([
                ("fc", nn.Linear(embed_dim, representation_size)),
                ("act", nn.Tanh())
            ]))
#pre_logits层结构一个全连接和tanh激活函数
        else:
            self.has_logits = False
            self.pre_logits = nn.Identity()

        # Classifier head(s)
        self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
        self.head_dist = None
        if distilled:
            self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
#distilled为none不用管

        # Weight init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        if self.dist_token is not None:
            nn.init.trunc_normal_(self.dist_token, std=0.02)

        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(_init_vit_weights)
#权重初始化

    def forward_features(self, x):
        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层
        # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

        x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout
        x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
        x = self.norm(x)
#进行归一化
#vit中self.dist_token is None所以模型只有分类标记 (class token)。
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
        else:
            return x[:, 0], x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
#首先获取 Transformer 的特征输出
        if self.head_dist is not None:
            x, x_dist = self.head(x[0]), self.head_dist(x[1])
            if self.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
            x = self.head(x)
        return x

 操作:

代码:

        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层

操作:

代码:

    # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

操作:

代码:

x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout

操作:

代码:

        x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
        x = self.norm(x)
#进行归一化

操作:

代码:

#vit中self.dist_token is None所以模型只有分类标记 (class token)。
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
        else:
            return x[:, 0], x[:, 1]

操作:

代码:

#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
            x = self.head(x)

分类标记 (Class Token):

是一种特殊的 输入 token,在 Transformer 模型中被用来聚合全局特征。

它在模型中起到了类似于 CNN 中全局池化 (Global Pooling) 的作用,负责从所有 patch 的信息中提取一个全局表示。

这个 token 的输出向量被用作分类任务的特征输入,之后会被送入分类头 (classifier head) 进行最终的类别预测。

embedding层:

Vision Transformer(vit)的Embedding层结构-CSDN博客

Multi-Head Self-Attention:

Vision Transformer(vit)的Multi-Head Self-Attention(多头注意力机制)结构-CSDN博客

MLP模块:

Vision Transformer(vit)的MLP模块-CSDN博客

Encoder block:

Vision Transformer(vit)的Encoder层结构-CSDN博客

详解:Vision Transformer详解-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2252408.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mongodb配置ssl连接

mongodb5.0.9 centos7.6x86 1、正常启动mongod -f mongodb.conf 2、生成所需要的ssl证书 服务端ssl配置: 2.1生成ca.pem证书 #-x509: 用于生成自签证书,如果不是自签证书则不需要此项 #-days: 证书的有效期限&…

Linux 中的 ls 命令:从使用到源码解析

ls 命令是 Linux 系统中最常用和最基本的命令之一。下面将深入探讨 ls 命令的使用方法、工作原理、源码解析以及实际应用场景。 1. ls 命令的使用** ls 命令用于列出目录内容,显示文件和目录的详细信息。 1.1 基本用法 ls [选项] [文件或目录]例如: …

Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo)

Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo) 目录 Python 【图像分类】之 PyTorch 进行猫狗分类功能的实现(Swanlab训练可视化/ Gradio 实现猫狗分类 Demo) 一、简单介绍 二、PyTorch 三、CNN 1、神经网络 2、卷…

【C语言】结构体(二)

一&#xff0c;结构体的初始化 和其它类型变量一样&#xff0c;对结构体变量可以在定义时指定初始值 #include <stdio.h> #include <stdlib.h> struct books // 结构体类型 {char title[50];char author[50]; //结构体成员char subject[100];int book_id; }…

C++(4个类型转换)

1. C语言中的类型转换 1. 隐式 类型转换&#xff1a; 具有相近的类型才能进行互相转换&#xff0c;如&#xff1a;int,char,double都表示数值。 2. 强制类型转换&#xff1a;能隐式类型转换就能强制类型转换&#xff0c;隐式类型之间的转换类型强相关&#xff0c;强制类型转换…

Windows下从命令行(Powershell/CMD)发送内容到系统通知中心

Windows下从命令行&#xff08;Powershell/CMD&#xff09;发送内容到系统通知中心 01 前言 在平时写脚本的时候&#xff0c;将日志等信息直接输出到控制台固然是最直接的&#xff0c;而如果是一些后台执行的任务&#xff0c;不需要时刻关注运行细节但是又想知道一些大致的情…

四、初识C语言(4)

一、作业&#xff1a;static修饰局部变量 #define _CRT_SECURE_NO_WARNINGS 1 #include <stdio.h> #include <string.h> //作业&#xff1a;static修饰局部变量 int sum (int a) {int c 0;static int b 3;c 1;b 2;return (abc); } int main() {int i 0;int a …

基于深度学习的甲状腺结节影像自动化诊断系统(PyQt5界面+数据集+训练代码)

随着医学影像技术的发展&#xff0c;计算机辅助诊断在甲状腺结节的早期筛查中发挥着重要作用。甲状腺结节的良恶性鉴别对临床治疗具有重要意义&#xff0c;但传统的诊断方法依赖于医生的经验和影像学特征&#xff0c;存在一定的主观性和局限性。为了解决这一问题&#xff0c;本…

VLC 播放的音视频数据处理流水线搭建

VLC 播放的音视频数据处理流水线搭建 音视频流播放处理循环音频输出处理流水线VLC 用 input_thread_t 对象直接或间接管理音视频播放有关的各种资源,包括 Access, Demux, Decode, Output, Filter 等,这个类型定义 (位于 vlc-3.0.16/include/vlc_input.h) 如下: s…

Android 12系统源码_RRO机制(一)Runtime Resource Overlay机制实践

前言 Android的RRO&#xff08;Runtime Resource Overlay&#xff09;机制允许开发者在运行时替换或重写系统资源&#xff0c;例如布局、图标、字符串等。这个机制的目标是为了支持设备定制和主题化&#xff0c;特别是在不修改系统源代码的情况下。RRO通过在系统的资源上叠加一…

Tomcat新手成长之路:安装部署优化全解析(下)

接上篇《Tomcat新手成长之路&#xff1a;安装部署优化全解析&#xff08;上&#xff09;》: link 文章目录 7.应用部署7.1.上下文7.2.启动时进行部署7.3.动态应用部署 8.Tomcat 类加载机制8.1.简介8.2.类加载器定义8.3.XML解析器和 Java 9.JMS监控9.1.简介9.2.启用 JMX 远程监…

动态代理如何加强安全性

在当今这个信息爆炸、网络无孔不入的时代&#xff0c;我们的每一次点击、每一次浏览都可能留下痕迹&#xff0c;成为潜在的安全隐患。如何在享受网络便利的同时&#xff0c;有效保护自己的隐私和信息安全&#xff0c;成为了每位网络使用者必须面对的重要课题。动态代理服务器&a…

python---面向对象-python中的实践(2)

如何定义一个类&#xff1f; class 类名:pass怎样通过类&#xff0c;创建出一个对象&#xff1f; 根据类创建对象one Money() 执行流程1. 类的定义2. 根据类&#xff0c;创建出一个对象3. 将对象的唯一标识返回class Money:passprint(Money.__name__) xxx Money print(xxx.…

数据结构-散列函数的构造方法

一.数字关键词 关键词存储应该尽可能的离散 直接定址法:利用线性函数,例如上面的例子,h(key)key-1990,key1990&#xff0c;这个就被存放在0的位置 数字分析法:关键字可能有很到位组成,每一位变化可能都不一样&#xff0c;有的位是不变的,就是说不同的对象这一位都是一样的,有的…

单点登录解决方案 CAS(Central Authentication Service)详解

目录 CAS 的工作原理 票据&#xff08;Ticket&#xff09;详解 CAS 的优势 CAS 的应用场景 小结 参考资料 Central Authentication Service&#xff08;中央认证服务&#xff0c;简称 CAS&#xff09;是一个开源的企业级单点登录&#xff08;Single Sign-On, SSO&#xf…

输入json 达到预览效果

下载 npm i vue-json-pretty2.4.0 <template><div class"newBranchesDialog"><t-base-dialogv-if"addDialogShow"title"Json数据配置"closeDialog"closeDialog":dialogVisible"addDialogShow":center"…

U盘文件夹变打不开的文件:深度解析、恢复策略与预防之道

一、U盘文件夹变打不开的文件现象解析 在日常使用U盘的过程中&#xff0c;我们时常会遇到这样的困扰&#xff1a;原本存储有序、可以轻松访问的文件夹&#xff0c;突然之间变成了无法打开的文件。这些文件通常以未知图标或乱码形式显示&#xff0c;双击或右键尝试打开时&#…

2025年软考-网络工程师新旧教程及考纲变化对比!

2025网工教材改版基本确认&#xff01;网络工程师一直是软考中级的热门科目。最近&#xff0c;官方发布了官方第六版的网工教材&#xff0c;本次出版在原有第五版的基础上做了大量的删减&#xff0c;并新增了部分新内容。明年的软考考纲大概率会根据这次的新版教材进行修改&…

视觉处理基础1

目录 一、CNN 1. 概述 1.1 与传统网络的区别 1.2 全连接的局限性 1.3 卷积思想 1.4 卷积的概念 1.4.1 概念 1.4.2 局部连接 1.4.3 权重共享 2. 卷积层 2.1 卷积核 2.2 卷积计算 2.3 边缘填充 2.4 步长Stride 2.5 多通道卷积计算 2.7 特征图大小计算方法 2…

大疆T100大载重吊运植保无人机技术详解

大疆T100作为一款大载重吊运植保无人机&#xff0c;融合了全新的AI和AR功能&#xff0c;旨在进一步提升安全性并满足喷洒、播撒、吊运等多种作业场景的需求。以下是对其技术的详细解析&#xff1a; 一、总体性能 最大起飞重量&#xff1a;149.9公斤 喷洒容量&#xff1a;75升…