微调Hugging Face中图像分类模型

news2025/1/11 6:11:16

前言

  • 本文主要针对Hugging Face平台中的图像分类模型,在自己数据集上进行微调,预训练模型为Googlevit-base-patch16-224模型,模型简介页面。
  • 代码运行于kaggle平台上,使用平台免费GPU,型号P100,笔记本地址,欢迎大家copy & edit
  • Github项目地址,Hugging Face模型微调文档

依赖安装

  • 如果是在本地环境下运行,只需要同时安装3个包就好transformersdatasetsevaluate,即pip install transformers datasets evaluate
  • 在kaggle中因为accelerate包与环境冲突,所以需要从项目源进行安装,即:
import IPython.display as display
! pip install -U git+https://github.com/huggingface/transformers.git
! pip install -U git+https://github.com/huggingface/accelerate.git
! pip install datasets
display.clear_output()
  • 因为安装过程中会产生大量输出,所以使用display.clear_output()清空jupyter notebook的输出。

数据处理

  • 这里使用kaggle中的图像分类公共数据集,5 Flower Types Classification Dataset,数据结构如下:
 - flower_images
	 - Lilly
		 - 000001.jpg
		 - 000002.jpg
		 - ......
	 - Lotus
		 - 001001.jpg
		 - 001002.jpg
		 - ......
	 - Orchid
	 - Sunflower
  • 可以看到flower_images为主文件夹,Lilly,Lotus,Orchid,Sunflower为各类花的种类,每类花的图片数量均为1000张
  • 微调模型图像的数据集读取与加载需要使用datasets包中的load_dataset函数,有关该函数的文档
from datasets import load_dataset
from datasets import load_metric
# 加载本地数据集
dataset = load_dataset("imagefolder", data_dir="/kaggle/input/5-flower-types-classification-dataset/flower_images")
# 整合数据标签与下标
labels = dataset["train"].features["label"].names

label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

metric = load_metric("accuracy")
display.clear_output()
  • 如果想要查看图片,可以使用image来访问
example = dataset["train"][0]
example['image'].resize((224, 224))

请添加图片描述

  • 确定想要进行微调的模型,加载其配置文件,这里选择vit-base-patch16-224,关于transfromers包中的AutoImageProcessor类,from_pretrained方法,请参见文档
from transformers import AutoImageProcessor
model_checkpoint = "google/vit-base-patch16-224"
batch_size = 64
image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor 
  • 根据vit-base-patch16-224预训练模型图像标准化参数标准化微调数据集,都是torchvision库中的一些常见变换,这里就不赘述了,重点是preprocess_trainpreprocess_val函数,分别用于标准化训练集与验证集。
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch
  • 划分数据集,并分别将训练集与验证集进行标准化
# 划分训练集与测试集
splits = dataset["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

display.clear_output()

微调模型

  • 加载预训练模型使用transformers包中AutoModelForImageClassification类,from_pretrained方法,参考文档
  • 需要注意的是ignore_mismatched_sizes参数,如果你打算微调一个已经微调过的检查点,比如google/vit-base-patch16-224(它已经在ImageNet-1k上微调过了),那么你需要给from_pretrained方法提供额外的参数ignore_mismatched_sizes=True。这将确保输出头(有1000个输出神经元)被扔掉,由一个新的、随机初始化的分类头取代,其中包括自定义数量的输出神经元。你不需要指定这个参数,以防预训练的模型不包括头。
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(model_checkpoint, 
                                                        label2id=label2id,
                                                        id2label=id2label,
                                                        ignore_mismatched_sizes = True)
display.clear_output()
  • 配置训练参数由TrainingArguments函数控制,该函数参数较多,参考文档
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"{model_name}-finetuned-eurosat",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    save_total_limit = 5,
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=1,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",)
  • 我解释一下上面出现的一些参数
    • output_dir:模型预测和检查点的输出目录
    • remove_unused_columns:是否自动删除模型转发方法未使用的列
    • evaluation_strategy: 在训练期间采用的评估策略
    • save_strategy:在训练期间采用的检查点保存策略
    • save_total_limit:限制检查点的总数,删除较旧的检查点
    • learning_rateAdamW优化器的初始学习率
    • per_device_train_batch_size:训练过程中GPU/TPU/CPU核心batch大小
    • gradient_accumulation_steps:在执行向后/更新传递之前累积梯度的更新步数
    • per_device_eval_batch_size:评估过程中GPU/TPU/CPU核心batch大小
    • num_train_epochs:要执行的训练时期总数
    • warmup_ratio:用于学习率从0到线性预热的总训练步数的比率
    • logging_steps:记录steps间隔数
    • load_best_model_at_end:是否在训练结束时加载训练期间找到的最佳模型
    • metric_for_best_model:指定用于比较两个不同模型的指标
  • 制定评估指标函数
import numpy as np
import torch

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}
  • 传递训练配置,准备开始微调模型,Trainer函数,参考文档
trainer = Trainer(model,
                  args,
                  train_dataset=train_ds,
                  eval_dataset=val_ds,
                  tokenizer=image_processor,
                  compute_metrics=compute_metrics,
                  data_collator=collate_fn,)
  • 同样的,我解释一下上面的一些参数
    • model:训练、评估或用于预测的模型
    • args:调整训练的参数
    • train_dataset:用于训练的数据集
    • eval_dataset:用于评估的数据集
    • tokenizer:用于预处理数据的标记器
    • compute_metrics:将用于在评估时计算指标的函数
    • data_collator:用于从train_dataseteval_dataset的元素列表形成批处理的函数
  • 开始训练,并在训练完成后保存模型权重,模型训练指标变化,模型最终指标。
train_results = trainer.train()
# 保存模型
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
  • 在训练过程中可选择使用wandb平台对训练过程进行实时监控,但需要注册一个账号,获取对应api,个人推荐使用,当然也可以ctrl+q选择退出。
  • 训练输出:
Epoch	Training Loss	Validation Loss	Accuracy
1	0.384800	0.252986	0.948000
2	0.174000	0.094400	0.968000
3	0.114500	0.070972	0.978000
4	0.106000	0.082389	0.972000
5	0.056300	0.056515	0.982000
6	0.044800	0.058216	0.976000
7	0.035700	0.060739	0.978000
8	0.068900	0.054247	0.980000
9	0.057300	0.058578	0.982000
10	0.067400	0.054045	0.980000
11	0.067100	0.051740	0.978000
12	0.039300	0.069241	0.976000
13	0.029000	0.056875	0.978000
14	0.027300	0.063307	0.978000
15	0.038200	0.056551	0.982000
16	0.016900	0.053960	0.984000
17	0.021500	0.049470	0.984000
18	0.031200	0.049519	0.984000
19	0.030500	0.051168	0.984000
20	0.041900	0.049122	0.984000
***** train metrics *****
  epoch                    =         20.0
  total_flos               = 6494034741GF
  train_loss               =       0.1092
  train_runtime            =   0:44:01.61
  train_samples_per_second =       34.062
  train_steps_per_second   =        0.538

wandb平台指标可视化

请添加图片描述

请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述
请添加图片描述

评估模型

metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

输出:

***** eval metrics *****
  epoch                   =       20.0
  eval_accuracy           =      0.984
  eval_loss               =      0.054
  eval_runtime            = 0:00:11.18
  eval_samples_per_second =     44.689
  eval_steps_per_second   =      0.715

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/651755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【惯性导航】隧道、高架桥、高楼林立弱信号环境室外定位_惯导模块

汽车行驶在路上,视野可能会受到周边的树木、同行的卡车、城市楼群的遮挡,卫星导航系统容易受到周围环境的影响,例如树木楼房等,造成多路径效应,使得定位结果精度降低甚至丢失,尤其是在隧道或者室内环境中&a…

树莓派开Samba协议和Windows电脑共享资料

文章目录 1. 前言2. 树莓派安装和配置Samba2.1. 更新源2.2. 安装Samba软件2.3. 修改Samba配置文件2.4. 重启Samba服务2.5. 添加用户到Samba 3. Windows访问共享目录3.1. 查看树莓派的地址3.2. 打开这个IP地址 4. 报错4.1. 用户名或者密码不正确 1. 前言 虽然出门派很方便&…

C++算法————二分查找

又是鸽了三千万年 马上要打csp了,开始回流学j组的知识了,浅说一下二分吧() --------------------------------------------------------------------------------------------------------------------------------- 二分查找 …

tsx写法

1.安装插件 npm install vitejs/plugin-vue-jsx -D vite.config.ts 配置 import { defineConfig } from vite import vue from vitejs/plugin-vue import vueJsx from vitejs/plugin-vue-jsx; // https://vitejs.dev/config/ export default defineConfig({plugins: [vue(),v…

TC15WProteus仿真DS18B20温度采集报警控制系统STC15W4K32S4

STC15WProteus仿真DS18B20温度采集报警控制系统STC15W4K32S4 Proteus仿真小实验: STC15WProteus仿真DS18B20温度采集报警控制系统STC15W4K32S4 功能: 硬件组成:STC15W4K32S4单片机 LCD1602显示器DS18B20温度传感器蜂鸣器 1.单片机读取DS18…

数据链路层(MAC)、网络层(IP)、传输层(TCP/UDP)抓包分析

目录 OSI七层模型数据包逐层封装头部抓包分析数据包概况数据链路层抓包网络层抓包(IP协议抓包)UDP抓包数据负载抓包 Linux cooked-mode capture OSI七层模型 OSI模型(OSI model),开放式系统互联通信参考模型&#xff…

【读书笔记】《小王子》- [法] 安托万•德•圣埃克苏佩里 / [法国] 安东尼·德·圣-埃克苏佩里

文章目录 Chapter 01Chapter 02Chapter 03Chapter 04Chapter 05Chapter 06Chapter 07Chapter 08Chapter 09 Chapter 01 Chapter 02 “因为我住的地方非常小…” 想起了陀思妥耶夫斯基书中的一句话,“要爱具体的人,不要爱抽象的人;要爱生活本…

给开发者的ChatGPT提示词工程指南

ChatGPT Prompt Engineering for Development 基本大语言模型和指令精调大语言模型的区别: 指令精调大语言模型经过遵从指令的训练,即通过RLHF(基于人类反馈的强化学习)方式在指令上精调过,因而更加有帮助&#xff0…

OverLeaf(LaTeX在线编辑器)制作项目周报

目录 注册及登录 1、登录官网 2、设置语言 制作周报 1、新建项目 2、整体预览 3、分段解析 3.1 页面大小/页边距 3.2 页眉页脚 3.3 标题样式 3.4 内容分栏显示 3.5 调整行间距 3.6 插入图片 3.7 图片和文字排版 3.8 分页 3.9 标题和内容 4、编译和导出 4.1 编…

hutool poi、apache poi实现导入导出以及解析excel

一、前言 看了例子之后后续需要更加深入学习或者更多理解其他API的话,建议看官方文档。hutool项目是中国人维护的,有中文文档,阅读起来很方便。apache poi比较底层一点,可以更加自由去二次开发自己所需的功能。 hutool官方文档 …

zkML零知识机器学习介绍

1. 引言 零知识证明技术的2大基石为: 1)succinctness:相比于直接运行整个计算本身,验证该计算完整性证明要简单很多。2)zero-knowledge:可在不泄露计算隐私的情况下,证明计算的完整性。 生成…

【Java入门】-- Java基础详解之 [数组、冒泡排序]

目录 1.为什么需要数组? 2.数组的介绍 3.数组的快速入门 4.数组的使用 5.动态初始化 6.静态初始化 7.数组的细节 8.数组的赋值机制 9.数组拷贝 10.数组反转 11.二维数组 12.冒泡排序 1.为什么需要数组? 有五个学生,他们英语成绩…

探索不同学习率对训练精度和Loss的影响

验证精度、验证Loss的影响 1 问题 在探索mnist数据集过程中,学习率的不同,对我们的实验结果,各种参数数值的改变有何变化,有何不同。 学习率对精度和损失的影响研究。训练周期100学习率 [0.1, 0.01, 0.001, 0.0001](1) 不同学习率…

蓝牙网关Gateway_数据采集,连接控制,室内定位VDB2602

蓝牙网关,内部集成了WiFi、蓝牙、4G等多种无线通信方式,因此也继承了蓝牙、WiFi的有扫描功能、连接功能、数据透传功能,被应用于智能家居的各种场景中,例如:远程控制BLE装置,接收BLE设备发送的数据&#xf…

线程的创建和使用(一)

1、线程 1.1、线程的概念 一个线程就是一个 "执行流". 每个线程之间都可以按照顺讯执行自己的代码. 多个线程之间 "同时" 执行着多份代码. 1.2、创建线程 方法一:继承Thread类 public class Exe_01 {public static void main(String[] args…

pandas与pyspark计算效率对比

日常工作中,主要还是应用HQL和SparkSQL,数据量大,分布式计算很快; 本地数据处理,一般会使用python的pandas包,api丰富,写法比较简单,但只能利用单核性能跑数,数据量大可…

【MySQL入门】-- 数据库简单的SELECT语句详解

目录 1.SQL分类 2.注释 3.数据导入指令 4.基本的SELECT语句 5.列的别名 6.去重复行 7.显示表结构 8.一些数据库基本操作 1.SQL分类 SQL语言在功能上主要分为三大类: DDL(Data Defintion Language)数据定义语言:定义不同的数据库,表…

【C#】并行编程实战:任务并行性(中)

本章继续介绍任务并行性,因篇幅所限,本章为中篇。 4、取消任务 .NET Framework 提供了以下两个类来支持任务取消: CancellationTokenSource :此类负责创建取消令牌,并将取消请求传递给通过源创建的所有令牌。 Cancell…

关于xinput1_3.dll丢失的详细解决方法

xinput1_3.dll是电脑文件中的dll文件(动态链接库文件)。如果计算机中丢失了某个dll文件,可能会导致某些软件和游戏等程序无法正常启动运行,并且导致电脑系统弹窗报错。 在我们打开软件或者游戏的时候,电脑提示xinput1_…

8、共享模型之工具

目录 8.1 线程池2、ThreadPoolExecutor(及其重要)1) 线程池状态2) 构造方法3) newFixedThreadPool4) newCachedThreadPool5) newSingleThreadExecutor6) 提交任务7) 关闭线程池8) 任务调度线程池 8.1 线程池 2、ThreadPoolExecutor(及其重要…