基于Optuna的transformers模型自动调参

news2025/1/18 13:56:12

文章目录

  • 一、导入相关包
  • 二、加载数据集
  • 三、划分数据集
  • 四、数据集预处理
  • 五、创建模型(区别一)
  • 六、创建评估函数
  • 七、创建 TrainingArguments(区别二)
  • 八、创建 Trainer(区别三)
  • 九、模型训练
  • 十、模型训练(自动搜索)(区别四)
  • 启动 tensorboard

  • 以文本分类为例

六、Trainer和文本分类
image.png
image.png


一、导入相关包

!pip install transformers datasets evaluate accelerate
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

二、加载数据集

dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset
'''
Dataset({
    features: ['label', 'review'],
    num_rows: 7765
})
'''

三、划分数据集

datasets = dataset.train_test_split(test_size=0.1)
datasets
'''
DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 777
    })
})
'''

四、数据集预处理

import torch

tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(examples):
    tokenized_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenized_examples["labels"] = examples["label"]
    return tokenized_examples

tokenized_datasets = datasets.map(process_function, batched=True, 
                                  remove_columns=datasets["train"].column_names)
tokenized_datasets
'''
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 6988
    })
    test: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 777
    })
})
'''

五、创建模型(区别一)

def model_init():
    model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")
    return model

六、创建评估函数

import evaluate

acc_metric = evaluate.load("accuracy")
f1_metirc = evaluate.load("f1")
def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metirc.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return acc

七、创建 TrainingArguments(区别二)

  • logging_steps=500为了防止多次训练 log 太多可以增大 logging_steps
train_args = TrainingArguments(output_dir="./checkpoints",      # 输出文件夹
                               per_device_train_batch_size=64,  # 训练时的batch_size
                               per_device_eval_batch_size=128,  # 验证时的batch_size
                               logging_steps=500,               # log 打印的频率
                               evaluation_strategy="epoch",     # 评估策略
                               save_strategy="epoch",           # 保存策略
                               save_total_limit=3,              # 最大保存数
                               learning_rate=2e-5,              # 学习率
                               weight_decay=0.01,               # weight_decay
                               metric_for_best_model="f1",      # 设定评估指标
                               load_best_model_at_end=True)     # 训练完成后加载最优模型

八、创建 Trainer(区别三)

  • 没有指定 model而是指定 model_init
from transformers import DataCollatorWithPadding
trainer = Trainer(model_init=model_init, 
                  args=train_args, 
                  train_dataset=tokenized_datasets["train"], 
                  eval_dataset=tokenized_datasets["test"], 
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)


# 之前
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model,
                  args=train_args,
                  train_dataset=tokenized_datasets["train"],
                  eval_dataset=tokenized_datasets["test"],
                  data_collator=DataCollatorWithPadding(tokenizer=tokenizer),
                  compute_metrics=eval_metric)

九、模型训练

trainer.train()

十、模型训练(自动搜索)(区别四)

!pip install optuna
  • 使用默认的超参数空间
  • compute_objective=lambda x: x["eval_f1"]中的 x是指的评价函数的返回值,在这里因为没有显示的指定评价函数返回值的 key,所以 f1key采用默认值 eval_f1
trainer.hyperparameter_search(compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)
  • 自定义超参数空间
    • 可以在default_hp_space_optuna 函数中增加 trainer 的选项
def default_hp_space_optuna(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "num_train_epochs": trial.suggest_int("num_train_epochs", 1, 5),
        "seed": trial.suggest_int("seed", 1, 40),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [4, 8, 16, 32, 64]),
        "optim": trial.suggest_categorical("optim", ["sgd", "adamw_hf"]),
    }

trainer.hyperparameter_search(hp_space=default_hp_space_optuna, compute_objective=lambda x: x["eval_f1"], direction="maximize", n_trials=10)

启动 tensorboard

  • 进入运行日志文件夹
    • 终端启动
!tensorboard --logdir runs
  • jupyter 启动
# 运行这行代码将加载 TensorBoard并允许我们将其用于可视化
%reload_ext tensorboard 
%tensorboard --logdir=./runs/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1215669.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

使用 Redis 构建轻量的向量数据库应用:图片搜索引擎(一)

本篇文章聊聊更轻量的向量数据库方案:Redis。 以及基于 Redis 来快速实现一个高性能的本地图片搜索引擎,在本地环境中,使用最慢的稠密向量检索方式来在一张万图片中查找你想要的图片,总花费时间都不到十分之一秒。 写在前面 接着…

C++进阶-STL 常用算法列举

STL 常用算法列举 概述常用遍历算法for_each 遍历容器transfrom 搬运容器到另一个容器中 常用查找函数find 查找元素find_if 按条件查找元素adjacent_find 查找相邻重复元素binary_search 二分查找法count 统计元素个数count_if 按条件统计元素个数 常用排序算法sort 对元素内内…

Linux C/C++全栈开发知识图谱(后端/音视频/游戏/嵌入式/高性能网络/存储/基础架构/安全)

众所周知,在所有的编程语言中,C语言是一门颇具学习难度,需要很长学习周期的编程语言。甚至很多人经常听到一句调侃的话语——“C,从入门到放弃”。 C界的知名书籍特别多,从简单到高端书籍,许多书籍都是C之…

8086与8088

一、8086与8088概述 8088/8086都是16位微处理器,内部运算器和寄存器都是16位的,同样具有20位地址线8088/8086都是由执行单元(EU)和总线接口部件(BIU)两大部分构成指令系统和寻址能力都相同,两种CPU是兼容的8088被称作准十六位的、是紧继8086…

为什么红帽Linux值得学习?红帽Linux是什么

为什么红帽Linux值得学习 物联网、云计算、大数据,或许你都耳熟能详,但是Linux可能你却感觉到有点陌生。这些未来趋势的行业使用的嵌入式、c、JAVA、PHP等底层应用软件都是在Linux操作系统上Linux运维工程师作为移动互联网的关键支撑岗位,缺口…

Redis内存淘汰机制

Redis内存淘汰机制 引言 Redis 启动会加载一个配置&#xff1a; maxmemory <byte> //内存上限 默认值为 0 (window版的限制为100M)&#xff0c;表示默认设置Redis内存上限。但是真实开发还是需要提前评估key的体量&#xff0c;提前设置好内容上限。 此时思考一个问题…

避坑指南!!在树莓派4b上安装Pycharm以及无法使用终端的问题解决!!

一、下载Pycharm–linux安装包 这里是踩的第一个坑&#xff0c;一开始我下载的是pycharm2023-linux版本的。后面发现缺少很多东西&#xff0c;安装不成功。后面下载了低版本的Pycharm才可以。我下载的是2020版本的。注意&#xff1a;在下载安装包时&#xff0c;一定要在window…

小米手机获取电池健康度

目录 方法一&#xff1a;使用Bug反馈功能 1. 打开拨号界面&#xff0c;输入*#*#284#*#* 2. 导出结果&#xff0c;等待即可 3. 找到这个压缩文件 4. 解压缩【我这里直接拷贝到电脑中操作&#xff0c;手机同理】 4.1 解压&#xff1a; 4.2 将得到的新的压缩文档解压 5. 打…

Stages—研发过程可视化建模和管理平台

产品概述 Stages是美国UL Solutions旗下UL Method Park GmbH的产品&#xff0c;用于帮助企业定义、管理、发布、控制、优化其研发过程&#xff0c;同时使其研发过程符合CMMI、ASPICE、ISO26262等标准。Stages的核心理念是把过程理论和实际项目进行有机结合。Stages聚焦于研发过…

web环境实现一键式安装启动

部署的痛点 一般在客户环境安装web环境&#xff0c;少说需要花费1-2小时。一般需要安装jdk、nginx、mysql、redis等 等你接触到了inno setup &#xff0c;你有可能会节约更少的时间去部署。也有可能是一个不懂技术的人&#xff0c;都可以进行操作的。废话不多说&#xff0c;接…

JS原生-弹框+阿里巴巴矢量图

效果&#xff1a; 代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content&q…

LeetCode - 160. 相交链表(C语言,配图)

思路&#xff1a; 1. 我们算出两个链表的长度lenA&#xff0c;lenB。我们在这里最后判断一下&#xff0c;两个链表的尾节点是否是一样的&#xff0c;如果是相交链表&#xff0c;那它们的尾节点一定是一样的。 2. 算出长链表和短链表的差距n&#xff08;n | lenA- lenB |&#…

【分享课】11月16日晚19:30PostgreSQL分享课:PG缓存管理器主题

PostsreSQL分享课分享主题: PG缓存管理器主题 直播分享平台&#xff1a;云贝教育视频号 时间&#xff1a;11月16日 周四晚 19: 30 分享内容: 缓冲区管理器结构 缓冲区管理器的工作原理 环形缓冲区 脏页的刷新

《洛谷深入浅出基础篇》 p3370字符串哈希——hash表

上链接&#xff1a; P3370 【模板】字符串哈希 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)https://www.luogu.com.cn/problem/P3370#submit上题干&#xff1a; 就是说&#xff0c;给你N个字符串&#xff0c;然后让你判断&#xff0c;这N个字符串里面有多少不同的字符串&am…

谭巍主任科普:面对HPV感染挑战,迈出关键一步!

在当今社会&#xff0c;HPV(人乳头瘤病毒)感染的问题日益凸显&#xff0c;它不仅限于性传播疾病&#xff0c;更上升为一种公共卫生问题。据统计&#xff0c;全球范围内每年有大量人群感染HPV&#xff0c;其中部分地区感染率高达50%。正因如此&#xff0c;谭巍主任投身于科普事业…

表单校验wed第十九章

常见的表单验证 一。表单选择器 属性过滤选择器 &#xff1a;selected 选中所有的下拉元素 &#xff1a;checked&#xff1a;选项元素 &#xff1a;disabled 不可用元素 &#xff1a;enable 所有可用元素 二。字符串演示 1.判断非空 isNaN(j) :判断是否为数字 2.表…

flutter绘制弧形进度条

绘制: import package:jade/utils/JadeColors.dart; import package:flutter/material.dart; import dart:math as math; import package:flutter_screenutil/flutter_screenutil.dart;class ArcProgressBar extends StatefulWidget{const ArcProgressBar({Key key}) : super…

蘑菇街获得mogujie商品详情 API 返回值说明

速卖通API接口是速卖通平台提供的一种数据交换接口&#xff0c;可以帮助卖家快速获取平台上的商品信息、订单信息、用户信息等数据&#xff0c;以便在自己的应用程序中进行展示、管理或分析。 速卖通API接口可以通过以下步骤进行使用&#xff1a; 注册速卖通账号并获取API密钥…

SystemVerilog学习 (7)——面向对象编程

一、概述 对结构化编程语言,例如Verilog和C语言来讲&#xff0c;它们的数据结构和使用这些数据结构的代码之间存在很大的沟壑。数据声明、数据类型与操作这些数据的算法经常放在不同的文件里,因此造成了对程序理解的困难。 Verilog程序员的境遇比C程序员更加棘手,因为 Verilog …

React父组件怎么调用子组件的方法

调用方法&#xff1a;1、类组件中的调用可以利用React.createRef()、ref的函数式声明或props自定义onRef属性来实现&#xff1b;2、函数组件、Hook组件中的调用可以利用useImperativeHandle或forwardRef抛出子组件ref来实现。 【程序员必备开发工具推荐】Apifox一款免费API管理…