「解析」YOLOv5 classify分类模板

news2025/1/10 4:05:57

学习深度学习有些时间了,相信很多小伙伴都已经接触 图像分类、目标检测甚至图像分割(语义分割)等算法了,相信大部分小伙伴都是从分类入门,接触各式各样的 Backbone算法开启自己的炼丹之路。

但是炼丹并非全是 Backbone,更多的是各种辅助代码,而这部分公开的并不多,特别是对于刚接触/入门的人来说就更难了,博主当时就苦于没有完善的辅助代码,走了很多弯路,好在YOLOv5提供了分类、目标检测的完整代码,不同于目标检测,因数据集不同,对应的数据辅助代码也不兼容,图像分类就不会有这方面的影响,只需要更换下模型,设置下输出类即可。可谓相当的成熟,学者必备!!!

官方代码:https://github.com/ultralytics/yolov5

在这里插入图片描述
分类任务有四部分组成:tutorial说明,train、val、predict 脚本

对于有一定基础的小伙伴可以直接查看 tutorial 自行运行,如果遇到一些暂无解决的问题时再往下阅读!


train任务

整个 图像分类任务还是较为复杂的,内容略微庞大,一篇讲解不完,讲解不清的可以下方留言,较难问题博主再出新博文解释。

在这里插入图片描述

parse_opt() 函数

首先大家在学习代码时,一定要学会 debug 模型,这样才知道代码是如何运行的,一般从 if __name__ == "__main__": 开始进行。
首先是 def parse_opt(known=False): 解析配置参数

  1. parser.add_argument('--model', type=str, default='yolov5s-cls.pt', help='initial weights path')
    –model 参数是配置模型类型,从下面的解析 --model参数可以看出,如果 --model的值是模型权重名称/路径的话,直接加载到模型model,如果–model是torchvision模型库的,将从torchvision库中读取, 如果都没有的话,将以错误输出。
    所以 --model 一定要是 模型权重名称/路径,并且需要能够读取得到才可以。亦可以是torchvision模型库中的模型名称也可以(可以通过 torchvision.models.__dict_ 查看安装的torchvision封装了哪些模型库)
    此外 torchvision.models.__dict__[opt.model](weights='IMAGENET1K_V1' if pretrained else None) 代码并不适用于所有版本的 torchvision模型,还是需要进入 torchvision.model下的具体模型代码中查看 调用方法,否则会出现错误。
    在这里插入图片描述

  2. parser.add_argument('--data', type=str, default='mnist', help='cifar10, cifar100, mnist, imagenet, etc.')
    –data 可以是数据集的路径,也可以是数据集而名称, 只是数据集名称必须是 ultralytics 公开的数据集才可以,比如:Classification:Caltech 101、Caltech 256、CIFAR-10、CIFAR-100、Fashion-MNIST、ImageNet、ImageNet-10、Imagenette、Imagewoof、MNIST
    在这里插入图片描述
    如果是自定义的数据集,需要注意的是每一类的所有数据需要放到同一个文件夹下面,如同 cifar10 数据集一样,在 train/val/test 文件夹下分别建立每一类的子文件夹,其中可以存放全部图片,也可以有多层嵌套路径,注意:train/val/test下的文件夹名称和数量 要保持一致,否则训练出来的指标会很差
    在这里插入图片描述

  3. parser.add_argument('--epochs', type=int, default=10)
    就是训练的迭代轮数

  4. parser.add_argument('--imgsz', '--img', '--img-size', type=int, default=128, help='train, val image size (pixels)')
    训练时 图片的尺寸大小

  5. parser.add_argument('--nosave', action='store_true', help='only save final checkpoint')
    不保存中间每个epoch的权重,如果需要保存的话,将其设置为 False

  6. parser.add_argument('--cache', type=str, nargs='?', const='ram', help='--cache images in "ram" (default) or "disk"')
    选择数据的读取方式,ram方式为一次性将所有的数据读取到内存里,以为内存与显存的传输速度高,因此训练市场可以极大降低,前提是内存够大,如果没有足够大的内存的话,可以算法disk硬盘读取,效率略低

  7. parser.add_argument('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
    选择训练设备,可以选择:“cup, mps, cuda”(MPS:Apple Metal Performance Shaders)

  8. parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
    数据集加载时的线程数

  9. parser.add_argument('--project', default=ROOT / 'runs/train-cls', help='save to project/name')
    项目保存路径及名称

  10. parser.add_argument('--name', default='exp', help='save to project/name')
    每次训练的子文件名

  11. parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
    如果已经存在保存文件名/路径,可以覆盖保存

  12. parser.add_argument('--pretrained', nargs='?', const=True, default=True, help='start from i.e. --pretrained False')
    是否使用预训练权重(前提是必须是torchvision中的模型,官方提供预训练接口的模型才有用)

  13. parser.add_argument('--optimizer', choices=['SGD', 'Adam', 'AdamW', 'RMSProp'], default='Adam', help='optimizer')
    优化器选择,此处官方配置好了 [‘SGD’, ‘Adam’, ‘AdamW’, ‘RMSProp’] 优化器,如果需要其他优化器,需用小伙伴自行配置

  14. parser.add_argument('--lr0', type=float, default=0.001, help='initial learning rate')
    优化器的初始学习率

  15. parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing epsilon')
    label-smoothing 方法,对 label进行 smoothing 处理

  16. parser.add_argument('--cutoff', type=int, default=None, help='Model layer cutoff index for Classify() head')
    裁切模型的 classify分支 的层数,model.model = model.model[:cutoff]

  17. parser.add_argument('--dropout', type=float, default=None, help='Dropout (fraction)')
    随机失效部分神经元,dropout处理

  18. parser.add_argument('--verbose', action='store_true', help='Verbose mode')
    冗余模式,记录中间的模型日志

  19. parser.add_argument('--seed', type=int, default=0, help='Global training seed')
    全局随机种子

  20. parser.add_argument('--local_rank', type=int, default=-1, help='Automatic DDP Multi-GPU argument, do not modify')
    如果小伙伴有多卡,可以采用,此方法可以自动调用多个显卡的资源,即DDP 模式,-1 为不采用


train() 函数

train() 函数前面都是一些模型配置

  1. 模型训练保存路径,以及配置训练日志,默认情况下,模型训练保存 一个 last.pt 和 best.pt在这里插入图片描述
  2. 数据集下载,如果是官方的数据集,直接 对 --data 设置数据集名称即可(完成路径也是可以的),如果是自己的数据集,需要设置数据集路径,只需要给到 train 的上一级目录即可
    在这里插入图片描述
  3. 数据集构建,此处将读取数据集的类别数以及加载数据集,此处默认是以 test 为验证集的,如果没有test 备份选择 val。如果需要用 val 当验证集,手动改为 val即可。再次提示:train 下的文件夹名称和数量需要和 验证集下的保持一致,否则模型性能很低,且无法提升(惨痛的教训!)
    在这里插入图片描述
  4. 构建模型,此处需要注意一点,作为分类模型,模型的输出层必须和数据集的类别数量保持一致,必须!!!
    如果不使用 torchvision中的模型,只需要将 model 赋值为自己的模型即可
    在这里插入图片描述
  5. 日志保存模型等信息,以及加载 数据和标签,此处的数据加载器采用的是迭代器方式,因此采用 nest(iter());然后是优化器设置,学习率、调度器(scheduler)设置 和 EMA配置;最后是损失函数criterion。
    在这里插入图片描述
  6. 进行完上面所有的参数配置,真正的模型训练还在下面这个循环里
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/971396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Redis——数据结构介绍

Redis是一个key-value的数据库,key一般是String类型,不过value的类型是多样的: String:hello wordHash:{name:"Jack",age:21}List:[A -> B -> C -> D]Set:{A,B,C}SortedSet…

盖子的c++小课堂——第二十二讲:2维dp

前言 大家好,我又来更新了,今天终于有时间了aaaaaaaa 破500粉了,我太高兴了哈哈哈哈哈哈(别看IP地址,我去北京旅游回来了,他没改回来),然后我马上就成为创作者一年了,希…

航空货运站AAT EDI 解决方案

Asia Airfreight Terminal (AAT)是一个航空货运站,总部设在香港国际机场,是亚洲首屈一指的运输枢纽。 AAT旨在成为世界上最好的航空货运站,将围绕成本竞争力和服务效率,客户服务,创新和员工承诺的业务战略来构建。 | 业…

Gradle下载安装教程

1、Gradle 入门 1.1、Gradle 简介 Gradle 是一款Google 推出的基于 JVM、通用灵活的项目构建工具,支持 Maven,JCenter 多种第三方仓库;支持传递性依赖管理、废弃了繁杂的xml 文件,转而使用简洁的、支持多种语言(例如:java、groo…

grpc + springboot + mybatis-plus 动态配置数据源

前言 这是我在这个网站整理的笔记,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 grpc springboot mybatis-plus 动态配置数据源 一. 源码解析1.1 项目初始化1.2 接口请求时候 二. web应用三. grpc应用程序 一. 源码解析 1.1 项目初…

Logback日志记录只在控制台输出sql,未写入日志文件【解决】

原因:持久层框架对于Log接口实现方式不一样,日记记录的位置及展示方式也也不一样 mybatis-plus:configuration:log-impl: org.apache.ibatis.logging.stdout.StdOutImpl # sql只会打印到控制台不会输出到日志文件种mybatis-plus:configuration:log-impl…

链路追踪Skywalking应用实战

目录 1 Skywalking应用2 agent下载3 agent应用3.1 应用名配置3.2 IDEA集成使用agent3.3 生产环境使用agent 4 Rocketbot4.1 Rocketbot-仪表盘4.2 Rocketbot-拓扑图4.3 追踪4.4 性能分析4.5 告警4.5.1 警告规则详解4.5.2 Webhook规则4.5.3 自定义Webhook消息接收 1 Skywalking应…

微服务·架构组件之服务注册与发现-Nacos

微服务组件架构之服务注册与发现之Nacos Nacos服务注册与发现流程 服务注册:Nacos 客户端会通过发送REST请求的方式向Nacos Server注册自己的服务,提供自身的元数据,比如ip地址、端口等信息。 Nacos Server接收到注册请求后,就会…

ChatGPT 案例实战趋势分析面积图制作

面积图使用HTML,JS,Echarts 来完成代码可以使用ChatGPT AIGC 来实现代码编写。 完整的代码复制如下: <!DOCTYPE html> <html> <head><meta charset="utf-8"><script src="https://cdn.bootcdn.net/ajax/libs/echarts/5.2.2/echa…

EasyPOI处理excel、CSV导入导出

1 简介 使用POI在导出导出excel、导出csv、word时代码有点过于繁琐&#xff0c;好消息是近两年在开发市场上流行一种简化POI开发的类库&#xff1a;easyPOI。从名称上就能发现就是为了简化开发。 能干什么&#xff1f; Excel的快速导入导出,Excel模板导出,Word模板导出,可以…

【kubernetes】Harbor部署及KubeSphere使用私有仓库Harbor

私有仓库Harbor https://goharbor.io/ 内容学习于马士兵云原生课程 Harbor部署 部署docker及docker-compose 略 获取Harbor安装文件 https://github.com/goharbor/harbor/releases/download/v2.4.1/harbor-offline-installer-v2.4.1.tgz tar -zxvf harbor-offline-installe…

线程安全缓存ConcurrentLinkedHashMap,Kotlin

线程安全缓存ConcurrentLinkedHashMap&#xff0c;Kotlin LinkedHashMap实现LRU缓存cache机制&#xff0c;Kotlin_zhangphil的博客-CSDN博客* * 基于Java LinkedList,实现Android大数据缓存策略 * 作者&#xff1a;Zhang Phil * 原文出处&#xff1a;http://blog.csdn.net/zha…

无法将类型为“Newtonsoft.Json.Linq.JObject”的对象转换为类型“Newtonsoft.Json.Linq.JArray”解决方法

对于“Newtonsoft.Json.Linq.JObject”的对象强制类型转换为类型“Newtonsoft.Json.Linq.JArray”报错 第一的图为对象{“*************”:“********”} 第二个图片为数组[{“…”:“…”}] 在我这里进行强制转换对象转换为类型“Newtonsoft.Json.Linq.JArray”报错. 那我们…

【java】【项目实战】[外卖十一]项目优化(Ngnix)

目录 一、Nginx概述 1、Nginx介绍 2、Nginx下载和安装 3、Nginx目录结构 二、Nginx命令 1、查看版本 2、检查配置文件正确性 3、启动和停止 4、重新加载配置文件 三、Nginx配置文件结构 1、全局块 2、events块 3、http块 四、Nginx具体应用 1、部署静态资源 2、…

mysql基于AES_ENCRYPTAES_DECRYPT实现密码的加密与解密

1.直接使用AES_ENCRYPT&&AES_DECRYPT函数导致的问题。 执行语句 select AES_ENCRYPT(cd123,key) 结果 加密过后的字符串是一串很奇怪的字符。 尝试使用上面加密过后的字符解密。 select AES_DECRYPT(u5£d|#,key) 结果 并未成功的解密 2.解决办法 使用 hex(…

kubernetes——RBAC鉴权

简介 基于角色的访问控制(RBAC)是一种基于组织中各个用户的角色来调节对计算机或网络资源的访问的方法。 目的&#xff1a;防止k8s里的pod&#xff08;会运行程序&#xff09;能随意获取整个集群里的信息和访问集群里的资源 概念 Rule&#xff1a;规则&#xff0c;一组属于…

图解SQL查询之模糊查询技巧:如何使用LIKE对数据进行筛选

模糊查询是一种特殊的条件查询方式&#xff0c;它允许根据模式匹配来查找符合特定条件的数据。在 SQL 中&#xff0c;我们使用 LIKE 关键字来进行模糊查询。在 LIKE 模糊查询中&#xff0c;有两种常用的通配符&#xff1a; 百分号&#xff08;%&#xff09;&#xff1a;表示任…

合宙Air724UG LuatOS-Air LVGL API控件-图片 (Image)

图片 (Image) 图片IMG是用于显示图像的基本对象类型&#xff0c;图像来源可以是文件&#xff0c;或者定义的符号。 示例代码 -- 创建图片控件 img lvgl.img_create(lvgl.scr_act(), nil) -- 设置图片显示的图像 lvgl.img_set_src(img, "/lua/luatos.png") -- 图片…

SpringMVC:从入门到精通

一、SpringMVC是什么 SpringMVC是Spring提供的一个强大而灵活的web框架&#xff0c;借助于注解&#xff0c;Spring MVC提供了几乎是POJO的开发模式【POJO是指简单Java对象&#xff08;Plain Old Java Objects、pure old java object 或者 plain ordinary java object&#xff0…

社区团购新玩法,生鲜蔬菜配货发货小程序商城

在当前的电商市场中&#xff0c;生鲜市场具有巨大的潜力和发展空间。为了满足消费者的需求&#xff0c;许多生鲜店正在寻找创新的方法来提高销售和客户满意度。其中&#xff0c;制作一个个性且功能强大的生鲜小程序商城是一个非常有效的策略。以下是在乔拓云平台上制作生鲜小程…