Vision Transformer架构Pytorch逐行实现

news2025/1/11 18:37:50

前言

  • 代码来自哔哩哔哩博主deep_thoughts,视频地址,该博主对深度学习框架方面讲的非常详细,推荐大家也去看看原视频,不管是否已经非常熟练,我相信都能有很大收获。
  • 论文An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale,下载地址。开源项目地址
  • 本文不对开源项目中代码进行解析,仅使用pytorch实现ViT框架,让大家对框架有更清楚的认知。

模型框架展示

在这里插入图片描述

  • Encoder部分和Transformer中的实现方法一致,可以直接调用pytorch中的API实现(博主在前面几个视频中使用pytorch逐行写了decoderencoder,再次推荐大家去看他的视频),下面主要针对左边的部分进行实现。
  • 架构思维导图,如下图
    -在这里插入图片描述
  • 导入必要包
import torch
import torch.nn as nn
import torch.nn.functional as F
  • 定义初始变量
# batch_size, 输入通道数,图像高,图像宽
bs, ic, image_h, image_w = 1, 3, 8, 8
# 分块边长
patch_size = 4
# 输出通道数
model_dim = 8
# 最大子图片块数
max_num_token = 16
# 分类数
num_classes = 10
# 生成真实标签
label = torch.randint(10,(bs,))
# 卷积核面积 * 输入通道数
patch_depth = patch_size * patch_size * ic
# image张量
image = torch.randn(bs, ic, image_h, image_w)
# model_dim:输出通道数,patcg_depth:卷积核面积 * 输入通道数
weight = torch.randn(patch_depth, model_dim)

perspective

  • 这一部分有两种实现方式,第1种是DNN方式,利用pytorch中的unfold函数滑动提取图像块。第2种是使用2维卷积的方法,最后将特征铺平。

DNN perspective

  • 首先使用unfold函数,滑动提取不重叠的块,所以kernel_size和stride相同。
  • 再与weight进行矩阵相乘,维度变化以及每个维度意义都在注释中。
def image2emb_naive(image, patch_size, weight):
    # patch:[batch_size, patch_size * patch_size * ic, (image_h * image_w) / (patch_size * patch_size)]
    patch = F.unfold(image, kernel_size=patch_size,stride=patch_size)
    # 转置操作[batch_size, (image_h * image_w) / (patch_size * patch_size), patch_size * patch_size * ic]]
    patch = patch.transpose(-1, -2)
    # 矩阵乘法weight:[patch_size * patch_size * ic, model_dim]
    patch_embedding = patch @ weight
    return patch_embedding
  • 调用函数,得到patch_embedding,检查维度
# 得到patch_embedding:[batch_size, (image_h * image_w) / (patch_size * patch_size), model_dim]
patch_embedding_naive = image2emb_naive(image, patch_size, weight)
print(patch_embedding_naive.shape)

输出:

torch.Size([1, 4, 8])

CNN perspective

def image2emb_conv(image, kernel, stride):
    conv_output = F.conv2d(image, kernel, stride = stride)
    bs, oc, oh, ow = conv_output.shape
    # patch_embedding:[batch_size, outchannel, o_h * o_w]
    patch_embedding = conv_output.reshape((bs, oc, oh*ow))
    print(patch_embedding.shape)
    # patch_embedding:[batch_size, o_h * o_w, outchannel]
    patch_embedding = patch_embedding.transpose(-1,-2)
    print(patch_embedding.shape)
    return patch_embedding

weight = weight.transpose(0,1)
print(weight.shape)
# kernel:[outchannel, inchannel, patch_size, patch_size]
kernel = weight.reshape((-1,ic, patch_size, patch_size))
print(kernel.shape)
patch_embedding_conv = image2emb_conv(image, kernel, patch_size)
print(patch_embedding_conv.shape)

输出:

torch.Size([8, 48])
torch.Size([8, 3, 4, 4])
torch.Size([1, 8, 4])
torch.Size([1, 4, 8])
torch.Size([1, 4, 8])

class token embedding

  • 随机生成cls_token_emnedding,并将其设为可训练参数。沿着图片块数维度进行拼接,检查cls_token_emneddingtoken_embedding维度。
# CLS token embedding
# cls_token_emnedding:[batch_size,1,mode_dim]
cls_token_emnedding = torch.randn(bs, 1, model_dim, requires_grad=True)
# 沿着图片块数维度进行拼接
token_embedding = torch.cat([cls_token_emnedding, patch_embedding_naive], dim=1)
print(cls_token_emnedding.shape)
print(token_embedding.shape)

输出:

torch.Size([1, 1, 8])
torch.Size([1, 5, 8])

position embedding

  • 创建pos embedding:[max_num_token,model_dim],然后使用tile函数进行增自我拼接,重复batch_size次。
# add position embedding
# 创建pos embedding:[max_num_token,model_dim]
position_embedding_table = torch.randn(max_num_token, model_dim, requires_grad = True)
# 取图片块数维度
seq_len = token_embedding.shape[1]
# tile增自我拼接,dims参数指定每个维度中的重复次数,dims = [batch_size,1,1]
position_embedding = torch.tile(position_embedding_table[:seq_len], [token_embedding.shape[0],1,1])
print(position_embedding.shape)

Transformer Encoder部分

  • 实例化TransformerEncoderLayer,再实例化TransformerEncoder,得到Encoder输出。
# pass embedding to Transformer Encoder
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim,nhead=8)
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
enconder_output = transformer_encoder(token_embedding)
print(enconder_output.shape)

classification head

  • 取出pos embedding维,经过线性层,对输出计算交叉熵损失
# 取出第1个图片块数维度,就是pos embedding维
cls_token_output = enconder_output[:,0,:]
# 实例化线性层model_dim --> num_classes
linear_layer = nn.Linear(model_dim, num_classes)
# 得到线性层输出
logits = linear_layer(cls_token_output)
# 交叉熵损失
loss_fn = nn.CrossEntropyLoss()
# 计算交叉熵损失
loss = loss_fn(logits,label)
print(loss)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/480158.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

iOS审核这些坑,腾讯游戏也踩过

WeTest 导读 在App上架苹果应用商店的过程中,相信大多数iOS开发者往往都有过这样的经历:辛苦开发出来的产品,测试验收也通过了,满怀期待的提交App给苹果审核,结果经常被苹果各种理由拒之门外,苦不堪言。 …

Prometheus监控系统存储容量优化攻略,让你的数据安心保存!

云原生监控领域不可撼动,Prometheus 是不是就没缺点?显然不是。 一个软件如果什么问题都想解决,就会导致什么问题都解决不好。所以Prometheus 也存在不足,广受诟病的问题就是 单机存储不好扩展。 1 真的需要扩展容量吗&#xff…

0x80070570文件或目录损坏且无法读取解决方法

第一种解决方法:命令提示符修复。 1、首先按下“Win标R”键,打开运行。 2、然后如果要修复的文件在E盘,那就输入:chkdsk e: /f,h盘就是:chkdsk h: /f,反正是哪个盘就把中间的字幕改成那个盘的…

ecs思考

VPC网络诊断,从router看起,连接公有子网路有一个默认,再新增一条指向igw路由;连接私有子网路由有一个默认,再新增一条指向NAT网关的路由,其中NAT网关一定要在公有子网中,否则,私有子…

Android 10.0 设置默认浏览器后安装另外浏览器后默认浏览器功能修复

1.前言 在10.0的系统rom定制化开发中,当在系统中有多个浏览器的时候,会在用代码启用浏览器的时候,让用户选择进入哪个浏览器,这样显得特别的不方便 所以产品开发中,要求用RoleManager的相关api来设置默认浏览器,但是在设置完默认浏览器以后,在安装一款浏览器的时候,默认…

〔金融帝国实验室〕(Capitalism Lab)v9.0.00官方重大版本更新!

〖金融帝国实验室〗(Capitalism Lab)v9.0.00正式发布! ◎制作发行:Enlight Software ◎发布时间:2023年04月28日 ————————————— ※v9.0.00更新说明: 1.实现6项数据信息双窗口并列显示&#…

兴寿镇“春踏青,兴寿行”特色旅游线路点靓辛庄

记者:云飞 踏着欢乐的节拍,伴着春日的暖阳,2023年4月29日,北京市昌平区兴寿镇,2023党建引领文旅农产业融合发展系列旅游季——“春踏青,兴寿行”特色旅游线路第二站,在兴寿镇辛庄村圆满举办。 此…

【搭建私有云盘】无公网IP,在外远程访问本地微力同步

文章目录 1.前言2. 微力同步网站搭建2.1 微力同步下载和安装2.2 微力同步网页测试2.3 cpolar的安装和注册 3.本地网页发布3.1 Cpolar云端设置3.2 Cpolar本地设置 4. 公网访问测试5. 结语 1.前言 私有云盘作为云存储概念的延伸,虽然谈不上多么新颖,但是其…

《QDebug 2023年4月》

一、Qt Widgets 问题交流 二、Qt Quick 问题交流 1.对 qml 基本类型 list 的编辑 在 Qt5 中,QML 的 list 类型只提供了 push 添加数据,或者重新赋值,没法 pop。到了 Qt6,实测可以对 list 调用 pop/shift 等操作。 Qt5 中可以先…

【Liunx】进程的程序替换——自定义编写极简版shell

目录 进程程序替换[1~5]1.程序替换的接口(加载器)2.什么是程序替换?3.进程替换的原理4.引入多进程5.系列程序替换接口的详细解析(重点!) 自定义编写一个极简版shell[6~8]6.完成命令行提示符7.获取输入的命令…

Docker 架构

Docker 架构 简介Docker daemon (守护进程)Docker client (客户端)Docker registries (仓库)Images (镜像)Containers (容器)The underlying technology &…

前缀和 技巧小记

前缀和 子数组的元素之和:一维前缀和子矩阵的元素之和:二维前缀和前缀和 哈希表:寻找和为 target 的子数组 子数组的元素之和:一维前缀和 前缀和适用于快速、频繁地计算一个索引区间内的元素之和。 int res 0; // 存储区间[…

链表:常见面试题-拷贝特殊链表

题目: 一种特殊的单链表节点类描述如下: class Node { int value; Node next; Node rand; Node(int val) {value val} } rand指针是单链表节点结构中新增的指针,rand可能指向链表中的任意一个节点(包括自己),也可…

计算机电脑中了勒索病毒怎么办,Windows系统中了faust勒索病毒解密数据恢复

电脑的操作系统被恶意软件攻击已不再是新鲜的话题了。而攻击的恶意软件中有一种叫做faust勒索病毒,常常袭击Windows电脑系统。如果我们的电脑在使用Windows操作系统时感染了faust勒索软件,请不要慌张,我们可以咨询专业的数据恢复厂商&#xf…

深度学习技巧应用11-模型训练中稀疏化参数与稀疏损失函数的应用

大家好,我是微学AI,今天给大家介绍一下深度学习技巧应用11-模型训练中稀疏化参数与稀疏损失函数的应用,在训练神经网络的过程中,将稀疏损失加入到常规损失函数的作用主要是降低模型复杂性和提高模型泛化能力。通过引入稀疏性约束,优化算法会在减小常规损失的同时,尽量让参…

快速上手非关系型数据库Redis

一、Redis介绍 1.非关系型数据库,纯内存操作,key-value存储,性能很高,可持久化(内存---->保存到硬盘上) 2.缓存,计数器,验证码,geo地理位置信息,发布订阅…

【前端知识】Cookie, Session,Token和JWT的发展及区别(上)

【前端知识】Cookie, Session,Token和JWT的发展及区别(上) 1. 背景2. Cookie2.1 Cookie的定义2.2 Cookie的特点2.3 Cookie的一些重要属性✨2.3.1 Cookie的重要属性🎇2.3.2 Cookie的有效期,max-age和作用域,…

SQL注入(一)联合查询 报错注入

目录 1.sql注入漏洞是什么 2.联合查询: 2.1注入思想 2.2 了解information_schema 数据库及表 3.可替代information_schema的表 3.1 sys库中重要的表 4. 无列名注入 利用 join-using 注列名。 4. 报错注入 4.1 常用函数:updatexml、extractvalue…

Java 基础进阶篇(五)—— 接口详解

文章目录 一、接口概述二、接口的基本使用三、接口从 JDK 8 开始新增的方法四、接口的注意事项(了解)补充:接口与接口的关系 一、接口概述 规范的基本特征是约束和公开。 接口就是一种规范,其约束别人必须干什么事情。 所以&…

FileZilla读取目录列表失败(vsftpd被动模式passive mode部署不正确)

文章目录 现象问题原因解决方法临时解决(将默认连接方式改成主动模式)从根本解决(正确部署vsftpd的被动模式) 现象 用FileZilla快速连接vsftpd服务器时,提示读取目录列表失败 问题原因 是我vsftpd服务端的被动模式没…