使用huggingface微调预训练模型

news2024/9/24 7:22:05

官方教程:https://huggingface.co/docs/transformers/training

准备数据集(基于datasets库)

train.json 数据格式:

{"source":"你是谁?", "target":"我是恁爹"}
{"source":"你多大了?", "target":"我五百岁了哦!"}
{"source": "我想吃火锅。", "target": "好的,我帮您找附近的火锅店。"}
{"source": "明天天气如何?", "target": "明天有雨,记得带伞。"}

加载数据集:

from datasets import load_dataset

data_files = {"train": "train.json", "validation": "validation.json"}
dataset = load_dataset('./data', data_files=data_files)

Tokenize 数据集:

def tokenize_function(example):
    encoded = tokenizer(example["source"], truncation=True, padding="max_length", max_length=128)
    # seq2seq模型需:with tokenizer.as_target_tokenizer():
    encoded["labels"] = tokenizer(example["target"], truncation=True, padding="max_length", max_length=128)["input_ids"]
    return encoded

# batched=True 可批量处理数据
tokenized_dataset = dataset.map(encode_data, batched=True)
print(tokenized_dataset.column_names)

{‘train’: [‘source’, ‘target’, ‘input_ids’, ‘attention_mask’, ‘labels’], ‘validation’: [‘source’, ‘target’, ‘input_ids’, ‘attention_mask’, ‘labels’]}

字段更名(仅供学习):

tokenized_dataset = tokenized_dataset.rename_column("source","sources")
tokenized_dataset

DatasetDict({
train: Dataset({
features: [‘sources’, ‘target’, ‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 28
})
validation: Dataset({
features: [‘sources’, ‘target’, ‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 28
})
})

移除某列(仅供学习):

tokenized_dataset = tokenized_dataset.remove_columns(["sources","target"])
tokenized_dataset

DatasetDict({
train: Dataset({
features: [‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 28
})
validation: Dataset({
features: [‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 28
})
})

随机打乱并取子集(仅供学习):

small_train_dataset = tokenized_dataset["train"].shuffle(seed=42).select(range(3))
small_eval_dataset = tokenized_dataset["validation"].shuffle(seed=42).select(range(3))
small_train_dataset

Dataset({
features: [‘input_ids’, ‘attention_mask’, ‘labels’],
num_rows: 3
})

DataLoader 定义(data_collator)

类似 torch.utils.data.DataLoadercollate_fn,用来处理训练集、验证集。官方提供了下面这些 Collator:
在这里插入图片描述

上一小节 tokenize_function 函数的作用是将原始数据集中的每个样本编码为模型可接受的输入格式,包括对输入和标签的分词、截断和填充等操作,最终返回一个包含 input_idslabels 的字典。它主要用于 dataset.map 函数中对数据集进行转换。

DataCollatorForLanguageModeling 类的作用是将一批样本组合成一个训练用的 mini-batch,它会将每个样本的 input_ids 合并成一个大的矩阵,每个样本的 attention_mask 合并成一个大的矩阵,每个样本的 labels 合并成一个大的向量。它还会计算出训练时需要用到的特殊的 mask 矩阵,用于在训练时计算 loss。在使用 Trainer 训练模型时,DataCollatorForLanguageModeling 类需要传入到 Trainerdata_collator 参数中,它会在每个训练步骤中将训练数据组合成 mini-batch,并且计算 loss 和其他统计信息。

在这里插入图片描述

模型训练(基于 Trainer)

在这里插入图片描述

通过 TrainingArguments 设置超参:

from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results/train_xxxx',  # 保存模型和日志的目录
    num_train_epochs=10,  # 训练轮数
    per_device_train_batch_size=4,  # 训练时每个 GPU 上的 batch size
    per_device_eval_batch_size=4,  # 验证时每个 GPU 上的 batch size
    warmup_steps=50,  # 学习率 warmup 步数
    learning_rate=2e-5,  # 初始学习率
    logging_dir='./logs',  # 日志保存目录
    logging_steps=50,  # 每隔多少步打印一次训练日志
    evaluation_strategy='epoch',  # 在哪些时间步骤上评估性能:'no', 'steps', 'epoch'
    save_total_limit=3,  # 保存的模型数量上限
    save_strategy='epoch', # 模型保存策略,'steps':每隔多少步保存一次,'epoch':每个epoch保存一次
    gradient_accumulation_steps=2,  # 每多少个 batch 合并为一个,等于期望的 batch size / per_device_train_batch_size
)

通过 Trainer 加载训练器

from transformers import Trainer

# 定义 Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer
)

# 开始训练
trainer.train()

但上面只能根据验证集的 loss 评估,如果需要针对任务设置指标,则参考如下:
在这里插入图片描述
因此,需要定义一个 compute_metrics 方法,用于计算任务指标(可以用 evaluate 库),并传给 Trainer。
其中 predictions 是通过 trainer
这样,训练时就会输出 compute_metrics 中自定义的指标了:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/376596.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FSP:Flow of Solution Procedure (CVPR 2017) 原理与代码解析

paper:A Gift From Knowledge Distillation: Fast Optimization, Network Minimization and Transfer Learningcode:https://github.com/HobbitLong/RepDistiller/blob/master/distiller_zoo/FSP.py背景深度神经网络DNN逐层生成特征。更高层的特征更接近…

内存数据库的设计与实现(已在大型项目中应用)

一、概况 1、设计总图 组成,由Redis集群缓存,普通缓存,传统数据库,各类数据驱动 2、内存数据库的增删改查,分页查询 组成,由数据查询,分页查询,数据存储,数据修改,数据删除 3、内存数据库的驱动 组成,由驱动适配器,普通缓存驱动,Redis缓存驱动 4、内存数据库与…

C++常见类型及占用内存表

GPS生产厂家在定义数据的时候都会有一定的数据类型,例如double、int、float等,我们知道它们在内存中都对应了一定的字节大小,而我在实际使用时涉及到了端序的问题(大端序高字节在前,小端序低字节在前)&…

redis主从同步:如何实现数据一致

Redis 提供了主从库模式,以保证数据副本的一致,主从库之间采用的是读写分离的方式。读操作:主库、从库都可以接收;写操作:首先到主库执行,然后,主库将写操作同步给从库。和mysql差不多。但是同步…

自动驾驶专题介绍 ———— 毫米波雷达

文章目录介绍工作原理特点性能参数应用厂家介绍 毫米波雷达是工作在毫米波波段探测的雷达,与普通雷达相似,是通过发射无线电信号并接收反射信号来测量物体间的距离。毫米波雷达工作频率为30~300GHz(波长为1 - 10mm),波长介于厘米波和光波之间…

【数据挖掘实战】——家用电器用户行为分析及事件识别(BP神经网络)

项目地址:Datamining_project: 数据挖掘实战项目代码 目录 一、背景和挖掘目标 1、问题背景 2、原始数据 3、挖掘目标 二、分析方法与过程 1、初步分析 2、总体流程 第一步:数据抽取 第二步:探索分析 第三步:数据的预处…

为什么负责任的技术始于数据治理

每个组织都处理数据,但并非每个组织都将其数据用作业务资产。但是,随着数据继续呈指数级增长,将数据视为业务资产正在成为竞争优势。 埃森哲的一项研究发现,只有 33% 的公司“足够信任他们的数据,能够有效地使用它并从…

色环电阻的阻值如何识别

这种是色环电阻,其外表有一圈圈不同颜色的色环,现在在一些电器和电源电路中还有使用。下面的两种色环电阻它颜色还不一样,一个蓝色,一个土黄色,其实这个蓝色的属于金属膜色环电阻,外表涂的是一层金属膜&…

Qt新手入门指南 - 如何创建模型/视图(四)

每个UI开发人员都应该了解ModelView编程,本教程的目标是为大家提供一个简单易懂的介绍。Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写,所有平台无差别运行,更提供了几乎所有开发过程中需要用到的工具。如今&#xff…

AJAX介绍及其应用

1.1 AJAX 简介 AJAX全称为 Asynchronous JavaScript and XML ,就是异步的js和xml。通过AJAX可以在浏览器中向服务器发送异步请求,最大的优势,无刷新获取数据。AJAX不是新的编程语言,而是一种现有的标准组合再一起使用的新方式 应…

scanpy 单细胞分析API接口使用案例

参考:https://zhuanlan.zhihu.com/p/537206999 https://scanpy.readthedocs.io/en/stable/api.html scanpy python包主要分四个模块: 1)read 读写模块、 https://scanpy.readthedocs.io/en/stable/api.html#reading 2)pp Prepr…

springBoot自动装配原理探究springBoot配置类Thymeleaf模板引擎

微服务 微服务是一种架构风格,由于单体架构不利于团队协作完成并且代码量较大,后期维护成本较高,逐渐有了微服务架构。微服务是将一个项目拆分成不同的服务,各个服务之间相互独立互不影响,互相通过轻量级机制通信比如…

(转载)STM32与LAN9252构建EtherCAT从站

目录 (一):项目简介 EtherCAT及项目简述 LAN9252工作模式 整体开发流程 移植要处理的问题 代码层面的工作 开发中使用的工具 (二):SSC的使用 SSC简介和下载 SSC构建协议栈文件和XML &#xff08…

爬虫数据解析-正则表达式

数据解析-正则表达式 正则表达式 正则编写规则简介 字符含义.匹配除换行符以外的任意字符|A|B表示:匹配正则表达式条件A或B^匹配字符串的开始(在集合[]里表示"非")的意思$匹配字符串的结束{n}重复n次{,n}重复小于n次{n,}重复n次或更多次{n,…

2023软件测试金三银四常见的软件测试面试题-【抓包和网络协议篇】

八、抓包与网络协议 8.1 抓包工具怎么用 我原来的公司对于抓包这块,在App的测试用得比较多。我们会使用fiddler抓取数据检查结果,定位问题,测试安全,制造弱网环境; 如:抓取数据通过查看请求数据,请求行&…

经验 // 指标异常了怎么办?

本文参考了数据万花筒的文章,结合我自己工作经验。希望给大家一些帮助。 指标异常排查,是数据分析师的工作重点之一,是各行各业数据分析师都绕不开的话题。 本文试图回答: 1、指标波动的影响因素有哪些? 2、如何快速…

Web3中文|泰勒·斯威夫特演唱会票务闹乌龙,NFT票务急需普及

2022年底,美国艺人Taylor Swift(泰勒斯威夫特)的2023年巡回演唱会Eras Tour门票开始出票。作为当今世界最受欢迎的流行歌手之一,四年多没举办大型巡演无疑积攒了大量的粉丝需求。但是在2022年11月15日开放预售的当天,售…

数据驱动下的物种保护,拯救生命的“特效药”

如果给出这样      一张猎豹的图片      我们能否通过图中有限的信息      判断它的年龄、健康状况      以及所属族群?      如果你是一名研究动物的专家,你可能会从其花纹和斑点中获取一定量的信息,但对于大多数人以及一线的动物保护者来说,它可能只是一…

imx6ull——I2C驱动

I2C基本介绍 SCL 为高电平,SDA 出现下降沿:起始位 SCL 位高电平,SDA出现上升沿:停止位 主机——从机地址(ack)——寄存器地址(ack)——数据(ack) 重点:先是写&#xff0c…

context.Context

context.Context前言一、为什么要context二、context有什么用三、基本数据结构3.1、context包的整体工作机制3.2 基本接口和结构体3.3 API函数3.4 辅助函数3.5 context用法3.6 使用 context 传递数据的争议总结参考资料前言 context是go语言的一个并发包,一个标准库…