Vision Transformer(ViT)模型原理及PyTorch逐行实现

news2024/9/20 16:37:29

Vision Transformer(ViT)模型原理及PyTorch逐行实现

一、TRM模型结构

1.Encoder

  1. Position Embedding 注入位置信息
  2. Multi-head Self-attention 对各个位置的embedding融合(空间融合)
  3. LayerNorm & Residual
  4. Feedforward Neural Network 对每个位置上单独仿射变换(通道融合)
    • Linear1(large)
    • Linear2(d_model)
  5. LayerNorm & Residual

2.Decoder

  1. Position Embedding
  2. Casual Multi-head Self-attention
  3. LayerNorm & Residual
  4. Memory-base Multi-head Cross-attention
  5. LayerNorm & Residual
  6. Feedforward Neural Network
    • Linear1(large)
    • Linear2(d_model)
  7. LayerNorm & Residual

二、TRM使用类型

  1. Encoder only 【 ViT 所使用的】
    • BERT、分类任务、非流式任务
  2. Decoder only
    • GPT系列、语言建模、自回归生成任务、流式任务
  3. Encoder-Decoder
    • 机器翻译、语音识别

三、TRM特点

  1. 无先验假设(例如:局部关联性、有序建模性)
  2. 核心计算在于自注意力机制,平方复杂度
  3. 数据量的要求与归纳偏置【人类通过归纳法得到的经验,把这些经验带入到模型中,很多事物的共性】的引入成反比

四、Vision Transformer(ViT)

  1. DNN perspective 图像的信息量主要还是聚集在一块区域上
    • image2patch 将图片切分成很多个块
    • patch2embedding 将每个块转换为向量
  2. CNN perspective 从卷积的角度得到向量
    • 2D convolution over image 二维卷积
    • flatten the output feature map 把输出的卷积图拉直
  3. class token embedding 占位符
  4. position embedding
    • interpolation when inference
  5. Transformer Encoder 只使用的Encoder
  6. classification head 最后分类

五、ViT论文讲解

image-20240908092056541

​ 首先将一副图片分为很多个块,每个块的大小都是不会变化的,图片即使大一点,只是序列更长一点。先左到右,再上到下,把图片拉直成一个序列的形状。把每个块中的像素点进行归一化,范围变为0到1之间,再把块里面的所有值通过一个线性变换映射到模型的维度,得到patchembedding,得到以后,我们为了做分类任务,还需要在序列的开头加上一个可训练的embedding,这个是随机初始化的。这样就构造出了一个n+1长度的序列,然后我们再加入position embedding,加上后的这个序列的表征就可以送入到TRM的encoder当中,最后取出结果中的我们加入的可训练的embedding位置上的值(输出状态),经过一个MLP,得到各个类别的概率分布,再通过一个交叉熵函数算出分类的loss,这样就完成了一个ViT模型的搭建。

六、代码实现

1.convert image to embedding vector sequence

1.通过DNN实现

import torch
import torch.nn as nn
import torch.nn.functional as F

def image2emb_naive(image,patch_size,weight):
    # image shape: bs*channel*h*w
    patch = F.unfold(image,kernel_size=patch_size,stride=patch_size).transpose(-1,-2)
    patch_embedding = patch @ weight
    return patch_embedding

# test code for image2emb
bs,ic,image_h,image_w=1,3,8,8
patch_size=4 # 每个块的大小为4*4(自定义)
model_dim=8 #将每个块映射成长度为8的向量(自定义)
patch_depth=patch_size*patch_size*ic
image=torch.randn(bs,ic,image_h,image_w) #初始化
weight=torch.randn(patch_depth,model_dim)#初始化

patch_embedding_navie=image2emb_navie(image,patch_size,weight)
print(patch_embedding_naive.shape) # [1,4,8],分成四块了,每块对应一个长度为8的向量 

2.通过CNN实现

import torch
import torch.nn as nn
import torch.nn.functional as F

def image2emb_conv(image,kernel,stride):
    conv_output=F.conv2d(image,kernel,stride=stride) # bs*oc*oh*ow
    bs,oc,oh,ow=conv_output.shape
    patch_embedding=conv_output.reshape((bs,oc,oh*ow)).transpose(-1,-2)
    return patch_embedding

# test code for image2emb
bs,ic,image_h,image_w=1,3,8,8
patch_size=4
model_dim=8
patch_depth=patch_size*patch_size*ic
image=torch.randn(bs,ic,image_h,image_w)
weight=torch.randn(patch_depth,model_dim) #model_dim是输出通道数目,patch_depth是卷积核的面积乘以输入通道数

kernel=weight.transpose(0,1).reshape((-1,ic,patch_size,patch_size)) # oc*ic*kh*kw
patch_embedding_conv=image2emb_conv(image,kernel,patch_size) # 二维卷积的方法得到embedding

2.prepend CLS token embedding

cls_token_embedding = torch.randn(1,model_dim,requires_grad=True)
token_embedding = torch.cat([[bs,cls_token_embedding],patch_embedding_conv],dim=1)

​ 提问:本身cls_token_embedding没有和任何样本矩阵有乘法联系,最后训练出来的也是一张确定的表,在做inference的时候,完全是一个常数的作用。送入transformer后,又与其他矩阵做了MHA,没搞懂用意何在啊?

​ 答:有联系啊,就是与其他时刻的sample做MHSA。这个token其实是取代了avg pool的作用,也就是说,你可以用avg pool得到分类的logits,也可以用采用cls token来得到分类的logits

注意:cls_token_embedding作为batch_size中每一个序列的开始,应该对于每一个序列的开始都torch.cat同样的一个cls_token_embedding,然后都是对这同一个cls_token_embedding进行训练,所以这里的cls token embedding应该是二维的,1*model_dim,与batchsize无关。

3.add position embedding

max_num_token=16 #自定义
position_embedding_table = torch.randn(max_num_token,model_dim,requires_grad=True)
seq_len=token_embedding.shape[1] # 刚刚的1+4
position_embedding=torch.tile(position_embedding_table[:seq_len],[token_embedding.shape[0],1,1]) # 5,bs,1,1
token_embedding += position_embedding

4.pass embedding to Transformer Encoder

encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder=nn.TransformerEncoder(encoder_layer,num_layers=6)
encoder_output=transformer_encoder(token_embedding)

5.do classification

cls_token_output=encoder_output[:,0,:] #拿到TRM的输出值
num_classes=10 # 自定义的类别数目
label=torch.randint(10,(bs,)) # 自定义的生成的label
linear_layer = nn.Linear(model_dim,num_classes) 
logits = linear_layer(cls_token_output)
loss_fn=nn.CrossEntropyLoss()
loss=loss_fn(logits,label)
print(loss)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2115129.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

LabVIEW FIFO详解

在LabVIEW的FPGA开发中,FIFO(先入先出队列)是常用的数据传输机制。通过配置FIFO的属性,工程师可以在FPGA和主机之间,或不同FPGA VIs之间进行高效的数据传输。根据具体需求,FIFO有多种类型与实现方式&#x…

SpringSecurity原理解析(二):认证流程

1、SpringSecurity认证流程包含哪几个子流程? 1)账号验证 2)密码验证 3)记住我—>Cookie记录 4)登录成功—>页面跳转 2、UsernamePasswordAuthenticationFilter 在SpringSecurity中处理认证逻辑是在UsernamePas…

iOS——线程安全、线程同步与线程通信

线程安全和线程同步 线程安全:如果你的代码所在的进程中有多个线程在同时运行,而这些线程可能会同时运行这段代码。如果每次运行结果和单线程运行的结果是一样的,而且其他的变量的值也和预期的是一样的,就是线程安全的。 若每个…

18055 主对角线上的元素之和

### 思路 1. 输入一个3行4列的整数矩阵。 2. 计算主对角线上的元素之和。 3. 输出主对角线上的元素之和。 ### 伪代码 1. 初始化一个3行4列的矩阵 matrix。 2. 输入矩阵的元素。 3. 初始化一个变量 sum 为0,用于存储主对角线元素之和。 4. 遍历矩阵的行&#xff0c…

【Day08-IO-文件字节流】

File 1. 概述 File对象既可以代表文件、也可以代表文件夹。它封装的对象仅仅是一个路径名,这个路径可以存在,也可以不存在 构造器 说明 public File​(String pathname) 根据文件路径创建文件对象 public File​(String parent, String child) 根据…

vscode中使用go环境配置细节

1、在docker容器中下载了go的sdk 2、在/etc/profile.d/go.sh里填入如下内容: #!/bin/bashexport GOROOT=/home/ud_dev/go export PATH=$GOROOT/bin:$PATH 3、设置go env go env -w GOPROXY=https://goproxy.cn,direct go env -w GO111MODULE=on 4、重启这个容器,使得vscod…

DBAPI如何使用内存缓存

背景 在使用DBAPI创建API的时候,有时候SQL查询比较耗时,如果业务上对数据时效性要求不高,这种耗时的SQL可以使用缓存插件来将数据缓存起来,避免重复查询。 一般来说,可以使用redis memcache等缓存服务来存储缓存数据。…

活动|华院计算宣晓华受邀出席“AI引领新工业革命”大会,探讨全球科技的最新趋势

8月31日,“AI引领新工业革命”大会于上海图书馆圆满落幕。本次大会由TAA校联会和台协科创工委会联合主办,得到上海市台办、上海市台联、康师傅的大力支持。大会邀请了NVIDIA全球副总裁、亚太区企业营销负责人刘念宁,元禾厚望资本创始合伙人潘…

ispunct函数讲解 <ctype.h>头文件函数

目录 1.头文件函数 2.ispunct函数使用 小心&#xff01;VS2022不可直接接触&#xff0c;否则..!没有这个必要&#xff0c;方源一把抓住VS2022&#xff0c;顷刻 炼化&#xff01; 1.头文件函数 以上函数都需要包括头文件<ctype.h> &#xff0c;其中包括 ispunct 函数 #…

esp8266+sg90实现远程开关灯(接线问题)

1需要准备的设备 首先需要的设备 硬件&#xff1a;esp8266开发板和sg90舵机&#xff0c;还有公对母的杜邦线&#xff0c;以及一根usb程序下载线。 软件&#xff1a;Arduino IDE 因为sg90舵机接口是三个连着的&#xff0c;只能用公对母的杜邦线把三条信号线接到esp8266的不同引…

Linux驱动.之字符设备驱动框架,新内核框架,设备树(二)

第一篇比较长&#xff0c;第二篇&#xff0c;继续写&#xff0c;内容有重复 一、字符设备驱动框架 在用户空间中调用open&#xff0c;打开一个字符设备&#xff0c;执行流程如下&#xff1a;最终会执行chrdev中的ops对应的open函数。

【python计算机视觉编程——8.图像内容分类】

python计算机视觉编程——8.图像内容分类 8.图像内容分类8.1 K邻近分类法&#xff08;KNN&#xff09;8.1.1 一个简单的二维示例8.1.2 用稠密SIFT作为图像特征8.1.3 图像分类:手势识别 8.2贝叶斯分类器用PCA降维 8.3 支持向量机8.3.2 再论手势识别 8.4 光学字符识别8.4.2 选取特…

面试官:你是怎么处理vue项目中的错误的?

一、错误类型 任何一个框架&#xff0c;对于错误的处理都是一种必备的能力 在Vue 中&#xff0c;则是定义了一套对应的错误处理规则给到使用者&#xff0c;且在源代码级别&#xff0c;对部分必要的过程做了一定的错误处理。 主要的错误来源包括&#xff1a; 后端接口错误代…

网络原理之TCP协议(万字详解!!!)

目录 前言 TCP协议段格式 TCP协议相关特性 1.确认应答 2.超时重传 3.连接管理&#xff08;三次握手、四次挥手&#xff09; 三次握手&#xff08;建立TCP连接&#xff09; 四次挥手&#xff08;断开连接&#xff09; 4.滑动窗口 5.流量控制 6.拥塞控制 7.延迟应答…

(入门篇)JavaScript 网页设计案例浅析-简单的交互式图片轮播

网页设计已经成为了每个前端开发者的必备技能,而 JavaScript 作为前端三大基础之一,更是为网页赋予了互动性和动态效果。本篇文章将通过一个简单的 JavaScript 案例,带你了解网页设计中的一些常见技巧和技术原理。今天就说一说一个常见的图片轮播效果。相信大家在各类电商网…

使用vscode上传git远程仓库流程(Gitee)

目录 参考附件 git远程仓库上传流程 1&#xff0c;先将文件夹用VScode打开 2&#xff0c;第一次进入要初始化一下仓库 3&#xff0c;通过这个&#xff08;.gitignore&#xff09;可以把一些不重要的文件不显示 注&#xff1a;&#xff08;.gitignore中&#xff09;可屏蔽…

AI辅助编程里的 Atom Group 的概念和使用

背景 在我们实际的开发当中&#xff0c;一个需求往往会涉及到多个文件修改&#xff0c;而需求也往往有相似性。 举个例子&#xff0c;我经常需要在 auto-coder中需要添加命令行参数&#xff0c;通常是这样的&#xff1a; /coding 添加一个新的命令行参数 --chat_model 默认值为…

基于RAG和知识库的智能问答系统设计与实现

开局一张图&#xff0c;其余全靠编。 自己画的图&#xff0c;内容是由Claude根据图优化帮忙写的。 1. 引言 在当今数字化时代&#xff0c;智能问答系统已成为提升用户体验和提高信息获取效率的重要工具。随着自然语言处理技术的不断进步&#xff0c;特别是大型语言模型&#x…

Sonarqube 和 Sonar-scanner的安装和配置

SonarQube 简介 所谓sonarqube 就是代码质量扫描工具。 官网&#xff1a; https://www.sonarsource.com/sonarqube/ 在个人开发学习中用处不大&#xff0c; 我草&#xff0c; 我的代码质量这么高需要这玩意&#xff1f; 但是在公司项目中&#xff0c; 这个可是必须的&#x…

【高校主办,EI稳定检索】2024年人机交互与虚拟现实国际会议(HCIVR 2024)

会议简介 2024年人机交互与虚拟现实国际会议&#xff08;HCIVR 2024&#xff09;定于2024年11月15-17日在中国杭州召开&#xff0c;会议由浙江工业大学主办。人机交互&#xff0c;虚拟现实技术的发展趋势主要体现在系统将越来越实际化&#xff0c;也越来越贴近人类的感知和需求…