Transformers x SwanLab:可视化NLP模型训练(2025最新版)

news2025/3/24 22:19:01

HuggingFace 的 Transformers 是目前最流行的深度学习训框架之一(100k+ Star),现在主流的大语言模型(LLaMa系列、Qwen系列、ChatGLM系列等)、自然语言处理模型(Bert系列)等,都在使用Transformers来进行预训练、微调和推理。

在这里插入图片描述

SwanLab 是一个深度学习实验管理与训练可视化工具,融合了Weights & Biases与Tensorboard的特点,能够方便地进行 训练可视化、多实验对比、超参数记录、大型实验管理和团队协作,并支持用网页链接的方式分享你的实验。

你可以使用Transformers快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。

transformers>=4.50.0 的版本,已官方集成了SwanLab
如果你的版本低于4.50.0,请使用SwanLabCallback集成。

1. 一行代码完成集成

只需要在你的训练代码中,找到TrainingArguments部分,添加report_to="swanlab"参数,即可完成集成。

from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    ...,
    report_to="swanlab" # [!code ++]
)

trainer = Trainer(..., args=args)

2. 自定义项目名

默认下,项目名会使用你运行代码的目录名

如果你想自定义项目名,可以设置SWANLAB_PROJECT环境变量:

::: code-group

import os
os.environ["SWANLAB_PROJECT"]="qwen2-sft"
export SWANLAB_PROJECT="qwen2-sft"
set SWANLAB_PROJECT="qwen2-sft"

:::

3. 案例代码:Bert文本分类

import evaluate
import numpy as np
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


dataset = load_dataset("yelp_review_full")

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

metric = evaluate.load("accuracy")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

training_args = TrainingArguments(
    output_dir="test_trainer",
    num_train_epochs=3,
    logging_steps=50,
    report_to="swanlab", # [!code ++]
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

4. SwanLabCallback集成

如果你使用的是Transformers<4.50.0的版本,或者你希望更灵活地控制SwanLab的行为,则可以使用SwanLabCallback集成。

4.1 引入SwanLabCallback

from swanlab.integration.transformers import SwanLabCallback

SwanLabCallback是适配于Transformers的日志记录类。

SwanLabCallback可以定义的参数有:

  • project、experiment_name、description 等与 swanlab.init 效果一致的参数, 用于SwanLab项目的初始化。
  • 你也可以在外部通过swanlab.init创建项目,集成会将实验记录到你在外部创建的项目中。

4.2 传入Trainer

from swanlab.integration.transformers import SwanLabCallback
from transformers import Trainer, TrainingArguments

...

# 实例化SwanLabCallback
swanlab_callback = SwanLabCallback(project="hf-visualization")

trainer = Trainer(
    ...
    # 传入callbacks参数
    callbacks=[swanlab_callback],
)

trainer.train()

4.3 完整案例代码

import evaluate
import numpy as np
import swanlab
from swanlab.integration.transformers import SwanLabCallback
from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)


dataset = load_dataset("yelp_review_full")

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

metric = evaluate.load("accuracy")

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

training_args = TrainingArguments(
    output_dir="test_trainer",
    # 如果只需要用SwanLab跟踪实验,则将report_to参数设置为”none“
    report_to="none",
    num_train_epochs=3,
    logging_steps=50,
)

# 实例化SwanLabCallback
swanlab_callback = SwanLabCallback(experiment_name="TransformersTest")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
    # 传入callbacks参数
    callbacks=[swanlab_callback],
)

trainer.train()

4.4 GUI效果展示

超参数自动记录:

在这里插入图片描述

指标记录:

在这里插入图片描述

4.5 拓展:增加更多回调

试想一个场景,你希望在每个epoch结束时,让模型推理测试样例,并用swanlab记录推理的结果,那么你可以创建一个继承自SwanLabCallback的新类,增加或重构生命周期函数。比如:

class NLPSwanLabCallback(SwanLabCallback):    
    def on_epoch_end(self, args, state, control, **kwargs):
        test_text_list = ["example1", "example2"]
        log_text_list = []
        for text in test_text_list:
            result = model(text)
            log_text_list.append(swanlab.Text(result))
            
        swanlab.log({"Prediction": test_text_list}, step=state.global_step)

上面是一个在NLP任务下的新回调类,增加了on_epoch_end函数,它会在transformers训练的每个epoch结束时执行。

查看全部的Transformers生命周期回调函数:链接

5. 环境变量

参考:HuggingFace Docs: transformers.integrations.SwanLabCallback

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2320390.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

VSCode 抽风之 两个conda环境同时在被激活

出现了神奇的(toolsZCH)(base) 提示符&#xff0c;如下图所示&#xff1a; 原因大概是&#xff1a;conda 环境的双重激活&#xff1a;可能是 conda 环境没有被正确清理或初始化&#xff0c;导致 base 和 toolsZCH 同时被激活。 解决办法就是 &#xff1a;conda deactivate 两次…

Mybatis的基础操作——03

写mybatis代码的方法有两种&#xff1a; 注解xml方式 本篇就介绍XML的方式 使用XML来配置映射语句能够实现复杂的SQL功能&#xff0c;也就是将sql语句写到XML配置文件中。 目录 一、配置XML文件的路径&#xff0c;在resources/mapper 的目录下 二、写持久层代码 1.添加mappe…

React:React主流组件库对比

1、Material-UI | 官网 | GitHub | GitHub Star: 94.8k Material-UI 是一个实现了 Google Material Design 规范的 React 组件库。 Material UI 包含了大量预构建的 Material Design 组件&#xff0c;覆盖导航、滑块、下拉菜单等各种常用组件&#xff0c;并都提供了高度的可定制…

python每日十题(6)

】函数定义&#xff1a;函数是指一组语句的集合通过一个名字&#xff08;函数名&#xff09;封装起来&#xff0c;要想执行这个函数&#xff0c;只需要调用其函数名即可。函数能提高应用的模块性和代码的重复利用率 在Python语言中&#xff0c;用关键字class来定义类 在Python语…

1.Go - Hello World

1.安装Go依赖 https://go.dev/dl/ 根据操作系统选择适合的依赖&#xff0c;比如windows&#xff1a; 2.配置环境变量 右键此电脑 - 属性 - 环境变量 PS&#xff1a; GOROOT&#xff1a;Go依赖路径&#xff1b; GOPATH&#xff1a;Go项目路径&#xff1b; …

优先队列 priority_queue详解

说到&#xff0c;priority_queue优先队列。必须先要了解啥是堆与运算符重载(我在下方有解释)。 否则只知皮毛&#xff0c;极易忘记寸步难行。 但在开头&#xff0c;还是简单的说下怎么用 首先&#xff0c;你需要调用 #include <queue> 在main函数中&#xff0c;声明…

《信息系统安全》(第一次上机实验报告)

实验一 &#xff1a;网络协议分析工具Wireshark 一 实验目的 学习使用网络协议分析工具Wireshark的方法&#xff0c;并用它来分析一些协议。 二实验原理 TCP/IP协议族中网络层、传输层、应用层相关重要协议原理。网络协议分析工具Wireshark的工作原理和基本使用规则。 三 实…

简要分析IPPROTO_TCP参数

IPPROTO_TCP是操作系统或网络编程中定义的一个 协议号常量&#xff0c;用于标识 传输控制协议&#xff08;TCP&#xff09;。其核心作用是 在传输层指定使用TCP协议&#xff0c;确保数据通过TCP的可靠传输机制进行通信。 一、定义与值 头文件&#xff1a;定义在<netinet/in.…

JavaScript与客户端开发

1、简介 简单的讲&#xff0c;JavaScript是一种脚本语言&#xff0c;为网站提供了一种在客户端运行程序的手段&#xff0c;通过它可以实现客户端数据验证、网页特效等功能。 JavaScript是一种基于对象和事件驱动&#xff08;不懂啥意思&#xff0c;暂不管它&#xff09;&…

基于CNN的FashionMNIST数据集识别5——GoogleNet模型

源码 import torch from torch import nn from torchsummary import summaryclass Inception(nn.Module):def __init__(self, in_channels, c1, c2, c3, c4):super().__init__()self.ReLu nn.ReLU()#路径1self.p1_1 nn.Conv2d(in_channelsin_channels, out_channelsc1, kern…

JVM垃圾回收笔记01-垃圾回收算法

文章目录 前言1. 如何判断对象可以回收1.1 引用计数法1.2 可达性分析算法查看根对象哪些对象可以作为 GC Root ?对象可以被回收&#xff0c;就代表一定会被回收吗&#xff1f; 1.3 引用类型1.强引用&#xff08;StrongReference&#xff09;2.软引用&#xff08;SoftReference…

【初探数据结构】树与二叉树

&#x1f4ac; 欢迎讨论&#xff1a;在阅读过程中有任何疑问&#xff0c;欢迎在评论区留言&#xff0c;我们一起交流学习&#xff01; &#x1f44d; 点赞、收藏与分享&#xff1a;如果你觉得这篇文章对你有帮助&#xff0c;记得点赞、收藏&#xff0c;并分享给更多对数据结构感…

蓝桥杯备考:二分答案之路标设置

最大距离&#xff0c;找最小空旷指数值&#xff0c;我们是很容易想到用二分的&#xff0c;我们再看看这个答案有没有二段性 是有这么个二段性的&#xff0c;我们只要二分就行了&#xff0c;但是二分的check函数是有点不好想的&#xff0c;我们枚举空旷值的时候&#xff0c;为了…

回调方法传参汇总

文章目录 0. 引入问题1. 父子组件传值1.1 父传子&#xff1a;props1.2 子传父&#xff1a;$emit1.3 双向绑定&#xff1a;v-model 2. 多个参数传递3. 父组件监听方法传递其他值3.1 $event3.2 箭头方法 4. 子组件传递多个参数&#xff0c;父组件传递本地参数4.1 箭头函数 … 扩…

XSS基础靶场练习

目录 1. 准备靶场 2. PASS 1. Level 1&#xff1a;无过滤 源码&#xff1a; 2. level2&#xff1a;转HTML实体 htmlspecialchars简介&#xff1a; 源码 PASS 3. level3:转HTML深入 源码&#xff1a; PASS 4. level4:过滤<> 源码&#xff1a; PASS: 5. level5:过滤on 源码…

Redis核心机制(一)

目录 Redis的特性 1.速度快 2.以键值对方式进行存储 3.丰富的功能 4.客户端语言多 5.持久化 6.主从复制 7.高可用和分布式 Redis使用场景 Redis核心机制——持久化 RDB bgsave执行流程 ​编辑 AOF AOF重写流程 3.混合持久化&#xff08;RDBAOF&#xff09; Red…

QGroupBox取消勾选时不禁用子控件

默认情况下&#xff0c;QGroupBox取消勾选会自动禁用子控件&#xff0c;如下图所示 那么如何实现取消勾选时不禁用子控件呢&#xff1f; 实现很简单&#xff0c;直接上代码了 connect(ui->groupBox, &QGroupBox::toggled, this, [](bool checked){if (checked false){…

MyBatis-Plus 自动填充:优雅实现创建/更新时间自动更新!

目录 一、什么是 MyBatis-Plus 自动填充&#xff1f; &#x1f914;二、自动填充的原理 ⚙️三、实际例子&#xff1a;创建时间和更新时间字段自动填充 ⏰四、注意事项 ⚠️五、总结 &#x1f389; &#x1f31f;我的其他文章也讲解的比较有趣&#x1f601;&#xff0c;如果喜欢…

canvas数据标注功能简单实现:矩形、圆形

背景说明 基于UI同学的设计&#xff0c;在市面上找不到刚刚好的数据标注工具&#xff0c;遂决定自行开发。目前需求是实现图片的矩形、圆形标注&#xff0c;并获取标注的坐标信息&#xff0c;使用canvas可以比较方便的实现该功能。 主要功能 选中图形&#xff0c;进行拖动 使…

【UI设计】一些好用的免费图标素材网站

阿里巴巴矢量图标库https://www.iconfont.cn/国内最大的矢量图标库之一&#xff0c;拥有 800 万 图标资源。特色功能包括团队协作、多端适配、定制化编辑等&#xff0c;适合企业级项目、电商设计、中文产品开发等场景。IconParkhttps://iconpark.oceanengine.com/home字节跳动…