ViT论文Pytorch代码解读

news2024/11/26 14:27:29

ViT论文代码实现

论文地址:https://arxiv.org/abs/2010.11929
Pytorch代码地址:https://github.com/lucidrains/vit-pytorch

ViT结构图

在这里插入图片描述

调用代码

import torch
from vit_pytorch import ViT

def test():
    v = ViT(
        image_size = 256, 
        patch_size = 32,  
        num_classes = 1000,  
        dim = 1024,  
        depth = 6,  
        heads = 16,  
        mlp_dim = 2048,  
        dropout = 0.1,
        emb_dropout = 0.1
    )

    img = torch.randn(1, 3, 256, 256)

    preds = v(img)
    print(preds.shape)
    assert preds.shape == (1, 1000), 'correct logits outputted'

if __name__ == '__main__':
    test()

ViT结构

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        
        # 将image_size和patch_size都转换为(height, width)形式
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)
		
		# 检查图像尺寸是否可以被patch尺寸整除
        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

		# 计算图像中的patch数量
        num_patches = (image_height // patch_height) * (image_width // patch_width)
		
		# 计算每个patch的维度(即每个patch的元素数量)
        patch_dim = channels * patch_height * patch_width
        
        # 确保池化方式是'cls'或'mean'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

		# 将图像转换为patch嵌入的操作
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 图像切分重排,后文有注释
            # 注:此时的维度为[b, h*w/p1/p2, p1*p2*c]:[批处理尺寸、图像中patch的数、每个patch的元素数量]
            nn.LayerNorm(patch_dim),  # 对patch进行层归一化
            nn.Linear(patch_dim, dim),  # 使用线性层将patch的维度从patch_dim转化为dim
            nn.LayerNorm(dim),  # 对结果进行层归一化
        )
		
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))  # 初始化位置嵌入
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))  # 初始化CLS token(用于分类任务的特殊token)
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)  # 定义Transformer模块
 
        self.pool = pool  # 设置池化方式('cls'或'mean')
        self.to_latent = nn.Identity()  # 设置一个恒等映射(在此实现中不改变数据,但可以在子类或其他变种中进行修改)

        self.mlp_head = nn.Linear(dim, num_classes)   # 定义MLP头部,用于最终的分类

    def forward(self, img):
        x = self.to_patch_embedding(img) # 第一步,将图片切分为若干小块
		# 此时维度为:[b, h*w/p1/p2, dim]
        b, n, _ = x.shape
		
		# 第二步,设置位置编码
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # 将cls_token复制b个 
        # (为每个输入图像复制一个CLS token,使输入批次中的每张图像都有一个相应的CLS token)
        x = torch.cat((cls_tokens, x), dim=1)  # 将CLS token与patch嵌入合并; cat之后,原来的维度[1,64,1024],就变成了[1,65,1024]
        x += self.pos_embedding[:, :(n + 1)] # 原数据和位置编码直接进行相加操作,即完成结构图中的【Patch + Position Embedding】操作
        
        x = self.dropout(x)

		# 第三步,Transformer的Encoder结构
        x = self.transformer(x)
        
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]   # 根据所选的池化方式进行池化

        x = self.to_latent(x)  # 将数据传递给恒等映射
        return self.mlp_head(x)  # 使用MLP头部进行分类
  

Rearrange解释:
y = x.transpose(0, 2, 3, 1)
可以写成:y = rearrange(x, ‘b c h w -> b h w c’)

关于pos_embedding和cls_token的逻辑讲解:
在这里插入图片描述如图所示,红色框框出的部分。
图像被切分为多个小块之后,经过self.to_patch_embedding 中的Rearrange,原本的[b,c,h,w]维度变为[b, h*w/p1/p2, p1*p2*c]。
再经过线性层nn.Linear(patch_dim, dim),维度变为[b, h*w/p1/p2, dim]。
输出结果即为上图中黄色框标出的部分的粉色条(不包括紫色条,是因为此处还没进行Position Embedding操作)。
继续往下走,进行torch.cat((cls_tokens, x), dim=1),此时将xcls_tokens进行concat操作,得到红色框框出的所有粉色条(在原本的基础上增加了带*号的粉色条)。
记下来的x += self.pos_embedding[:, :(n + 1)]操作就是将xpos_embedding直接进行相加,用图表示出来就是上图中整个红色框框出的部分了(紫色条就是传说中的pos_embedding)。
举一个有数字的例子:
原本输入图像维度为[1, 3, 256, 256],dim设置为1023,经过self.to_patch_embedding后维度变为:[1,64,1024],cls_tokens的维度为:[1,1,1024],经过concat操作后,x的维度变为[1,65,1024],然后经过pos_embedding加操作后,维度依然是[1,65,1024],因为在设置变量pos_embedding时的维度就是torch.randn(1, num_patches + 1, dim)
~这个解释应该够清晰了吧!~

Transformer Encoder结构

# 定义前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            # Vit_base: dim=768,hidden_dim=3072
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),  # 将输入从dim维映射到hidden_dim维
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),  # 将隐藏状态从hidden_dim维映射回到dim维
            nn.Dropout(dropout) 
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads  # 64*8=512  # 计算内部维度
        project_out = not (heads == 1 and dim_head == dim) # 判断是否需要投影输出,投影输出就是是否需要经过线性层
        # 如果只有一个attention头并且其维度与输入相同则不需要投影输出,否则需要。

        self.heads = heads
        self.scale = dim_head ** -0.5 # 缩放因子,通常是头维度的平方根的倒数

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)   # softmax函数用于最后一个维度,计算注意力权重
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) # 一个线性层生成Q, K, V

		# 判断是否需要投影输出
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 用线性层生成QKV,并在最后一个维度上分块;相当于写3遍nn.Linear
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) 
        # 将[batch_size, sequence_length, heads_dimension] 转换为 [batch_size, number_of_heads, sequence_length, dimension_per_head]

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale  # 计算Q和K的点乘,然后进行缩放
        # q: [batch_size, number_of_heads, sequence_length, dimension_per_head]
        # k转置后:[batch_size, number_of_heads, sequence_length, dimension_per_head] -> [batch_size, number_of_heads, dimension_per_head, sequence_length]
        # q和k点乘后:[batch_size, number_of_heads, sequence_length, sequence_length]

        attn = self.attend(dots)   # 使用softmax函数获取注意力权重
        attn = self.dropout(attn)
		
		# 使用注意力权重对V进行加权
        out = torch.matmul(attn, v) 
        out = rearrange(out, 'b h n d -> b n (h d)') # 使用rearrange函数重新组织输出的维度
        return self.to_out(out)  # 投影输出(如果需要)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):  # depth设置为几层,就重复几次
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:  # 残差
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

如上就是ViT的整体结构了。

附:完整代码

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange


# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)


# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            # Vit_base: dim=768,hidden_dim=3072
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads  # 64*8=512
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim=-1)  # 相当于写3遍nn.Linear
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
        # 将[batch_size, sequence_length, heads_dimension] 转换为 [batch_size, number_of_heads, sequence_length, dimension_per_head]

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        # q: [batch_size, number_of_heads, sequence_length, dimension_per_head]
        # k转置后:[batch_size, number_of_heads, sequence_length, dimension_per_head] -> [batch_size, number_of_heads, dimension_per_head, sequence_length]
        # q和k点乘后:[batch_size, number_of_heads, sequence_length, sequence_length]

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)


class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3,
                 dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),  # 图像切分重排
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )
        # Rearrange解释:
        # y = x.transpose(0, 2, 3, 1)
        # 可以写成:y = rearrange(x, 'b c h w -> b h w c')

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)  # 数字编码,将cls_token复制b个
        x = torch.cat((cls_tokens, x), dim=1)  # cat之后,原来的维度[1,64,1024],就变成了[1,65,1024]
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

附:训练代码

model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=3,
).to(device)


# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0

    for data, label in tqdm(train_loader):
        data = data.to(device)
        label = label.to(device)

        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to(device)
            label = label.to(device)

            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}\n"
    )

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/940919.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【第四阶段】kotlin语言可变list集合

1.可变list集合 完整写法 var list:MutableList<String> mutableListOf<String>("java","kotlin","c","c") 省略写法 var list mutableListOf("java","kotlin","c","c")fun ma…

CobaltStrike提权

攻击机&#xff1a;Kali Linux 靶 机&#xff1a;Windows 7 一、上线CS 复制命令&#xff0c;在靶机执行上线CS 2.安装插件&#xff0c;获取shell https://github.com/rsmudge/ElevateKit 上线CS 右击shell&#xff0c;选择插件 有七个模块可以利用&#xff0c;可以逐一…

C++实现YOLOP

C实现YOLOP 一、简介 使用OpenCV部署全景驾驶感知网络YOLOP&#xff0c;可同时处理交通目标检测、可驾驶区域分割、车道线检测&#xff0c;三项视觉感知任务&#xff0c;依然是包含C和Python两种版本的程序实现 onnx文件从百度云盘下载&#xff0c;链接&#xff1a;https://…

【AutoLayout案例07-如何通过代码添加约束 Objective-C语言】

一、那么,接下来,我们就给大家介绍一下,如何通过代码,来实现这个AutoLayout 1.咱们之前是不是都是通过,storyboard、来拖、拉、拽、的方式实现的吧 现在给大家介绍一下,如何通过代码,来实现 在继续介绍,如何通过代码,来实现AutoLayout之前呢, 我们先要给大家补充一…

基于Java+SpringBoot+Vue前后端分离疫苗发布和接种预约系统设计和实现

博主介绍&#xff1a;✌全网粉丝30W,csdn特邀作者、博客专家、CSDN新星计划导师、Java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专…

前端开发之Element Plus的分页组件el-pagination显示英文转变为中文

前言 在使用element的时候分页提示语句是中文的到了element-plus中式英文的&#xff0c;本文讲解的就是怎样将英文转变为中文 效果图 解决方案 如果你的element-plus版本为2.2.29以下的 import { createApp } from vue import App from ./App.vue import ElementPlus from …

ruoyi-vue-plus 配置邮箱

ruoyi-vue-plus 配置邮箱 &#x1f4d4; 千寻简笔记介绍 千寻简笔记已开源&#xff0c;Gitee与GitHub搜索chihiro-notes&#xff0c;包含笔记源文件.md&#xff0c;以及PDF版本方便阅读&#xff0c;且是用了精美主题&#xff0c;阅读体验更佳&#xff0c;如果文章对你有帮助请…

Java实现根据短连接获取1688商品详情数据,1688淘口令接口,1688API接口封装方法

要通过1688的API获取商品详情数据&#xff0c;您可以使用1688开放平台提供的接口来实现。以下是一种使用Java编程语言实现的示例&#xff0c;展示如何通过1688开放平台API获取商品详情属性数据接口&#xff1a; 首先&#xff0c;确保您已注册成为1688开放平台的开发者&#xf…

网络工程师的尽头是……

大家好&#xff0c;我是许公子。 最近工作挺忙&#xff0c;很久没有给你们输出文章了&#xff0c;抽空和大家唠嗑唠嗑。 前两天&#xff0c;一个实习生问了我这个问题&#xff1a; “网络工程师的尽头是什么&#xff1f;” 我当时一下子&#xff0c;脑子空白了&#xff0c;…

stackoverflow问题

Stack Overflow requires external JavaScript from another domain, which is blocked or failed to load. stackoverflow引用了谷歌中被屏ajax.googleapis.com的jquery.min.js文件。“https://ajax.googleapis.com/ajax/libs/jquery/1.12.4/jquery.min.js” 方案1.打开网站…

Viobot算法控制

一.上位机控制 状态反馈在系统反馈出会根据当前系统状态显示。 控制是在操作栏里面的一个选项三个按键。具体的已经在前面一篇基本功能使用及介绍中讲过了。 二.ROS控制 算法的控制我们也整合成了一个ROS msg&#xff0c;具体的msg信息可以查看demo里面的msg包的algo_ctrl.m…

Arduino RGBLED灯 模块学习与使用

Arduino RGBLED灯模块学习与使用 硬件原理制作衍生连接线Mixly程序Arduino程序演示视频 人生如逆旅&#xff0c;我亦是行人。 —— 苏轼江客:时荒 硬件原理 RGBLED灯三个引脚分别控制三个LED灯的亮度&#xff0c;RGB分别是red&#xff0c;green&#xff0c;blue的英文缩写&…

iOS - 资源按需加载 - ODR

一、瘦身技术大图 二、On-Demand Resources 简介 将其保存管理在苹果的服务器&#xff0c;按需使用资源、优化包体积&#xff0c;实现更小的应用程序。ODR 的好处&#xff1a; 应用体积更小&#xff0c;下载更快&#xff0c;提升初次启动速度资源会在后台下载操作系统将会在磁…

《C/C++等级考试(1~8级)历届真题解析》专栏总目录

❤️ 专栏名称&#xff1a;《C/C等级考试&#xff08;1~8级&#xff09;历届真题解析》 &#x1f338; 专栏介绍&#xff1a;中国电子学会《全国青少年软件编程等级考试》C/C编程&#xff08;1~8级&#xff09;历届真题解析。 &#x1f680; 订阅专栏&#xff1a;订阅后可阅读专…

【Linux】进程概念,轻松入门【下篇】

目录 1. 基本概念 2. 常见环境变量 常见环境变量指令 &#xff08;1. PATH &#xff08;2. HOME &#xff08;3. SHELL 3.环境变量的组织形式 &#xff08;1&#xff09;通过代码如何获取环境变量 &#xff08;2&#xff09;普通变量与环境变量的区别 &#xff08;3&…

第十二届中国PMO大会在京成功召开

8月12-13日&#xff0c;由PMO评论主办&#xff0c;以“拥抱变革 展现PMO力量”为主题的第十二届中国PMO大会在京成功召开。全国项目管理标准化技术委员会俞彪秘书长、《项目管理技术》杂志张星明主编莅临大会并致开幕词&#xff0c;53位来自知名企业的PMO实践精英及业内专家做了…

【某大型攻防演练中 QQ 远程代码执行复现】

目录 前言 漏洞概述 复现过程 修复方式 前言 在最近某全国大型攻防演练中&#xff0c;接近尾声爆出了QQ 0day&#xff0c;攻击者使用QQ版本<9.7.13&#xff0c;存在远程命令执行漏洞&#xff0c;对众多靶标系统进行攻击。漏洞爆出后&#xff0c;就做过一次复现。最近活动…

EXSI技术--SAN组网

(1).SAN概述 存储区域网络(Storage Area Network,SAN)采用网状通道技术,通过FC交换机连接存储阵列和服务器主机,建立专用于数据存储的区域网络。 作为一种专门用于实现存储系统互连的高速网络技术,SAN(存储区域网络)克服了NAS(网络连接存储)中存储吞吐量受底层网络介质限…

【高阶数据结构】哈希表详解

文章目录 前言1. 哈希的概念2. 哈希冲突3. 哈希函数3.1 直接定址法3.2 除留余数法--(常用)3.3 平方取中法--(了解)3.4 折叠法--(了解)3.5 随机数法--(了解)3.6 数学分析法--(了解) 4. 哈希冲突的解决方法及不同方法对应的哈希表实现4.1 闭散列&#xff08;开放定址法&#xff0…

解决win10 wsl子系统安装的ubuntu环境中lsof,netstat命令查看端口没有任何输出的问题

最近有个以前的ssm项目需要在新电脑上运行测试一下&#xff0c;发现需要redis环境&#xff0c;看了官网说&#xff1a;有两种选择&#xff1a; 1. 要么在虚拟机比如vmware安装linux基础环境&#xff0c;然后再安装redis 2. 要么可以利用win10的wsl linux子系统安装ubuntu&…