Transformers x SwanLab:可视化NLP模型训练

news2024/11/18 12:17:52

HuggingFace 的 Transformers 是目前最流行的深度学习训框架之一(100k+ Star),现在主流的大语言模型(LLaMa系列、Qwen系列、ChatGLM系列等)、自然语言处理模型(Bert系列)等,都在使用Transformers来进行预训练、微调和推理。

在这里插入图片描述
SwanLab是一个深度学习实验管理与训练可视化工具,由西安电子科技大学团队打造,融合了Weights & Biases与Tensorboard的特点,能够方便地进行 训练可视化、多实验对比、超参数记录、大型实验管理和团队协作,并支持用网页链接的方式分享你的实验。

在这里插入图片描述

你可以使用Transformers快速进行模型训练,同时使用SwanLab进行实验跟踪与可视化。
下面将用一个Bert训练,来介绍如何将Transformers与SwanLab配合起来:

1. 代码中引入SwanLabCallback

from swanlab.integration.huggingface import SwanLabCallback

SwanLabCallback是适配于Transformers的日志记录类。

SwanLabCallback可以定义的参数有:

  • project、experiment_name、description 等与 swanlab.init 效果一致的参数, 用于SwanLab项目的初始化。
  • 你也可以在外部通过swanlab.init创建项目,集成会将实验记录到你在外部创建的项目中。

2. 传入Trainer

from swanlab.integration.huggingface import SwanLabCallback
from transformers import Trainer, TrainingArguments
...

# 实例化SwanLabCallback
swanlab_callback = SwanLabCallback()

trainer = Trainer(
    ...
    # 传入callbacks参数
    callbacks=[swanlab_callback],
)

3. 案例-Bert训练

查看在线实验过程:BERT-SwanLab

在这里插入图片描述

下面是一个基于Transformers框架,使用BERT模型在imdb数据集上做微调,同时用SwanLab进行可视化的案例代码

"""
用预训练的Bert模型微调IMDB数据集,并使用SwanLabCallback回调函数将结果上传到SwanLab。
IMDB数据集的1是positive,0是negative。
"""

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from swanlab.integration.huggingface import SwanLabCallback
import swanlab

def predict(text, model, tokenizer, CLASS_NAME):
    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits).item()

    print(f"Input Text: {text}")
    print(f"Predicted class: {int(predicted_class)} {CLASS_NAME[int(predicted_class)]}")
    return int(predicted_class)
# 加载IMDB数据集
dataset = load_dataset('imdb')

# 加载预训练的BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# 定义tokenize函数
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

# 对数据集进行tokenization
tokenized_datasets = dataset.map(tokenize, batched=True)

# 设置模型输入格式
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# 加载预训练的BERT模型
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_first_step=100,
    # 总的训练轮数
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="none",
    # 单卡训练
)

CLASS_NAME = {0: "negative", 1: "positive"}

# 设置swanlab回调函数
swanlab_callback = SwanLabCallback(project='BERT',
                                   experiment_name='BERT-IMDB',
                                   config={'dataset': 'IMDB', "CLASS_NAME": CLASS_NAME})

# 定义Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    callbacks=[swanlab_callback],
)

# 训练模型
trainer.train()

# 保存模型
model.save_pretrained('./sentiment_model')
tokenizer.save_pretrained('./sentiment_model')

# 测试模型
test_reviews = [
    "I absolutely loved this movie! The storyline was captivating and the acting was top-notch. A must-watch for everyone.",
    "This movie was a complete waste of time. The plot was predictable and the characters were poorly developed.",
    "An excellent film with a heartwarming story. The performances were outstanding, especially the lead actor.",
    "I found the movie to be quite boring. It dragged on and didn't really go anywhere. Not recommended.",
    "A masterpiece! The director did an amazing job bringing this story to life. The visuals were stunning.",
    "Terrible movie. The script was awful and the acting was even worse. I can't believe I sat through the whole thing.",
    "A delightful film with a perfect mix of humor and drama. The cast was great and the dialogue was witty.",
    "I was very disappointed with this movie. It had so much potential, but it just fell flat. The ending was particularly bad.",
    "One of the best movies I've seen this year. The story was original and the performances were incredibly moving.",
    "I didn't enjoy this movie at all. It was confusing and the pacing was off. Definitely not worth watching."
]

model.to('cpu')
text_list = []
for review in test_reviews:
    label = predict(review, model, tokenizer, CLASS_NAME)
    text_list.append(swanlab.Text(review, caption=f"{label}-{CLASS_NAME[label]}"))

if text_list:
    swanlab.log({"predict": text_list})

swanlab.finish()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4. 相关链接

  • Transformers文档:🤗 Transformers
  • SwanLab官网:SwanLab - 在线AI实验平台,一站式跟踪、比较、分享你的模型
  • SwanLab官方文档:SwanLab官方文档 | 先进的AI团队协作与模型创新引擎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1708038.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Web3探索加密世界:空投常见类型有哪些?附操作教程

每种空投类型都有独特的特征和目的,我们需要了解不同类型的加密空投。本文给大家介绍的是流行的加密货币空投类型,以及一般空投是如何做的。感兴趣的话请看下去。 一、空投常见类型 1、持有者空投 持有者空投向钱包中持有一定数量数字货币的人免费发放…

基于MyBatisPlus表结构维护工具

SuperTable表结构维护工具 一、简述 用于同步表实体与数据库表结构,同步建表、删改字段、索引,种子数据的工具… 一、开发环境 JDK:JDK8SpringBoot:2.7.2MyBatisPlus: 3.5.6MySQL: 5.7其他依赖:略 二、特性 表结…

大话C语言:第18篇 函数

1 函数概述 函数是c语言的功能单位,实现一个功能可以封装一个函数来实现。函数是一种可重用的代码块,用于执行特定任务或完成特定功能。函数主要作用是对具备相同逻辑的代码进行封装,提高代码的编写效率,实现对代码的重用。例如&a…

【Spring】深入学习AOP编程思想的实现原理和优势

【切面编程】深入学习AOP编程思想的实现原理和优势 前言AOP的由来及AOP与代理的关系AOP的实现方式详解静态代理动态代理 AOP的应用场景记录日志权限控制数据库事务控制缓存处理异常处理监控管理 AOP的底层实现全流程解析Spring AOP的简介动态代理的实现原理Spring AOP的实现原理…

流水账(CPU设计实战)——lab3

Lab3 Rewrite V1.0 版本控制 版本描述V0V1.0相对V0变化: 修改了文件名,各阶段以_stage结尾(因为if是关键词,所以module名不能叫if,遂改为if_stage,为了统一命名,将所有module后缀加上_stage&a…

关于Windows中桌面窗口管理器的知识,看这篇文章就可以了

序言 你打开了任务管理器,发现了一个叫做“桌面窗口管理器”的东西,它是恶意软件吗?它应该在任务管理器吗?如果它应该在那里,它的作用什么?以下是你需要了解的所有信息。 什么是桌面窗口管理器 Desktop Window Manager(dwm.exe)是一个合成窗口管理器,可以在Windows…

SB-OSC,最新的 MySQL Schema 在线变更方案

目前主流的 MySQL 在线变更方案有两个: 基于 trigger 的 pt-online-schema-change基于 binlog 的 gh-ost 上周 Sendbird 刚开源了他们的 MySQL Schema 在线变更方案 SB-OSC: Sendbird Online Schema Change。 GitHub 上刚刚 25 颗星星,绝对新鲜出炉。 …

Java | Leetcode Java题解之第104题二叉树的最大深度

题目&#xff1a; 题解&#xff1a; class Solution {public int maxDepth(TreeNode root) {if (root null) {return 0;}Queue<TreeNode> queue new LinkedList<TreeNode>();queue.offer(root);int ans 0;while (!queue.isEmpty()) {int size queue.size();wh…

go语言基准测试Benchmark 最佳实践-冒泡排序和快速排序算法基准测试时间复杂度对比

在go语言中Benchmark基准测试( 在后缀为_test.go的文件中&#xff0c;函数原型为 func BenchmarkXxx(b *testing.B) {}的函数 )可以用来帮助我们发现代码的性能和瓶颈&#xff0c; 其最佳实践 应该是我们最常用的 冒泡排序和快速排序的测试了&#xff0c;废话不说&#xff0c;直…

玄机平台应急响应—webshell查杀

1、前言 这篇文章说一下应急响应的内容&#xff0c;webshell查杀呢是应急响应的一部分。那么什么是应急响应呢&#xff0c;所谓的应急响应指的是&#xff0c;当网站突然出现异常情况或者漏洞时&#xff0c;能够马上根据实际问题进行分析&#xff0c;然后及时解决问题。 2、应…

软件游戏找不到xinput1_3.dll如何修复?分享几种有效的解决方法

当您在使用电脑过程中遇到系统提示“xinput1_3.dll文件丢失”时&#xff0c;这可能会导致某些游戏或应用程序无法正常运行&#xff0c;因为xinput1_3.dll是与DirectX相关的一个重要动态链接库文件&#xff0c;主要负责处理游戏控制器输入。为了解决这个问题&#xff0c;我通过查…

华为实训课笔记 2024

华为实训 5/205/215/225/235/275/28 5/20 5/21 5/22 5/23 5/27 5/28

蓝桥杯嵌入式国赛笔记(4):多路AD采集

1、前言 蓝桥杯的国赛会遇到多路AD采集的情况&#xff0c;这时候之前的单路采集的方式就不可用了&#xff0c;下面介绍两种多路采集的方式。 以第13届国赛为例 2、方法一&#xff08;配置通道&#xff09; 2.1 使用CubeMx配置 设置IN13与IN17为Single-ended 在Parameter S…

新增长100人研讨会:台州制造业企业共探数字驱动下的业绩增长策略

2024年5月17日&#xff0c;纷享销客联合鑫磊压缩机&#xff0c;在台州举办了一场主题为“数字化驱动下的业绩增长策略”的研讨会。本次会议汇聚台州多家制造行业的10余位数字化管理者&#xff0c;共同探讨在数字化转型浪潮中&#xff0c;制造业如何实现业绩的持续增长。 鑫磊压…

【hackmyvm】Slowman靶机

文章目录 主机探测端口探测FTP匿名登录 目录探测hydra爆破mysql爆破zip------fcrackzip爆破密码-----john提权 主机探测 ┌──(root㉿kali)-[/home/kali] └─# fping -ag 192.168.9.1/24 2>/dev/null 192.168.9.221 主机192.168.9.224 靶机端口探测 ┌──(roo…

蓝桥杯嵌入式国赛笔记(3):其他拓展板程序设计(温、湿度传感器、光敏电阻等)

目录 1、DS18B20读取 2、DHT11 2.1 宏定义 2.2 延时 2.3 设置引脚输出 2.4 设置引脚输入 2.5 复位 2.6 检测函数 2.7 读取DHT11一个位 2.7.1 数据位为0的电平信号显示 2.7.2 数据位为1的电平信号显示 2.8 读取DHT11一个字节 2.9 DHT11初始化 2.10 读取D…

【接口自动化_05课_Pytest接口自动化简单封装与Logging应用】

一、关键字驱动--设计框架的常用的思路 封装的作用&#xff1a;在编程中&#xff0c;封装一个方法&#xff08;函数&#xff09;主要有以下几个作用&#xff1a;1. **代码重用**&#xff1a;通过封装重复使用的代码到一个方法中&#xff0c;你可以在多个地方调用这个方法而不是…

【Linux学习】进程间通信 (3) —— System V (1)

下面是有关进程通信中 System V 的相关介绍&#xff0c;希望对你有所帮助&#xff01; 小海编程心语录-CSDN博客 目录 1. System V IPC 1. 消息队列 msg 消息队列的使用方法 1.1 消息队列的创建 1.2 向消息队列发送消息 1.3 从消息队列接收消息 1.4 使用msgctl函数显式地…

Java面试八股之对threadLocal是怎么理解的

对threadLocal是怎么理解的 概念与特点&#xff1a;ThreadLocal是Java提供的一个类&#xff0c;它允许你创建线程局部变量。每个线程都拥有自己的ThreadLocal变量副本&#xff0c;彼此之间互不影响&#xff0c;实现了变量在线程间的隔离。这意味着&#xff0c;即使多个线程使用…