【多模态】ViT模型技术学习

news2025/1/11 8:06:29

前言

最近多模态模型特别火,模型也越来越小,性能优异的MiniCPM-2.6只有8B大小,它采用的图片编码器是SigLipViT模型,一起从头学习ViT和Transformer!本文记录一下学习过程,所以是自上而下的写,从ViT拆到Transformer。

用Transformer来做图像分类?!

  1. Vision Transformer (ViT)出自ICLR 2021的论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》,使用之前做文本任务的Transformer来做图片分类任务
  2. ViT模型的构成主要包含图像切片、图像映射、Transformer模块和分类头
    在这里插入图片描述

ViT整体工作流程

假设输入图片尺寸image_size是 224 × 224 224 \times 224 224×224,子图大小(patch_size)为16,图片编码维度(hidden_dim)为768,
当1张224*224的图输入ViT后(批大小batch_size=1)会经历:

  1. 图片切片 – 图片首先被分割为 16 × 16 16 \times 16 16×16大小的子图,总共 ( 224 / / 16 ) × ( 224 / / 16 ) = 14 × 14 = 196 (224//16) \times (224//16)=14 \times 14=196 (224//16)×(224//16)=14×14=196
  2. 图片映射 – 子图被分别送到Linear Projection这个模块进行映射,得到大小为[1,768,196]的向量
  3. 变换一下维度便于输入Transformer,所有子图拼成的图片隐向量维度为[1,196,768];
  4. 分类token – 在输入Transformer前,为了与bert架构统一,也使用一个类似[CLS]的标记,在图片隐向量前面插入一个class_token,最终输入Transformer的向量大小为[1,197,768]
  5. 位置编码 – 随机初始化的pos_embedding大小也是[1,197,768],加到图片向量上
  6. 输入Transformer,编码器输入输出维度一致,输出的维度是[1,197,768]
  7. 输出分类结果 – 取class_token对应的输出向量输入分类头

*需要注意的是:分类任务不一定要取class_token对应的向量,也可在最后一个Transformer块的输出接一个global average pooling层再接MLP分类层,特定学习率参数情况下效果类似;ViT是为了和bert架构统一所以加入了class_token

ViT源代码拆解

1. VisionTransformer类的forward()

在torchvision代码中可以找到ViT的torch官方实现

def forward(self, x: torch.Tensor):
    # 图片切片、图片编码并把图片向量调整为transformer能接受的维度
    x = self._process_input(x)
    n = x.shape[0] # n是batch size
    # 给这个batch的n个图片向量最前面,都加入一个class_token,类似[CLS]
    batch_class_token = self.class_token.expand(n, -1, -1)
    x = torch.cat([batch_class_token, x], dim=1)
                   
    # 图片向量用Transformer的block进行处理
    x = self.encoder(x)
    # 取class_token对应的向量,x是[1,197,768],x[:,0]表示x[:,0,:]
    x = x[:, 0]
    # 输入分类头进行分类任务
    x = self.heads(x)
    return x

2.图片切片与编码——VisionTransformer类的_process_input()

  • ViT框架图里面的Linear Projection模块实际上是用一个nn.Con2d隐式实现的
  • nn.Con2d起到的作用和单独把一个个子图放到Linear层编码是一样的
  • 所以实际上图片编码后的维度为[1,768,14,14]–> [1,196,768]
    ........
    self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
def _process_input(self, x: torch.Tensor) -> torch.Tensor:    n, c, h, w = x.shape 
    # 图片维度为(n, c, h, w),n是batchsize,c是图像通道数一般为3,h/w是图像高宽
    p = self.patch_size  # 图片切片大小,例如为16,子图大小为patch_size*patch_size 
    n_h = h // p         # 图片切片,高度维度切的片数
    n_w = w // p         # 图片切片,高度维度切的片数
    # (n, c, h, w) -> (n, hidden_dim, n_h, n_w)
    x = self.conv_proj(x)
    # (n, hidden_dim, n_h, n_w) -> (n, hidden_dim, (n_h * n_w)),进行展平操作
    x = x.reshape(n, self.hidden_dim, n_h * n_w)    
    # Transformer期望的输入维度是(N,S,E),N是batchsize,S是序列长度,E是文本编码隐向量维度
    # 所以把维度变换一下,permute(0,2,1)表示把第0维放最前面,第2维放中间,第1维放后面
    x = x.permute(0, 2, 1) # 得到(n, (n_h * n_w), hidden_dim), n_h * n_w是子图数,类似文本序列长度
    return x

其中,对于卷积操作而言

self.conv_proj = nn.Conv2d(in_channels=3, out_channels=hidden_dim, kernel_size=patch_size, stride=patch_size)
  • 默认hidden_dim=768,patch_size=16,卷积核个数也就是输出的特征图通道数为768
  • 卷积核大小为16,步长也是16,可以保证卷积扫描的时候每次正好对一个子图做运算,子图互相之间不重叠,一个卷积核卷积运算的次数为(224//16)* (224//16)正好是14*14,每个运算值对应一个子图
  • 有768个卷积核,所以输出的大小为(n,768,14,14),对RGB图像而言卷积核也是个[3,16,16]的矩阵
  • RGB图像的卷积如下,RGB分别计算后相加,这只是1/768个卷积核的计算结果,所有结果拼接为矩阵
    在这里插入图片描述

3. Transformer的Encoder

3.1 Encoder的forward()

ViT使用的是Encoder for sequence to sequence translation

    ......
    super().__init__()
    # Note that batch_size is on the first dim because
    # we have batch_first=True in nn.MultiAttention() by default
    self.pos_embedding = nn.Parameter(torch.empty(1, seq_length, hidden_dim).normal_(std=0.02))  # from BERT
    self.dropout = nn.Dropout(dropout)
    layers: OrderedDict[str, nn.Module] = OrderedDict()
    for i in range(num_layers):
        layers[f"encoder_layer_{i}"] = EncoderBlock(
        	num_heads, 
        	hidden_dim, 
        	mlp_dim, 
        	dropout,
        	attention_dropout,  
        	norm_layer,)
    self.layers = nn.Sequential(layers)
    self.ln = norm_layer(hidden_dim)
def forward(self, input: torch.Tensor):
    torch._assert(input.dim() == 3, f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}")
    input = input + self.pos_embedding
    return self.ln(self.layers(self.dropout(input)))

3.2 Transformer的Encoder Block

主要包含self_attention结构,在self-attention中每个patch和patch之间计算相似度,学习patch间的关系

        ......
        super().__init__()
        self.num_heads = num_heads

        # Attention block
        self.ln_1 = norm_layer(hidden_dim)  # 层归一化,是对单个样本在其特征维度(最后一个维度)上进行的归一化
        self.self_attention = nn.MultiheadAttention(
            hidden_dim, num_heads, dropout=attention_dropout, batch_first=True
        )
        self.dropout = nn.Dropout(dropout)

        # MLP block
        self.ln_2 = norm_layer(hidden_dim)
        self.mlp = MLPBlock(hidden_dim, mlp_dim, dropout)

    def forward(self, input: torch.Tensor):
        torch._assert(
            input.dim() == 3,
            f"Expected (batch_size, seq_length, hidden_dim) got {input.shape}"
        )
        x = self.ln_1(input)
        x, _ = self.self_attention(x, x, x, need_weights=False)
        x = self.dropout(x)
        x = x + input  # 残差连接

        y = self.ln_2(x)
        y = self.mlp(y)
        return x + y

3.3 Transformer的Encoder Block的MultiHeadAttention模块

关于代码:

  • 多头注意力模块nn.MultiHeadAttention,forward方法在torch.nn.functional中,在这里之前ViT代码中已经统一把向量变换为(L,N,E)的形状

  • q(L,N,E) k(S,N,E) v(S,N,E) output(L,N,E)

  • L is the target
    length, S is the sequence length, H is the number of attention heads,
    N is the batch size, and E is the embedding dimension

  • nn.MultiHeadAttention的attention有一版注释的代码也在源文件中,搜索” multihead attention”往下翻

  • *torchtext.nn.modules.multiheadattention的多头注意力模块代码更简洁一些

下面是torchtext.nn.modules.multiheadattention的多头注意力模块代码:

# 假设这是在一个类的方法中定义的
    ......
    if self.batch_first:  # 如果是batch_first的先从(N, L, E)变为(L, N, E)形式
        query, key, value = query.transpose(-3, -2), key.transpose(-3, -2), value.transpose(-3, -2)

    # 获取维度信息
    tgt_len, src_len, bsz, embed_dim = (
        query.size(-3),
        key.size(-3),
        query.size(-2),
        query.size(-1)
    )

    # 分别乘qkv矩阵得到qkv
    q, k, v = self.in_proj_container(query, key, value)

    # 确保query的embed_dim可以被head数整除
    assert q.size(-1) % self.nhead == 0, "query's embed_dim must be divisible by the number of heads"
    head_dim = q.size(-1) // self.nhead
    q = q.reshape(tgt_len, bsz * self.nhead, head_dim)

    # 确保key的embed_dim可以被head数整除
    assert k.size(-1) % self.nhead == 0, "key's embed_dim must be divisible by the number of heads"
    head_dim = k.size(-1) // self.nhead
    k = k.reshape(src_len, bsz * self.nhead, head_dim)

    # 确保value的embed_dim可以被head数整除
    assert v.size(-1) % self.nhead == 0, "value's embed_dim must be divisible by the number of heads"
    head_dim = v.size(-1) // self.nhead
    v = v.reshape(src_len, bsz * self.nhead, head_dim)

    # 计算注意力输出和权重
    attn_output, attn_output_weights = self.attention_layer(
        q, k, v,
        attn_mask=attn_mask,
        bias_k=bias_k,
        bias_v=bias_v
    )

    # 将输出重新调整为原始形状
    attn_output = attn_output.reshape(tgt_len, bsz, embed_dim)
    attn_output = self.out_proj(attn_output)

    # 如果是batch_first从(L, N, E)变回去(N, L, E),编码器输入输出形状保持一致
    if self.batch_first:
        attn_output = attn_output.transpose(-3, -2)

    return attn_output, attn_output_weights

3.4 Transformer的Encoder Block的attention layer

  • torchtext.nn.modules.multiheadattention的self.attention_layer是ScaledDotProduct
  • query: (L, N * H, E / H) , key: (S, N * H, E / H),在self-attantion中序列长度L=序列长度S

attention计算方法为 S o f t m a x ( Q ⋅ K T d ) ⋅ V Softmax(\frac{Q \cdot K^T }{\sqrt{d}}) \cdot V Softmax(d QKT)V

  1. 计算注意力权重 : matmul(query,key)
  2. 权重归一化:对 attn_output_weights 进行 softmax 归一化时,希望确保每个查询位置(L)对所有键位置(S)的注意力权重之和为 1。因此,我们需要沿着最后一个维度 S 进行 softmax 归一化,即 dim=-1
  3. 加权求和:matmul(att_output_weights, value)
# Scale query
# 变成(N*H,L,E/H)
query, key, value = query.transpose(-2, -3), key.transpose(-2, -3), value.transpose(-2, -3)
query = query * (float(head_dim) ** -0.5)
# Dot product of q, k
#(N*H,L,E/H) ×  (N*H, E/H, S),matmul计算最后2维,也就是[N*H,:,:]×[N*H,:,:],得到[N*H,L,S]
attn_output_weights = torch.matmul(query, key.transpose(-2, -1))
attn_output_weights = torch.nn.functional.softmax(attn_output_weights, dim=-1) # (N*H, L, S)
attn_output_weights = torch.nn.functional.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_output_weights, value) # (N*H, L, E/H)

self-attention的直观解释-b站视频

Attention的解释有一个b站上搬运的视频非常直观,attention可以关注到全局上信息的关联,卷积只能关注到局部的信息
在这里插入图片描述

  • 假设图片有4个像素,RGB三个通道,表示起来x是[4,3]的矩阵

  • 如果隐空间维度hidden_dim=2,输入x乘以[3,2]的Wq/Wk/Wv矩阵可以得到[4,2]的Q/K/V向量
    在这里插入图片描述
    在这里插入图片描述

  • 计算相似性度量 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 QKT
    在这里插入图片描述

  • 注意到每次是向量的点积运算,例如Q的第4行表示q4,K的第4列表示k4,计算的实际上是向量相似度,得到的 𝑄 ∙ 𝐾 𝑇 𝑄∙𝐾^𝑇 QKT是每个像素间的相似度矩阵

  • 在self-attention的计算中涉及到除以√𝑑放缩,否则维度越大雅可比矩阵接近零矩阵梯度消失,详细原理可以在文末知乎专栏中找到

  • 计算softmax进行归一化,因为每一行是一个像素和其它像素的相似度,所以预期是每行概率值相加为1,对列做softmax:
    在这里插入图片描述

  • 最后乘以V矩阵完成注意力计算:
    在这里插入图片描述

  • 左边的相似度矩阵可以理解为权重,乘以V矩阵类似加权平均

  • 例如0.23表示第1个像素关注第1个像素的程度,0.33表示第1个像素关注第2个像素的程度

参考链接

  1. b站attention视频讲解:https://www.bilibili.com/video/BV1Ke411X7t7
  2. 知乎解释为什么attention需要除以√𝑑放缩:https://zhuanlan.zhihu.com/p/503321685

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2215785.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

windows上svn设置忽略

目的 就是在windows环境下设置svn的需要忽略的文件,这还是挺实用的一个功能,不然,很多编译的中间文件都上传到svn上了,这样就不好了;ignore设置,也需要注意一下。 过程 svn服务端式忽略,这是…

【前端】制作一个简单的网页(2)

单标签组成的元素 这类标签不需要内容产生效果&#xff0c;通常表示对网页的某种行为&#xff0c;它们不用标记任何内容&#xff0c;开始即是结束。 比如&#xff0c;<hr>标签的作用是在网页中添加一条分割线&#xff0c;它仅包含开始标签&#xff0c;是一个单标签元素。…

【Linux】解读信号的本质&相关函数及指令的介绍

前言 大家好吖&#xff0c;欢迎来到 YY 滴Linux系列 &#xff0c;热烈欢迎&#xff01; 本章主要内容面向接触过C的老铁 主要内容含&#xff1a; 欢迎订阅 YY滴C专栏&#xff01;更多干货持续更新&#xff01;以下是传送门&#xff01; YY的《C》专栏YY的《C11》专栏YY的《Lin…

k8s的部署和安装

k8s的部署和安装 一、Kubernets简介及部署方法 1.1 应用部署方式演变 在部署应用程序的方式上&#xff0c;主要经历了三个阶段&#xff1a; 传统部署&#xff1a;互联网早期&#xff0c;会直接将应用程序部署在物理机上 优点&#xff1a;简单&#xff0c;不需要其它技术的参…

【去哪儿-注册安全分析报告-缺少轨迹的滑动条】

前言 由于网站注册入口容易被黑客攻击&#xff0c;存在如下安全问题&#xff1a; 1. 暴力破解密码&#xff0c;造成用户信息泄露 2. 短信盗刷的安全问题&#xff0c;影响业务及导致用户投诉 3. 带来经济损失&#xff0c;尤其是后付费客户&#xff0c;风险巨大&#xff0c;造…

查看SQL执行计划 explain

查看SQL执行计划 explain explain使用方式 alter session set current_schematest; explain plan for sql语句; --并不会实际执行&#xff0c;因此生成的执行计划也是预估的 select * from table(dbms_xplan.display); explain使用场景 1.内存中没有谓词信息了&#xff0…

网站仿制的五大要素

网站仿制的五大要素 在数字化快速发展的今天&#xff0c;仿制一个网站不仅是技术上的挑战&#xff0c;更是对创意与灵感的考验。无论是为了学习设计理念&#xff0c;还是为企业进行市场竞争&#xff0c;以下五大要素是网站仿制时必不可少的。 **1. 目标分析** 在仿制网站之前…

1. ESP32简介

ESP32是什么&#xff1a;它是乐鑫科技研发和设计的一种无线系统级芯片优点&#xff1a; 强大的处理能力无线通信功能广泛的外设接口低功耗特性 为什么选择ESP-IDF开发&#xff1a; 基于C/C开发官方主推实际项目需求 常见的ESP32型号&#xff1a;

Docker 环境下 GPU 监控实战:使用 Prometheus 实现 DCGM Exporter 部署与 GPU 性能监控

Docker 环境下 GPU 监控实战&#xff1a;使用 Prometheus 实现 DCGM Exporter 部署与 GPU 性能监控 文章目录 Docker 环境下 GPU 监控实战&#xff1a;使用 Prometheus 实现 DCGM Exporter 部署与 GPU 性能监控一 查看当前 GPU 信息二 dcgm-exporter 部署1&#xff09;Docker r…

电脑端微信图片文件视频的缓存目录

C:\Users\aaa30\Documents\WeChat Files\wxid_n6b8j77iqho412\FileStorage\Video\2024-10

冠层体散射反射对称性协方差矩阵的模型,在多种分解中都适用

把这个矩阵对应到观测协方差矩阵中&#xff0c;看占比多少&#xff0c;然年后得到体散射功率&#xff0c;时前面的系数乘对角线的和

51单片机快速入门之左移右移流水灯 2024年10/15

51单片机快速入门之左移右移流水灯 左移操作: <<1 每次往左移动一位假设一个八位数为0000 1111 当这个数左移一次之后 0 0001 1110当这个数左移两次之后 00 0011 1100 注意观察 橙色 数字 Python代码如下: 0b表示这是一个二进制 注意这里前置0被省略了 …

数据结构——树和森林

目录 树的存储结构 1、双亲表示法 2、孩子链表 3、孩子兄弟表示法 树与二叉树的转换 将树转换为二叉树 将二叉树转换为树 森林与二叉树的转化 森林转换成二叉树 二叉树转换为森林 树和森林的遍历 1、 树的遍历&#xff08;三种方式&#xff09; 2、森林的遍历 树的存…

DVWA之File Inclusion(文件包含)

DVWA之File Inclusion&#xff08;文件包含&#xff09; 1、定义&#xff1a;服务器通过php的特性&#xff08;函数的特性&#xff09;去包含任意文件时&#xff0c;由于对包含的文件来源没有过滤或过滤不严&#xff0c;从而可去包含一个恶意的文件。文件包含包括&#xff1a;…

【从零开始的LeetCode-算法】3200. 三角形的最大高度

给你两个整数 red 和 blue&#xff0c;分别表示红色球和蓝色球的数量。你需要使用这些球来组成一个三角形&#xff0c;满足第 1 行有 1 个球&#xff0c;第 2 行有 2 个球&#xff0c;第 3 行有 3 个球&#xff0c;依此类推。 每一行的球必须是 相同 颜色&#xff0c;且相邻行…

AdaTAD(CVPR 2024)视频动作检测方法详解

前言 论文&#xff1a;End-to-End Temporal Action Detection with 1B Parameters Across 1000 Frames 代码&#xff1a;AdaTAD 从论文标题可以看出&#xff0c;AdaTAD 可以在 1B 参数且输入视频在 1000 帧的情况下实现端到端的训练&#xff0c;核心创新点是引入 Temporal-Inf…

STM32传感器模块编程实践(六) 1.8寸液晶屏TFT LCD彩屏简介及驱动源码

文章目录 一.概要二.TFT彩屏主要参数三.TFT彩屏参考原理图四.TFT彩屏模块接线说明五.模块SPI通讯协议介绍六.TFT彩屏模块显示1.显示英文字符串2.显示数字3.显示中文 七.TFT彩屏实现图片显示八.STM32单片机1.8寸 TFT LCD显示实验1.硬件准备2.软件工程3.软件主要代码4.实验效果 九…

【C++】——list 容器的解析与极致实现

人的一切痛苦&#xff0c;本质上都是对自己的无能的愤怒。 —— 王小波 目录 1、list 介绍 2、list的使用 2.1 list 的构造 2.2 iterator 的使用 2.3 list 的修改 2.4一些特殊接口 2.5 迭代器失效问题 3、实现list 3.1底层结构 结点类 list类 迭代器类 3.2功能接…

VLOG视频制作解决方案,开发者可自行定制包装模板

无论是旅行见闻、美食探店&#xff0c;还是日常琐事、创意挑战&#xff0c;每一个镜头背后都蕴含着创作者无限的热情和创意。然而&#xff0c;面对纷繁复杂的视频编辑工具&#xff0c;美摄科技凭借其前沿的视频制作技术和创新的解决方案&#xff0c;为每一位视频创作者提供了开…

服务端负载均衡和客户端负载

负载均衡分为服务端负载均衡和客户端负载均衡&#xff0c;图解&#xff1a; 客户端的负载均衡还需要从注册中心获取集群部署的服务地址&#xff0c;其中客户的负载均衡器定时读取注册中心的IP和端口&#xff0c;然后缓存起来&#xff0c;这样以后可以先判断缓存IP和端口是否可用…