Vit Transformer

news2025/1/8 19:27:10

一 VitTransformer 介绍

vit : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

        论文是基于Attention Is All You Need,由于图像数据和词数据数据格式不一样,经典的transformer不能处理图像数据,在视觉领域的应用有限。本文提出的方法可以将transformer直接应用图像分类任务,引入Patch Embedding,位置编码等方法,克服了Transformer在处理图像数据时的限制。整体流程如下。

从图中可以看出, Vision Transformer 主要有三个部分组成: 1 ) 第一部分是Linear Projection of Flattened Patches ,也就是 Emdedding 层,主要的工作就是将图像数据转换成transformer可以处理的数据格式。2)第二部分是Transformer Encoder部分,它是vit 最核心的组件(原始的NLP的transformer还有Decoder部分)。它主要是层归一化,多头注意力机制,MLP,Dropout/DropPath四个小block组成,用于学习图像数据。3) 第三部分就是MLP head ,用于分类。

二 PatchEmbedding & Positional Encoding

        首先,每个图像被分割成一系列不重叠的块(16x16或者 32x32),然后做一个线性的embedding ,由于这些块如果并行的输入到transformer中,不提供位置信息,模型不知道这些块的顺序。因此要加一个 positional encoding。 

        在实际的实现上,图像数据是[batch_size, C , H, W] 的格式,要将其变成[batch_size , token_len , dim],其中token_len 可以理解成图像patch token的数量。以[4,3,224,224]的图像为例子,首先我们模拟分割块,对于一个图像,我们要将其分割成 (H*w)/(patch_size*patch_size)个patches,即(224x224)// (16x16) = 196个 patches 。每个patch的大小是(3,16,16),然后我们将其flatten一个768( 3x16x16)dim的 token。这样数据格式就变成[4,196,768]。

代码分割图像块 : 

def split_patches(x, patch_size=16):
    batch_size, channels, height, width = x.shape
    x = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(batch_size, -1, channels * patch_size * patch_size)
    return x

当然这个过程可以通过卷积实现,官方代码其实就是用卷积来实现的。

class PatchEmbed(nn.Module) :

    def __init__(self,img_size=224,patch_size=16,in_channels=3,embed_dim=768,norm_layer=None):

        super().__init__()
        img_size = (img_size,img_size)
        patch_size = (patch_size,patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0],img_size[1]//patch_size[1])
        self.num_patches = self.grid_size[0] * self.grid_size[1] 

        # 卷积层
        self.proj = nn.Conv2d(in_channels=in_channels,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size)
        # 归一化
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
    def forward(self,x) :
        B,C,H,W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
        f"Input image size ({H}x{W} does not match input image size ({self.img_size[0]}x{self.img_size[1]}"
        x = self.proj(x).flatten(2).transpose(1,2)
        x = self.norm(x)

        return x

 positional Embedding 

由于输入的图像数据patch序列没有能够表达patch之间相对位置关系,因此需要加入位置编码(Positional encoding)这个特征,为了得到不同位置的对应的编码,Transformer模型使用不同频率的正余弦函数

PE(Pos,2i) = sin(\frac{pos}{10000^{2i/d}})

PE(Pos,2i+1) = cos(\frac{pos}{10000^{2i/d}})

 其中 pos是表示token(flattened image patch)的位置,2i和2i+1表示位置编码向量中对应的维度,d是对应位置编码的总维度。

def add_positional_encoding(x, max_len):
    batch_size, patch_numbers, dim = x.shape
    position = torch.arange(max_len).reshape(-1, 1)
    div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
    pe = torch.zeros((max_len, dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    x += pe[:patch_numbers]
    return x

三 Self-Attention 

        其中最核心的部分就是对于注意力部分了。在基于Transformer的机器翻译模型中,要建模源语言和目标语言任意两个单词的依赖关系,引入自注意力K(键)Q(查询)V(值)。这三个用来计算上下文单词所对应的权重得分,这些权重反映了在编码当前单词时,对于上下文不同部分所需要关注程度。

    同样在vision transformer中,对于一个个image patch token 来说,也需要建模任意 token之间的相互关注关系,当处理当前token时,哪些token与它有更高的关联度。

        上图是论文中的Scaled Dot-Product AttentionMulti-head Attention,我们首先定义三个矩阵 Q,K,V,这三个矩阵是由 输入X([4,196,768])分别经过三个权重矩阵$w_{q},w_{k},w_{v}$得到的。其中 Q 矩阵和K 矩阵,V矩阵是“同源”的,因为它们都是来自于同一个输入序列(图像patch token)的某种表示(线性变换的嵌入表示)。

根据Attenton分数的计算公式,Q(shape=[4,196,768])左乘一个K(shape=[4,196,768])矩阵的转置,得到一个相似度矩阵(shape=[4,196,196]),为了防止过大的相似度数值在后续Softmax计算过程中导致的梯度爆炸以及收敛效率差的问题,因此使用一个缩放因子\sqrt{d}缩放来稳定优化。放缩后的得分经过Softmax函数归一化为概率后,与其他位置的值向量相乘来聚合希望关注的上下文信息,并最小化不相关信息的干扰。

def self_attention(x, w_q, w_k, w_v):
    query = torch.matmul(x, w_q)
    key = torch.matmul(x, w_k)
    value = torch.matmul(x, w_v)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
    attention_scores = softmax(scores)
    output = torch.matmul(attention_scores, value)
    return attention_scores, output

过程如下

四 Multi-head self-attention (MSA)

        为了进一步提升自注意力机制的全局信息聚合能力,提出了Multi-head attention机制,具体来说,上下文的每个token 向量的表示x_{i}的经过多组的线性{W_{j}^{Q},W_{j}^{K},W_{j}^{V}}映射到不同的表示空间。计算出不同子空间得到的attention score得到{Z_{j}}_{j=1}^{N},再用一个线性变换w^{o} 用于综合不同子空间中的上下文表示形成最后的输出。

import torch
import torch.nn.functional as F

class MultiHeadAttention(torch.nn.Module):
    def __init__(self, input_dim, num_heads, head_dim):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = head_dim

        assert input_dim % self.num_heads == 0
        self.projection_dim = input_dim // self.num_heads

       # 定义 权重矩阵 
        self.weight_q = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))
        self.weight_k = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))
        self.weight_v = torch.nn.Parameter(torch.randn(num_heads, input_dim, self.projection_dim))

       
        self.weight_combine = torch.nn.Parameter(torch.randn(num_heads * self.projection_dim, input_dim))

    def forward(self, x):
        batch_size, seq_length, _ = x.size()

        
        queries = torch.matmul(x, self.weight_q)
        keys = torch.matmul(x, self.weight_k)
        values = torch.matmul(x, self.weight_v)

     
        queries = queries.view(batch_size, seq_length, self.num_heads, self.projection_dim)
        keys = keys.view(batch_size, seq_length, self.num_heads, self.projection_dim)
        values = values.view(batch_size, seq_length, self.num_heads, self.projection_dim)

    
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算得分
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.projection_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        attention_output = torch.matmul(attention_weights, values)

        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, seq_length, -1)
        output = torch.matmul(attention_output, self.weight_combine)

        return output

input_dim = 64
num_heads = 8
head_dim = input_dim // num_heads
seq_length = 10
batch_size = 4


multihead_attention = MultiHeadAttention(input_dim, num_heads, head_dim)
x = torch.rand(batch_size, seq_length, input_dim)
output = multihead_attention(x)

print("输入形状:", x.shape)
print("输出形状:", output.shape)


五 代码实现

import torch
import matplotlib.pyplot as plt
from PIL import Image
from torchvision.transforms import transforms
import numpy as np

def softmax(x):
    return torch.nn.functional.softmax(x, dim=-1)

def split_patches(x, patch_size=16):
    batch_size, channels, height, width = x.shape
    x = x.reshape(batch_size, channels, height // patch_size, patch_size, width // patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5)
    x = x.reshape(batch_size, -1, channels * patch_size * patch_size)
    return x

def add_positional_encoding(x, max_len):
    batch_size, patch_numbers, dim = x.shape
    position = torch.arange(max_len).reshape(-1, 1)
    div_term = torch.exp(torch.arange(0, dim, 2) * -(torch.log(torch.tensor(10000.0)) / dim))
    pe = torch.zeros((max_len, dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    x += pe[:patch_numbers]
    return x

def plot_heatmap(scores,index,name ):
    plt.figure(figsize=(8, 6))
    plt.imshow(scores, cmap='hot', interpolation='nearest')
    plt.xlabel('Keys')
    plt.ylabel('Queries')
    plt.title(f'Attention Scores Heatmap {name}')
    plt.colorbar()
    plt.savefig(f"./attention_heatmap{index}.png")
    # plt.show()
def self_attention(x, w_q, w_k, w_v):
    query = torch.matmul(x, w_q)
    key = torch.matmul(x, w_k)
    value = torch.matmul(x, w_v)
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(query.shape[-1], dtype=torch.float32))
    attention_scores = softmax(scores)
    output = torch.matmul(attention_scores, value)
    return attention_scores, output

def plot_heatmap_on_image(image, attention_scores, patch_size=16,index=0):
    # 对每个patch的注意力分数求平均
    attention_scores_mean = attention_scores.mean(dim=1)
    # 将注意力分数转换为与原始图像大小相匹配的热力图
    attention_map = attention_scores_mean.view(1, 1, int(224 / patch_size), int(224 / patch_size))
    attention_map = torch.nn.functional.interpolate(attention_map, size=(224, 224), mode='bilinear', align_corners=False)
    attention_map = attention_map.squeeze().cpu().detach().numpy()


    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.imshow(attention_map, cmap='jet', alpha=0.5)
    plt.axis('off')
    plt.savefig(f'attention_map{index}.png')
    # plt.show()

if __name__ == '__main__':
    batch_size = 4
    channels = 3
    height = 224
    width = 224
    input_dim = 768
    output_dim = 64

    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])


    image_paths = ["./images/11.jpg",
                   "./images/15.jpg",
                   "./images/16.jpg",
                   "./images/17.jpg"]

    images = torch.zeros((4, 3, 224, 224), dtype=torch.float32)
    for i, path in enumerate(image_paths):
        img = Image.open(path).convert('RGB')
        img_tensor = transform(img)
        images[i] = img_tensor


    patch_embeddings = split_patches(images, patch_size=16)
    patch_embeddings_pe = add_positional_encoding(patch_embeddings, max_len=196)

    w_q = torch.normal(0, 0.01, size=(input_dim, output_dim))
    w_k = torch.normal(0, 0.01, size=(input_dim, output_dim))
    w_v = torch.normal(0, 0.01, size=(input_dim, output_dim))

    attention_scores, output = self_attention(patch_embeddings_pe, w_q, w_k, w_v)

    # plot_heatmap(attention_scores[0])
    # for index in range(4) :
    #
    #     name = image_paths[index].split('/')[-1].split('.')[0]
    #     plot_heatmap(attention_scores[index],index,name)  # 选择第一张图像的注意力分数进行绘制


    # 将热力图叠加到原始图像上
    for index in range(4) :
        image_path = image_paths[index]
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img)

        img_np = np.array(img)

        plot_heatmap_on_image(img_np, attention_scores[index],16,index=index)

参考

  • LLM(廿四):Transformer 的结构改进与替代方案 - 知乎

  • 【深度学习系列】五、Self Attention_self attention 加入位置信息-CSDN博客

  • NLP(五):Transformer及其attention机制 - 知乎

  • 有关vision transformer的一个综述 - 知乎

  • 为什么 Vision transformer 训练和推理很慢? - 知乎

  • 大规模语言模型:从理论到实践 -- 张奇、桂韬、郑锐、黄萱菁

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1548909.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【jenkins+cmake+svn管理c++项目】Windows环境安装以及工具配置

一、目标和环境 目标:搭建一个jenkins环境,实现jenkins调用cmake和svn和VS编译c项目,并将生成的库上传svn。 环境:win10虚拟机(练习流程用,正式用的话还是放到服务器),VS2017. 二、…

肿瘤靶向肽 iRGD peptide环肽 1392278-76-0 c(CRGDKGPDC)

RGD环肽 c(CRGDKGPDC),iRGD peptide 1392278-76-0 结 构 式: H2N-CRGDKGPDC-OH(Disulfide Bridge:C1-C9) H2N-Cys-Arg-Gly-Asp-Lys-Gly-Pro-Asp-Cys-COOH(Disulfide Bridge:Cys1-Cys9) 氨基酸个数: 9 C35H57N13O14S2 平均分子量:…

智能优化算法 | Matlab实现牛顿-拉夫逊优化算法Newton-Raphson-based optimize(内含完整源码)

文章目录 效果一览文章概述源码设计参考资料效果一览 文章概述 智能优化算法 | Matlab实现牛顿-拉夫逊优化算法Newton-Raphson-based optimize(内含完整源码) 源码设计 % ------------------------------------------------------------------------------------------------…

GIS与Python机器学习:开创地质灾害风险评价新纪元

地质灾害是指全球地壳自然地质演化过程中,由于地球内动力、外动力或者人为地质动力作用下导致的自然地质和人类的自然灾害突发事件。由于降水、地震等自然作用下,地质灾害在世界范围内频繁发生。我国除滑坡灾害外,还包括崩塌、泥石流、地面沉…

55、Qt/事件机制相关学习20240326

一、代码实现设置闹钟,到时间后语音提醒用户。示意图如下: 代码: #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent): QWidget(parent), ui(new Ui::Widget), speecher(new QTextToSpeech(t…

【Redis】Redis 介绍Redis 为什么这么快?Redis数据结构Redis 和Memcache区别 ?为何Redis单线程效率也高?

目录 Redis 介绍 Redis 为什么这么快? Redis数据结构 Redis 和Memcache区别 ? 为何Redis单线程效率也高? Redis 介绍 Redis 是一个开源(BSD 许可)、基于内存、支持多种数据结构的存储系统,可以作为数据…

【Linux】从零开始认识进程 — 中下篇

送给大家一句话: 人一切的痛苦,本质上都是对自己无能的愤怒。而自律,恰恰是解决人生痛苦的根本途径。—— 王小波 从零认识进程 1 进程优先级1.1 什么是优先级1.2 为什么要有优先级1.3 Linux优先级的特点 && 查看方式1.4 其他概念 2…

1.6.1 变换

我们要想改变物体的位置,现有解决办法是,每一帧改变物体的顶点并且重配置缓冲区从而使物体移动,但是这样太繁琐,更好的解决方式是使用矩阵(Matrix)来更好的变换(Transform)一个物体。…

Python更改Word文档的页面大小

页面大小确定文档中每个页面的尺寸和布局。在某些情况下,您可能需要自定义页面大小以满足特定要求。在这种情况下,Python可以帮助您。通过利用Python,您可以自动化更改Word文档中页面大小的过程,节省时间和精力。本文将介绍如何使…

每日一题 --- 删除链表的倒数第 N 个结点[力扣][Go]

删除链表的倒数第 N 个结点 题目:19. 删除链表的倒数第 N 个结点 给你一个链表,删除链表的倒数第 n 个结点,并且返回链表的头结点。 示例 1: 输入:head [1,2,3,4,5], n 2 输出:[1,2,3,5]示例 2&#x…

Python中lambda函数使用方法

在Python中,lambda 关键字用于创建匿名函数(无名函数),这些函数的特点是简洁、一次性使用,并且通常用于只需要一行表达式的简单场景。下面是lambda函数的基本结构和使用方法: 基本语法: lambd…

代码随想录算法训练营DAY7| C++哈希表Part.2|LeetCode:454.四数相加II、383.赎金信、15. 三数之和、18.四数之和

文章目录 454.四数相加II思路C代码 383.赎金信C 代码 15. 三数之和排序哈希法思路C代码 排序双指针法思路去重C代码 18.四数之和前言剪枝C代码 454.四数相加II 力扣题目链接 文章链接:454.四数相加II 视频链接:学透哈希表,map使用有技巧&…

STL的基本概念

一、STL的诞生 长久以来,软件界一直希望建立一种可重复利用的东西 C的面向对象和泛型编程思想,目的就是复用性的提升 面向对象的三大特性(简单理解) 封装:把属性和行为抽象出来作为一个整体来实现事和物 继承:子类继承父类&a…

linux下docker容器的使用

1、根据已有镜像images创建容器 1.1、查看镜像 如果是接手的别人的项目,需要从以往的images镜像中创建新容器,使用命令查看当前机器上的docker镜像: docker images1.2、创建容器 使用docker run 根据images镜像名创建容器,命令…

电阻的妙用:限流、分压、滤波,助力电路设计!

电阻可以降低电压,这是通过电阻的分压来实现的。事实上,利用电阻来降低电压只是电阻的多种功能之一。电路中的电阻与其他元件(电容、电感)结合用于限流、滤波等。(本文素材来源:https://www.icdhs.com/news…

SV-7045V网络草坪音箱 室外网络广播POE供电石头音箱

SV-7045V网络草坪音箱 室外网络广播POE供电石头音箱 描述 IP网络广播草坪音箱 SV-7045V是深圳锐科达电子有限公司的一款防水网络草坪音箱,具有10/100M以太网接口,可将网络音源通过自带的功放和喇叭输出播放,可达到功率20W。用在公园&#…

Kotlin高效App爬取工具:利用HttpClient与代理服务器的技巧

在当今数字化时代,移动应用(App)数据的价值日益凸显,而为了获取并分析这些数据,开发高效的数据爬取工具变得至关重要。Kotlin作为一种现代化、功能强大的编程语言,与HttpClient等强大工具的结合&#xff0c…

day53 动态规划part10

121. 买卖股票的最佳时机 简单 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格。 你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票。设计一个算法来计算你所能获取的最大利润。 返回你可…

蓝桥杯学习笔记 单词分析

试题 G: 单词分析 时间限制: 1.0s 内存限制: 512.0MB 本题总分:20 分 [问题描述] 小蓝正在学习一门神奇的语言,这门语言中的单词都是由小写英文字母组成,有些单词很长,远远超过正常英文单词的长度。小蓝学了很长时间也记不住一些单词&#xf…

C语言实现顺序表(增,删,改,查)

目录 一.概念: 1.静态顺序表:使用定长数组存储元素。 2.动态顺序表:使用动态开辟的数组存储。 二.顺序表的实现: 1.顺序表增加元素 1.检查顺序表 2.头插 3.尾插 2.顺序表删除元素 1.头删 2.尾删 3.指定位置删 3.顺序表查找元素 …