pytorch 训练EfficientnetV2

news2024/12/23 6:23:43

文章目录

  • 前言
  • 一、数据摆放
  • 二、训练
  • 二、测试
  • 三、训练和测试结果
  • 总结


前言

  前不久用pytorch复现了efficientnetv2的网络结构,但是后边自己一直有其他事情再做,所以训练部分的文章拖到了现在。关于Efficientnet的部分文章链接可参考如下:
EfficientnetV1训练
Flask部署EfficientnetV1分类网络
pytorch构建EfficientnetV2网络结构
然后,本篇的训练代码也是基于v1的训练代码完成,所以和官方会有差距。


一、数据摆放

  训练数据摆放方式可以参考EfficientnetV1训练的文章,大致如下:
在这里插入图片描述
在train和val文件夹下又有各个类别的数据,这里拿猫狗分类来说,如下:
在这里插入图片描述

二、训练

  这里直接给了训练代码,部分注释已在代码中添加,仔细看会发现和V1的训练代码差不多,只是改了封装方式和几个参数而已。需要注意的是需要将上边提到的用pytorch复现的efficientnetV2网络结构copy至model.py中,摆放如下
在这里插入图片描述
代码详情如下:

from model import EfficientnetV2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
import os,time,argparse

device="cuda" if torch.cuda.is_available() else "cpu"
#数据处理
def process(opt):
    # 数据增强
    data_transforms = {
        'train': transforms.Compose([
            # transforms.Resize((self.imgsz, self.imgsz)),  # resize
            transforms.CenterCrop((opt.imgsz, opt.imgsz)),  # 中心裁剪
            transforms.RandomRotation(10),  # 随机旋转,旋转范围为【-10,10】
            transforms.RandomHorizontalFlip(p=0.2),  # 水平镜像
            transforms.ToTensor(),  # 转换为张量
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
        ]),
        "val": transforms.Compose([
            # transforms.Resize((self.imgsz, self.imgsz)),  # resize
            transforms.CenterCrop((opt.imgsz, opt.imgsz)),  # 中心裁剪
            transforms.ToTensor(),  # 张量转换
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    }

    # 定义图像生成器
    image_datasets = {x: datasets.ImageFolder(os.path.join(opt.img_dir, x), data_transforms[x]) for x in
                      ['train', 'val']}
    # 得到训练集和验证集
    trainx = DataLoader(image_datasets["train"], batch_size=opt.batch_size, shuffle=True, drop_last=True)
    valx = DataLoader(image_datasets["val"], batch_size=opt.batch_size, shuffle=True, drop_last=True)

    b = image_datasets["train"].class_to_idx  # id和类别对应
    print(b)
    return trainx,valx,b
#训练
def train(opt):
    start_time = (time.strftime("%m%d_%H%M", time.localtime()))
    save_weight = opt.save_dir + os.sep + start_time  # 保存路径
    os.makedirs(save_weight, exist_ok=True)
    model=EfficientnetV2(opt.model_type,opt.class_num).cuda()
    best_acc = 0
    best_epoch = 0  # 准确率最高的模型的训练周期
    model.train(True)
    # 优化器
    optimzer=optim.SGD(model.parameters(),lr=opt.lr,momentum=opt.m,weight_decay=0.0004)
    cross = nn.CrossEntropyLoss() #损失函数
    trainx, valx, b = process(opt)

    for ech in range(opt.epochs):
        optimzer1 = lrfn(ech, optimzer,opt.lr)

        print("----------Start Train Epoch %d----------" % (ech + 1))
        # 开始训练
        run_loss = 0.0  # 损失
        run_correct = 0.0  # 准确率
        count = 0.0  # 分类正确的个数

        for i, data in enumerate(trainx):

            inputs, label = data
            inputs, label = inputs.to(device), label.to(device)

            # 训练
            optimzer1.zero_grad()
            output = model(inputs)

            loss = cross(output, label)
            loss.backward()
            optimzer1.step()

            run_loss += loss.item()  # 损失累加
            _, pred = torch.max(output.data, 1)
            count += label.size(0)  # 求总共的训练个数
            run_correct += pred.eq(label.data).cpu().sum()  # 截止当前预测正确的个数
            # 每隔100个batch打印一次信息,这里打印的ACC是当前预测正确的个数/当前训练过的的个数
            if (i + 1) % 500 == 0:
                print('[Epoch:{}__iter:{}/{}] | Acc:{}'.format(ech + 1, i + 1, len(trainx), run_correct / count))

        train_acc = run_correct / count
        # 每次训完一批打印一次信息
        print('Epoch:{} | Loss:{} | Acc:{}'.format(ech + 1, run_loss / len(trainx), train_acc))

        # 训完一批次后进行验证
        print("----------Waiting Test Epoch {}----------".format(ech + 1))
        with torch.no_grad():
            correct = 0.  # 预测正确的个数
            total = 0.  # 总个数
            for inputs, labels in valx:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)

                # 获取最高分的那个类的索引
                _, pred = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += pred.eq(labels).cpu().sum()
            test_acc = correct / total
            print("批次%d的验证集准确率:" % (ech + 1), test_acc.cpu().detach().numpy())
        if best_acc < test_acc:
            best_acc = test_acc
            best_epoch = ech + 1

            torch.save(model, save_weight + os.sep + "best.pth")
        print(f'best epoch : {best_epoch}, best accuracy : {best_acc}')

#学习率设置
def lrfn(num_epoch,optim,lr):
    lr_start=lr
    max_lr=0.01
    lr_up_epoch = 5  # 学习率上升批次
    lr_sustain_epoch = 10  # 学习率保持不变
    lr_exp = .8  # 衰减因子
    if num_epoch < lr_up_epoch:  # 0-10个epoch学习率线性增加
        lr = (max_lr - lr_start) / lr_up_epoch * num_epoch + lr_start
    elif num_epoch < lr_up_epoch + lr_sustain_epoch:  # 学习率保持不变
        lr = max_lr
    else:  # 指数下降
        lr = (max_lr - lr_start) * lr_exp ** (num_epoch - lr_up_epoch - lr_sustain_epoch) + lr_start
    for param_group in optim.param_groups:
        param_group['lr'] = lr
    return optim


def parse_opt():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type",type=str,default="S",help="Model type") #模型选型,可选s,m,l,大小写均可
    parser.add_argument("--img-dir", type=str, default="", help="train image path")  # 数据集的路径
    parser.add_argument("--imgsz", type=int, default=480, help="image size")  # 图像尺寸
    parser.add_argument("--epochs", type=int, default=100, help="train epochs")  # 训练批次
    parser.add_argument("--batch-size", type=int, default=16, help="train batch-size")  # batch-size
    parser.add_argument("--class_num", type=int, default=2, help="class num")  # 类别数
    parser.add_argument("--lr",type=float,default=0.0001,help="Init lr") #学习率初始值
    parser.add_argument("--m", type=float, default=0.9, help="optimer momentum")  # 动量
    parser.add_argument("--save_dir", type=str, default="",
                        help="save models dir")  # 保存模型路径
    opt = parser.parse_known_args()[0]
    return opt



if __name__ == '__main__':
    opt=parse_opt()
    models=train(opt)

在train.py中,只需要在parse_opt()中选取个模型类型,如S、M、L等,其次,数据集路径、类别数、保存路径等参数均设置为自己的。
注意:代码中38行附近有个print(b),这个是类别列表,一定记住这个列表,在测试中会用,否则在测试时,类别可能都是错的。

二、测试

测试部分直接给出代码,其中第44行和第48行需要该为自己的测试数据路径和模型路径,代码如下

import torch
import torchvision
from PIL import Image
import cv2,glob,os,time
import shutil
from pathlib import Path

def expend_img(img,img_size=480,expand_pix=0):
    '''
    :param img: 图片数据
    :param fill_pix: 填充像素,默认为灰色,自行更改
    :return:
    '''
    h, w = img.shape[:2]
    if h > w and h >= img_size:  # 左右padding
        top_expand = 0
        bottom_expand = 0
        left_expand = int((h - w) / 2)
        right_expand = left_expand
        new_img = cv2.copyMakeBorder(img, top_expand, bottom_expand, left_expand, right_expand, cv2.BORDER_CONSTANT,
                                     value=expand_pix)
    elif w > h and w >= img_size:  # 上下padding
        left_expand = 0
        right_expand = 0
        top_expand = int((w - h) / 2)
        bottom_expand = top_expand
        new_img = cv2.copyMakeBorder(img, top_expand, bottom_expand, left_expand, right_expand, cv2.BORDER_CONSTANT,
                                     value=expand_pix)
    elif w < img_size and h < img_size:  # 四周padding
        left_expand = int((img_size - w) / 2)
        right_expand = left_expand
        top_expand = int((img_size - h) / 2)
        bottom_expand = top_expand
        new_img = cv2.copyMakeBorder(img, top_expand, bottom_expand, left_expand, right_expand, cv2.BORDER_CONSTANT,
                                     value=expand_pix)
    else:
        new_img = img

    new_img = cv2.resize(new_img, (img_size, img_size))
    return new_img

#模糊分类
if __name__ == '__main__':
    img_dir="" #测试数据路径
    img_list=glob.glob(img_dir+os.sep+"*.jpg")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    #加载模型
    model=torch.load("").to(device)
    model.eval()

    class_list=[]

    for imgpath in img_list:
        img=cv2.imread(imgpath)
        s=time.time()
        img=expend_img(img)

        #PIL
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))  # 注意时间
        # # img= img[:, :, ::-1].transpose(2, 0, 1).copy() #注意时间

        data_transorform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        img = data_transorform(img)
        pred_img = torch.reshape(img, (-1, 3, 480,480)).to(device)
        start_time=time.time()
        pred=model(pred_img)[0]
        pred = torch.nn.Softmax(dim=0)(pred)
        end_time=time.time()
        score, pred_id = torch.max(pred, dim=0)
        #预测类别
        pred_class=class_list[pred_id]
        e=time.time()

        print(f"{imgpath} is {pred_class},score is {score},inference time is {e-s}")
    print("Finished!")



三、训练和测试结果

这里只给出部分结果,如下:
训练12epoch时,自己有事情请要用服务器,所以中断了训练,此时的准确率为:
在这里插入图片描述
用该模型测试3000张数据时,结果为:
在这里插入图片描述
注意:训练和测试代码中,设置的imgsize均为480,可以根据自己需求进行更改


总结

  以上就是本篇的全部内容,如有问题,欢迎评论区交流。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/592907.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

一百二十、Kettle——用kettle把Hive数据同步到ClickHouse

一、目标 用kettle把hive数据同步到clickhouse&#xff0c;简单运行、直接全量导入数据 工具版本&#xff1a;kettle&#xff1a;8.2 Hive:3.1.2 ClickHouse21.9.5.16 二、前提 &#xff08;一&#xff09;kettle连上hive &#xff08;二&#xff09;kettle连上cli…

10. 数据结构之树

前言 之前介绍了顺序表的数据结构&#xff0c;包含队列&#xff0c;栈等&#xff0c;这种结构都是一对一的&#xff0c;但是现实生活中&#xff0c;经常会遇见一对多的数据结构&#xff0c;比如族谱&#xff0c;部门机构等&#xff0c;此时我们需要一个更复杂的数据结构来表示…

分布式系统概念和设计——(事务与并发控制)

分布式系统概念和设计 事务与并发控制 简介 事务的目标是在多个事务访问对象以及服务器面临崩溃的情况下&#xff0c;保证所有由服务器管理的对象始终维持在一个一致的状态上 事务是由客户定义的针对服务器对象的一组操作&#xff0c;组成为一个不可分割的单元&#xff0c;由…

Unity | HDRP高清渲染管线学习笔记:HDRP配置文件(HDRP Asset)

目录 一、Frame Settings&#xff08;帧设置&#xff09; 二、Volume 三、HDRP配置文件、帧设置和Volume之间的关系 四、HDRP配置文件 1.Rendering &#xff08;1&#xff09;Color Buffer Format&#xff08;颜色缓存格式&#xff09; &#xff08;2&#xff09;Lit Sh…

芭比Q了,现在的00后实在是太卷了.....

都说00后躺平了&#xff0c;但是有一说一&#xff0c;该卷的还是卷。 这不&#xff0c;前段时间我们公司来了个00后&#xff0c;工作都没两年&#xff0c;跳槽到我们公司起薪20K&#xff0c;都快接近我了。后来才知道人家是个卷王&#xff0c;从早干到晚就差搬张床到工位睡觉了…

掌握这个90%的人都不会的大屏技术,裁员、降薪与你无关

裁员话题时不时就被拉到热搜上溜几圈&#xff0c;一方面让各位打工人们焦虑恐惧失业风险&#xff0c;另一方面也能让各位从一波波裁员危机事件中吸取“经验”。例如&#xff0c;技术人员狂敲代码、业务人员猛冲业绩…该被裁的依旧如此&#xff0c;在当今你得具备点别人没有的技…

测评补单操作在美客多店铺及产品优化中的决定性角色:深度解读

许多经营美客多平台的商家有一种观念&#xff0c;他们认为美客多平台的规则与亚马逊有所区别。在美客多上&#xff0c;店铺比产品更重要&#xff0c;而且平台的竞争相对较小。因此&#xff0c;他们认为在美客多平台进行补单操作是不必要的。 然而&#xff0c;是否真的如此呢&a…

RF接口测试(1)

RF是做接口测试的一个非常方便的工具&#xff0c;我们只需要写好发送报文的脚本&#xff0c;就可以灵活的对接口进行测试。 做接口测试我们需要做如下工作&#xff1a; 1、拼接发送的报文 2、发送请求的方法 3、对结果进行判断 我们先按步骤实现&#xff0c;再进行RF操作的…

人效九宫格|三个提升路径,三种管理模式,让人效实时可量化

文&#xff5c;盖雅学苑‍‍ 本文共5202字 在经济高速发展的过去&#xff0c;企业更关注机遇&#xff0c;当经济发展速度进入新常态时&#xff0c;企业更关注效率。在盖雅工场发布的《企业人效管理白皮书》中的数据显示&#xff0c;69.9%的企业依旧将人效提升作为紧急事项&am…

Vue主界面精美模板分享

文章目录 &#x1f412;个人主页&#x1f3c5;Vue项目常用组件模板仓库&#x1f4d6;前言&#xff1a;&#x1f380;源码如下&#xff1a; &#x1f412;个人主页 &#x1f3c5;Vue项目常用组件模板仓库 &#x1f4d6;前言&#xff1a; 本篇博客主要提供vue组件之主页面组件源…

代码级质量技术之基本框架介绍

作者 | CQT&星云团队 一、背景 代码级质量技术&#xff1a;顾名思义为了服务质量更好&#xff0c;涉及到代码层面的相关技术&#xff0c;特别要指出的是&#xff0c;代码级质量技术不单纯指代码召回技术&#xff0c;如静态代码扫描、单元测试等。 研究代码级质量技术主要…

1.6 初探JdbcTemplate操作

一、JdbcTemplate案例演示 1、创建数据库与表 &#xff08;1&#xff09;创建数据库 执行命令&#xff1a;CREATE DATABASE simonshop DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci; 或者利用菜单方式创建数据库 - simonshop 打开数据库simonshop &#x…

边缘计算盒子在视觉分析领域的优势

边缘计算盒子在视觉分析领域有广泛的应用。边缘计算盒子是一种集成了计算、存储和网络连接功能的设备&#xff0c;通常部署在物理环境中的边缘位置&#xff0c;如工厂、城市、交通系统等。它们能够在离数据源更近的位置进行实时数据处理和分析&#xff0c;从而提供更低的延迟和…

使用Docker安装Kafka

第一步&#xff1a;使用下述命令从Docker Hub查找镜像&#xff0c;此处我们要选择的是zookeeper官网的镜像 docker search zookeeper 第二步&#xff1a;拉取zookeeper镜像 docker pull zookeeper:latest 第三步&#xff1a;启动zookeeper容器 docker run -d --name zookee…

微服务-Elasticsearch基础篇【内含思维导图】

Elasticsearch官网&#xff1a;欢迎来到 Elastic — Elasticsearch 和 Kibana 的开发者 | Elastic 注意&#xff1a;Elasticsearch官网访问和加载的耗时很长&#xff01;&#xff01;&#xff01; Lucene官网&#xff1a;Apache Lucene - Welcome to Apache Lucene 目录 一、E…

Docker基本操作与自定义镜像Docker-Compose与Docker镜像仓库

目录 一.基本操作 1.镜像操作 1.1.镜像名称 1.2.镜像命令 1.3.案例-拉取、查看镜像 1.4.案例-保存、导入镜像 2.容器操作 2.1.容器相关命令 2.2.案例-创建并运行一个容器 2.3.案例-进入容器&#xff0c;修改文件 2.4.小结 3.数据卷&#xff08;容器数据管理&#x…

挑选在线帮助文档协作工具的技巧与要点

随着互联网的发展&#xff0c;越来越多的公司和团队开始使用在线帮助文档协作工具来共同编辑和维护文档。这些工具可以让多个用户同时协作编辑同一篇文档&#xff0c;从而提高工作效率和减少沟通成本。然而&#xff0c;在选择在线帮助文档协作工具时&#xff0c;需要注意一些技…

低代码平台投票榜揭晓:这些平台最受欢迎

低代码平台是软件开发工具&#xff0c;允许用户快速轻松地创建和部署应用程序&#xff0c;只需最少的编程知识。对于寻求在不需要大量IT资源的情况下构建自定义应用程序的企业来说&#xff0c;这些平台非常有用。在本文中&#xff0c;我们将讨论低代码平台排行榜投票榜&#xf…

注解和反射复习

注解 注解&#xff1a;给程序和人看的&#xff0c;被程序读取&#xff0c;jdk5.0引用 内置注解 override:修饰方法&#xff0c;方法声明和重写父类方法&#xff0c; Deprecated:修饰&#xff0c;不推荐使用 suppressWarnings用来抑制编译时的警告,必须添加一个或多个参数s…

外贸客户背调的几种干货技巧

外贸人要想做到知己知彼&#xff0c;那背调是必不可少的。 有经验的外贸人会通过关键词、邮箱等开展模糊搜索&#xff0c;然而这种方式不光效率低&#xff0c;而且搜索到的信息也不全。今天小编分享的这几种背调组合工具&#xff0c;不光收集到的客户信息全面&#xff0c;而且…