VIT总结

news2025/1/12 23:04:56

关于transformer、VIT和Swin T的总结

1.transformer

1.1.注意力机制

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility)每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出

Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
在这里插入图片描述
代码实现:

import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            # Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小

4.为什么需要多头
在这里插入图片描述
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。

1.2.TransformerBlock

解码端的后面两部分和编码段一样,所以打包成一个类
在这里插入图片描述

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

1.3.Encoder

关键的就是位置编码

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

2.VIT

在这里插入图片描述

Reference:

[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1294701.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis基础系列-持久化

Redis基础系列-持久化 文章目录 Redis基础系列-持久化1. 什么是持久化2. 为什么要持久化3. 持久化的两种方式3.1 持久化方式1:RDB(redis默认持久化方式)3.11 配置步骤-自动触发3.12 配置步骤-手动触发3.12 优点3.13 缺点3.14 检查和修复RDB快照文件3.15 哪些情况会触…

【华为数据之道学习笔记】3-2 基础数据治理

基础数据用于对其他数据进行分类,在业界也称作参考数据。基础数据通常是静态的(如国家、币种),一般在业务事件发生之前就已经预先定义。它的可选值数量有限,可以用作业务或IT的开关和判断条件。当基础数据的取值发生变…

小航助学2023年6月GESP_Scratch四级真题(含题库答题软件账号)

需要在线模拟训练的题库账号请点击 小航助学编程在线模拟试卷系统(含题库答题软件账号 单选题2.00分 删除编辑附件图文 答案:D 第1题高级语言编写的程序需要经过以下( )操作,可以生成在计算机上运行的可执行代码。 A、编辑B、…

MQTT 协议入门:轻松上手,快速掌握核心要点

文章目录 什么是 MQTT?MQTT 的工作原理MQTT 客户端MQTT Broker发布-订阅模式主题QoS MQTT 的工作流程开始使用 MQTT:快速教程准备 MQTT Broker准备 MQTT 客户端创建 MQTT 连接通过通配符订阅主题发布 MQTT 消息MQTT 功能演示保留消息Clean Session遗嘱消…

【C语言】7-32 刮刮彩票 分数 20

7-32 刮刮彩票 分数 20 全屏浏览题目 切换布局 作者 DAI, Longao 单位 杭州百腾教育科技有限公司 “刮刮彩票”是一款网络游戏里面的一个小游戏。如图所示: 每次游戏玩家会拿到一张彩票,上面会有 9 个数字,分别为数字 1 到数字 9&#xf…

前端自动化测试Vue中TDD和单元测试示例详解

1、简单用例入门 Vue 提供了 vue/test-utils 来帮助我们进行单元测试,创建 Vue 项目的时候勾选测试选项会自动帮我们安装 先来介绍两个常用的挂载方法: mount:会将组件以及组件包含的子组件都进行挂载shallowMount:浅挂载&…

计算机网络复习资料

一、题型 选择题(包括单选和多选,共30分。其中单选每题1分,计20分;多选每题2分,计10分) 简答题(每题5分,共20分) 分析计算题(共40分,共4题) 论述题(本题10分,共1题) 二、考试大纲[人工智能…

防止企业敏感数据泄露

敏感数据泄露是指意外或故意泄露关键信息,例如个人身份信息(PII)、支付卡信息(PCI)、受保护的电子健康信息(ePHI)和知识产权(IP),数据保护措施不足的组织会在…

点滴生活记录2

我从小跟着我爷爷奶奶,小学六年级转到县城上小学,就没跟我奶奶他们住一起了。十一回家,把奶奶接到我这住,细想,自六年级之后,就很少跟奶奶住一起了。 奶奶(间歇性)耳聋,为…

Linux 驱动开发需要掌握哪些编程语言和技术?

Linux 驱动开发需要掌握哪些编程语言和技术? 在开始前我有一些资料,是我根据自己从业十年经验,熬夜搞了几个通宵,精心整理了一份「Linux从专业入门到高级教程工具包」,点个关注,全部无偿共享给大家&#xf…

(C语言实现)高精度除法 (洛谷 P2005 A/B Problem II)

前言 本期我们分享用C语言实现高精度除法,可通过该题测试点我点我,洛谷 p2005。 那么话不多说我们开始吧。 讲解 大家还记不记得小学的时候我们是怎么做除法的?我们以1115为例。 我们的高精度除法也将采用这个思路进行,分别用两…

JavaSE基础50题:23. 数组拷贝(数组练习题)

文章目录 概述方法一:运用for循环进行拷贝方法二:Java内置方法进行拷贝方法三:指定区间进行拷贝 概述 数组拷贝。 注意: public static void main(String[] args) {int[] array1 {1,2,3,4};System.out.println(myToString(array…

python爬取 HTTP_2 网站超时问题的解决方案

问题背景 在进行网络数据爬取时,使用 Python 程序访问支持 HTTP/2 协议的网站时,有时会遇到超时问题。这可能会导致数据获取不完整,影响爬虫程序的正常运行。 问题描述 在实际操作中,当使用 Python 编写的爬虫程序访问支持 HTT…

第一课【习题】给应用添加通知和提醒

构造进度条模板通知,name字段当前需要固定配置为downloadTemplate。 给通知设置分发时间,需要设置showDeliveryTime为false。 OpenHarmony提供后台代理提醒功能,在应用退居后台或退出后,计时和提醒通知功能被系统后台代理接管…

【开源】基于Vue+SpringBoot的教学过程管理系统

项目编号: S 054 ,文末获取源码。 \color{red}{项目编号:S054,文末获取源码。} 项目编号:S054,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 教师端2.2 学生端2.3 微信小程序端2…

使用 KubeRay 和 Kueue 在 Kubernetes 中托管 Ray 工作负载

在 KubeCon CN 2023 的「 Open AI 数据 | Open AI Data」专题中,火山引擎软件工程师胡元哲分享了《使用 KubeRay 和 Kueue 在 Kubernetes 中托管 Ray 工作负载|Sailing Ray workloads with KubeRay and Kueue in Kubernetes议题。以下是本次演讲的文字…

区块链实验室(28) - 拜占庭节点劫持区块链仿真

在以前的FISCO环境中仿真拜占庭节点攻击区块链网络。该环境共有100个节点,采用PBFT作为共识机制,节点编号分别为:Node0,Node,… ,Node99。这100个节点的前2010区块完全相同,自区块2011开始分叉。…

Qt/C++音视频开发58-逐帧播放/上一帧下一帧/切换播放进度/实时解码

一、前言 逐帧播放是近期增加的功能,之前也一直思考过这个功能该如何实现,对于mdk/qtav等内核组件,可以直接用该组件提供的接口实现即可,而对于ffmpeg,需要自己处理,如果有缓存的数据的话,可以…

一文了解半导体检测的利器—探针台

探针台是半导体行业重要的检测装备之一,其广泛应用于复杂、 高速器件的精密电气测量,旨在确保质量及可靠性,并缩减研发时间和器件制造工艺的成本。 半导体测试可以按生产流程可以分为三类:验证测试、晶圆测试测试、封装检测。探针…

王炸升级!PartyRock 10分钟构建 AI 应用

前言 一年一度的亚马逊云科技的 re:Invent 可谓是全球云计算、科技圈的狂欢,每次都能带来一些最前沿的方向标,这次也不例外。在看完一些 keynote 和介绍之后,我也去亲自体验了一些最近发布的内容。其中让我感受最深刻的无疑是 PartyRock 了。…