ViT模型复现项目实战

news2024/11/7 19:16:04

项目源码获取方式见文章末尾! 600多个深度学习项目资料,快来加入社群一起学习吧。

《------往期经典推荐------》

项目名称
1.【基于CNN-RNN的影像报告生成】
2.【卫星图像道路检测DeepLabV3Plus模型】
3.【GAN模型实现二次元头像生成】
4.【CNN模型实现mnist手写数字识别】
5.【fasterRCNN模型实现飞机类目标检测】
6.【CNN-LSTM住宅用电量预测】
7.【VGG16模型实现新冠肺炎图片多分类】
8.【AlexNet模型实现鸟类识别】
9.【DIN模型实现推荐算法】
10.【FiBiNET模型实现推荐算法】
11.【钢板表面缺陷检测基于HRNET模型】

1. 项目简介

本项目的目标是复现Vision Transformer(ViT)模型,通过深入理解其核心架构和应用,探索其在图像分类任务中的性能表现。ViT模型是近年来在视觉任务中取得突破性进展的深度学习模型之一,核心思想是将Transformer这种原本应用于自然语言处理的模型引入到计算机视觉领域,解决了传统卷积神经网络(CNN)在处理全局信息时的局限性。ViT模型通过将输入图像划分为若干固定大小的图块(patches),再将这些图块展平并转换为序列形式输入Transformer模型,从而捕捉图像中的长距离依赖关系。这种方法克服了CNN的局部感受野限制,在大规模数据集上取得了比CNN更优的效果。本项目通过复现ViT模型,帮助用户深入理解其在计算机视觉中的应用及优势,尤其是在图像分类、目标检测等领域的实际表现。

2.技术创新点摘要

将Transformer引入计算机视觉领域:ViT模型将Transformer这种原本用于自然语言处理的架构引入到计算机视觉领域,摒弃了传统的卷积神经网络(CNN)。这种创新性设计突破了CNN局部感受野的限制,能够更好地捕捉图像中的长距离依赖关系。

图像分割为Patch的处理方法:ViT模型通过将输入图像划分为固定大小的图块(Patch),然后将这些图块展平并作为序列输入到Transformer中进行处理。这种处理方式与传统CNN处理整幅图像的方式不同,允许ViT模型能够对全局信息进行更加灵活的建模。

使用Multi-Head Attention机制捕捉全局依赖:模型中采用了多头自注意力机制(Multi-Head Attention),能够并行处理不同位置的图像信息,有效捕捉全局依赖关系。这使得模型可以在多个头上关注图像的不同部分,增强了模型对复杂场景的理解能力。

更高效的训练和推理:ViT模型相比CNN在大规模数据集上训练时效率更高,尤其是在处理高分辨率图像和复杂任务时展现出了显著的优势。这得益于其Transformer架构的优势,使得模型在图像分类任务中的表现优于传统的卷积神经网络。

权重初始化和自监督预训练:代码中还展示了对ViT模型权重初始化的优化方案,通过自监督预训练(self-supervised pretraining)技术,进一步提升了模型的泛化能力。

3. 数据集与预处理

在本项目中使用的数据集为经典的图像分类数据集,主要用于评估ViT模型在图像分类任务中的表现。常见的数据集包括ImageNet等大规模数据集,这些数据集具有类别丰富、样本数量大、图像分辨率高等特点。项目中选用的数据集包含多种类别的图片,每个类别的样本数较为均衡,能够为模型提供丰富的特征信息,帮助模型学习更具泛化能力的特征表示。

数据预处理流程是模型训练中至关重要的一环,确保输入数据的质量和一致性。首先,对于每张输入图片,进行了统一的图像尺寸调整,确保所有图像都能适配模型的输入要求。具体来说,ViT模型通常将图片划分为固定大小的图块(例如16×16像素),因此在预处理阶段,首先需要将图像缩放到指定大小。

接下来,应用了常见的归一化操作,将像素值缩放到[0, 1]或[-1, 1]区间。这有助于加快模型的收敛速度,并防止梯度消失或爆炸。此外,归一化还可以减少各特征间的量纲差异,提高模型的鲁棒性。

为了增强模型的泛化能力,数据增强技术也在预处理阶段被广泛应用。常见的数据增强方法包括随机裁剪、水平翻转、色彩抖动和旋转等操作。这些增强方法通过生成不同的图像变体,扩大了训练数据的多样性,减少了模型过拟合的风险。

在这里插入图片描述

4. 模型架构

该项目采用了Vision Transformer(ViT)模型,其模型结构由多个Transformer块组成,具体如下:

  • 输入层:输入图像被划分为固定大小的图块(Patch)。假设输入图像大小为 H×W×C,其中 H 为高度,W 为宽度,C为通道数。每个图像被分割为 N=HP×WP个图块, P 是每个Patch的大小。
  • Patch Embedding Layer:图块被展平,并通过一个线性投影层映射到固定维度的嵌入空间中。假设线性投影的输出维度为 D,则Patch的表示为:

Z 0 = [ x 1 E ; x 2 E ; …   ; x N E ] + E p o s Z_0 = [x_1E; x_2E; \dots; x_NE] + E_{pos} Z0=[x1E;x2E;;xNE]+Epos

  • 其中,xi 是第 iii 个Patch, E是可学习的嵌入矩阵, Epos是位置编码,确保模型能够捕捉Patch的相对位置。

  • Transformer Block:每个Transformer块包含以下部分:

    • Layer Normalization:对输入进行标准化处理。
    • 多头自注意力机制(Multi-Head Attention) Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V 其中 Q, K, V 分别是查询矩阵、键矩阵和值矩阵, dk是键的维度。
    • 前馈神经网络:包含两个线性层,中间有一个激活函数(通常为GELU)。公式如下: FFN ( x ) = GELU ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{GELU}(xW_1 + b_1)W_2 + b_2 FFN(x)=GELU(xW1+b1)W2+b2
  • 分类层:输出层是一个线性分类器,输入的是Transformer最后一层的输出(即第一个Token的表示),用于图像分类任务。

  1. 模型的整体训练流程

模型的训练过程分为以下几步:

  • 前向传播:将图像输入模型,通过各层的处理,输出分类结果。
  • 损失函数:使用交叉熵损失(CrossEntropy Loss)计算模型预测结果与真实标签之间的误差。公式为:

L = − ∑ i = 1 N y i log ⁡ ( y i ^ ) L = - \sum_{i=1}^{N} y_i \log(\hat{y_i}) L=i=1Nyilog(yi^)

  • 其中 yi是实际标签,yi^ 是模型预测的概率分布。
  • 反向传播:通过计算梯度来更新模型的参数,优化目标是最小化损失函数。
  • 评估指标:训练过程中主要采用准确率(Accuracy)作为评估指标,计算模型正确预测样本的比例:

Accuracy = 正确分类的样本数 总样本数 \text{Accuracy} = \frac{\text{正确分类的样本数}}{\text{总样本数}} Accuracy=总样本数正确分类的样本数

通过该架构,模型在处理图像分类任务时展现了较强的全局建模能力,有效捕捉图像中的长距离依赖关系

5. 核心代码详细讲解

1. 数据预处理

data_transform = transforms.Compose(
    [transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
  • transforms.Resize(256):将输入图像的最小边缩放到256像素。
  • transforms.CenterCrop(224):从缩放后的图像中取224×224像素的中心部分。这是标准的图像分类输入尺寸。
  • transforms.ToTensor():将图像从PIL格式转换为PyTorch的张量格式,并且像素值被归一化到 [0, 1]。
  • transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]):对每个通道进行归一化,将图像的像素值缩放到[-1, 1]之间,方便模型训练。

2. Patch Embedding Layer

class PatchEmbed(nn.Module):def init(self, img_size=224, patch_size=16, in_c=3, embed_dim=768):super().
__init__
()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
  • img_size=224:输入图像的大小为224×224像素。
  • patch_size=16:图像被分割为16×16的Patch。
  • nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size):使用卷积层将图像分割为不重叠的Patch,并将其投影到嵌入维度(embed_dim),这里通过卷积的步幅等于Kernel Size实现图块划分。

3. Transformer Block

class Block(nn.Module):def init(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_ratio=0., attn_drop_ratio=0., drop_path_ratio=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super(Block, self).
__init__
()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop_ratio=attn_drop_ratio, proj_drop_ratio=drop_ratio)
        self.drop_path = DropPath(drop_path_ratio) if drop_path_ratio > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop_ratio)
  • self.norm1 = norm_layer(dim):Layer Normalization层,用于对输入进行标准化。
  • self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias):多头自注意力机制(Attention),通过多头机制捕捉图像中不同区域的相互关系。
  • self.mlp = Mlp(...):多层感知机(MLP),包含激活函数GELU及两个线性层。

4. 前向传播

def forward(self, x):
    x = x + self.drop_path(self.attn(self.norm1(x)))
    x = x + self.drop_path(self.mlp(self.norm2(x)))return x
  • x = x + self.drop_path(self.attn(self.norm1(x))):首先对输入进行Layer Normalization,然后通过Attention层计算注意力得分,最后使用残差连接(Residual Connection)保留输入信息。
  • x = x + self.drop_path(self.mlp(self.norm2(x))):第二步是对输出进行标准化,并通过MLP层处理,同样使用残差连接。

5. 模型训练流程

for step, data in enumerate(data_loader):
    images, labels = data
    pred = model(images.to(device))
    loss = loss_function(pred, labels.to(device))
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
  • images, labels = data:从数据加载器中获取一批图像和对应的标签。
  • pred = model(images.to(device)):将图像输入模型,得到预测结果。
  • loss = loss_function(pred, labels.to(device)):计算预测结果和实际标签之间的损失,这里使用的是交叉熵损失。
  • loss.backward():通过反向传播算法计算损失的梯度。
  • optimizer.step():更新模型的参数,使损失最小化。
  • optimizer.zero_grad():清除上一步的梯度,以防止累积。

6. 模型评估

@torch.no_grad()def evaluate(model, data_loader, device):
    model.eval()for step, data in enumerate(data_loader):
        images, labels = data
        pred = model(images.to(device))
        pred_classes = torch.max(pred, dim=1)[1]
        accu_num += torch.eq(pred_classes, labels.to(device)).sum()return accu_num.item() / sample_num
  • model.eval():将模型置于评估模式,关闭Dropout等随机操作。
  • pred_classes = torch.max(pred, dim=1)[1]:通过取预测值的最大索引,得到模型的预测类别。
  • accu_num += torch.eq(pred_classes, labels.to(device)).sum():计算模型在当前批次中的预测正确数。

6. 模型优缺点评价

优点

  1. 全局信息捕捉能力强:Vision Transformer(ViT)通过自注意力机制能够有效捕捉图像中不同区域之间的长距离依赖关系,与传统的卷积神经网络(CNN)相比,ViT能够更好地理解全局特征。
  2. 较少依赖卷积操作:ViT抛弃了CNN中的卷积操作,减少了对局部感受野的依赖,适合处理高分辨率图像和大规模数据。
  3. 扩展性强:ViT架构灵活,可以通过增加Transformer的深度和宽度来提高模型的能力,在大数据集上表现突出,尤其在大规模预训练后迁移到下游任务时有较好的表现。

缺点

  1. 数据需求高:与传统CNN相比,ViT对大规模数据的依赖更强。如果数据量不足,ViT容易出现过拟合,难以学习到有效的特征。
  2. 训练成本高:Transformer模型计算复杂度高,训练过程中占用大量的计算资源,尤其是在较深的网络结构下,显著增加了计算时间和内存消耗。
  3. 不适合小型数据集:在小型数据集上,ViT的表现不如CNN,因为缺乏丰富的卷积特征提取能力。

改进方向

  1. 模型结构优化:可以在ViT中加入混合架构,如结合卷积层与Transformer层,使模型既具备局部特征提取能力,又能有效捕捉全局信息。
  2. 超参数调整:可以通过调节模型的深度、宽度、注意力头的数量以及学习率等超参数,找到适合特定任务的最佳模型配置。
  3. 更多数据增强方法:为减少对大规模数据的依赖,可以引入更多的数据增强技术,如CutMix、MixUp等,提高模型的泛化能力。
  4. 预训练技术:通过更大规模的自监督学习进行预训练,有助于提升ViT在下游任务中的表现,尤其是小数据集的迁移能力。

↓↓↓更多热门推荐:
Densenet模型花卉图像分类
ResNet18模型扑克牌图片预测
transformer模型写诗词

查看全部项目数据集、代码、教程点击下方名片

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2235213.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

是时候用开源降低AI落地门槛了

过去三十多年,从Linux到KVM,从OpenStack到Kubernetes,IT领域众多关键技术都来自开源。开源技术不仅大幅降低了IT成本,也降低了企业技术创新的门槛。 那么,在生成式AI时代,开源能够为AI带来什么?…

【C++打怪之路Lv13】- “继承“篇

🌈 个人主页:白子寰 🔥 分类专栏:重生之我在学Linux,C打怪之路,python从入门到精通,数据结构,C语言,C语言题集👈 希望得到您的订阅和支持~ 💡 坚持…

数据特征工程:如何计算Teager能量算子(TEO)? | 基于SQL实现

目录 0 TKEO能量算子 1 数据准备 2 特征求解 3 小结 0 TKEO能量算子 TEO(Teager能量算子),由Kaiser于1990年代提出的非线性分析方法(参见Kaiser, 1990; 1993),是一种有效的非线性信号处理工具,它能即时反映信号能量的变化。通过计算相邻采样点的值,TEO能够迅速跟…

淘宝/天猫探店大冒险:用taobao.item_search_shop API把宝贝一网打尽

想象一下,你是一位勇敢的探险家,手拿藏宝图(店铺ID),准备潜入神秘的淘宝/天猫店铺,寻找那些隐藏在角落里的宝贝。今天,我们要用taobao.item_search_shop API这张神奇的藏宝图,带你走…

D60【python 接口自动化学习】- python基础之数据库

day60 数据库定义 学习日期:20241106 学习目标:MySQL数据库-- 128:数据库定义 学习笔记: 无处不在的数据库 数据库如何存储数据 数据库管理系统(数据库软件) 数据库和SQL的关系 总结 数据库就是指数据…

2024年最佳解压软件推荐:轻松管理压缩文件的必备工具

在当今数字化时代,文件的传输和存储变得日益频繁,解压软件在文件管理中扮演着至关重要的角色。 随着数据量的不断增长,大文件的压缩和解压需求也越来越高。解压软件能够将大容量的文件压缩成较小的体积,便于存储和传输&#xff0…

Kubernetes的基本构建块和最小可调度单元pod-0

文章目录 一,什么是pod1.1pod在k8s中使用方法(1)使用方法一(2)使用方法二 1.2pod中容器的进程1.3pod的网络隔离管理(1)pause容器的作用 1.4 Pod分类:(1)自主式…

vue实现天地图电子围栏

一、文档 vue3 javascript WGS84、GCj02相互转换 天地图官方文档 注册登录然后申请应用key&#xff0c;通过CDN引入 <script src"http://api.tianditu.gov.cn/api?v4.0&tk您的密钥" type"text/javascript"></script>二、分析 所谓电子围…

QT 实现绘制汽车仪表盘

1.界面实现效果 以下是具体的项目需要用到的效果展示,通常需要使用QPainter类来绘制各种图形和文本,包括一个圆形的仪表盘、刻度、指针和数字。 2.简介 分为以下几个部分,首先设置抗锯齿 painter.setRenderHint(QPainter::Antialiasing)。 QPainter p(this);p.setRender…

【网络】传输层协议TCP(下)

目录 四次挥手状态变化 流量控制 PSH标记位 URG标记位 滑动窗口 快重传 拥塞控制 延迟应答 mtu TCP异常情况 四次挥手状态变化 之前我们讲了四次挥手的具体过程以及为什么要进行四次挥手&#xff0c;下面是四次挥手的状态变化 那么我们下面可以来验证一下CLOSE_WAIT这…

阿里云docker安装禅道记录

docker network ls docker network create -d bridge cl_network sudo docker run --name zentao --restart always -p 9982:80 --networkcl_network -v /data/zentao:/data -e MYSQL_INTERNALtrue -d hub.zentao.net/app/zentao:18.5 升级禅道 推荐用按照此文档升级&a…

迈入国际舞台,AORO M8防爆手机获国际IECEx、欧盟ATEX防爆认证

近日&#xff0c;深圳市遨游通讯设备有限公司&#xff08;以下简称“遨游通讯”&#xff09;旗下5G防爆手机——AORO M8&#xff0c;通过了CSA集团的严格测试和评估&#xff0c;荣获国际IECEx及欧盟ATEX防爆认证证书。2024年11月5日&#xff0c;CSA集团和遨游通讯双方领导在遨游…

Win11家庭版 配置 WSL/Ubuntu+Docker详细步骤

最近换了台工作电脑&#xff0c;Windows系统的&#xff0c;想发挥下显卡的AI算算力&#xff0c;所以准备搞下docker环境&#xff0c;下面开始详细介绍&#xff1a; 1、准备系统 最开始是想安装Windows Docker Desktop的&#xff0c;奈何网络问题&#xff0c;死活不能下载镜像…

apache poi 实现下拉框联动校验

apache poi 提供了 DataValidation​ 接口 让我们可以轻松实现 Excel 下拉框数据局校验。但是下拉框联动校验是无法直接通过 DataValidation ​实现&#xff0c;所以我们可以通过其他方式间接实现。 ‍ 步骤如下&#xff1a; 创建一个隐藏 sheet private static void create…

LabVIEW扫描探针显微镜系统

开发了一套基于LabVIEW软件开发的扫描探针显微镜系统。该系统专为微观尺度材料的热性能测量而设计&#xff0c;特别适用于纳米材料如石墨烯、碳纳米管等的研究。系统通过LabVIEW编程实现高精度的表面形貌和热性能测量&#xff0c;广泛应用于科研和工业领域。 项目背景 随着纳…

【Python】强大的正则表达式工具:re模块详解与应用

强大的正则表达式工具&#xff1a;re模块详解与应用 在编程和数据处理中&#xff0c;字符串的处理是不可避免的一项任务。无论是从文本中提取信息、验证数据格式&#xff0c;还是进行复杂的替换操作&#xff0c;正则表达式&#xff08;Regular Expression&#xff0c;简称Rege…

Redis数据库测试和缓存穿透、雪崩、击穿

Redis数据库测试实验 实验要求 1.新建一张user表&#xff0c;在表内插入10000条数据。 2.①通过jdbc查询这10000条数据&#xff0c;记录查询时间。 ②通过redis查询这10000条数据&#xff0c;记录查询时间。 3.①再次查询这一万条数据&#xff0c;要求根据年龄进行排序&#…

今天要重新认识下注解@RequestBody

在Spring框架中&#xff0c;RequestBody是一个常用的注解&#xff0c;它用于将HTTP请求体中的数据绑定到控制器&#xff08;Controller&#xff09;处理方法的参数上。这个注解通常与RESTful Web服务一起使用&#xff0c;在处理POST或PUT请求时尤为常见&#xff0c;因为这些请求…

在vscode中如何利用git 查看某一个文件的提交记录

在 Visual Studio Code (VSCode) 中&#xff0c;你可以使用内置的 Git 集成来查看某个文件的提交历史。以下是具体步骤&#xff1a; 使用 VSCode 内置 Git 功能 打开项目&#xff1a; 打开你的项目文件夹&#xff0c;确保该项目已经是一个 Git 仓库&#xff08;即项目根目录下…

JavaScript 23种经典设计模式简介

23种JavaScript经典设计模式 JavaScript经典设计模式 通过之前的学习&#xff0c;我们知道设计模式是一种解决代码组织、代码复用和代码可维护性等问题的技术方法。它通过将代码以特定的方式组织起来&#xff0c;使代码结构更加清晰、可读性更高、易于维护和扩展。为了在开发…