AOT源码解析4.3-model主体解析

news2024/9/24 18:16:41

1.添加参考图像(add_reference_frame)

1.1 生成位置编码和ID编码

具体操作见详情。
在这里插入图片描述

图1:如图所示,显示的是参考图像的位置编码和id编码的生成过程。对于id编码,将mask图像输入进conv2d卷积网络后,进行结构转换,得到相应的id编码。对于位置编码,①根据最后一个比例特征图的高度和宽度生成高宽位置索引,索引值是0~29。②根据高宽的位置缩影得到xy两个维度的位置编码,分别命名为grid_x和gird_y。③在位置编码的前后分别扩充一个维度后,将位置编码除以温度变量dim_t。④计算位置编码偶数位的sin值和奇数位的cos值。⑤合并两个位置编码

位置编码和ID编码的生成步骤如上图所示。

1.2 reference mask匹配和预测其余帧mask

1.2.1 LSTT输入数据

在这里插入图片描述

图2,如图所示,是LSTT结构输入的数据,它们分别为curr-enc-embs第一个列表(这个列表存储的是相关帧的图像特征)的最后一个特征图、最后一个特征图的位置编码矩阵以及one-hot-mask的ID编码矩阵。

1.2.2 LSTT整体结构

在这里插入图片描述

图3,如图所示,是LSTT结构的概览,接下来将会逐一分解各个模块。注意,图上的spilit是将shape为[900,4,256]的特征图分解为[4,256,30,30]。

1.2.3 self-attention结构

在这里插入图片描述

图4, 如图4所示是self-attnetion结构示意图。

在self-attention结构中,将当前特征图进行层归一化后,得到v分支;将当前特征图与位置编码进行相加后得到K和Q分支。
将K\Q\V分支的特征图经过Linear线性层和多头分割后,原本的[900,4,256]结构的特征图的第三个维度被切分成32*8,因此得到[4,8,900,32]大小的特征图。
再将K、Q分支的多头特征图进行相乘和softmask操作,得到包含自注意力权重的特征图atten,(自注意力权重即图像本身应重点关注的区域,这些区域的特征图权重高。反之不需要关注的区域权重低)。
最后将自注意力权重特征图atten与V分支的特征图相乘,得到经过自注意力权重加权后的特征图output。
关于自注意力权重部分,以两个相同的3*3的矩阵相乘为例(分别为矩阵A和矩阵B),A和B相乘后得到矩阵C。矩阵A和矩阵B的值在神经网络中是一个个向量,所以矩阵A和矩阵B中的值越相似,相乘后权重越高,反之权重越低。因此矩阵C中的值Cij可以看作矩阵A第i行和矩阵B的第j列的相关程度。

    def forward(self, Q, K, V):
        """
        :param Q: A 3d tensor with shape of [T_q, bs, C_q]
        :param K: A 3d tensor with shape of [T_k, bs, C_k]
        :param V: A 3d tensor with shape of [T_v, bs, C_v]
        """
        num_head = self.num_head
        hidden_dim = self.hidden_dim

        bs = Q.size()[1]

        # Linear projections
        if self.use_linear:
            Q = self.linear_Q(Q)
            K = self.linear_K(K)
            V = self.linear_V(V)

        # Scale
        Q = Q / self.T

        if not self.training and self.max_mem_len_ratio > 0:
            mem_len_ratio = float(K.size(0)) / Q.size(0)
            if mem_len_ratio > self.max_mem_len_ratio:
                scaling_ratio = math.log(mem_len_ratio) / math.log(
                    self.max_mem_len_ratio)
                Q = Q * scaling_ratio

        # Multi-head
        Q = Q.view(-1, bs, num_head, self.d_att).permute(1, 2, 0, 3)
        K = K.view(-1, bs, num_head, self.d_att).permute(1, 2, 3, 0)
        V = V.view(-1, bs, num_head, hidden_dim).permute(1, 2, 0, 3)

        # Multiplication
        QK = multiply_by_ychunks(Q, K, self.qk_chunks)
        if self.use_dis:
            QK = 2 * QK - K.pow(2).sum(dim=-2, keepdim=True)

        # Activation
        if not self.training and self.top_k > 0 and self.top_k < QK.size()[-1]:
            top_QK, indices = torch.topk(QK, k=self.top_k, dim=-1)
            top_attn = torch.softmax(top_QK, dim=-1)
            attn = torch.zeros_like(QK).scatter_(-1, indices, top_attn)
        else:
            attn = torch.softmax(QK, dim=-1)

        # Dropouts
        attn = self.dropout(attn)

        # Weighted sum
        outputs = multiply_by_xchunks(attn, V,
                                      self.qk_chunks).permute(2, 0, 1, 3)

        # Restore shape
        outputs = outputs.reshape(-1, bs, self.d_model)

        outputs = self.projection(outputs)

        return outputs, attn

1.2.4 Fuse ID

        if curr_id_emb is not None:
            global_K, global_V = self.fuse_key_value_id(
                curr_K, curr_V, curr_id_emb)

    def fuse_key_value_id(self, key, value, id_emb):
        K = key
        V = self.linear_V(value + id_emb)
        return K, V

如上代码可以发现global_k=curr_k。
global_v 等于curr_v和id编码矩阵相加后,再输入进linear层。
在这里插入图片描述

图5,如图所示是globalv和global k的生成步骤

1.2.5 Long term attention

Long term attention的结构和1.2.3节的self-attention结构一致,都是MultiHeadAttention结构。MultiHeadAttention的特色就在于将Q、K、V分支分解成多个Block后再计算注意力权重和加权后的特征图。

再Long term attention中,Q K V分支分别是 curr_Q, global_K, global_V。这里的原始特征图是经多头自注意力机制加权后的特征图tgt。将原始特征图经过LN和lnear层后得到curr_Q和curr_K,原始特征图先经过LN层得到curr_V,curr_V和ID编码矩阵相加后再输入进Linear层得到global_V。

这里的curr KQV和 global KQV被写的很复杂,但实际上就是:
KQV分支都是原始特征图经过LN+Linear层的结果,这里的是curr 信息
而V分支由于要添加ID编码矩阵,因此在LN层后,先添加ID矩阵,再进行Linear操作。因此添加了ID编码矩阵的V分支成为了global信息

值得注意的是,位置编码矩阵被添加到KQ分支,而ID编码矩阵被添加到V分支。这里我的理解是,有助于注意力理解的信息添加到KQ分支,用于计算相关性权重;而与注意力理解无关的,是特征图本身的信息被添加到V分支中。

由Long term attention的KQV分支的输入可知,Long Term Attention就是添加了位置编码和ID编码后,进行了两次自注意力操作。

1.2.6 Short term attention

Short Term Attention的KQV输入是,将Lone term attention的KQV输入分解成块后,再进行注意力计算。
因此在这里Lone term attention和Short Term Attention的区别在于,KQV分支的特征图是否被分解。

1.2.7 Feed - Forward

这里是将Short Term Attention的输出和Lone term attention的输出与原始特征图(原始特征图假设为tgt)相加后,再进行一些了LN、GN、Conv和线性操作(这里得到的假设为tgt2).最后将tgt与tgt2相加便完成了Feed Forward。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2161059.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

容器化安装Jenkins部署devops

基础环境介绍 系统使用的是centos7.9 内核使用的是5.16.13-1.el7.elrepo.x86_64 容器使用的是26.1.4 docker-compose使用的是 v2.29.0 链路图 devops 配置git环境插件 部署好jenkins后开始配置 jenkins连接git&#xff0c;这里需要jenkins有连接git的插件。在已安装的插件…

【SD教程】图片也能开口说话?别惊讶!用SadTalker插件,一键生成自己的数字人,本地部署,免费使用!(附资料)

最近数字人越来越火&#xff0c;连互联网大佬都纷纷下场&#xff0c;比如360的周鸿祎&#xff0c;京东的刘强东等等。小伙伴可能也想拥有自己的数字人如果想用最简单的方式&#xff0c;那么可以用第三方的网站&#xff0c;例如 HeyGen平台、腾讯的智影等等。可这些网站都是收费…

HFSS中看TDR波形详细设置以及相关的解释

时域反射测量&#xff08;TDR&#xff09;中心思想就是用阶跃函数作为激励&#xff0c;应用在模型上&#xff0c;并检查反射随时间的变化。在检查时域之前&#xff0c;必须对driven solution&#xff08;Modal、Terminal或Transient&#xff09;执行插值扫描。然后&#xff0c;…

vite分目录打包以及去掉默认的.gz 文件

1.vite打包情况介绍&#xff1a; 1.1vite在不进行任何配置的情况下&#xff0c;会将除开public的所有引用到资源打包编译添加哈希值至assets文件夹中&#xff08;非引用文件以及行内样式图片未被打包编译资源会被treeSharp直接忽略不打包&#xff09;&#xff0c;     1.2w…

阿里云函数计算 x NVIDIA 加速企业 AI 应用落地

作者&#xff1a;付宇轩 前言 阿里云函数计算&#xff08;Function Compute, FC&#xff09;是一种无服务器&#xff08;Serverless&#xff09;计算服务&#xff0c;允许用户在无需管理底层基础设施的情况下&#xff0c;直接运行代码。与传统的计算架构相比&#xff0c;函数…

极星Polestar EDI 项目案例

近期国内汽车行业供应商J公司收到了极星Polestar的邀请&#xff0c;需要通过EDI与其国内工厂传输业务数据。本案例将为大家介绍对接过程以及实施方案。 梳理需求文档 极星Polestar的EDI需求与Volvo一样&#xff0c;传输协议选择 OFTP&#xff0c;报文标准为EDIFACT&#xff0…

Swing模拟银行柜台系统

> 这是一个基于JavaSwing实现的模拟银行柜台系统。 > 具有管理员、柜员、客户三种登录角色。 > 支持开户、注册、存取款、转账、汇款、账单查询等功能。 > 本项目适合JAVA初学者作为入门学习项目。 一、部分界面演示 二、基础依赖 技术/框架版本描述Java11编…

Vue前端浏览器指纹获取:数字世界的身份密码

程序员必备宝典https://tmxkj.top/#/一个开源的JavaScript库&#xff0c;它通过收集用户浏览器的多种属性&#xff08;如屏幕分辨率、浏览器插件、字体、Canvas和WebGL等&#xff09;来生成一个独特的浏览器指纹&#xff0c;用于识别和追踪用户。 #Github地址 GitHub - finger…

Uniapp时间戳转时间显示/时间格式

使用uview2 time 时间格式 | uView 2.0 - 全面兼容 nvue 的 uni-app 生态框架 - uni-app UI 框架 <text class"cell-tit clamp1">{{item.create_time}} --- {{ $u.timeFormat(item.create_time, yyyy-mm-dd hh:MM:ss)}} </text>

apply、call和bind的作用和区别

apply与call 首先介绍一下apply与call&#xff0c;因为这两个方法的功能和使用方式都差不多&#xff0c;只是传参的方式不同。call和apply的作用都是改变函数运行时的上下文&#xff08;context&#xff09; 语法 fun.call(thisArg, arg1, arg2, ...)fun.apply(thisArg, arg…

类的难疑点

一、知识点 1、类的属性和对象属性&#xff08;实例属性&#xff09; shuxing"123" self.shuxing"123" 2、类的对象 self.loginMyclass() loginMyclass() 3、访问类属性和方法的操作 通过“类名.属性”访问&#xff1a;Myclass.shuxing 通…

详解常见排序

目录 ​编辑 插入排序 希尔排序&#xff08;缩小增量排序&#xff09; 选择排序 冒泡排序 堆排序 快速排序 hoare版 挖坑法 前后指针法 非递归版 归并排序 递归版 非递归版 计数排序 声明&#xff1a;以下排序代码由Java实现&#xff01;&#xff01;&#xff01…

【研赛D题成品论文】24华为杯数学建模研赛D题成品论文(第一问)+可运行代码丨免费分享

2024华为杯研究生数学建模竞赛D题精品成品论文已出&#xff01; D题 大数据驱动的地理综合问题 一、问题分析 问题一&#xff1a;目标&#xff1a;利用1990-2020年的数据&#xff0c;针对降水量和土地利用的时空演化特征进行描述。数据&#xff1a;两个核心变量&#xff0c;一…

电商效果图渲染神器:轻松高效出图

在这个电商行业飞速发展的今天&#xff0c;离不开商品图的效果。而电商效果图同样离不开渲染&#xff0c;而大量的渲染需求有需要大量的机器&#xff0c;还要追求更快的渲染速度和更稳定的性能。毕竟&#xff0c;谁不想快点完成项目又省心呢&#xff1f; 而云渲染服务是个很好…

C++之STL—deque容器

双端数组 区别于 vector (单端数组)&#xff0c; 构造函数 注意&#xff1a;读取数据时&#xff0c;const修饰保证函数内只能读取&#xff0c;不能修改数据 void print(const deque<int>& deq) {for (deque<int>::const iterator it deq.begin(); it ! deq.e…

使用 Nuxt Kit 的构建器 API 来扩展配置

title: 使用 Nuxt Kit 的构建器 API 来扩展配置 date: 2024/9/24 updated: 2024/9/24 author: cmdragon excerpt: 摘要:本文详细介绍了如何使用 Nuxt Kit 的构建器 API 来扩展和定制 Nuxt 3 项目的 webpack 和 Vite 构建配置,包括扩展Webpack和Vite配置、添加自定义插件、…

正向科技|格雷母线定位系统的设备接线安装示范

格雷母线安装规范又来了&#xff0c;这次是设备接线步骤 格雷母线是格雷母线定位系统的核心部件&#xff0c;沿着移动机车轨道方向上铺设&#xff0c;格雷母线以相互靠近的扁平状电缆与天线箱电磁偶合来进行信号传递&#xff0c;从而检测得到天线箱在格雷母线长度方向上的位置。…

OpenLayers 开源的Web GIS引擎 - 添加地图控件地图控件

中心点按钮、地图放大缩小滑块、全图和比例尺控件 直接上代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.…

python爬虫案例——腾讯网新闻标题(异步加载网站数据抓取,post请求)(6)

文章目录 前言1、任务目标2、抓取流程2.1 分析网页2.2 编写代码2.3 思路分析前言 本篇案例主要讲解异步加载网站如何分析网页接口,以及如何观察post请求URL的参数,网站数据并不难抓取,主要是将要抓取的数据接口分析清楚,才能根据需求编写想要的代码。 1、任务目标 目标网…

基于深度学习的树叶识别系统的设计与实现(pyqt5 python3.9 yolov8 10000张数据集)

&#x1f497;博主介绍&#x1f497;&#xff1a;✌在职Java研发工程师、专注于程序设计、源码分享、技术交流、专注于Java技术领域和毕业设计✌ 温馨提示&#xff1a;文末有 CSDN 平台官方提供的老师 Wechat / QQ 名片 :) Java精品实战案例《700套》 2025最新毕业设计选题推荐…