SeaFormer实战:使用SeaFormer实现图像分类任务(一)

news2024/9/28 9:22:14

文章目录

  • 摘要
  • 安装包
    • 安装timm
    • 安装mmcv
    • 安装 grad-cam
  • 数据增强Cutout和Mixup
  • EMA
  • 项目结构
  • 计算mean和std
  • 生成数据集

摘要

论文翻译:https://blog.csdn.net/m0_47867638/article/details/130437649?spm=1001.2014.3001.5501
官方源码:https://github.com/fudan-zvg/SeaFormer

SeaFormer是一个轻量级的Transformers模型,最小的SeaFormer_T只有6M大小。设计了一种具有压缩轴向和细节增强的注意力模块,使其能够更好的在移动端应用,架构设计如下:
在这里插入图片描述
这篇文章使用植物分类任务,模型采用SeaFormer_T向大家展示如何使用SeaFormer。SeaFormer_T在这个数据集上实现了95%的ACC,如下图:

请添加图片描述
请添加图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现SeaFormer模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 实现热力图可视化?

如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。

安装包

安装timm

使用pip就行,命令:

pip install timm

mixup增强和EMA用到了timm

安装mmcv

pip install mmcv

安装 grad-cam

pip install grad-cam

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

SeaFormer_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─models
│  ├─__init__.py
│  └─seaformer.py
├─mean_std.py
├─makedata.py
├─SeaFormer_T_cls_68.1.pth
├─train.py
├─cam_image.py
└─test.py

models:来源官方代码,对面的代码做了一些适应性修改。增加了一些加载预训练,调用模型的逻辑。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
SeaFormer_T_cls_68.1.pth:预训练权重
train.py:训练PoolFormer模型
cam_image.py:热力图可视化

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/587079.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Vue2 创建 Vite 项目,新手教学

关于vite Vite是一种快速的现代化构建工具,可以显著提高Web应用程序的开发效率和性能。 以下是一些Vite的好处: 快速的冷启动:Vite使用原生ES模块解析器,在冷启动时会非常快速,不需要像Webpack一样构建整个应用程序。…

Linux输入输出重定向

目录 Linux输入输出重定向 Linux中的默认设备 输入输出重定向定义 输入输出重定向操作符 实用形式 标准输入、标准输出、标准错误 输出重定向案例 案例1 --- 输出重定向(覆盖) 案例2 --- 输出重定向(追加) 案例3 --- 错误…

chatgpt赋能python:Python中向上取整函数详解

Python中向上取整函数详解 对于Python中的向上取整运算,大家一定不会感到陌生。在FPython中,我们通常使用math.ceil()函数来对数值进行向上取整。本文将为大家详细介绍Python中的向上取整函数,以及如何在实践中应用。 什么是向上取整&#…

被黑客攻击了?无所谓,我会拔网线。。。

「作者简介」:CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」:对网络安全感兴趣的小伙伴可以关注专栏《网络安全入门到精通》 最近老是有粉丝问我,被黑客攻击了,一定要拔网线吗?还有…

C/S客户端核服务端-简单收发

一、程序 首先上程序 client端的程序 #include <arpa/inet.h> #include <netinet/in.h> #include <netinet/ip.h> #include <stdio.h> #include <stdlib.h> #include <strings.h> #include <sys/socket.h> #include <sys/type…

keep-alive 是 Vue 内置的一个组件,被用来缓存组件实例。

文章目录 简介注意点使用 keep-alive 有以下优缺点优点缺点 简介 keep-alive 是 Vue 内置的一个组件&#xff0c;被用来缓存组件实例。 使用 keep-alive 包裹动态组件时&#xff0c;被包裹的组件实例将会被缓存起来&#xff0c;而不会被销毁&#xff0c;直到 keep-alive 组件…

LSM零知识学习一、概念与框架机制

本文内容参考&#xff1a; LSM(Linux Security Modules)框架原理解析_lsm框架_pwl999的博客-CSDN博客 LSM相关知识及理解-布布扣-bubuko.com 一文了解Linux安全模块&#xff08;LSM&#xff09; - 嵌入式技术 - 电子发烧友网 在此特别致谢&#xff01; 一、什么是LSM LSM全…

HiFB 与Linux Framebuffer的对比

引言 HiFB和Linux Framebuffer是两种不同的图形缓冲区技术&#xff0c;它们在处理计算机图形显示方面有着重要的作用。以下是对这两种技术的简短定义&#xff1a; HiFB&#xff08;High-performance Intelligent FrameBuffer&#xff09;&#xff1a;HiFB是华为推出的一种高性…

Socket(五)

文章目录 1. 日志2. 如何记录日志 1. 日志 服务器要在无人看管的情况下运行很长时间&#xff0c;通常需要在很久以后对服务器中发生的情况进行调试&#xff0c;这很重要。由于这个原因&#xff0c;建议在存储服务器日志&#xff0c;至少要存储一段时间的日志。日志中通常希望记…

ARM微架构与程序编写

目录 1.流水线 2.指令流水线 3. 多核处理器​编辑 4. 工程搭建 4.1为Keil软件配置编译工具链 5.程序编写 5.1 数据处理指令 5.2 带标志位的加法ADC ADDS 5.3 跳转指令B\BL 5.4 单寄存器内存访问 5.5 批量寄存器内存访问 5.6 满减操作 1.流水线 2.指令流水线 3.…

算法基础学习笔记——⑭欧拉函数\快速幂\扩展欧几里得算法\中国剩余定理

✨博主&#xff1a;命运之光 ✨专栏&#xff1a;算法基础学习 目录 ✨欧拉函数 &#x1f353;求欧拉函数 : &#x1f353;筛法求欧拉函数 : ✨快速幂 ✨扩展欧几里得算法 ✨中国剩余定理 前言&#xff1a;算法学习笔记记录日常分享&#xff0c;需要的看哈O(∩_∩)O&#…

chatgpt赋能python:Python中的倒序输出方法

Python中的倒序输出方法 在Python中&#xff0c;倒序输出是一个经常用到的操作。倒序输出可以用于字符串、列表、元组等数据类型&#xff0c;帮助我们更方便地处理数据。 字符串的倒序输出 对于字符串&#xff0c;我们可以使用字符串切片的方法倒序输出。例如&#xff0c;我…

十二、Vben之Vue3+vite跨域代理地址实现

在vue2中使用proxy进行跨域的原理是:将域名发送给本地的服务器(启动vue项目的服务,loclahost:8080),再由本地的服务器去请求真正的服务器。 代码如下: 1.在proxy中设置要访问的地址,并重写/api为空的字符串,这里如果不重写,会相当于在代理的地址上默认加了/api,所以…

chatgpt赋能python:Python中安装jieba分词器

Python中安装jieba分词器 介绍 中文分词是文本挖掘中非常重要的一个环节&#xff0c;而jieba是Python中最受欢迎的中文分词器之一。jieba分词器是基于汉语词汇库进行分词&#xff0c;并支持多种分词模式&#xff0c;可以满足不同场景的分词需求。 本文将介绍如何在Python环境…

chatgpt赋能python:Python中如何安装pip

Python中如何安装pip 什么是pip&#xff1f; pip&#xff0c;全称pip installs packages&#xff0c;是一个Python包管理工具&#xff0c;可以用来安装、升级和卸载Python包。它广泛地应用于Python社区&#xff0c;可以帮助Python开发者快速地获取和分享Python代码。 安装pi…

对比 RS232,RS422,RS485

对比 RS232,RS422,RS485 首先&#xff0c; 串口、UART口、COM口、RJ45网口、USB口是指的物理接口形式(硬件)。TTL、RS-232、RS-485、RS-422是指的电平标准(电信号)。 RS232,RS422,RS485 对比表格 通信标准RS-232RS-422RS-485工作方式单端差分差分通信线数量4 地线52 地线3节…

《深入理解计算机系统(CSAPP)》第5章 优化程序性能 - 学习笔记

写在前面的话&#xff1a;此系列文章为笔者学习CSAPP时的个人笔记&#xff0c;分享出来与大家学习交流&#xff0c;目录大体与《深入理解计算机系统》书本一致。因是初次预习时写的笔记&#xff0c;在复习回看时发现部分内容存在一些小问题&#xff0c;因时间紧张来不及再次整理…

Java中如何判断是否为闰年

✨博主&#xff1a;命运之光 ✨专栏&#xff1a;Java经典程序设计 目录 ✨介绍 &#x1f353;引言&#xff1a;闰年的定义和在编程中的应用 &#x1f353;目的&#xff1a;介绍如何使用Java编写一个函数来判断年份是否为闰年 ✨闰年的条件 ✨提供数学原理和背景知识 &…

软考A计划-试题模拟含答案解析-卷十一

点击跳转专栏>Unity3D特效百例点击跳转专栏>案例项目实战源码点击跳转专栏>游戏脚本-辅助自动化点击跳转专栏>Android控件全解手册点击跳转专栏>Scratch编程案例 &#x1f449;关于作者 专注于Android/Unity和各种游戏开发技巧&#xff0c;以及各种资源分享&am…

牛客网刷题学习SQL(三)

SQL23 统计每个学校各难度的用户平均刷题数 首先分析题目&#xff1a; 想要计算一些参加了答题的不同学校、不同难度的用户平均答题量 不同学校&#xff1a; group by 学校 不同难度&#xff1a; group by 难度 平均答题量&#xff1a;注意用户去重&#xff0c;还有指定questi…