基于paddlex图像分类模型训练(二):训练自己的分类模型、熟悉官方demo

news2024/11/15 18:25:11

0. 前言

相关系列博文:基于paddlex图像分类模型训练(一):图像分类数据集切分:文件夹转化为imagenet训练格式

代码在线运行:

https://aistudio.baidu.com/aistudio/projectdetail/5440569

1. 官方demo:6类蔬菜分类

在这里插入图片描述

1.1 百度6类蔬菜数据集下载(各200张,共1200)

import paddlex as pdx
from paddlex import transforms as T

# 下载和解压蔬菜分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')

下载后的数据集结构

aistudio@jupyter-40397-5440569:~/work/vegetables_cls$ tree -L 1
.
├── bocai
├── changqiezi
├── hongxiancai
├── huluobo
├── labels.txt
├── test_list.txt
├── train_list.txt
├── val_list.txt
├── xihongshi
└── xilanhua
6 directories, 4 files

1.2 训练

原始代码:https://github.com/PaddlePaddle/PaddleX/blob/develop/tutorials/train/image_classification/mobilenetv3_small.py

import paddlex as pdx
from paddlex import transforms as T

# 下载和解压蔬菜分类数据集
veg_dataset = 'https://bj.bcebos.com/paddlex/datasets/vegetables_cls.tar.gz'
pdx.utils.download_and_decompress(veg_dataset, path='./')

# 定义训练和验证时的transforms
# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
train_transforms = T.Compose(
    [T.RandomCrop(crop_size=224), T.RandomHorizontalFlip(), T.Normalize()])

eval_transforms = T.Compose([
    T.ResizeByShort(short_size=256), T.CenterCrop(crop_size=224), T.Normalize()
])

# 定义训练和验证所用的数据集
# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
train_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/train_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=train_transforms,
    shuffle=True)

eval_dataset = pdx.datasets.ImageNet(
    data_dir='vegetables_cls',
    file_list='vegetables_cls/val_list.txt',
    label_list='vegetables_cls/labels.txt',
    transforms=eval_transforms)

# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/visualdl.md
num_classes = len(train_dataset.labels)
model = pdx.cls.MobileNetV3_small(num_classes=num_classes)

# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/classification.md
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/tree/develop/docs/parameters.md
model.train(
    num_epochs=10,
    train_dataset=train_dataset,
    train_batch_size=32,
    eval_dataset=eval_dataset,
    lr_decay_epochs=[4, 6, 8],
    learning_rate=0.01,
    save_dir='output/mobilenetv3_small',
    use_vdl=True)

训练结果 (百度aistudio )

在这里插入图片描述

1.3 预测

为了验证实用性,从百度随意下载两张图片


'''
代码来源:
https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/prediction.md
'''
import paddlex as pdx
test_jpg = 'fanqie.jpg'
model = pdx.load_model('output/mobilenetv3_small/best_model/')
result = model.predict(test_jpg)
print("Predict Result: ", result)
# Predict Result:  [{'category_id': 4, 'category': 'xihongshi', 'score': 0.7541489}]

在这里插入图片描述
在这里插入图片描述

2. 训练自己的动漫分类模型

2.1 数据集

在这里插入图片描述

2.2 训练

代码参考:新增超轻量分类模型PPLCNet,在Intel CPU上,单张图像预测速度约5ms,ImageNet-1K数据集上Top1识别准确率达到80.82%,超越ResNet152的模型效果

import paddlex as pdx
from paddlex import transforms as T



# 定义训练和验证时的transforms
# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/transforms/transforms.md
train_transforms = T.Compose(
    [T.RandomCrop(crop_size=224), T.RandomHorizontalFlip(), T.Normalize()])

eval_transforms = T.Compose([
    T.ResizeByShort(short_size=256), T.CenterCrop(crop_size=224), T.Normalize()
])

# 定义训练和验证所用的数据集
# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/datasets.md
train_dataset = pdx.datasets.ImageNet(
    data_dir='anime_cls_2',
    file_list='anime_cls_2/train_list.txt',
    label_list='anime_cls_2/labels.txt',
    transforms=train_transforms,
    shuffle=True)

eval_dataset = pdx.datasets.ImageNet(
    data_dir='anime_cls_2',
    file_list='anime_cls_2/val_list.txt',
    label_list='anime_cls_2/labels.txt',
    transforms=eval_transforms)

# 初始化模型,并进行训练
# 可使用VisualDL查看训练指标,参考https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/visualdl.md
num_classes = len(train_dataset.labels)
model = pdx.cls.PPLCNet(num_classes=num_classes, scale=1)

# API说明:https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/apis/models/classification.md
# 各参数介绍与调整说明:https://github.com/PaddlePaddle/PaddleX/tree/develop/docs/parameters.md
model.train(
    num_epochs=10,
    pretrain_weights='IMAGENET',
    train_dataset=train_dataset,
    train_batch_size=16,
    eval_dataset=eval_dataset,
    lr_decay_epochs=[4, 6, 8],
    learning_rate=0.1,
    save_dir='output/pplcnet',
    log_interval_steps=10,
    label_smoothing=.1,
    use_vdl=True)

训练时间大约1分钟
在这里插入图片描述

2.3 预测

import paddlex as pdx
test_jpg = 'https://img1.baidu.com/it/u=642615975,3013253527&fm=253&fmt=auto&app=138&f=JPEG?w=501&h=500'
model = pdx.load_model('output/pplcnet/best_model/')
result = model.predict(test_jpg)
print("Predict Result: ", result)

附录

在线训练素材数据集

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/196103.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

spring框架之注解开发

Spring是轻代码而重配置的框架,配置比较繁重,影响开发效率,所以注解开发是一种趋势。 让我们来看看注解开发之前是如何定义bean的? ① BrandDemo.java ② applicationContext.xml ③Test.java 一、注解开发定义bean 组件扫描 二…

信用卡APP评测系列——工银e生活5.0打造个人生活服务平台,引领用户美好生活

易观:中国信用卡市场规模增速趋稳,线上成为存量用户经营主阵地, APP用户高质量经营成为新发力点,也是业务良性增长保障,对此,银行机构着力用户体验竞相升级信用卡APP。工商银行顺势升级工银e生活APP5.0版&a…

第二章.神经网络—3层神经网络的实现,输出层设计

第二章.神经网络 2.3 三层神经网络的实现 1.各层间信号传递的实现 1).示意图: 2).公式: ①.用数学式表示a1(1): ②.用矩阵表示第一层的加权和: 3).实现: import numpy as np# 3层神经网络的实现# 参数初始化 def i…

华数杯B题——校任务尝试

一、背景说明 根据影响社会稳定的因素,以及颜色革命,来衡量社会稳定性,判断社会风险 社会预警指标体系是由一系列经过理论遴选的敏感指标组成的一种测量社会危机现象及其运行过程的指标系统,它作为一种特定的测量工具和手段&…

Java设计模式--原型模式

概念:用原型实例(最初的)指定创建对象的种类,并且通过拷贝这些原型,创建新的对象。(自我复制能力)1.类图原理类图分析Prototype:原型类,声明一个克隆自己的接口ConcreteP…

Go编程规范和性能调优(三)——规范编码和性能优化

文章目录一、本次学习重点内容:二、详细知识点介绍:1、高质量编程简介什么是高质量?编程原则:2、编码规范注释:代码格式:命名规范变量:函数:package:错误和异常处理&…

关于yolov8的训练的一些改动

1、YOLOv8创新改进点: 1.1.Backbone 使用的依旧是CSP的思想,不过YOLOv5中的C3模块被替换成了C2f模块,实现了进一步的轻量化,同时YOLOv8依旧使用了YOLOv5等架构中使用的SPPF模块; 1.2.PAN-FPN 毫无疑问YOLOv8依旧使…

大文件传输软件的优势有哪些?-镭速传输

互联网时代,大数据传输是企业面临的必不可免的问题,可以选择传统的FTP、网盘等方式来传输,对于小型文件或许是有优势的;但是对于大型文件数据的话,也许会出现传输速度慢,数据不可靠的情况,极大的…

python3+requests+unittest:接口自动化测试(一)

简单介绍框架的实现逻辑,参考代码的git地址: GitHub - zhangying123456/python_unittest_interface: pythonunittest接口自动化测试脚本 1.环境准备 python3 pycharm编辑器 2.框架目录展示 (该套代码只是简单入门,有兴趣的可…

Nginx——Keepalived的原理与配置

摘要 Keepalived的作用是检测服务器的状态,如果有一台web服务器宕机,或工作出现故障,Keepalived将检测到,并将有故障的服务器从系统中剔除, 同时使用其他服务器代替该服务器的工作,当服务器工作正常后Keep…

python求解带约束的优化问题

带约束的优化问题可被定义为: 在python中,可以使用scipy的optimize包进行求解,具体求解函数为linprog,下面举例说明求解方法: 假设问题被定义为: 首先,求解最大值问题,我们可以通…

Spring Security 源码解读 :认证总览

Spring Security 提供如下几种认证机制: Username & PasswordOAuth2.0 LoginSAML 2.0 LoginRemember MeJAAS AuthenticationPre-authentication ScenariosX509 Authentication 这里使用Spring Boot 2.7.4版本,对应Spring Security 5.7.3版本 Serv…

LeetCode题目笔记——1588. 所有奇数长度子数组的和

文章目录题目描述题目难度——简单方法一:暴力代码/C代码/Python方法二:前缀和代码/C代码/Python总结题目描述 给你一个正整数数组 arr ,请你计算所有可能的奇数长度子数组的和。 子数组 定义为原数组中的一个连续子序列。 请你返回 arr 中…

MySql性能优化(六)索引监控

文章目录索引监控Handler_read_firstHandler_read_keyHandler_read_lastHandler_read_nextHandler_read_prevHandler_read_rndHandler_read_rnd_next索引监控 SHOW STATUS LIKE Handler_read%解释一下各个参数的含义 Handler_read_first 通过index获取数据的次数 Handler_r…

在cmd中遍历局域网内的IP命令解析

简单的方法 1,直接通过浏览器访问路由器,通过路由器的页面查看。2,网络中很多扫描网络的软件,3,自己使用cmd命令查看 有时候自己也觉得,有简单的方式还用这麻烦的干嘛。但遇到不知道路由的登录密码呢&…

Djiango零基础-快速了解基本框架笔记-附案例

初识Djiango 1. 安装djiango pip install django4.1 -i https://mirrors.aliyun.com/pypi/simple/C:\python38- python.exe- Scripts- pip.exe- djiango-admin.exe 【工具,创建djiango项目】- Lib- 内置模块- site-packages- openpyxl- python-docx- flask- djia…

IPV6实验(2.3)

目标: 一、首先将r2、r3、r4这个公网先弄通 [r2]int gi 0/0/0 [r2-GigabitEthernet0/0/0]ip add 23.1.1.1 24 [r3]int gi 0/0/0 [r3-GigabitEthernet0/0/0]ip add 23.1.1.2 24 [r3-GigabitEthernet0/0/0]int gi 0/0/1 [r3-GigabitEthernet0/0/1]ip add 34.1.1.1 2…

YOLO的学习

如何评价Alexey Bochkovskiy团队提出的YoloV7? - 知乎 1, Selective Search,RCNN和FasterRCNN 机器视觉(CV) 超简指南 选择性搜索 Selective Search_哔哩哔哩_bilibili 【精读RCNN】03选择性搜索,selective search_哔哩哔哩_bilibili …

win10系统安装

系统安装 文章目录系统安装1.工具下载2.制作启动盘3. win 10镜像下载4.进入PE系统1.工具下载 需要准备一个至少16 GB的U盘,工具下载链接 U盘:https://share.weiyun.com/aHhPh16e 迅雷:https://dl.xunlei.com/ win 10 镜像链接&#xff1a…

大咖说·计算讲谈社|当我们在谈目标时,究竟在谈什么?

本讲内容,节选自阿里巴巴研究员吴翰清(道哥)面向团队的内部讲话,经删减整理后,作为【计算讲谈社】第十六讲公开分享。 讲师介绍 吴翰清(道哥):阿里巴巴研究员,阿里巴巴、…