HuggingFace (transformers) 自定义图像数据集、使用 DeiT 模型、Trainer 进行训练回归任务

news2025/1/16 2:54:21

资料

Hugging Face 官方文档:https://huggingface.co/
Hugging Face 代码链接:https://github.com/huggingface/transformers

1. 环境准备

  1. 创建 conda 环境
  2. 激活 conda 环境
  3. 下载 transformers 依赖
  4. 下载 transformers 中需要处理数据集的依赖
  5. 下载 pytorch 依赖,因为这里使用的 transformers 是基于 PyTorch 实现的,所以需要导入 pytorch 依赖
  6. 下载 tensorboard 依赖。训练过程中,使用 TensorBoard 可视化
conda create -n hugging python=3.7 
conda activate hugging
conda install -c huggingface transformers
conda install datasets
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
conda install tensorboard
  1. 打开 PyCharm,配置 Interpreter
    依次点击:File -> Settings:
    在这里插入图片描述
    然后选择刚才创建的 conda 环境
    在这里插入图片描述

2 任务及数据集描述

需求说明:有一个视线估计任务,输入为人脸图像,输出为该人脸图像在手机屏幕上的注视点坐标 (x, y)。

数据集的目录结构如下:

\GazeCapture_new
	-- Image
		-- 00002
			-- face
				-- 00000.jpg
				-- 00001.jpg
				-- .....
			-- grid
				-- .....
			-- left
				-- ....
			-- right
				-- .....
		-- 00003
			-- face
				-- .....
			-- grid
				-- .....
			-- left
				-- ....
			-- right
				-- .....
		-- ......
	-- Label
		-- train
			-- 00002.label
			-- .....
		-- test
			-- 03024.label
			-- .....
		-- val
			-- ......

每一个标签文件中的内容,如 00002.label 存储的内容

Face Left Right Grid Xcam, Ycam Xdot, Ydot Device
00002\face\00000.jpg 00002\left\00000.jpg 00002\right\00000.jpg 00002\grid\00000.jpg 1.064,-6.0055 160,284 iPhone6
00002\face\00001.jpg 00002\left\00001.jpg 00002\right\00001.jpg 00002\grid\00001.jpg 1.064,-6.0055 160,284 iPhone6
00002\face\00002.jpg 00002\left\00002.jpg 00002\right\00002.jpg 00002\grid\00002.jpg 1.064,-6.0055 160,284 iPhone6
00002\face\00003.jpg 00002\left\00003.jpg 00002\right\00003.jpg 00002\grid\00003.jpg 1.064,-6.0055 160,284 iPhone6
.......
  • Face 表示脸部图片的存储路径。
  • Left 表示左眼图片的存储路径。
  • Right 表示右眼图片的存储路径。
  • Grid 表示网格图片的存储路径。
  • Xcam, Ycam 是标签,表示人脸图片对应的视线位置的 (x, y) 坐标,单位为厘米。 后续的训练过程使用这两个值作为标签。
  • Xdot, Ydot 表示人脸图片对应的视线位置的 (x, y) 坐标,单位为像素。
  • Device 表示采集设备型号。

如果想要使用我的数据集,先把代码跑通,这里提供我使用的部分数据集作为参考,但由于不是完整的数据集,所以训练效果不是很好,仅供跑通代码作为参考。
https://drive.google.com/file/d/1gM-wzkaEcnw0GEKQ2eedpYlvjuqhp3gA/view?usp=sharing

3. DataSet

!!!注意:Dataset 一定不要完全粘贴我的代码,一定要按照自己的数据集编写对应代码。只有以下几点需要和我一模一样:

  1. 自定义类继承 Dataset,自定义的类名可以自行命名。
  2. 重写 __init____len____getitem__这三个方法,方法内的具体逻辑根据自己的数据集修改。
  3. __getitem__ 方法的返回值形式一定要是 {"labels": xxx, "pixel_values": xxx}
import os.path

from torch.utils.data import Dataset
from transform import transform
import numpy as np

# 读取数据,如果是训练数据,随即打乱数据顺序
def get_label_list(label_path):
    # 存储所有标签文件中的所有内容
    full_lines = []
    # 获取所有标签文件的名称,如 00002.label, 00003.label, ......
    label_names = os.listdir(label_path)
    # 遍历每一个标签文件,并读取其中内容
    for label_name in label_names:
        # 标签文件全路径,如 D:\datasets\GazeCapture_new\Label\train\00002.label
        label_abs_path = os.path.join(label_path, label_name)
        # 读取每一个标签文件中的内容
        with open(label_abs_path) as flist:
            # 存储该标签文件中的所有内容
            full_line = []
            for line in flist:
                full_line.append(line.strip())
            # 移除首行表头 'Face Left Right Grid Xcam, Ycam Xdot, Ydot Device'
            full_line.pop(0)
            full_lines.extend(full_line)
    return full_lines


class GazeCaptureDataset(Dataset):
    def __init__(self, root_path, data_type):
        self.data_dir = root_path
        # 标签文件的根路径,如 D:\datasets\GazeCapture_new\Label\train
        label_root_path = os.path.join(root_path + '/Label', data_type)
        # 获取所有标签文件中的所有内容
        self.full_lines = get_label_list(label_root_path)
        # 每一行内容的分隔符
        self.delimiter = ' '
        # 数据集长度,也就是一共有多少个图片
        self.num_samples = len(self.full_lines)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 标签文件的一行,对应一个训练实例
        line = self.full_lines[idx]
        # 将标签文件中的一行内容按照分隔符进行分割
        Face, Left, Right, Grid, XYcam, XYdot, Device = line.split(self.delimiter)
        # 获取网络的输入:人脸图片
        face_path = os.path.join(self.data_dir + '/Image/', Face)
        # 读取人脸图像
        with open(face_path, 'rb') as f:
            img = f.read()
        # 将人脸图像进行格式转化:缩放、裁剪、标准化
        pixel_values = transform(img)
        # 获取标签值
        labels = np.array(XYcam.split(","), np.float32)
        # 注意返回值的形式一定要是 {"labels": xxx, "pixel_values": xxx}
        result = {"labels": labels}
        result["pixel_values"] = pixel_values
        return result

transform.py 工具类的代码如下:

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import cv2
from PIL import Image


# 定义decode_image函数,将图片转为Numpy格式r
def decode_image(img, to_rgb=True):
    data = np.frombuffer(img, dtype='uint8')
    img = cv2.imdecode(data, 1)
    if to_rgb:
        assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
            img.shape)
        img = img[:, :, ::-1]

    return img


# 定义resize_image函数,对图片大小进行调整
def resize_image(img, size=None, resize_short=None, interpolation=-1):
    interpolation = interpolation if interpolation >= 0 else None
    if resize_short is not None and resize_short > 0:
        resize_short = resize_short
        w = None
        h = None
    elif size is not None:
        resize_short = None
        w = size if type(size) is int else size[0]
        h = size if type(size) is int else size[1]
    else:
        raise ValueError("invalid params for ReisizeImage for '\
            'both 'size' and 'resize_short' are None")

    img_h, img_w = img.shape[:2]
    if resize_short is not None:
        percent = float(resize_short) / min(img_w, img_h)
        w = int(round(img_w * percent))
        h = int(round(img_h * percent))
    else:
        w = w
        h = h
    if interpolation is None:
        return cv2.resize(img, (w, h))
    else:
        return cv2.resize(img, (w, h), interpolation=interpolation)


# 定义crop_image函数,对图片进行裁剪
def crop_image(img, size):
    if type(size) is int:
        size = (size, size)
    else:
        size = size  # (h, w)

    w, h = size
    img_h, img_w = img.shape[:2]
    w_start = (img_w - w) // 2
    h_start = (img_h - h) // 2

    w_end = w_start + w
    h_end = h_start + h
    return img[h_start:h_end, w_start:w_end, :]


# 定义normalize_image函数,对图片进行归一化
def normalize_image(img, scale=None, mean=None, std=None, order= ''):
    if isinstance(scale, str):
        scale = eval(scale)
    scale = np.float32(scale if scale is not None else 1.0 / 255.0)
    mean = mean if mean is not None else [0.485, 0.456, 0.406]
    std = std if std is not None else [0.229, 0.224, 0.225]

    shape = (3, 1, 1) if order == 'chw' else (1, 1, 3)
    mean = np.array(mean).reshape(shape).astype('float32')
    std = np.array(std).reshape(shape).astype('float32')

    if isinstance(img, Image.Image):
        img = np.array(img)
    assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
    # 对图片进行归一化
    return (img.astype('float32') * scale - mean) / std


# 定义to_CHW_image函数,对图片进行通道变换,将原通道为‘hwc’的图像转为‘chw‘
def to_CHW_image(img):
    if isinstance(img, Image.Image):
        img = np.array(img)
    # 对图片进行通道变换
    return img.transpose((2, 0, 1))


# 图像预处理方法汇总
def transform(data, mode='train'):

    # 图像解码
    data = decode_image(data)
    # 图像缩放
    data = resize_image(data, resize_short=224)
    # 图像裁剪
    data = crop_image(data, size=224)
    # 标准化
    data = normalize_image(data, scale=1./255., mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    # 通道变换
    data = to_CHW_image(data)
    return data

4. 训练

from transformers import TrainingArguments
from transformers import DeiTForImageClassification
from torch import nn
from transformers import Trainer
from transformers import DeiTConfig
from dataset import GazeCaptureDataset

# 数据集根路径
root_path = r"D:\datasets\GazeCapture_new"
# 1.定义 Dataset
train_dataset = GazeCaptureDataset(root_path, data_type='train')
val_dataset = GazeCaptureDataset(root_path, data_type='val')

# 2.定义 DeiT 图像模型
'''
num_labels 表示图像的输出值为 2,即 (x, y) 两个坐标值
problem_type="regression" 表示任务是回归任务
'''
configuration = DeiTConfig(num_labels=2, problem_type="regression")
model = DeiTForImageClassification(configuration)


# 3.训练
## 3.1 训练参数
'''
output_dir:模型预测和 checkpoint 的输出目录。
evaluation_strategy 训练过程中采用的验证策略。可能的取值有:
    "no": 训练过程中不验证
    "steps": 在每个 eval_steps 中执行(并记录)验证。
    "epoch": 在每个 epoch 结束时进行验证。
eval_steps=100:每 100 次训练执行一次验证。
per_device_train_batch_size/per_device_eval_batch_size:用于训练/验证的 batch size。
logging_dir:TensorBoard 日志目录。默认为 *output_dir/runs/CURRENT_DATETIME_HOSTNAME*。
logging_steps=50:每隔 50 步写入 TensorBoard
save_strategy 训练期间采用的 checkpoint 保存策略。可能取值为:
    "no": 训练期间不保存 checkpoint
    "epoch": 每个 epoch 结束后保存 checkpoint
    "steps": 每个 save_steps 结束后保存 checkpoint
save_steps=100:每 100 次训练保存一次 checkpoint
'''
training_args = TrainingArguments(output_dir="gaze_trainer",
                                  evaluation_strategy="steps",
                                  eval_steps=100,
                                  per_device_train_batch_size=2,
                                  per_device_eval_batch_size=2,
                                  logging_dir='./logs',
                                  logging_steps=50,
                                  save_strategy="steps",
                                  save_steps=100)
## 3.2 自定义 Trainer
class RegressionTrainer(Trainer):
    # 重写计算 loss 的函数
    def compute_loss(self, model, inputs, return_outputs=False):
        # 获取标签值
        labels = inputs.get("labels")
        # 获取输入值
        x = inputs.get("pixel_values")
        # 模型输出值
        outputs = model(x)
        logits = outputs.get('logits')
        # 定义损失函数为平滑 L1 损失
        loss_fct = nn.SmoothL1Loss()
        # 计算输出值和标签的损失
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

## 3.3 定义Trainer对象:
trainer = RegressionTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)

## 3.4 开始训练:
trainer.train()

更多 Trainer 参数参考:https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments

5. 查看 Tensorboard

在这里插入图片描述

在当前工程目录下,打开命令行,执行

(hugging) PS D:\PycharmProjects\hugging> tensorboard --logdir ./logs

然后打开浏览器,访问 http://localhost:6006/ ,即可看到训练过程的 TensorBoard 可视化结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/166121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

win10录屏文件在哪?如何找到录制后的文件

在工作和学习中,我们会遇到需要使用录屏工具录制电脑屏幕的情况,很多小伙伴在录制完win10电脑屏幕之后,找不到录制的视频文件。win10录屏文件在哪?今天小编教大家如何找到电脑录屏文件和录制win10电脑屏幕的方法,如果您…

带你认识QOwnNotes

导读QOwnNotes 是一款自由而开源的笔记记录和待办事项的应用,可以运行在 Linux、Windows 和 mac 上。这款程序将你的笔记保存为纯文本文件,它支持 Markdown 支持,并与 ownCloud 云服务紧密集成。 QOwnNotes 的亮点就是它集成了 ownCloud 云服…

数据量大也不卡的bi软件有哪些?

用过数据分析软件的都知道,很多的软件在数据量不算特别大的时候还好,分析效率、响应速度都不慢,但一旦使用的数据量超过一定范围,系统就会明显变慢,甚至崩溃。随着企业业务的发展扩张,数据分析的精细化&…

Linksys WRT路由器刷入OpenWrt与原厂固件双固件及切换

Linksys路由器OpenWrt与原厂固件双固件刷入及切换双固件机制使用原厂固件刷其他固件使用原厂固件切换启动分区使用OpenWrt刷入Sysupgrade使用OpenWrt刷入Img使用OpenWrt切换分区通用的硬切换分区(三次重启)双固件机制 新机器默认有一个原厂固件&#xf…

详解分布式系统核心概念——CAP、CP和AP

最近研究Sykwalking,当调研 oap如何进行集群部署时发现:skywalking oap 之间本身不能搭建集群,需要一个集群管理器来组建集群,它支持nacos、zookeeper、Kubernetes、Consul、Etcd 五种集群管理器。我重点比较了nacos和zookeeper&a…

python中的闭包和装饰器

目录 一.闭包 1.闭包的用途和用法 简单闭包 2.nonlocal关键字的作用 ATM闭包实现 注意事项 小结 二.装饰器 装饰器的一般写法(闭包写法) 装饰器的语法糖写法 一.闭包 1.闭包的用途和用法 先看如下代码: 通过全局变量account_amount来…

【Python学习】条件和循环

前言 往期文章 【Python学习】列表和元组 【Python学习】字典和集合 条件控制 简单来说:当判断的条件为真时,执行某种代码逻辑,这就是条件控制。 那么在讲条件控制之前,可以给大家讲一个程序员当中流传的比较真实的一个例子…

CUDA规约算法(加和)

1.block内相邻元素规约(线程不连续) 上图为1个block内的16个线程的操作示意: 第0个线程会和第1,2,4,8发生关系 第2个线程会和第3个线程发生关系 第4个线程会和第5,6个线程发生关系 ... 以上…

这7个网络设备配置接口基本参数要牢记,从此接口相关配置不用怕!

本文给大家介绍网络设备配置接口基本参数,包括接口描述信息、接口流量统计时间间隔功能以及开启或关闭接口。 进入接口视图 背景信息 对接口进行基本配置前,需要进入接口视图。 操作步骤 执行命令system-view,进入系统视图。执行命令inte…

Widget小组件

目录 技能点 Widget背调 a. 设计定位 b. Widget小组件限制 c. Widget小组件 开发须知 d. 什么是 SwiftUI App Group 数据共享 a. 配置 App Groups 1、开发者账号配置,并更新pp证书 2、Xcode配置 b. 缓存数据共享-代码实现 1、文件存储 2. 沙盒存储&…

【MySQL】运算符及相关函数详解

序号系列文章3【MySQL】MySQL基本数据类型4【MySQL】MySQL表的七大约束5【MySQL】字符集与校对集详解6【MySQL】MySQL单表操作详解文章目录前言MySQL运算符1,算术运算符1.1,算术运算符的基本使用1.2,常用数学函数的基本使用2,比较…

vulnhub DC系列 DC-7

总结:社工尝试 目录 下载地址 漏洞分析 信息收集 ssh webshell 命令执行 提权 下载地址 DC-7.zip (Size: 939 MB)Download: http://www.five86.com/downloads/DC-7.zipDownload (Mirror): https://download.vulnhub.com/dc/DC-7.zip漏洞分析 信息收集 这里还…

代码随想录算法训练营第13天 239.滑动窗口最大值、347. 前 K 个高频元素

代码随想录算法训练营第13天 239.滑动窗口最大值、347. 前 K 个高频元素 滑动窗口最大值 力扣题目链接(opens new window) 给定一个数组 nums,有一个大小为 k 的滑动窗口从数组的最左侧移动到数组的最右侧。你只可以看到在滑动窗口内的 k 个数字。滑动窗口每次只…

YonBuilder 应用构建教程之移动端扩展

YonBuilder 移动端扩展 在上一篇文章中,我们通过对员工信息实体的移动端页面构建来对 YonBuilder 移动端配置的基础流程进行了简单的介绍,本篇文章则通过之前搭建的出入库实体来进行扩展,主要介绍如何在移动端中添加跳转页面的功能以及通过函…

大连理工大学(开发区校区)2023年新生赛(验题人题解)

难度分布 根据排行榜情况,大致分布如下: Easy:AIDE Middle:CJF Hard:GBH 题解 A. Hello World.(题意实现) 直接输出Hello world. I. lgl想签到(题意实现) 统计周…

组件优化 - 多project方案

背景 经销商项目目前是混合项目,有oc、swift、flutter,并对应各自的一些三方库,并随着需求的增加,项目代码体积也越来越大,编译速度也相应的慢了很多,这也严重影响了开发速度,故目前的期望是可…

Linux:git工具

文章目录一.git的下载二.如何使用git将代码传到远端仓库2.1在gitee上新建一个仓库2.2克隆仓库到本地git clone2.3将文件添加到本地仓库git add2.4将代码提交到本地仓库git commit -m2.5将本地仓库的内容传到远端仓库中git push三.git的一些其它使用3.1git log查看日志3.2git rm…

【魅力开源】第5集:通过Odoo实现将EXCEL表费用明细,快速导入到ERP总账系统生成凭证

文章目录前言一、拿到这样的一张表二、实现过程1. 控制器(Controller)2. 模型(Model)3. 视图(View)4. 返回生成的凭证号最后前言 这是一个小功能。 财务小姐姐每个月需要不少的时间去手录费用凭证,这个功能可以实现将半天一天时间内完成的事情,在1小时内…

204:vue+openlayers 学习Attribution各种API,示例展示自定义版权信息

第204个 点击查看专栏目录 本示例的目的是介绍如何在vue+openlayers项目中个性化修改版权信息,这里主要涉及到Attribution各种属性的设置,所以这里先列出属性的信息,然后用示例来展示如何使用。 名称类型说明classNamestring (默认为“ol-attribution”)CSS 类名。targetH…

Acwing---1219.移动距离

移动距离1.题目2.基本思想3.代码实现1.题目 X星球居民小区的楼房全是一样的,并且按矩阵样式排列。 其楼房的编号为 1,2,3… 当排满一行时,从下一行相邻的楼往反方向排号。 比如:当小区排号宽度为 6 时,开始情形如下&#xff1a…