Transformer---ViT:vision transformer

news2025/1/27 13:06:48

记录一下对transformer方法在计算机视觉任务中的应用方法的理解
参考博客:https://blog.csdn.net/weixin_42392454/article/details/122667271
参考代码:https://gitcode.net/mirrors/Runist/torch_vision_transformer?utm_source=csdn_github_accelerator

模型训练流程:

import torch
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

optimizer = optim.SGD(params, lr=args.lr, momentum=0.9, weight_decay=5e-5)
lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

optimizer.zero_grad()
logits = model(images)
loss = loss_function(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()

模型网络图
在这里插入图片描述

一.模型训练

假设输入维度为[B, C, H, W],假设其中C=3,H=224,W=224

1.x = self.patch_embed(x) # [B, 196, 768]

假设patch_size=16,
则:num_patches=(H/patch_size)(W/patch_size)=(224/16)(224/16)=1414=196
embed_dim=C
patch_sizepatch_size=316*16=768

# [B, C, H, W] -> [B, num_patches, embed_dim]=[B,  196, 768]

具体的流程:

image_size=224, patch_size=16, in_c=3, embed_dim=768
# The input tensor is divided into patches using 16x16 convolution
self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

# forward
#self.proj:[B, C, image_size, image_size]=[B, 3, 224, 224] -> [B, embed_dim, H, W]=[B, 768, 14, 14],后续操作中C=embed_dim=768,H=W=14
#flatten: [B, C, H, W] -> [B, C, HW]=[B, 768, 14*14]=[B, 768, 196]
# transpose: [B, C, HW] -> [B, HW, C]=[B, 196, 768],需要在最后一个维度对embed_dim=768维度进行norm,故做transpose
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)

2.cls_token = self.cls_token.expand(x.shape[0], -1, -1) # [1, 1, 768] -> [B, 1, 768]

3.x = torch.cat((cls_token, x), dim=1) # [B, 197, 768]

增加一个class token(分类层),数据格式和其他token一样,长度为768的向量,与位置编码的融合方式不一样,这里做的是Concat,这样做是因为分类信息是在后面需要取出来单独做预测的,所以不能以Add方式融合,shape也就从[196, 768]变为[197, 768].

4.x = self.pos_drop(x + self.pos_embed) # [B, 197, 768]

self.pos_embed   #[1,  num_patches + self.num_tokens,  embed_dim] = [1, 196+1, 768] = [1, 197, 768]

tensor broadcast:广播机制,self.pos_embed 的第一维是1,x的第一维是B,相加时,会对x的第一维的每个通道都加self.pos_embed

5. x = self.blocks(x) :# [B, 197, 768]

1) x = x + self.drop_path(self.attn(self.norm1(x)))

norm_layer=nn.LayerNorm
self.norm1 = norm_layer(dim)
self.drop_path:随机drop一个完整的block
self.attn: # [B, 197, 768]

[batch_size, num_patches + 1, total_embed_dim]
B, N, C = x.shape   #B=batch_size, N=num_patches + 1=196+1=197, C=total_embed_dim=768

#1)获得q, k, v
# qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim] = [B, 197, 3*768]
# reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head] = [B, 197, 3, 8, 768/8], 其中假设num_heads=8
# permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head] = [3, B, 8, 197, 768/8]
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# [batch_size, num_heads, num_patches + 1, embed_dim_per_head] = [B, 8, 197, 768/8]
q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

#2)计算注意力权重:w = softmax((q@k)*scale)
# transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1] = [B, 8, 768/8, 197]
# @: multiply -> [batch_size, num_heads, num_patches + 1, num_patches + 1] = [B, 8, 197, 197]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

#3)计算注意力得分s = w@v
# @: multiply -> [batch_size, num_heads, num_patches + 1, embed_dim_per_head] = [B, 8, 197, 768/8]
# transpose: -> [batch_size, num_patches + 1, num_heads, embed_dim_per_head] = [B, 197, 8, 768/8]
# reshape: -> [batch_size, num_patches + 1, total_embed_dim] = [B, 197, 768]
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

2) x = x + self.drop_path(self.mlp(self.norm2(x)))

self.norm2 = norm_layer(dim)
self.drop_path:随机drop一个完整的block
self.mlp: # [B, 197, 768]

in_features=dim=embed_dim #768
mlp_ratio=4.
mlp_hidden_dim = int(dim * mlp_ratio) #7687*4
hidden_features=mlp_hidden_dim
out_features = in_features #768

self.fc1 = nn.Linear(in_features, hidden_features) #(768, 768*4)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features) (768*4, 768)
self.drop = nn.Dropout(drop)

x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)

6.x = self.norm(x) :# [B, 197, 768]

norm_layer = nn.LayerNorm
self.norm = norm_layer(embed_dim)

7.x = self.head(x) :# [B, 197, 1000]

self.num_features = self.embed_dim = embed_dim # 768
num_classes=1000
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

二.推理预测

 with torch.no_grad():
        # predict class
        output = torch.squeeze(model(image.to(device))).cpu()
        predict = torch.softmax(output, dim=0)  #将所有输出归一化映射为概率分布,概率和为1
        index = torch.argmax(predict).numpy()   #最大概率所在位置索引,即属于1000个类别中的哪一类
    # 输出最大类别
    print("prediction: {}   prob: {:.3}\n".format(args.label_name[index],
                                                predict[index].numpy()))
    #输出每个类别的得分
    for i in range(len(predict)):
        print("class: {}   prob: {:.3}".format(args.label_name[i],
                                               predict[i].numpy()))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/905494.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【数据结构与算法】克鲁斯卡尔算法

克鲁斯卡尔算法 介绍 克鲁斯卡尔(Kruskal)算法是用来求加权连通图的最小生成树的算法。基本思想:按照权值从小到大的顺序选择 n - 1 条边,并保证这 n - 1 条边不构成回路。具体做法:首先构造一个只含 n 个顶点的森林…

文本三剑客sed grep awk

目录 1、sed 1.1、基本用法 1.2、sed脚本格式 1.3、搜索与替换 1.4、变量 2、awk 2.1、基础用法 2.2、常见的内置变量 2.3、模式 2.4、判断 2.5、for计算 2.6、数组 3、grep 1、sed sed 即 Stream EDitor,和 vi 不同,sed是行编辑器 Sed是从…

leetcode刷题之283:移动零

问题 实现思路 首先, 将dest指向-1 位置, cur指向下标为0 的位置, 在cur遍历的过程中: 1) 遇到非零元素则与下标dest1 位置的元素交换, 2) 若遇到零元素则只继续cur遍历. 下标为1 的位置上是 非零元素 执行1) 交换得到右图结果 随后cur 得到下图结果 下标为2 的位置上是零…

day-27 代码随想录算法训练营(19)part03

78.子集 画图分析: 思路:横向遍历,每次遍历的时候都进行一次添加,然后进行纵向递归,递归完之后进行回溯。 注意:空集也是子集。 90.子集|| 分析:和上题一样,区别在于有重复数字 …

LeetCode283.移动零

这道题还是很简单的,我用的是双指针,左指针i从头开始遍历数组,右指针j是从i后面第一个数开始遍历,当左指针i等于0的时候,右指针j去寻找i右边第一个为0的数和i交换位置,交换完了就break内层循环,…

STM8遇坑[EEPROM读取debug不正常release正常][ STVP下载成功单运行不成功][定时器消抖莫名其妙的跑不通流程]

EEPROM读取debug不正常release正常 这个超级无语,研究和半天,突然发现调到release就正常了,表现为写入看起来正常读取不正常,这个无语了,不想研究了 STVP下载不能够成功运行 本文摘录于:https://blog.csdn.net/qlexcel/article/details/71270780只是做学习备份之…

每周AI大事件 百度文心一言上线搜索、文生视频、图表制作等5大插件

每周AI大事件 | 百度文心一言上线搜索、文生视频、图表制作等5大插件 文章目录 一、百度文心一言简介二、百度文心一言五大插件功能详解三、 开启文心一言 体验览卷文档E言易图 (貌似不太理想,可能指令姿势不对)说图解画(貌似不太…

「第2讲」正版PyCharm但是免费,安装教程来了,还有中文插件哦~

大家好,这里是程序员晚枫。 免费的【50讲Python自动化办公】持续更新中,关注我学习吧👇想了解更多精彩内容,快来关注程序员晚枫 上一讲:「第1讲」Python的下载、安装和卸载,有手就能学 装完了Python&#…

char *str,char str,char * str和char str的区别

1.char *str是一个指向字符或字符串的指针&#xff0c;总是指向一个字符的起始地址&#xff0c;例如 char *str "Hello"; cout << *str << endl; // 输出&#xff1a;H cout << str << endl; // 输出&#xff1a;Hello str "World…

5.4 webrtc的线程

那今天呢&#xff1f;我们来了解一下webrtc中的threed&#xff0c;首先我们看一下threed的类&#xff0c;它里边儿都含了哪些内容&#xff1f;由于threed的类非常大啊&#xff0c;我们将它分成两部分。 那第一部分呢&#xff0c;是我们看threed的类中都包含了哪些数据之后呢&a…

linux设备驱动:kset、uevent、class

目录 kset&#xff1a;驱动的骨架 kset_create_and_add()函数 设备驱动模型实验2-kobject点灯&#xff08;加入kset&#xff09; kset.c文件 Makefile文件 执行过程 uevent&#xff1a;内核消息的快递包 uevent机制 kobject_uevent()函数 设备驱动模型实验3-kobject点…

AMBA总线协议(3)——AHB(一)

目录 一、前言 二、什么是AHB总线 1、概述 2、一个典型的基于AHB总线的微处理器架构 3、基本的 AHB 传送特性 三、AMBA AHB总线互联 四、小结 一、前言 在之前的文章中我们初步的了解了一下AMBA总线中AHB,APB,AXI的信号线及其功能&#xff0c;从本文开始我们…

NOIP2014普及组复赛 珠心算测验 螺旋矩阵 真题答案

珠心算测验 说明 珠心算是一种通过在脑中模拟算盘变化来完成快速运算的一种计算技术。珠心算训练&#xff0c; 既能够开发智力&#xff0c;又能够为日常生活带来很多便利&#xff0c;因而在很多学校得到普及。 某学校的珠心算老师采用一种快速考察珠心算加法能力的测验方法。他…

wustoj2006后天

#include <stdio.h> int main() {int n;scanf("%d",&n); printf("%d",(n2)%7);return 0;}

星际争霸之小霸王之小蜜蜂(一)--窗口界面设计

目录 前言 一、安装pygame库 1、pygame库简介 2、在windows系统安装pygame库 二 、搭建游戏框架 1、创建游戏窗口 2、改变窗口颜色 总结 前言 大家应该都看过或者都听说过python神书“大蟒蛇”&#xff0c;上面有一个案例是《外星人入侵》&#xff0c;游戏介绍让我想起了上…

上位机系统(系统的架构、串口的使用、协议的定义、开发环境的配置)

上位机系统 1. 系统架构 实机拓扑架构 硬件支持 使用 VSPD 6.9 实现&#xff1a; 效果图 当状态值超过警戒值&#xff0c;就会变成红色&#xff0c;同时在界面的上方显示红色的“设备告警” 3. 串口电气特性 波特率&#xff1a;19200 数据位数&#xff1a;8 位 u 奇偶校验&…

shell脚本之函数

shell函数 函数的组成&#xff1a;函数名和函数体 函数的格式 function 函数名 { 命令序列 } function cat {cat /etc/passwd}函数名() { 命令序列 } cat () {cat /etc/passwd}function 函数名 (){ 命令序列 } function cat() {cat /etc/passwd}函数相关命令 declare -F #查…

记录每日LeetCode 2236. 判断根结点是否等于子结点之和 Java实现

题目描述&#xff1a; 给你一个 二叉树 的根结点 root&#xff0c;该二叉树由恰好 3 个结点组成&#xff1a;根结点、左子结点和右子结点。 如果根结点值等于两个子结点值之和&#xff0c;返回 true &#xff0c;否则返回 false 。 初始代码&#xff1a; /*** Definition f…

Cpp学习——类与对象3

目录 一&#xff0c;初始化列表 1.初始化列表的使用 2.初始化列表的特点 3.必须要使用初始化列表的场景 二&#xff0c;单参数构造函数的隐式类型转换 1.内置类型的隐式类型转换 2. 自定义类型的隐式类型转换 3.多参数构造函数的隐式类型转换 4.当你不想要发生隐式类型转换…