CVPR 2023 Hybrid Tutorial: All Things ViTs之mean attention distance (MAD)

news2025/1/12 18:41:12

All Things ViTs系列讲座从ViT视觉模型注意力机制出发,本文给出mean attention distance可视化部分阅读学习体会.

课程视频与课件: https://all-things-vits.github.io/atv/
代码: https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/mean_attention_distance.ipynb
文献:A N I MAGE IS W ORTH 16 X 16 W ORDS :
T RANSFORMERS FOR I MAGE R ECOGNITION AT S CALE

1.总述

之前在阅读ViT论文的时候对MAD这部分没有十分理解,及MAD究竟是什么,如下图所示.将该部分代码进行调试理解,能够比较深入理解ViT的注意力机制.
Fig 1 vit-base-patch16-224 MAD可视化
Fig 1 vit-base-patch16-224 MAD可视化

2.关键代码讲解

2.1 注意力分数获得
def perform_inference(image: Image, model: torch.nn.Module, processor):
    """Performs inference given an image, a model, and its processor."""
    inputs = processor(image, return_tensors="pt")#[1, 3, 224, 224]
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
        print(type(outputs))

    # model predicts one of the 1000 ImageNet classes
    predicted_label = outputs.logits.argmax(-1).item()
    print(model.config.id2label[predicted_label])
    return outputs.attentions #[[1, 12, 197, 197]*12]

这部分代码将图像输入ViT网络,并得到输出的logits,类别以及ViT中每个block(如图Fig2)中每个head的注意力分数(outputs.attentions).ViT可以看作是transformer的一个encoder,如下:
在这里插入图片描述
Fig 2 ViT的一个block

此外,outputs.attentions是一个tuple,其中包括12个维度为[1, 12, 197, 197]的tensor.这个tensor可理解如下,其中12为head的数量,197是token的数量.197*197表示每个token之间的注意力分数.197包含196个图像token与一个cls token.其中MAD是图像token之间的距离

2.2 计算MAD
def gather_mads(attention_scores, patch_size: int = 16):
    all_mean_distances = {
        f"block_{i}_mean_dist": compute_mean_attention_dist(
            patch_size=patch_size, attention_weights=attention_weight.numpy()
        )
        for i, attention_weight in enumerate(attention_scores)
    }
    return all_mean_distances 

这段代码是遍历计算每一个block中的MAD

def compute_mean_attention_dist(patch_size, attention_weights):
    # The attention_weights shape = (batch, num_heads, num_patches, num_patches)
    attention_weights = attention_weights[
        ..., num_cls_tokens:, num_cls_tokens:
    ]  # Removing the CLS token, [1, 12, 196, 196]
    num_patches = attention_weights.shape[-1]
    length = int(np.sqrt(num_patches))
    assert length**2 == num_patches, "Num patches is not perfect square"

    distance_matrix = compute_distance_matrix(patch_size, num_patches, length)#[196, 196]
    h, w = distance_matrix.shape

    distance_matrix = distance_matrix.reshape((1, 1, h, w))#[1, 1, 196, 196], space distance between batch in the image
    # The attention_weights along the last axis adds to 1
    # this is due to the fact that they are softmax of the raw logits
    # summation of the (attention_weights * distance_matrix)
    # should result in an average distance per token
    mean_distances = attention_weights * distance_matrix#[1, 12, 196, 196]
    mean_distances = np.sum(
        mean_distances, axis=-1
    )  # sum along last axis to get average distance per token, [1, 12, 196]
    mean_distances = np.mean(
        mean_distances, axis=-1
    )  # now average across all the tokens

    return mean_distances

这段代码则是具体计算MAD.首先计算patch(Fig 1中阐述了什么是patch)之间的距离,ViT中的token可以理解为对每个patch的编码,patch之间的距离计算方法如下:

def compute_distance_matrix(patch_size, num_patches, length):
    """Helper function to compute distance matrix."""
    distance_matrix = np.zeros((num_patches, num_patches))
    for i in range(num_patches):
        for j in range(num_patches):
            if i == j:  # zero distance
                continue

            xi, yi = (int(i / length)), (i % length)
            xj, yj = (int(j / length)), (j % length)
            distance_matrix[i, j] = patch_size * np.linalg.norm([xi - xj, yi - yj])

    return distance_matrix

patch之间的距离即patch之间的空间距离.而MAD的核心计算代码为:

mean_distances = attention_weights * distance_matrix

之后在求每个head中所有token的距离均值.MAD是衡量每个patch与其他patch之间的综合距离,这个距离既考虑了它与其他patch的实际物理距离,又将注意力分数作为物理距离的加权.我对MAD的理解是,它是经过学习,对离散图像patch的一种建模.这种建模既考虑了patch与patch之间的空间关系,又考虑了patch之间实际的联系(注意力分数).这个距离可以用来探究每个head关注的范围,类似CNN中的感受野.

3.总述

接下来再回到Fig 2,我们再来理解这张图的含义.这张图横轴为block的编号,包含12个block,纵轴为每个head的MAD. 可以看到,ViT在浅层中就有的head开始关注全局(MAD大的head),有的关注局部(MAD小的head),这与CNN有所不同,CNN在浅层多关注局部,深层关注全局.因此说明.随着层数的加深,ViT逐步过渡到关注全局.相比于CNN来说,ViT是对图像的更一般的一种建模,这有利于表达更复杂的空间关系,但也更加难学习,因此一般认为在数据量比较大的情况下,ViT才能展现出其优势.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1401554.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024年【河北省安全员B证】最新解析及河北省安全员B证试题及解析

题库来源:安全生产模拟考试一点通公众号小程序 河北省安全员B证最新解析是安全生产模拟考试一点通生成的,河北省安全员B证证模拟考试题库是根据河北省安全员B证最新版教材汇编出河北省安全员B证仿真模拟考试。2024年【河北省安全员B证】最新解析及河北省…

python222网站实战(SpringBoot+SpringSecurity+MybatisPlus+thymeleaf+layui)-热门标签推荐显示实现

锋哥原创的SpringbootLayui python222网站实战: python222网站实战课程视频教程(SpringBootPython爬虫实战) ( 火爆连载更新中... )_哔哩哔哩_bilibilipython222网站实战课程视频教程(SpringBootPython爬虫实战) ( 火…

【Spring 篇】MyBatis注解开发:编写你的数据乐章

欢迎来到MyBatis的音乐殿堂!在这个充满节奏和韵律的舞台上,注解是我们编写数据乐章的得力助手。无需繁琐的XML配置,通过简单而强大的注解,你将能够轻松地与数据库交互。在这篇博客中,我们将深入探讨MyBatis注解开发的精…

5G_射频测试_发射机测量(四)

6.2 Base station output power 用于测量载波发射功率的大小,功率越大小区半径越大但是杂散也会越大 载波功率(用频谱仪测)天线口功率(用功率计测)载波功率是以RBW为单位的filter测量的积分功率不同带宽的多载波测试时…

一文读懂「RAG,Retrieval-Augmented Generation」检索增强生成

Retrieval-Augmented Generation(RAG)作为机器学习和自然语言处理领域的一大创新,不仅代表了技术的进步,更在实际应用中展示了其惊人的潜力。 RAG结合了检索(Retrieval)和生成(Generation&#…

项目解决方案:多地医馆的高清视频监控接入汇聚联网

目 录 一、背景 二、建设目标及需求 1.建设目标 2.现状分析 3.需求分析 三、方案设计 1.设计依据 2.设计原则 3.方案设计 3.1 方案描述 3.2 组网说明 四、产品介绍 1.视频监控综合资源管理平台介绍 2.视频录像服务器和存储 2.1概述 2.2存储设计 …

蓝桥杯省赛无忧 编程9

#include<bits/stdc.h> using namespace std; int main() {int n,k,ans0;cin>>n>>k;while(n--){int a;cin>>a;ansa&1;}if(ans&1) cout<<"Alice"<<\n;else cout<<"Bob"; return 0; }这个游戏是基于数…

软件工程应用题汇总

绘制数据流图(L0/L1/L2) DFD/L0&#xff08;基本系统模型&#xff09; 只包含源点终点和一个处理(XXX系统) DFD/L1&#xff08;功能级数据流图&#xff09;在L0基础上进一步划分处理(XXX系统) 个人理解 DFD/L2&#xff08;在L1基础上进一步分解后的数据流图&#xff09; 数据…

3.php开发-个人博客项目输入输出类留言板访问IPUA头来源

目录 知识点 : 输入输出 配置环境时&#xff1a; 搜索框&#xff1a; 留言板&#xff1a; 留言板的显示&#xff08;html&#xff09;&#xff1a; php代码显示提交的留言&#xff1a; 写入数据库 对留言内容进行显示&#xff1a; php全局变量-$_SERVER 检测来源 墨…

【复现】Hytec Inter HWL 2511 SS路由器RCE漏洞_25

目录 一.概述 二 .漏洞影响 三.漏洞复现 1. 漏洞一&#xff1a; 四.修复建议&#xff1a; 五. 搜索语法&#xff1a; 六.免责声明 一.概述 Hytec Inter HWL 2511 SS是日本Hytec Inter 公司的一款工业级 LTE 路由器&#xff0c;可用于远程数据传输&#xff0c;例如收集传…

网络安全(初版,以后会不断更新)

1.网络安全常识及术语 资产 任何对组织业务具有价值的信息资产&#xff0c;包括计算机硬件、通信设施、IT 环境、数据库、软件、文档 资料、信息服务和人员等。 漏洞 上边提到的“永恒之蓝”就是windows系统的漏洞 漏洞又被称为脆弱性或弱点&#xff08;Weakness&#xff09;&a…

class_15:虚函数

#include <iostream> #include <string> using namespace std;//基类,父类 class Vehicle{ public:string type;string contry;string color;double price;int numOfWheel;virtual ~Vehicle(){};//类中有虚函数&#xff0c;析构函数一般也写成虚函数virtual voi…

第135期 一周游历(上)(20240120)

数据库管理135期 2024-01-20 第135期 一周游历(上)&#xff08;20240120&#xff09;1 PolarDB开发者大会2 工作3 Oracle甲骨文4 Oracle ACE总结 第135期 一周游历(上)&#xff08;20240120&#xff09; 作者&#xff1a;胖头鱼的鱼缸&#xff08;尹海文&#xff09; Oracle AC…

基于Servlet建立表白墙网站

目录 一、设计思想 二、设计表白墙页面&#xff08;前端--VSCode&#xff09; 1、效果图 2、html部分&#xff08;网页上有哪些内容&#xff09; 3、css部分&#xff08;页面内容的具体样式&#xff09; 4、js部分&#xff08;页面行为&#xff09; 三、借助Servlet实现客…

PageHelper分页插件的使用

1.引入依赖 <!-- pagehelper 分页插件 --><dependency><groupId>com.github.pagehelper</groupId><artifactId>pagehelper-spring-boot-starter</artifactId><version>1.4.7</version></dependency> 2.application.yml…

webpack 核心武器:loader 和 plugin 的使用指南(上)

&#x1f90d; 前端开发工程师、技术日更博主、已过CET6 &#x1f368; 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 &#x1f560; 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 &#x1f35a; 蓝桥云课签约作者、上架课程《Vue.js 和 E…

vite 打包优化

✨专栏介绍 在当今数字化时代&#xff0c;Web应用程序已经成为了人们生活和工作中不可或缺的一部分。而要构建出令人印象深刻且功能强大的Web应用程序&#xff0c;就需要掌握一系列前端技术。前端技术涵盖了HTML、CSS和JavaScript等核心技术&#xff0c;以及各种框架、库和工具…

matplotlib画波动很小的图

今天测试一个画图时&#xff0c;有一个很神奇的发现 import matplotlib.pyplot as plt import numpy as np import pandas as pd x [10,20,30,40,50,60,70,80]y[-23.99534833975495,-23.9999998600783,-24.000000070633167,-24.000000068469788,-24.00000006672905,-24.000…

2023年DevOps国际峰会暨 BizDevOps 企业峰会(DOIS北京站):核心内容与学习收获(附大会核心PPT下载)

随着科技的飞速发展&#xff0c;软件开发的模式和流程也在不断地演变。在众多软件开发方法中&#xff0c;DevOps已成为当下热门的软件开发运维一体化模式。特别是在中国&#xff0c;随着越来越多的企业开始认识到DevOps的价值&#xff0c;这一领域的研究与实践活动日益活跃。本…

electron + selenium报错: Server terminated early with status 1

解决办法&#xff1a; 这种错误一般是浏览器创建的某方法致命错误导致的&#xff0c;查看一下实例化driver的地方有哪些配置&#xff0c;着重看日志、用执行信息存储一类的配置&#xff0c;我的问题是日志文件夹改过了但没有创建 // 浏览器参数设置 const customArguments [-…