FasterViT实战:使用FasterViT实现图像分类任务(一)

news2024/11/24 1:10:35

文章目录

  • 摘要
  • 安装包
    • 安装timm
    • 安装 grad-cam
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文翻译:https://blog.csdn.net/m0_47867638/article/details/131542132
官方源码:https://github.com/NVlabs/FasterViT

这是一篇来自英伟达的论文。FasterViT结合了CNN的快速局部表示学习和ViT的全局建模特性的优点。新提出的分层注意力(HAT)方法将具有二次复杂度的全局自注意力分解为具有减少计算成本的多级注意力。受益于基于窗口的高效自注意力。每个窗口都可以访问参与局部和全局表示学习的专用载体Token。在高层次上,全局的自注意力使高效的跨窗口通信能够以较低的成本实现。FasterViT在精度与图像吞吐量方面达到了SOTA Pareto-front。广泛地验证了它在各种CV任务上的有效性,包括分类、目标检测和分割。HAT可以用作现有网络的即插即用模块并增强它们。

在这里插入图片描述

这篇文章使用FasterViT完成植物分类任务,模型采用faster_vit_0_224向大家展示如何使用FasterViT。由于没有预训练模型,faster_vit_0_224在这个数据集上实现了86+%的ACC,如下图:

请添加图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现FasterViT模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 实现热力图可视化?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装 grad-cam

pip install grad-cam

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

InceptionNext_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  └─faster_vit.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练InceptionNext模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/761165.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Flask_使用flask_marshmallow序列化数据

代码如下: from flask import Flask from flask_marshmallow import Marshmallow from flask_sqlalchemy import SQLAlchemy from marshmallow import fieldsapp Flask(__name__) app.config["SQLALCHEMY_DATABASE_URI"] "mysqlpymysql://root:12…

JavaFx 用户界面控件2——ListView

1.列表显示ListView 下面是一个JavaFX ListView的示例代码和使用方法&#xff1a; public class ListViewExample extends Application {Overridepublic void start(Stage primaryStage) {// 创建一个可观察的列表&#xff0c;用于存储ListView中的数据ObservableList<Str…

【深度学习笔记】正则化与 Dropout

本专栏是网易云课堂人工智能课程《神经网络与深度学习》的学习笔记&#xff0c;视频由网易云课堂与 deeplearning.ai 联合出品&#xff0c;主讲人是吴恩达 Andrew Ng 教授。感兴趣的网友可以观看网易云课堂的视频进行深入学习&#xff0c;视频的链接如下&#xff1a; 神经网络和…

如何手动初始化项目目录结构,并在命令行用gradle编译运行项目

父目录 Android 开发入门 - wuyujin1997 文章目录 Intro纯手动手动创建项目目录结构源码gradle tasksgradle wrapper执行前&#xff1a;执行后&#xff1a;执行前后对比&#xff1a; gradle wrappergradlew 和 gradlew.batplugin java编译Java项目【重点】如何通过 gradle run …

【GAMES202】Real-Time Shadows1—实时阴影1

一、Shadow Mapping回顾 [计算机图形学]光线追踪前瞻&#xff1a;阴影图(前瞻预习/复习回顾)__Yhisken的博客-CSDN博客 关于Shadow Mapping我们在GAMES101中学过&#xff0c;如果不知道可以参考我的博客。 Shadow Mapping是光栅化时为了实现阴影的一种算法&#xff0c;而它实…

python将dataframe数据导入MySQL

文章目录 mysql操作pymysql操作导入数据并启动使用pandas导入MySQL数据库连接引擎使用to_sql方法pandas读取sqlread_sql_tableread_sql_query mysql操作 创建数据库test和table create database test;CREATE TABLE car (文章ID int, 链接 VARCHAR(255), 标题 VARCHAR(255),发…

ts中setState的类型

两种方法: 例子: 父组件 const [value, setValue] useState(); <ChildsetValue{setValue} />子组件 interface Ipros {setValue: (value: string) > void } const Child: React.FC<Ipros> (props) > {}

应用层协议设计及ProtoBuf

文章目录 一、协议概述二、消息的完整性三、协议设计3.1 协议设计实例IM即时通讯的协议设计nginx协议HTTP协议redis协议 3.2 序列化方法3.3 协议安全3.4 协议压缩3.5 协议升级 四、Protobuf4.1 安装编译4.2 工作流程4.3 标量数值类型4.4 编码原理4.4.1 Varints 编码4.4.2 Zigza…

soci源码解析

结构 use_type into_type statement backend 针对不同数据库后端的抽象 session

vue对象复制(使用es6对象扩展运算符,深拷贝)

vue3es6语法 直接上代码 const objA { name: 小飞, age: 18 };const objACopy { ...objA };console.log(对比objA与objACopy的引用地址是否相同);console.log(objA objACopy); //falseconsole.log(objA);console.log(objACopy);//对象包含对象&#xff0c;浅拷贝const objB …

pytorch cv自带预训练模型再微调

参考&#xff1a; https://pytorch.org/docs/0.4.1/torchvision/models.html https://zhuanlan.zhihu.com/p/436574436 https://blog.csdn.net/u014297502/article/details/125884141 Network Top-1 error Top-5 error AlexNet 43.45 20.91 VGG-11 30.98 11.37 VGG-13 30.07 …

动态规划完全背包之518零钱兑换 II

题目&#xff1a; 给你一个整数数组 coins 表示不同面额的硬币&#xff0c;另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额&#xff0c;返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 …

【Kafka中间件】ubuntu 部署kafka,实现Django生产和消费

原文作者&#xff1a;我辈李想 版权声明&#xff1a;文章原创&#xff0c;转载时请务必加上原文超链接、作者信息和本声明。 文章目录 前言一、Kafka安装1.下载并安装Java2.下载和解压 Kafka3.配置 Kafka4.启动 Kafka5.创建主题和生产者/消费者6.发布和订阅消息 二、KafkaDjang…

红黑树底层原理【白话版】

一、红黑树——特殊的平衡二叉搜索树 定义&#xff1a;红黑树是一种特殊的平衡二叉搜索树。我们用它来排列数据&#xff0c;并方便以后快速检索数据。 估计看到这句话&#xff0c;你就崩溃了&#xff0c;因为这话说了等于没说。 先观察这个图。 球要不是黑色&#xff0c;要不…

console的奇妙用法

console的奇妙用法 console.log是调试 JavaScript 代码的最佳方法之一。但是本文将介绍几个与console交互的更好方法。 在vscode或者的其他ide中输入console可以看到里边提供了非常多的方法。 虽然我们通常都是用console.log&#xff0c;但是使用其他可以使调试过程变得更加容…

分布式链路追踪

文章目录 1、背景2、微服务架构下的问题3、链路追踪4、核心概念5、技术选型对比6、zipkin 1、背景 随着互联网业务快速扩展&#xff0c;软件架构也日益变得复杂&#xff0c;为了适应海量用户高并发请求&#xff0c;系统中越来越多的组件开始走向分布式化&#xff0c;如单体架构…

流水灯——FPGA

文章目录 前言一、流水灯介绍二、系统设计1.模块框图2.RTL视图 三、源码四、效果五、总结六、参考资料 前言 环境&#xff1a; 1、Quartus18.0 2、vscode 3、板子型号&#xff1a;EP4CE6F17C8 要求&#xff1a; 每隔0.2s循环亮起LED灯 一、流水灯介绍 从LED0开始亮起到LED3又回…

如何定制自己的应用层协议?|面向字节流|字节流如何解决黏包问题?如何将字节流分成数据报?

前言 那么这里博主先安利一些干货满满的专栏了&#xff01; 首先是博主的高质量博客的汇总&#xff0c;这个专栏里面的博客&#xff0c;都是博主最最用心写的一部分&#xff0c;干货满满&#xff0c;希望对大家有帮助。 高质量干货博客汇总https://blog.csdn.net/yu_cblog/c…

基于ssm的社区生活超市的设计与实现

博主介绍&#xff1a;专注于Java技术领域和毕业项目实战。专注于计算机毕设开发、定制、文档编写指导等&#xff0c;对软件开发具有浓厚的兴趣&#xff0c;工作之余喜欢钻研技术&#xff0c;关注IT技术的发展趋势&#xff0c;感谢大家的关注与支持。 技术交流和部署相关看文章…

SpringBoot拦截器

一、SpringBoot拦截器介绍 Spring Boot中的拦截器是一种用于在处理请求之前或之后执行特定操作的组件。拦截器通常用于实现对请求进行预处理、日志记录、权限验证等功能。 在Spring Boot中&#xff0c;可以使用HandlerInterceptor接口来定义自己的拦截器&#xff0c;并通过配…