Transformer实战-系列教程6:Vision Transformer 源码解读4

news2024/10/5 21:24:21

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4

11、Encoder类------前向传播

class Encoder(nn.Module):
    def forward(self, hidden_states):
        # print(hidden_states.shape)
        attn_weights = []
        for layer_block in self.layer:
            hidden_states, weights = layer_block(hidden_states)
            if self.vis:
                attn_weights.append(weights)
        encoded = self.encoder_norm(hidden_states)
        return encoded, attn_weights

hidden_states.shape = torch.Size([16, 197, 768])
encoded.shape = torch.Size([16, 197, 768])

  1. attn_weights ,用于存储注意力权重
  2. 循环处理每一个Block层
  3. 将隐藏状态传递给当前的Block层,获取处理后的隐藏状态和注意力权重
  4. 将注意力权重添加到attn_weights列表中
  5. 将最后一层的输出通过LayerNorm层进行归一化处理
  6. 返回归一化后的输出和(如果有的话)注意力权重列表

这段代码实现了一个编码器,能够处理序列数据并可选择性地输出每层的注意力权重,通过层层叠加的Block和最终的归一化处理,它能够有效地学习输入数据的特征表示

12、Block类------前向传播

class Block(nn.Module):
    def forward(self, x):
        h = x
        x = self.attention_norm(x)
        x, weights = self.attn(x)
        x = x + h
        h = x
        x = self.ffn_norm(x)
        x = self.ffn(x)
        x = x + h
        return x, weights
  1. x.shape = torch.Size([16, 197, 768]),输入数据
  2. h=x
  3. self.attention_norm(x).shape = torch.Size([16, 197, 768]),x经过一个层归一化,层归一化直接在torch中的nn中调用
  4. (h+self.attn(x)).shape = torch.Size([16, 197, 768])经过一个self-Attention,再加上前面的x,加上x是一个残差连接
  5. self.ffn_norm(x).shape = torch.Size([16, 197, 768]),再经过一个层归一化
  6. self.ffn(x).shape = torch.Size([16, 197, 768]),经过一个MLP类
  7. (x + h).shape = torch.Size([16, 197, 768]),再来一个残差连接

13、MLP类------前向传播

class Mlp(nn.Module):
    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x
  1. x.shape = torch.Size([16, 197, 768]),原始输入
  2. x.shape = torch.Size([16, 197, 3072]),经过第1层全连接
  3. x.shape = torch.Size([16, 197, 3072]),经过gelu经过函数
  4. x.shape = torch.Size([16, 197, 3072]),经过第1次dropout
  5. x.shape = torch.Size([16, 197, 768]),经过第2层全连接
  6. x.shape = torch.Size([16, 197, 768]),经过第2次dropout

14、Attention类------前向传播

最重要的Attention类

class Attention(nn.Module):
    def transpose_for_scores(self, x):
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
        x = x.view(*new_x_shape)
        return x.permute(0, 2, 1, 3)
    def forward(self, hidden_states):
        mixed_query_layer = self.query(hidden_states)#Linear(in_features=768, out_features=768, bias=True)
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

14.1 transpose_for_scores函数

x------torch.Size([16, 197, 12, 64])
new_x_shape------torch.Size([16, 197, 12, 64])
x.permute(0, 2, 1, 3)------torch.Size([16, 12, 197, 64])
transpose_for_scores函数将输入张量的最后一个维度拆分为num_attention_heads(注意力头数)和attention_head_size(每个头的大小),然后将维度重新排列以满足矩阵乘法的需求

14.2 前向传播

  1. hidden_states,torch.Size([16, 197, 768]),原始输入
  2. mixed_query_layer,torch.Size([16, 197, 768]),经过一层全连接生成Q向量,维度不变
  3. mixed_key_layer,torch.Size([16, 197, 768]),经过一层全连接生成K向量,维度不变
  4. mixed_value_layer,torch.Size([16, 197, 768]),经过一层全连接生成V向量,维度不变
  5. query_layer ,torch.Size([16, 12, 197, 64]),多头注意力的拆分,拆分12头,每头64维向量
  6. key_layer,torch.Size([16, 12, 197, 64]),多头注意力的拆分,拆分12头,每头64维向量
  7. value_layer,torch.Size([16, 12, 197, 64]),多头注意力的拆分,拆分12头,每头64维向量
  8. attention_scores ,torch.Size([16, 12, 197, 197]),q和k的内积结果
  9. attention_scores ,torch.Size([16, 12, 197, 197]),将注意力得分除以注意力头的大小的平方根进行缩放,防止梯度过小
  10. attention_probs ,softmax归一化
  11. weights ,如果开启可视化则保留注意力权重,否则不保留
  12. attention_probs , 对注意力权重应用dropout
  13. context_layer ,torch.Size([16, 12, 197, 64]),注意力权重重构v向量
  14. context_layer ,torch.Size([16, 197, 12, 64]),重新排列维度
  15. new_context_layer_shape ,计算重塑后的上下文表示形状
  16. context_layer ,torch.Size([16, 197, 768]),重塑上下文表示的形状,合并所有注意力头的输出,将多头进行还原
  17. attention_output ,torch.Size([16, 197, 768]),使用另一个线性层处理上下文表示,生成最终的注意力模块输出
  18. attention_output ,torch.Size([16, 197, 768]),使用另一个线性层处理上下文表示,生成最终的注意力模块输出

最后回顾一下ViT的网络架构:
在这里插入图片描述

Vision Transformer 源码解读1
Vision Transformer 源码解读2
Vision Transformer 源码解读3
Vision Transformer 源码解读4

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1436223.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mac检查CPU温度和风扇速度软件:Macs Fan Control Pro 1.5.17中文版

Macs Fan Control Pro for Mac是一款专业的电脑风扇控制工具,旨在帮助Mac用户有效控制电脑的风扇速度,提高电脑的运行效率和稳定性。 软件下载:Macs Fan Control Pro 1.5.17中文版 该软件支持多种风扇控制模式和预设方案,用户可以…

【leetcode】206. 反转链表(简单)题解学习

题目描述: 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 示例 1: 输入:head [1,2,3,4,5] 输出:[5,4,3,2,1]示例 2: 输入:head [1,2] 输出:[2,1]示例 …

VLM 系列——Llava1.6——论文解读

一、概述 1、是什么 Llava1.6 是llava1.5 的升级暂时还没有论文等,是一个多模态视觉-文本大语言模型,可以完成:图像描述、视觉问答、根据图片写代码(HTML、JS、CSS),潜在可以完成单个目标的视觉定位、名画…

【JavaScript 漫游】【006】数据类型 array

文章简介 本文为【JavaScript 漫游】专栏的第 006 篇文章,记录笔者在了解 JS 数据类型 array 中摘录的知识点。 数组的本质是对象属组的 length 属性for ... in 循环和数组的遍历数组的空位类数组对象 除了上述 5 个重要知识点,学习数组更为重要的是掌…

MySQL组复制的介绍

前言 本文介绍关于MySQL组复制的背景信息和基本原理。包括,介绍MySQL传统复制方法的原理和隐患、介绍组复制的原理,单主模式和多主模式等等。通过结合原理图学习这些概念,可以很好的帮助我们理解组复制技术这一MySQL高可用方案,有…

Linux系统中安装JDK

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

深入探究:JSONCPP库的使用与原理解析

君子不器 🚀JsonCPP开源项目直达链接 文章目录 简介Json示例小结 JsoncppJson::Value序列化Json::Writer 类Json::FastWriter 类Json::StyledWriter 类Json::StreamWriter 类Json::StreamWriterBuilder 类示例 反序列化Json::Reader 类Json::CharReader 类Json::Ch…

VScode上无法运行TSC命令,Typescript

如何解决问题 第一步:使用 winx 快捷键,会出现如下弹窗,鼠标左键单击Windows PowerShell 即可打开shell 第二步:运行 set-ExecutionPolicy RemoteSigned 命令,在询问更改执行策略的时候选择敲Y或者A 第三步&#xff…

python_蓝桥杯刷题记录_笔记_全AC代码_入门5

前言 关于入门地刷题到现在就结束了。 题单目录 1.P1579 哥德巴赫猜想(升级版) 2.P1426 小鱼会有危险吗 1.P1579 哥德巴赫猜想(升级版) 一开始写的代码是三重循环,结果提交上去一堆地TLE,然后我就给减少…

数据库学习案例20240206-ORACLE NEW RAC agent and resource关系汇总。

1 集群架构图 整体集群架构图如下: 1 数据库启动顺序OHASD层面 操作系统进程init.ohasd run启动ohasd.bin init.ohasd run 集群自动启动是否被禁用 crsctl enable has/crsGIHOME所在文件系统是否被正常挂载。管道文件npohasd是否能够被访问, cd /var/t…

Etsy做什么店铺比较靠谱? 什么叫做真人店?

作为Etsy是美国最大的原创手工在线销售平台之一,以其店铺利润高、平台佣金低、平台同质化不严重的优点在跨境电商里颇具竞争力。然而Etsy开店却不是件容易事,原因是Etsy真人店对于环境以及卖家的运营操作要求较高,下面就给各位有意入局Etsy的…

2024版细致idea解读(包含下载,安装,破解,讲解怎么使用)

前言 我们历经了对应的javase开发,使用的软件从eclipse也逐步升级到了idea,IntelliJ旗下的产品之一 内部复函很大的集成平台插件供大家使用 下载介绍 IntelliJ IDEA – 领先的 Java 和 Kotlin IDE 这个是他的网站地址 进入之后我们可以看到对应的界面…

Linux-3进程概念(一)

1.冯诺伊曼结构 1.1 冯诺依曼结构的概念 冯诺依曼结构,又称为普林斯顿结构,是一种将程序指令存储器和数据存储器合并在一起的存储器结构。程序指令存储地址和数据存储地址指向同一个存储器的不同物理位置,因此程序指令和数据的宽度相同&…

【大模型】万亿级别的大语言模型训练,基础设施如何支持

万亿级别的大语言模型训练,基础设施如何支持 前言1)培训百万亿参数的LLM是可行的,但需要每个GPU高达1 TiB的次级内存池,双向带宽为100 GB/s。2)对于1T模型的强扩展在约12288个GPU左右停滞,因为矩阵乘法变得…

C++二维数组

个人主页:PingdiGuo_guo 收录专栏:C干货专栏 大家好,我是PingdiGuo_guo,今天我们来学习二维数组。 文章目录 1.二维数组的概念与思想 2.二维数组和一维数组的区别 3.二维数组的特点 4.二维数组的操作 1.定义 2.初始化 1.直…

IPv4的公网地址不够?NAT机制可能是当下最好的解决方案

目录 1.前言 2.介绍 3.NAT机制详解 1.前言 我们都知道IPv4的地址范围是32个字节,这其中还有很多地址是不可用的.比如127.*,这些都是环回地址.那么在网路发展日新月异的今天,互联网设备越来越多,我们该如何解决IP地址不够用的问题呢?目前有一种主流的解决方案,也是大家都在用…

JVM相关-JVM模型、垃圾回收、JVM调优

一、JVM模型 JVM内部体型划分 JVM的内部体系结构分为三部分,分别是:类加载器(ClassLoader)子系统、运行时数据区(内存)和执行引擎 1、类加载器 概念 每个JVM都有一个类加载器子系统(class l…

C语言的malloc(0)问题

malloc(0)详解 首先来解释malloc(0)的问题,这个语法是对的,而且确实也分配了内存,但是内存空间是0,就是说返回给你的指针是不能用的,感觉奇怪吧?但是从操作系统的原理来解释就不奇怪…

(c语言版)开源项目热榜,某个开源社区希望将最近热度比较高的开源项目出一个榜单,推荐给社区里面的开发者。对于每个开源项目

某个开源社区希望将最近热度比较高的开源项目出一个榜单,推荐给社区里面的开发者。对于每个开源项目,开发者可以进行关注(watch)、收藏(star)、fork、提issue、提交合并请求(MR)等。 数据库里面统计了每个开源项目关注、收藏、fork、issue、MR的数量&…

3. ⼤语⾔模型深度学习背景知识

1. LLM⼤语⾔模型⼀般训练过程 #mermaid-svg-8kci1fjEPiVolPue {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-8kci1fjEPiVolPue .error-icon{fill:#552222;}#mermaid-svg-8kci1fjEPiVolPue .error-text{fill:#5522…