【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

news2025/1/12 3:02:00

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H, W]
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        return x

# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/
    class1/
        img1.jpg
        img2.jpg
        ...
    class2/
        img1.jpg
        img2.jpg
        ...
    ...

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 数据转换和增强
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects.double() / len(data_loader.dataset)
    return epoch_loss, epoch_acc


def evaluate(model, criterion, data_loader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects.double() / len(data_loader.dataset)
    return epoch_loss, epoch_acc


# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)
    val_loss, val_acc = evaluate(model, criterion, val_loader, device)

    print(f'Epoch {epoch}/{num_epochs - 1}')
    print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1828061.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java阻塞队列:ArrayBlockingQueue

Java阻塞队列:ArrayBlockingQueue ArrayBlockingQueue是Java中的一个阻塞队列(Blocking Queue)实现,它是线程安全的,并且基于数组实现。ArrayBlockingQueue常用于生产者-消费者模型,在这种模型中&#xff…

移动硬盘打不开怎么办?原因解析!

移动硬盘是一种方便携带、快速传输大量数据的存储设备。但有时我们会遇到这样的问题:插上电脑后,移动硬盘无法打开,出现各种错误提示。这时候我们该怎么办呢? 以下是一些可能导致移动硬盘打不开的原因及解决方法: 1.硬…

47.PyCharm P版突然无法启动

目录 1.启动cmd.exe,进到pycharm\bin目录,启动.\pycharm.bat,如果正常,就像下面这个样子,如果不正常,则会报错, 2.用记事本打开pycharm.bat文件,加上以下代码后 今晨,无…

llama3-70B体验

NVIDIA LLAMA3-70B大模型体验地址: NVIDIA NIM | llama3-70b 问题几个关于宇宙的问题,答案挺有意思的,很有启发性,记录一下: 问题1:既然相对论认为时间是相对的,为何却说宇宙寿命有137亿年&a…

股指期货功能

其金融期货的本质,决定了股指期货具有以下几方面特点: (1)交割方式为现金交割; (2)股指期货的持有成本较低; (3)股指期货的保证金率较低,杠杆性…

【尚庭公寓SpringBoot + Vue 项目实战】图片上传(十)

【尚庭公寓SpringBoot Vue 项目实战】图片上传(十) 文章目录 【尚庭公寓SpringBoot Vue 项目实战】图片上传(十)1、图片上传流程2、图片上传接口查看3、代码开发3.1、配置Minio Client3.2、开发上传图片接口 4、异常处理 1、图片…

12.泛型、trait和生命周期(中)

标题 五、定义共享的行为5.1 定义trait5.2 实现trait5.4 将trait作为参数5.5 Trait Bound5.6 返回实现了trait的类型5.7 使用trait bounds修复get_max_value_from_vector的错误5.8 使用trait bound有条件地实现方法 五、定义共享的行为 如果告诉Rust编译器某个特定类型拥有可能…

【办公类-04-03】华为助手导出照片视频分类(根据图片、视频的文件名日期分类导出)

背景需求: 用华为手机助手导出的照片视频,只能将jpg照片(exifread读取图片的exif拍摄日期,Png、JPEG、mp4都无法识别到exif信息) 【办公类-04-02】华为助手导出照片(jpg)读取拍摄时间分类导出…

Linux——ansible的应用

要让ansible管理业务里的主机 1.得先知道,有哪些主机 用IP地址,用主机名 2.知道了有哪些主机以后,精细、细分管理 主机要用某些办法,分组管理 在ansible里,要用一个东西:清单->inventory inventory …

C学习自学笔记-会陆续完善对应章节编程经典例子

C学习笔记 0>C语言概述 为什么学习C语言 1)C的起源和发展------了解即可 B语言、C语言、C语言的产生地:都出自 美国贝尔实验室 2)C的特点 优点:代码量小、速度快、功能强大 缺点:危险性高、开发周期长、可移植性…

POC EXP | woodpecker插件编写

woodpecker插件编写 目录 woodpecker介绍woodpecker使用插件编写 安装环境 woodpecker-sdkwoodpecker-request 创建Maven项目 Confluence OGNL表达式注入漏洞插件编写 创建Package包和Class类编写POC 漏洞POC代码编写导出jar包将jar包放入woodpecker的plugin目录运行woodpeck…

2.2 抽头

目录 为什么要抽头 什么是抽头 接入系数 怎么抽头 信号源端抽头 负载端抽头 例题分析 要点总结 为什么要抽头 阻抗转换,使信号源内阻Rs与负载电阻RL变得很大,分流小,再使用并联方式。 什么是抽头 接入系数 电容越大,分压越…

Hadoop 2.0:主流开源云架构(四)

目录 五、Hadoop 2.0访问接口(一)访问接口综述(二)浏览器接口(三)命令行接口 六、Hadoop 2.0编程接口(一)HDFS编程(二)Yarn编程 五、Hadoop 2.0访问接口 &am…

最大连续子序列和问题详解

最大连续子序列和问题如下:给定一个数字序列,求i,j,使得最大,输出这个最大和。 这个问题如果用暴力来做,枚举左端点和右端点,需要的复杂度,而计算需要的复杂度,因此总时间复杂度为。…

表面声波滤波器——压电材料(2)

声表面波压电材料 压电效应: 1880年,法国物理学家居里兄弟(PCunie 和J.Curie)发现将重物置于α石英晶体上,品体部分表面会产生电荷,电荷量与所受压力成正比正压电效应:外加压力作应下在表面间产生电位差。 和逆压电效&#xff1a…

14 学习PID--步进电机梯形加减速实现原理

步进电机加减速使用的场景有那些呢?为什么要使用加减速呢? 硬件驱动细分器与软件的细分参数或定时器分频参数设置不当时启动电机时,会遇见步进电机有啸叫声但是不会转动,这是因为软件产生脉冲的频率大于步进电机的启动频率&#x…

Iptables深入浅出

1、iptables的基本概念 众所周知iptables是Linux系统下自带免费的包过滤防火墙。其实不然,iptables其实不是真正的防火墙,我们可以把它理解成一个客户端代理,用户通过iptables这个代理,将用户的安全设定执行到对应的”安全框架”…

【Oracle APEX开发小技巧1】转换类型实现显示小数点前的 0 以 及常见类型转换

在 apex 交互式式网格中,有一数值类型为 NUMBER,保留小数点后两位的项,在 展示时小数点前的 0 不显示。 效果如下: 转换前: m.WEIGHT_COEFFICIENT 解决方案: 将 NUMBER(20,2&#xf…

Python进阶:从函数到文件的编程艺术!!!

第二章:Python进阶 模块概述 函数是一段可重复使用的代码块,它接受输入参数并返回一个结果。函数可以用于执行特定的任务、计算结果、修改数据等,使得代码更具模块化和可重用性。 模块是一组相关函数、类和变量的集合,它们被封…

Android Media Framework(六)插件式编程与OMXStore

OpenMAX IL Spec阅读到上一节就结束了,这一节开始正式进入到Framework阅读阶段,我们将了解OpenMAX框架是如何与Android Framework连接的。 1、插件式编程 插件式编程(Plugin-based Programming)是一种软件开发模式,它…