PoolFormer实战:使用PoolFormer实现图像分类任务(一)

news2025/1/19 14:40:09

摘要

论文:https://arxiv.org/abs/2111.11418
论文翻译:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/128281326
官方源码:https://github.com/sail-sg/poolformer
模型代码解析:https://blog.csdn.net/hhhhhhhhhhwwwwwwwwww/article/details/128475827
MetaFormer是颜水成大佬的一篇Transformer的论文,该篇论文的贡献主要有两点:第一、将Transformer抽象为一个通用架构的MetaFormer,并通过经验证明MetaFormer架构在Transformer/ mlp类模型取得了极大的成功。
第二、通过仅采用简单的非参数算子pooling作为MetaFormer的极弱token混合器,构建了一个名为PoolFormer。
在这里插入图片描述
这篇文章主要讲解如何使用PoolFormer完成图像分类任务,接下来我们一起完成项目的实战。本例选用的模型是poolformer_s24,在植物幼苗数据集上实现了97%的准确率。

在这里插入图片描述
在这里插入图片描述

通过这篇文章能让你学到:

  1. 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
  2. 如何实现PoolFormer模型实现训练?
  3. 如何使用pytorch自带混合精度?
  4. 如何使用梯度裁剪防止梯度爆炸?
  5. 如何使用DP多显卡训练?
  6. 如何绘制loss和acc曲线?
  7. 如何生成val的测评报告?
  8. 如何编写测试脚本测试测试集?
  9. 如何使用余弦退火策略调整学习率?
  10. 如何使用AverageMeter类统计ACC和loss等自定义变量?
  11. 如何理解和统计ACC1和ACC5?
  12. 如何使用EMA?
  13. 如果使用Grad-CAM 实现热力图可视化?

安装包

安装timm

使用pip就行,命令:

pip install timm

本文实战用的timm里面的模型。

安装 grad-cam

pip install grad-cam

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    Cutout(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])

])

需要导入包:from timm.data.mixup import Mixup,

定义Mixup,和SoftTargetCrossEntropy

  mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=0.1, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=12)
 criterion_train = SoftTargetCrossEntropy()

参数详解:

mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。

cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。

cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。

如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0

prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。

switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。

mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。

correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正

label_smoothing (float):将标签平滑应用于混合目标张量。

num_classes (int): 目标的类数。

EMA

EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:


import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn

_logger = logging.getLogger(__name__)

class ModelEma:
    def __init__(self, model, decay=0.9999, device='', resume=''):
        # make a copy of the model for accumulating moving average of weights
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device  # perform ema on different device from model if set
        if device:
            self.ema.to(device=device)
        self.ema_has_module = hasattr(self.ema, 'module')
        if resume:
            self._load_checkpoint(resume)
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def _load_checkpoint(self, checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        assert isinstance(checkpoint, dict)
        if 'state_dict_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['state_dict_ema'].items():
                # ema model may have been wrapped by DataParallel, and need module prefix
                if self.ema_has_module:
                    name = 'module.' + k if not k.startswith('module') else k
                else:
                    name = k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)
            _logger.info("Loaded state_dict_ema")
        else:
            _logger.warning("Failed to find state_dict_ema, starting from loaded model weights")

    def update(self, model):
        # correct a mismatch in state dict keys
        needs_module = hasattr(model, 'module') and not self.ema_has_module
        with torch.no_grad():
            msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                if needs_module:
                    k = 'module.' + k
                model_v = msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

加入到模型中。

#初始化
if use_ema:
     model_ema = ModelEma(
            model_ft,
            decay=model_ema_decay,
            device='cpu',
            resume=resume)

# 训练过程中,更新完参数后,同步update shadow weights
def train():
    optimizer.step()
    if model_ema is not None:
        model_ema.update(model)


# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)

针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!

项目结构

PoolFormer_Demo
├─data1
│  ├─Black-grass
│  ├─Charlock
│  ├─Cleavers
│  ├─Common Chickweed
│  ├─Common wheat
│  ├─Fat Hen
│  ├─Loose Silky-bent
│  ├─Maize
│  ├─Scentless Mayweed
│  ├─Shepherds Purse
│  ├─Small-flowered Cranesbill
│  └─Sugar beet
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py

mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练PoolFormer模型
cam_image.py:热力图可视化

为了能在DP方式中使用混合精度,还需要在模型的forward函数前增加@autocast(),如果使用GPU训练导入包from torch.cuda.amp import autocast,如果使用CPU,则导入from torch.cpu.amp import autocast。
在这里插入图片描述

计算mean和std

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder
import torch
from torchvision import transforms

def get_mean_and_std(train_data):
    train_loader = torch.utils.data.DataLoader(
        train_data, batch_size=1, shuffle=False, num_workers=0,
        pin_memory=True)
    mean = torch.zeros(3)
    std = torch.zeros(3)
    for X, _ in train_loader:
        for d in range(3):
            mean[d] += X[:, d, :, :].mean()
            std[d] += X[:, d, :, :].std()
    mean.div_(len(train_data))
    std.div_(len(train_data))
    return list(mean.numpy()), list(std.numpy())

if __name__ == '__main__':
    train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())
    print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

我们整理还的图像分类的数据集结构是这样的

data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data
│  ├─val
│  │   ├─Black-grass
│  │   ├─Charlock
│  │   ├─Cleavers
│  │   ├─Common Chickweed
│  │   ├─Common wheat
│  │   ├─Fat Hen
│  │   ├─Loose Silky-bent
│  │   ├─Maize
│  │   ├─Scentless Mayweed
│  │   ├─Shepherds Purse
│  │   ├─Small-flowered Cranesbill
│  │   └─Sugar beet
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob
import os
import shutil

image_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):
    print('true')
    #os.rmdir(file_dir)
    shutil.rmtree(file_dir)#删除再建立
    os.makedirs(file_dir)
else:
    os.makedirs(file_dir)

from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(train_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

for file in val_files:
    file_class=file.replace("\\","/").split('/')[-2]
    file_name=file.replace("\\","/").split('/')[-1]
    file_class=os.path.join(val_root,file_class)
    if not os.path.isdir(file_class):
        os.makedirs(file_class)
    shutil.copy(file, file_class + '/' + file_name)

完成上面的内容就可以开启训练和测试了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/128322.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

超店有数给新手Tik Tok选品的三个建议

给新手Tik Tok选品的三个建议 1)避开饱和的红海类目 比如食品类、美妆个护类的产品,因为竞争过于激烈,一般都是有经验的大卖家操作运营,因为他们会有稳定的供应链、以及足够的成本去支撑,短期盈利比较难,也就是为什么…

企业电子招投标采购系统源码之登录页面

信息数智化招采系统 服务框架:Spring Cloud、Spring Boot2、Mybatis、OAuth2、Security 前端架构:VUE、Uniapp、Layui、Bootstrap、H5、CSS3 涉及技术:Eureka、Config、Zuul、OAuth2、Security、OSS、Turbine、Zipkin、Feign、Monitor、Stre…

(八)大白话MySQL通过配置多个Buffer Pool来优化数据库的并发性能

文章目录1、多线程在访问Buffer Pool的时候需要加锁吗?2、多线程并发访问会加锁,数据库的性能还能好吗?3、MySQL的生产优化经验:多个Buffer Pool优化并发能力100、创作不易,更多章节,请扫码关注&#xff0c…

学而不固,择善固之-杰克教诲

学而不固,择善固之 我也是第一次知道这个词,是受“杰克”大佬教诲。2022-12-30 解释一 学习有时候不是为了马上用起来,有时候只是让我们不再固执,通过学习知道原来世界上还有这么一说,对很多事情保有好奇心。 不持成见地学习&…

spring中为什么要三级缓存?二级不行吗

这是我看过视频中最能解释的文字表达了 先说bean的创建过程:实例化->依赖注入->初始化 实例化之后会提前暴露到缓存,用于解决循环依赖问题。 以下的解释保证你能看懂: 为什么需要一级缓存ioc容器 总得有个地方放那些单例吧 为什么需…

【决策树】简单介绍+个人理解(一)

∙\bullet∙ 分类模型中除了贝叶斯决策规则,SVM,最近邻分类器,还有决策树 ∙\bullet∙ 决策树就是选一个属性,根据属性的不同取值,将样本划分为不同的类,不断重复下去,直到终止。在叶子节点处&a…

支持图文公式混排的题库软件,Word试卷直接入库

试卷电子化的难题是入库难,只有试卷入库,才能做到后续的监督,复习,错题本等功能。 目前题库软件众多,但大多数题库软件仅支持纯文本题库,而很多试卷都是包括公式,图形,排版复杂。 …

中国外文局文化传播中心借力vLive虚拟直播,打造国际汉文化云讲堂

文明因多样而交流,因交流而互鉴,因互鉴而发展。 近日,中国外文局文化传播中心组织的“中华文化国际传播云讲堂”活动成功举办,本次云讲堂以“世界汉学家看中国文化”为主题,邀请世界汉学家共同探讨汉文化,…

【Git】一文带你入门Git分布式版本控制系统(创建合并分支、解决冲突)

个人简介 👀个人主页: 前端杂货铺 🙋‍♂️学习方向: 主攻前端方向,也会涉及到服务端 📃个人状态: 在校大学生一枚,已拿多个前端 offer(秋招) 🚀未…

Kyligence 客户浦发银行、招商银行荣获金融业科技赋能业务创新突出贡献奖

近日,由《金融电子化》杂志社主办的“2022中国金融科技年会暨第十三届金融科技应用创新奖颁奖典礼”成功举办。Kyligence 服务客户上海浦东发展银行股份公司(以下简称浦发银行)项目「客户旅程万花筒」、招商银行股份有限公司(以下…

户外运动如何安全享受音乐、专业户外运动耳机推荐

想采摘成熟的柿子、苹果、冰糕或栗子吗?出去运动吧在这个不冷不热的金秋季节里,大自然的一切都在等着我们出户外去探险,要说今年的哪一个户外运动最引人注目,露营和登山总是不相上下,但是运动怎么能少了音乐的陪伴呢&a…

智慧楼宇数字孪生应用方案

智慧楼宇也称智能建筑、智能楼宇,是将建筑、物联网感知和控制及结构、系统、管理和服务等各方面的先进科技相互交融结合,是现代化新型建筑发展的必经阶段。通过数字孪生技术,可将楼宇设备之间、系统之间融合数据互通,为组成智慧楼…

PyTorch学习笔记 7.TextCNN文本分类

PyTorch学习笔记 7.TextCNN文本分类一、模型结构二、文本分词与编码1. 分词与编码器2. 数据加载器二、模型定义1. 卷积层2. 池化层3. 全连接层三、训练过程四、测试过程五、预测过程一、模型结构 2014年,Yoon Kim针对CNN的输入层做了一些变形,提出了文本…

Redis事件循环

Redis事件循环文件事件时间事件事件调度和执行客户端部分关于客户端输出缓冲区限制ServerCron周期函数服务器启动流程小结Redis服务器是一个事件驱动程序, 主要处理两类事件: 文件事件 (File Event) : 对套接字操作的抽象,服务器与客户端的通信过程会产生相应的文件…

Java 中的继承和多态

面向对象的三大特性:封装、继承、多态。在这三个特性中,如果没有封装和继承,也不会有多态。 那么多态实现的途径和必要条件是什么呢?以及多态中的重写和重载在JVM中的表现是怎么样?在Java中是如何展现继承的特性呢&am…

常用密码算法介绍

算法种类 根据技术特征,现代密码学可分为三类: 对称算法 说明:加密密钥和解密密钥相同,对明文、密文长度没有限制 子算法: 流密码算法:每次加密或解密一位或一字节的明文或密文 分组密码算法&#xff…

LiveGBS流媒体平台国标GB/T28181功能-国标流媒体服务平台作为上级接入海康大华华为宇视等下级平台及摄像头

LiveGBS国标流媒体服务平台作为上级接入海康大华华为宇视等下级平台及摄像头1、背景说明2、部署国标平台2.1、安装使用说明2.2、服务器网络环境2.3、信令服务配置3、监控摄像头设备接入3.1、海康GB28181接入示例3.2、大华GB28181接入示例3.3、华为IPC GB28181接入示例4、硬件NV…

mysql 存储过程实现从一张表数据迁移到另一种表

通过存储过程迁移数据: 创建表 CREATE TABLE test1 ( idp varchar(255) DEFAULT NULL, brandIdp varchar(255) DEFAULT NULL, namep varchar(1000) DEFAULT NULL, urlp varchar(1000) DEFAULT NULL ) ENGINEInnoDB DEFAULT CHARSETkeybcs2; INSERT INTO t…

2023美国大学生数学建模竞赛(MCM/ICM)报名流程指南

数模乐园作为国内美赛报名最大官方平台,为参加美赛的同学解决国际支付报名难的问题,为同学们省去大部分繁琐流程的同时还附赠纸质证书打印邮寄、美赛赛题解析、美赛专属礼包、赛题翻译等备赛资料 数模乐园已累计为10万同学完成了美赛辅助报名&#xff0…

Android 音视频编解码(三) -- 视频编码和H264格式原理讲解

Android 音视频编解码(一) – MediaCodec 初探 Android 音视频编解码(二) – MediaCodec 解码(同步和异步) 前面学习了 MediaCodec 的基本原理,以及如何解码,在学习MediaCodec 编码之前,先来学习视频是如何编码的,以及最常用的 H2…