使用自己的数据集Fine-tune PaddleHub预训练模型

news2024/10/1 3:20:46

使用自己的数据Fine-tune PaddleHub预训练模型

果农需要根据水果的不同大小和质量进行产品的定价,所以每年收获的季节有大量的人工对水果分类的需求。基于人工智能模型的方案,收获的大堆水果会被机械放到传送带上,模型会根据摄像头拍到的图片,控制仪器实现水果的自动分拣,节省了果农大量的人力。


图5:水果在工厂传送带上自动分类

下面我们就看看如果采集到少量的桃子数据,如何基于PaddleHub对ImageNet数据集上预训练模型进行Fine-tune,得到一个更有效的模型。桃子分类数据集取自AI Studio公开数据集桃脸识别,该桃脸识别数据集中已经将所有桃子的图片分为2个文件夹,一个是训练集一个是测试集;每个文件夹中有4个分类,分别是B1、M2、R0、S3。


图6:自动分类结果示意

实现迁移学习,包括如下步骤:

  1. 安装PaddleHub
  2. 数据准备
  3. 模型准备
  4. 训练准备

下面将根据这四个主要步骤,展示如何利用PaddleHub实现finetune。

1. 安装PaddleHub

paddlehub安装可以使用pip完成安装,如下:

# 安装并升级PaddleHub,使用百度源更稳定、更迅速
pip install paddlehub==2.1 -i https://mirror.baidu.com/pypi/simple

2. 数据准备

在本次教程提供的数据文件中,已经提供了分割好的训练集、验证集、测试集的索引和标注文件。如果用户利用PaddleHub迁移CV类任务使用自定义数据,则需要自行切分数据集,将数据集切分为训练集、验证集和测试集。需要三个文本文件来记录对应的图片路径和标签,此外还需要一个标签文件用于记录标签的名称。

├─data: 数据目录	
  ├─train_list.txt:训练集数据列表	
  ├─test_list.txt:测试集数据列表	
  ├─validate_list.txt:验证集数据列表	
  ├─label_list.txt:标签列表	
  └─……	

训练集、验证集和测试集的数据列表文件的格式如下,列与列之间以空格键分隔。

图片1路径 图片1标签	
图片2路径 图片2标签	
...

label_list.txt的格式如下:

分类1名称	
分类2名称	
...	

准备好数据后即可使用PaddleHub完成数据读取器的构建,实现方法如下所示:构建数据读取Python类,并继承paddle.io.Dataset这个类完成数据读取器构建。在定义数据集时,需要预先定义好对数据集的预处理操作,并且设置好数据模式。在数据集定义中,需要重新定义__init____getitem____len__三个部分。示例如下:

import os

import paddle
import paddlehub as hub


class DemoDataset(paddle.io.Dataset):
    def __init__(self, transforms, num_classes=4, mode='train'):	
        # 数据集存放位置
        self.dataset_dir = "./work/peach-classification"  #dataset_dir为数据集实际路径,需要填写全路径
        self.transforms = transforms
        self.num_classes = num_classes
        self.mode = mode

        if self.mode == 'train':
            self.file = 'train_list.txt'
        elif self.mode == 'test':
            self.file = 'test_list.txt'
        else:
            self.file = 'validate_list.txt'
        
        self.file = os.path.join(self.dataset_dir , self.file)
        with open(self.file, 'r') as file:
            self.data = file.read().split('\n')[:-1]
            
    def __getitem__(self, idx):
        img_path, grt = self.data[idx].split(' ')
        img_path = os.path.join(self.dataset_dir, img_path)
        im = self.transforms(img_path)
        return im, int(grt)


    def __len__(self):
        return len(self.data)

将训练数据输入模型之前,我们通常还需要对原始数据做一些数据处理的工作,比如数据格式的规范化处理,或增加一些数据增强策略。

构建图像分类模型的数据读取器,负责将桃子dataset的数据进行预处理,以特定格式组织并输入给模型进行训练。

如下数据处理策略,只做了两种操作:

  1. 指定输入图片的尺寸,并将所有样本数据统一处理成该尺寸。
  2. 对所有输入图片数据进行归一化处理。

对数据预处理及加载数据集的示例如下:

import paddlehub.vision.transforms as T

transforms = T.Compose(
        [T.Resize((256, 256)),
         T.CenterCrop(224),
         T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])],
        to_rgb=True)

peach_train = DemoDataset(transforms)
peach_validate =  DemoDataset(transforms, mode='val')

3. 模型准备

我们要在PaddleHub中选择合适的预训练模型来Fine-tune,由于桃子分类是一个图像分类任务,这里采用Resnet50模型,并且是采用ImageNet数据集Fine-tune过的版本。这个预训练模型是在图像任务中的一个“万金油”模型,Resnet是目前较为有效的处理图像的网络结构,50层是一个精度和性能兼顾的选择,而ImageNet又是计算机视觉领域公开的最大的分类数据集。所以,在不清楚选择什么模型好的时候,可以优先以这个模型作为baseline。

使用PaddleHub,不需要重新手写Resnet50网络,可以通过一行代码实现模型的调用。

#安装预训练模型
! hub install resnet50_vd_imagenet_ssld==1.1.0
import paddlehub as hub

model = hub.Module(name='resnet50_vd_imagenet_ssld', label_list=["R0", "B1", "M2", "S3"])

4. 训练准备

定义好模型,也准备好数据后,我们就可以开始设置训练的策略。Paddle2.2提供了多种优化器选择,如SGD, Adam, Adamax等。

from paddlehub.finetune.trainer import Trainer

import paddle

optimizer = paddle.optimizer.Adam(learning_rate=0.001, parameters=model.parameters())
trainer = Trainer(model, optimizer, checkpoint_dir='img_classification_ckpt', use_gpu=True) 
trainer.train(peach_train, epochs=10, batch_size=16, eval_dataset=peach_validate, save_interval=1)

#打印
[2023-02-25 10:08:53,462] [   TRAIN] - Epoch=1/10, Step=10/375 loss=0.9796 acc=0.6250 lr=0.001000 step/sec=1.85 | ETA 00:33:46
[2023-02-25 10:08:54,244] [   TRAIN] - Epoch=1/10, Step=20/375 loss=0.6388 acc=0.7625 lr=0.001000 step/sec=12.78 | ETA 00:19:20
[2023-02-25 10:08:55,029] [   TRAIN] - Epoch=1/10, Step=30/375 loss=0.5733 acc=0.7375 lr=0.001000 step/sec=12.75 | ETA 00:14:31
[2023-02-25 10:08:55,827] [   TRAIN] - Epoch=1/10, Step=40/375 loss=0.2518 acc=0.9062 lr=0.001000 step/sec=12.53 | ETA 00:12:08
[2023-02-25 10:08:56,615] [   TRAIN] - Epoch=1/10, Step=50/375 loss=0.1935 acc=0.9250 lr=0.001000 step/sec=12.69 | ETA 00:10:41
[2023-02-25 10:08:57,428] [   TRAIN] - Epoch=1/10, Step=60/375 loss=0.1949 acc=0.9375 lr=0.001000 step/sec=12.31 | ETA 00:09:45
[2023-02-25 10:08:58,238] [   TRAIN] - Epoch=1/10, Step=70/375 loss=0.1502 acc=0.9563 lr=0.001000 step/sec=12.34 | ETA 00:09:05
[2023-02-25 10:08:59,023] [   TRAIN] - Epoch=1/10, Step=80/375 loss=0.1275 acc=0.9500 lr=0.001000 step/sec=12.73 | ETA 00:08:34
[2023-02-25 10:08:59,807] [   TRAIN] - Epoch=1/10, Step=90/375 loss=0.1811 acc=0.9187 lr=0.001000 step/sec=12.76 | ETA 00:08:09

其中Adam:

  • learning_rate: 全局学习率。默认为1e-3;
  • parameters: 待优化模型参数。

运行配置

Trainer 主要控制Fine-tune的训练,包含以下可控制的参数:

  • model: 被优化模型;
  • optimizer: 优化器选择;
  • use_gpu: 是否使用gpu;
  • use_vdl: 是否使用vdl可视化训练过程;
  • checkpoint_dir: 保存模型参数的地址;
  • compare_metrics: 保存最优模型的衡量指标;

trainer.train 主要控制具体的训练过程,包含以下可控制的参数:

  • train_dataset: 训练时所用的数据集;
  • epochs: 训练轮数;
  • batch_size: 训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
  • num_workers: works的数量,默认为0;
  • eval_dataset: 验证集;
  • log_interval: 打印日志的间隔, 单位为执行批训练的次数。
  • save_interval: 保存模型的间隔频次,单位为执行训练的轮数。

当Fine-tune完成后,我们使用模型来进行预测,实现如下:

import paddle
import paddlehub as hub

result = model.predict(['./work/peach-classification/test/M2/0.png'])
print(result)

# 打印:
[{'M2': 0.99999964}]

以上为加载模型后实际预测结果(这里只测试了一张图片),返回的是预测的实际效果,可以看到我们传入待预测的是M2类别的桃子照片,经过Fine-tune之后的模型预测的效果也是M2,由此成功完成了桃子分类的迁移学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/370285.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Jina 3.14 版本发布!支持独立部署Executor

Jina 是一个 MLOps 框架,赋能开发者在云上构建多模态、跨模态的应用程序。Jina 能够将 PoC 提升为生产就绪服务。基础设施的复杂性交给 Jina,开发者能够直接轻松使用高级解决方案和云原生技术。🌟 GitHubhttps://github.com/jina-ai/jina/rel…

【华为OD机试模拟题】用 C++ 实现 - 敏感字段加密(2023.Q1)

最近更新的博客 华为OD机试 - 入栈出栈(C++) | 附带编码思路 【2023】 华为OD机试 - 箱子之形摆放(C++) | 附带编码思路 【2023】 华为OD机试 - 简易内存池 2(C++) | 附带编码思路 【2023】 华为OD机试 - 第 N 个排列(C++) | 附带编码思路 【2023】 华为OD机试 - 考古…

Dubbo 源码解读:负载均衡策略

概览 org.apache.dubbo包下META-INF/dubbo/internal/org.apache.dubbo.rpc.cluster.LoadBalance中内部spi实现类有以下几种: randomorg.apache.dubbo.rpc.cluster.loadbalance.RandomLoadBalance roundrobinorg.apache.dubbo.rpc.cluster.loadbalance.RoundRobinL…

这次说说腾讯的一场 35K—55K 的 Android 高工面试

一、面试的由来 事情是这样的,因为跟公司发展一些想法的不同,早在十月份的时候就有了跳槽的想法,但是碍于老大的面子就一直就没有跟人事说出口,打算着等到年后金三银四在试试跳槽。 但是发生一件事终于让我忍不住了,…

阿里限量出产Elasticsearch学习手册,确定不心动?

前言只有光头才能变强。不知道大家的公司用Elasticsearch多不多,反正我公司的是有在用的。平时听同事们聊天肯定避免不了不认识的技术栈,例如说:把数据放在引擎,从引擎取出数据等等。如果对引擎不了解的同学,就压根听不…

一,下载iPerf3最新源代码

本文目录普通下载方式:git下载方式:普通下载方式: 如果你只是要阅读源代码,不涉及到编译安装修改源代码,那么可以简单的通过此方式下载代码。如果你希望编译安装修改源代码,那么建议通过git来进行源代码的…

软件测试之测试模型

软件测试的发展 1960年代是调试时期(测试即调试) 1960年 - 1978年 论证时期(软件测试是验证软件是正确的)和 1979年 - 1982年 破坏性测试时期(为了发现错误而执行程序的过程) 1983年起,软件测…

计算机网络笔记、面试八股(三)—— HTTPS协议

本章目录3. HTTPS协议3.1 HTTPS协议简介3.2 SSL/TLS协议3.2.1 SSL/TLS功能的实现3.3 HTTP和HTTPS的区别3.4 HTTPS协议的优点3.5 HTTPS协议的缺点3.6 HTTPS协议的工作流程3.7 HTTPS是如何解决HTTP的缺点的3.7.1 解决内容可能被窃听的问题——加密3.7.1.1 方法1.对称加密3.7.1.2 …

go cobra初试

cobra开源地址 https://github.com/spf13/cobra cobra是什么 Cobra is a library for creating powerful modern CLI applications. Cobra is used in many Go projects such as Kubernetes, Hugo, and GitHub CLI to name a few. This list contains a more extensive lis…

web中git漏洞的形成的原理及使用

目录 1.Git漏洞的成因 1.不正确的权限设置: 2.代码注入漏洞: 3.未经身份验证的访问: 4.非安全传输: 5.跨站脚本攻击(XSS): 2.git泄露环境的搭建 git init: git add&#xff1…

JWT(token) 认证

JSON Web Token(缩写 JWT)是目前最流行的跨域认证解决方案,本文介绍它的原理和用法。一、跨域认证的问题互联网服务离不开用户认证。一般流程是下面这样。1、用户向服务器发送用户名和密码。2、服务器验证通过后,在当前对话&#…

计算机网络笔记、面试八股(二)——HTTP协议

本章目录2. HTTP协议2.1 HTTP协议简介2.2 HTTP协议的优点2.3 HTTP协议的缺点2.4 HTTP协议属于哪一层2.5 HTTP通信过程2.6 常见请求方法2.7 GET和POST的区别2.8 请求报文与响应报文2.8.1 HTTP请求报文2.8.2 HTTP响应报文2.9 响应状态码2.10 HTTP 1.0和1.1的区别2.10.1 长连接2.1…

Lecture6 逻辑斯蒂回归(Logistic Regression)

目录 1 常用数据集 1.1 MNIST数据集 1.2 CIFAR-10数据集 2 课堂内容 2.1 回归任务和分类任务的区别 2.2 为什么使用逻辑斯蒂回归 2.3 什么是逻辑斯蒂回归 2.4 Sigmoid函数和饱和函数的概念 2.5 逻辑斯蒂回归模型 2.6 逻辑斯蒂回归损失函数 2.6.1 二分类损失函数 2.…

3-1 图文并茂说明raid0,raid1, raid10, raid01, raid5等原理

文章目录简介RAID类型RAID0RAID1RAID5RAID6RAID10RAID01RAID对比图简介 一、RAID 是什么? RAID ( Redundant Array of Independent Disks )即独立磁盘冗余阵列,简称为「磁盘阵列」,其实就是用多个独立的磁盘组成在一起…

Jenkins第一讲

目录 一、Jenkins 1.1 敏捷开发与持续集成 1.1.1 敏捷开发 1.1.2 持续集成 1.2 持续集成工具 1.2.1 jenkins和hudson 1.2.2 技术组合 1.2.3 部署方式对比 1.3 安装Jenkins 1.3.1 下载Jenkins的war包 1.3.2 开启Jenkins 1.4 Jenkins全局安全配置 1.5 使用Jenkins部…

InfluxDB docker安装与界面的使用

influxdb github主页:https://github.com/influxdata/influxdb chronograf github主页:https://github.com/influxdata/chronograf Docker安装InfluxDB docker run -p 8086:8086 --name influxdb-dev influxdb:latest这里博主安装的是2.2.1版本 然后…

Python学习-----排序问题2.0(sort()函数和sorted()函数)

目录 前言: 1.sort() 函数 示例1:阿斯克码比较 示例2:(设置reverse,由大到小排序) 示例3:基于key排序(传入一个参数) 示例4:key的其他应用 2.sorted() …

平时技术积累很少,面试时又会问很多这个难题怎么破?别慌,没事看看这份Java面试指南,解决你的小烦恼!

前言技术面试是每个程序员都需要去经历的事情,随着行业的发展,新技术的不断迭代,技术面试的难度也越来越高,但是对于大多数程序员来说,工作的主要内容只是去实现各种业务逻辑,涉及的技术难度并不高&#xf…

Allegro如何画Photoplot_Outline操作指导

Allegro如何画Photoplot_Outline操作指导 在用Allegro进行PCB设计的时候,最后进行光绘输出前,Photoplot_Outline是必备一个图形,所有在Photoplot_Outline中的图形将被输出,Photoplot_Outline以外的图形都将不被输出。 如何绘制Photoplot_Outline,具体操作如下 点击Shape点…

视觉人培训团队把它称之为,工业领域人类最伟大的软件创造,它的名字叫Halcon

目前为止,世界上综合能力强大的机器视觉软件,,它的名字叫Halcon。 视觉人培训团队把它称之为,工业领域人类最伟大的软件创造,它的名字叫Halcon。 持续不断更新最新的图像技术,软件综合能力持续提升。 综…