DEiT实战:使用DEiT实现图像分类任务(一)

news2025/1/24 11:45:35

DEiT实战

  • 摘要
  • 安装包
    • 安装timm
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

DEiT是FaceBook在2020年提出的一篇Transformer模型。该模型解决了Transformer难以训练的问题,三天内使用4块GPU,完成了ImageNet的训练,并且没有使用外部数据,达到了SOTA水平。
DEiT提出的蒸馏策略只增加了对token的蒸馏,没有引入其他的重要架构。如下图:
在这里插入图片描述
蒸馏令牌与类令牌的使用类似:它通过自注意力与其他令牌交互,并在最后一层后由网络输出。蒸馏令牌允许模型从老师的输出中学习,就像在常规蒸馏中一样,同时与类令牌保持互补。这一点我们可以代码中找到:

 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

cls_tokens 是类令牌,dist_token 是蒸馏令牌,确实很像是,仔细看都没有找到差别。

 def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2

如果想让模型从Teacher模型里学习,将 if self.training:设置为true,这样我们就可以像论文中那样使用RegNet做Teacher,使用DeiT模型做Student去蒸馏x_dist,否则将两者做平均实现二者的互补。
timm里面的代码和官方的代码有所不通,不过逻辑上是一样的,下面是timm的代码:

 def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
        if pre_logits:
            return (x[:, 0] + x[:, 1]) / 2
        x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1])
        if self.distilled_training and self.training and not torch.jit.is_scripting():
            # only return separate classification predictions when training in distilled mode
            return x, x_dist
        else:
            # during standard train / finetune, inference average the classifier predictions
            return (x + x_dist) / 2

(终于搞明白了。憋了好几天了,直到看了官方的代码才理解。)
等我有时间了再写一篇使用外部模型蒸馏的教程。

这篇文章主要讲解如何使用DEiT完成图像分类任务,接下来我们一起完成项目的实战。本例选用的模型是deit_small_patch16_224和deit_small_distilled_patch16_224,在植物幼苗数据集上实现了96%和97%的准确率。deit_small_patch16_224是没有蒸馏token的操作,deit_small_distilled_patch16_224有蒸馏token的操作,从结果上看蒸馏还是有不错的效果。
论文链接:https://arxiv.org/abs/2012.12877v2
论文翻译:https://wanghao.blog.csdn.net/article/details/128180419?spm=1001.2014.3001.5502

DeiT测试结果:
在这里插入图片描述

在这里插入图片描述

DeiT_dist测试结果:

在这里插入图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现DEit模型和DEiT_dist模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?

安装包

安装timm

使用pip就行,命令:

pip install timm

本文实战用的timm里面的模型。

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:

class EMA():
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

加入到模型中。

# 初始化
ema = EMA(model, 0.999)
ema.register()

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    ema.update()

# eval前,apply shadow weights;eval之后,恢复原来模型的参数
def evaluate():
    ema.apply_shadow()
    # evaluate
    ema.restore()

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

DEiT_demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─mean_std.py
├─makedata.py
├─ema.py
├─train.py
├─train_dist.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练DEiT模型
train_dist.py:训练蒸馏策略的DEiT模型。

为了能在DP方式中使用混合精度,还需要在模型的forward函数前增加@autocast(),如果使用GPU训练导入包from torch.cuda.amp import autocast,如果使用CPU,则导入from torch.cpu.amp import autocast。
在这里插入图片描述

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/91358.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

mPEG-N3;mPEG-Azide;甲氧基聚乙二醇叠氮CAS:89485-61-0

叠氮化物/叠氮基官能化的甲氧基聚乙二醇(mPEG-N3)是一种单官能PEG衍生物,可用于修饰蛋白质,肽和其他材料。 叠氮化物基团可以在铜催化的水溶液中与炔烃反应。 也可以容易地还原成胺基。 名称 甲氧基聚乙二醇叠氮 mPEG-N3 别称 甲…

周志华 《机器学习初步》模型评估与选择

周志华 《机器学习初步》模型评估与选择 Datawhale2022年12月组队学习 ✌ 文章目录周志华 《机器学习初步》模型评估与选择一.泛化能力二.过拟合和欠拟合泛化误差 VS 经验误差过拟合 VS 欠拟合三.模型选择的三大问题如何获得测试结果:评估方法如何评估性能优劣&…

工厂设备管理中经常会遇到哪些问题?

我调查过上百家企业的设备管理问题,发现大家认为所有设备管理问题中,最典型的问题主要包括以下五个方面: 1)领导不重视管理 “生产量是最重要的”、“销售额是最重要”、“重ERP,轻现场管理”……等管理理念是企业中的…

镜像法的理解——工程电磁场 P9

模型一:无限大导体平面 此处有几点理解需要格外谈一下 1. 只有在有电力线的地方,才会产生电场的作用 2.对于下平面的分析,下平面如果存在电荷的话,必然存在电力线,那么从无穷远处做功到此处,必然会存在电…

Java网络多线程——UDP编程

UDP编程通信 基本介绍 类DatagramSocket和DatagramPacket【数据包/数据报】实现了基于UDP协议网络程序。UDP数据报通过数据报套接字DatagramSocket发送和接收,系统不保证UDP数据报一定能安全送到目的地,也不确信什么时候可以抵达。DatagramPacket对象封…

从「堆叠」到「降本」,智能汽车传感器颠覆性革命即将到来!

随着汽车智能化的演进,传感器的堆叠造成了整车成本的急剧上升。尤其是多传感器融合(摄像头、毫米波雷达和激光雷达)技术作为当下的主流趋势之一,焦点依然回到成本层面。 同时,传统的整车电子架构和计算能力的限制&…

Flutter 小技巧之快速理解手势逻辑

又到了小技巧系列更新时间,今天我们主要分享 Flutter 里的手势触摸逻辑,其实在很久之前我就写过 《面深入触摸和滑动原理》相关的源码分析文章,但是最近有人说源码分析看不懂,有没有简要好理解的,那么本篇就用更简单的…

[附源码]Node.js计算机毕业设计高校图书馆网站Express

项目运行 环境配置: Node.js最新版 Vscode Mysql5.7 HBuilderXNavicat11Vue。 项目技术: Express框架 Node.js Vue 等等组成,B/S模式 Vscode管理前后端分离等等。 环境需要 1.运行环境:最好是Nodejs最新版,我…

高通平台 5G RF调试总结

目录: 1.QRCT4的使用 2.RFC配置 3.5G CA 配置概括 4.RFPD 运行及错误分析 5.CA吞吐率问题分析 最新的5G HImalyaa平台RFC的配置方法和之前的平台发生了根本性的变化,主要体现在使用QRCT4工具来配置RFC XML文件,然后根据XML文件编译生成s…

MobileNetV3原理说明及实践落地

本文参考: pytorch实现并训练MobileNetV3 - 灰信网(软件开发博客聚合) 【神经网络】(16) MobileNetV3 代码复现,网络解析,附Tensorflow完整代码 - 代码天地 1 MobileNetV3与V1、V2对比 (1)Mob…

【LeetCode每日一题:1945. 字符串转化后的各位数字之和~~~模拟】

题目描述 给你一个由小写字母组成的字符串 s ,以及一个整数 k 。 首先,用字母在字母表中的位置替换该字母,将 s 转化 为一个整数(也就是,‘a’ 用 1 替换,‘b’ 用 2 替换,… ‘z’ 用 26 替换…

匿名浏览器是什么?为什么联盟营销需要借助匿名浏览器?

这段时间小伙伴们都对联盟营销很感兴趣,东哥也是陆陆续续出了两三篇相关的科普文章,今天继续给大家介绍匿名浏览器在联盟营销上的帮助,毕竟互联网时代,学会如何借助工具高效工作是很重要的。关于联盟营销的概念科普文章大家可以看…

学不会的python之通过某几个关键字排序、分组一个字典列表(列表中嵌套字典)

通过某个关键字排序、分组一个字典列表排序问题描述解决方案1.operator 模块的 itemgetter 函数2.lambda 表达式引申分组问题描述解决方案1.itertools.groupby() 函数2.defaultdict() 构建多值字典排序 问题描述 现在你有一个字典列表(列表中嵌套字典),你想要根据…

web 向 unity 传输文件流 blob 记录

场景:web 与unity 通信,向 unity 传输文件 二进制流。 由 unity 转换并下载文件。 流程: web 端将缓存的 blob 数据流读取为 base64 编码的数据 → 传给 unity, →unity 解码转换 base64 数据并下载。 web 端: 1、 将数据转换成…

【Axure教程】自定义审批流原型模板

审批流即审批流程,是对某项工作的审批活动的一系列有序组合。审批流在业务系统中担当者非常重要的角色,所以今天作者就教大家制作一个通用的自定也审批流的原型模板,方便大家日后的工作。 一、效果展示 1、可以根据业务需要添加多个审批节点…

QT学习笔记(中)

QT学习笔记(中) 文章目录QT学习笔记(中)P21 消息对话框P22 其他标准对话框P23 登录窗口界面和布局P24 控件 按钮组P25 QListWidget控件P26 QTreeWidget控件的使用P27 tableWidgetP28 其他常用控件介绍P30 自定义控件P31 QEventP32…

PyQt5 QtChart-折线图

PyQt5 QtChart-QLineSeries 折线图QLineSeriesQLineSeries QLineSeries类将数据序列显示为折线图,其核心代码: lineSeries QLineSeries() lineSeries.append(1, 3) lineSeries.append(5, 8) … chart.addSeries(lineSeries) 常用方法: set…

【linux】容器之代码自动发布-docker

一、分析 旧: 代码发布环境提前准备,以主机为颗粒度静态 新: 代码发布环境 多套,以容器为颗粒度编译 二、业务发布逻辑设计图 三、工具使用流程图 工具 gitgitlabjenkinstomcatmavenharbordocker 流程图 四、主机规划 五…

​智能化加速,「中国供应商」如何跨越规模化周期|高工观察

在过去的十年时间里,中国在智能电动汽车行业下了巨大的「赌注」,整个行业及其背后快速成长的本地化产业链生态系统成为新一轮汽车产业增长的新引擎。 与此同时,电动化、智能化技术的国产化突围,也让整个中国本土汽车产业链获得了…

SuperMap GIS的TIN地形数据处理QA

目录 一、TIN地形数据简介 二、TIN地形数据格式 三、TIN地形数据处理 3.1 导入数据集 3.2 生成TIN地形缓存 3.3 IDesktop场景加载TIN地形 3.4 发布服务 3.5 WebGL场景加载 3.5.1 viewer初始化加载 3.5.2 scene.open加载 四、可能遇到的报错及解决方案 问题一:多个TI…