SwanLab入门深度学习:BERT IMDB文本情感分类

news2024/11/24 7:41:47

基于BERT模型的IMDB电影评论情感分类,是NLP经典的Hello World任务之一。

这篇文章我将带大家使用SwanLab、transformers、datasets三个开源工具,完成从数据集准备、代码编写、可视化训练的全过程。

观察了一下,中文互联网上似乎很少有能直接跑起来的BERT训练代码和教程,所以也希望这篇文章可以帮到大家。
在这里插入图片描述

  • 代码:完整代码直接看本文第5节
  • 模型与数据集:百度云,提取码: u9gi
  • 实验过程:BERT-SwanLab
  • SwanLab:https://swanlab.cn
  • transformers:https://github.com/huggingface/transformers
  • datasets:https://github.com/huggingface/datasets

1.环境安装

我们需要安装以下这4个Python库:

transformers>=4.41.0
datasets>=2.19.1
swanlab>=0.3.3

一键安装命令:

pip install transformers datasets swanlab

他们的作用分别是:

  1. transformers:HuggingFace出品的深度学习框架,已经成为了NLP(自然语言处理)领域最流行的训练与推理框架。代码中用transformers主要用于加载模型、训练以及推理。
  2. datasets:同样是HuggingFace出品的数据集工具,可以下载来自huggingface社区上的数据集。代码中用datasets主要用于下载、加载数据集。
  3. swanlab:在线训练可视化和超参数记录工具,官网,可以记录整个实验的超参数、指标、训练环境、Python版本等,并可是化成图表,帮助你分析训练的表现。代码中用swanlab主要用于记录指标和可视化。

本文的代码测试于transformers4.41.0、datasets2.19.1、swanlab==0.3.3,更多库版本可查看SwanLab记录的Python环境。

2.加载BERT模型

BERT模型我们直接下载来自HuggingFace上由Google发布的bert-case-uncased预训练模型。

执行下面的代码,会自动下载模型权重并加载模型:

from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# 加载预训练的BERT tokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

如果国内下载比较慢的话,可以在这个百度云(提取码: u9gi)下载后,把bert-base-uncased文件夹放到根目录,然后改写上面的代码为:

model = AutoModelForSequenceClassification.from_pretrained('./bert-base-uncased', num_labels=2)

3.加载IMDB数据集

IMDB数据集(Internet Movie Database Dataset)是自然语言处理(NLP)领域中一个非常著名和广泛使用的数据集,主要应用于文本情感分析任务。

IMDB数据集源自全球最大的电影数据库网站Internet Movie Database(IMDb),该网站包含了大量的电影、电视节目、纪录片等影视作品信息,以及用户对这些作品的评论和评分。
数据集包括50,000条英文电影评论,这些评论被标记为正面或负面情感,用以进行二分类任务。其中,25,000条评论被分配为训练集,另外25,000条则作为测试集。训练集和测试集都保持了平衡的正负样本比例,即各含50%的正面评论和50%的负面评论.

在这里插入图片描述
我们同样直接下载HuggingFace上的imdb数据集,执行下面的代码,会自动下载数据集并加载:

from datasets import load_dataset

# 加载IMDB数据集
dataset = load_dataset('imdb')

如果国内下载比较慢的话,可以在这个百度云(提取码: u9gi)下载后,把imdb文件夹放到根目录,然后改写上面的代码为:

dataset = load_dataset('./imdb')

4.集成SwanLab

因为swanlab已经和transformers框架做了集成,所以将SwanLabCallback类传入到trainercallbacks参数中即可实现实验跟踪和可视化:

from swanlab.integration.huggingface import SwanLabCallback

# 设置swanlab回调函数
swanlab_callback = SwanLabCallback()

...

# 定义Transformers Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    # 传入swanlab回调函数
    callbacks=[swanlab_callback],
)

想了解更多关于SwanLab的知识,请看SwanLab官方文档。

5.开始训练!

训练过程看这里:BERT-SwanLab。

在首次使用SwanLab时,需要去官网注册一下账号,然后在用户设置复制一下你的API Key。

在这里插入图片描述
然后在终端输入swanlab login:

swanlab login

把API Key粘贴进去即可完成登录,之后就不需要再次登录了。

完整的训练代码:

"""
用预训练的Bert模型微调IMDB数据集,并使用SwanLabCallback回调函数将结果上传到SwanLab。
IMDB数据集的1是positive,0是negative。
"""

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from swanlab.integration.huggingface import SwanLabCallback
import swanlab

def predict(text, model, tokenizer, CLASS_NAME):
    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predicted_class = torch.argmax(logits).item()

    print(f"Input Text: {text}")
    print(f"Predicted class: {int(predicted_class)} {CLASS_NAME[int(predicted_class)]}")
    return int(predicted_class)

# 加载IMDB数据集
dataset = load_dataset('imdb')

# 加载预训练的BERT tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

# 定义tokenize函数
def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True)

# 对数据集进行tokenization
tokenized_datasets = dataset.map(tokenize, batched=True)

# 设置模型输入格式
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# 加载预训练的BERT模型
model = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 设置训练参数
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_first_step=100,
    # 总的训练轮数
    num_train_epochs=3,
    weight_decay=0.01,
    report_to="none",
    # 单卡训练
)

CLASS_NAME = {0: "negative", 1: "positive"}

# 设置swanlab回调函数
swanlab_callback = SwanLabCallback(project='BERT',
                                   experiment_name='BERT-IMDB',
                                   config={'dataset': 'IMDB', "CLASS_NAME": CLASS_NAME})

# 定义Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['test'],
    callbacks=[swanlab_callback],
)

# 训练模型
trainer.train()

# 保存模型
model.save_pretrained('./sentiment_model')
tokenizer.save_pretrained('./sentiment_model')

# 测试模型
test_reviews = [
    "I absolutely loved this movie! The storyline was captivating and the acting was top-notch. A must-watch for everyone.",
    "This movie was a complete waste of time. The plot was predictable and the characters were poorly developed.",
    "An excellent film with a heartwarming story. The performances were outstanding, especially the lead actor.",
    "I found the movie to be quite boring. It dragged on and didn't really go anywhere. Not recommended.",
    "A masterpiece! The director did an amazing job bringing this story to life. The visuals were stunning.",
    "Terrible movie. The script was awful and the acting was even worse. I can't believe I sat through the whole thing.",
    "A delightful film with a perfect mix of humor and drama. The cast was great and the dialogue was witty.",
    "I was very disappointed with this movie. It had so much potential, but it just fell flat. The ending was particularly bad.",
    "One of the best movies I've seen this year. The story was original and the performances were incredibly moving.",
    "I didn't enjoy this movie at all. It was confusing and the pacing was off. Definitely not worth watching."
]

model.to('cpu')
text_list = []
for review in test_reviews:
    label = predict(review, model, tokenizer, CLASS_NAME)
    text_list.append(swanlab.Text(review, caption=f"{label}-{CLASS_NAME[label]}"))

if text_list:
    swanlab.log({"predict": text_list})

swanlab.finish()

训练可视化过程:
在这里插入图片描述

训练大概需要6G左右的显存,我在一块3090上跑了,1个epoch大概要12~13分钟时间。

训练的推理结果:

在这里插入图片描述
这里我生成了10个比较简单的测试文本,微调后的BERT模型基本都能答对。

至此,我们顺利完成了用BERT预训练模型微调IMDB数据的训练过程~

相关链接

  • 代码:完整代码直接看本文第5节
  • 模型与数据集:百度云,提取码: u9gi
  • 实验过程:BERT-SwanLab
  • SwanLab:https://swanlab.cn
  • transformers:https://github.com/huggingface/transformers
  • datasets:https://github.com/huggingface/datasets

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1686055.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

光敏聚酰亚胺(PSPI)行业技术壁垒较高 本土企业已具备相关产品量产能力

光敏聚酰亚胺(PSPI)行业技术壁垒较高 本土企业已具备相关产品量产能力 光敏聚酰亚胺(PSPI)又称光敏PI、光刻胶用聚酰亚胺,指将光敏基团引入聚酰亚胺分子链中制成的高性能有机材料。PSPI拥有极佳耐热性、化学稳定性、热…

新书推荐:6.1 if语句

计算机语言和人类语言类似,人类语言是为了解决人与人之间交流的问题,而计算机语言是为了解决程序员与计算机之间交流的问题。程序员编写的程序就是计算机的控制指令,控制计算机的运行。借助于编译工具,可以将各种不同的编程语言的…

Jenkins工具系列 —— 通过钉钉API 发送消息

文章目录 钉钉环境搭建使用钉钉API接口 发送消息机器人安全设置使用自定义关键词机器人安全设置使用加签方式 资料下载 钉钉环境搭建 在jenkins安装钉钉插件以及小机器人,这部分内容可参考:插件 钉钉发送消息 使用钉钉API接口 发送消息 机器人安全设置…

Day02:LeedCode977. 有序数组的平方 209.长度最小的子数组 59.螺旋矩阵II

详解:Day2:LeedCode977. 有序数组的平方 209.长度最小的子数组 59.螺旋矩阵II-CSDN博客 977. 有序数组的平方 给你一个按 非递减顺序 排序的整数数组 nums,返回 每个数字的平方 组成的新数组,要求也按 非递减顺序 排序。 示例 1: 输入&#…

Docker Compose快速入门

本教程旨在通过指导您开发基本Python web应用程序来介绍Docker Compose的基本概念。 使用Flask框架,该应用程序在Redis中提供了一个命中计数器,提供了如何在web开发场景中应用Docker Compose的实际示例。 即使您不熟悉Python,这里演示的概念也…

Window Linux 权限提升

#基础点: 0、为什么我们要学习权限提升转移技术: 简单来说就是达到目的过程中需要用到它 心里要想着我是谁 我在哪 我要去哪里 1、具体有哪些权限需要我们了解掌握的: 后台权限,数据库权限,Web权限,用户权…

React 中Redux结合React-Redux使用类组件版本(一)

一、Redux是什么? 1.Redux是一个专门用于状态管理的js库 2.它可以用在React、Angular、Vue的项目中,但基本与React配合使用。 3.作用:集中式管理React应用中多个组件共享的状态。 二、Redux 工作流程 三、Redux的三个核心概念 1.action 动…

线上研讨会 | 探索非标自动化产线行业的数转智改之路

报名链接: 2024 达索系统工业大发展在线研讨会 (tbh5.com)

azure gpt 技术教程教学 | 在Azure OpenAI 上部署GPT-4o

Azure OpenAI GPT-4o是OpenAI推出的最新旗舰级人工智能模型。GPT-4o模型设计为能够实时对音频、视觉和文本进行推理,这是迈向更自然人机交互的重要一步。该模型的一大特点是能够处理多种类型的数据输入和输出,包括文本、音频和图像,实现了跨模…

521源码-在线客服-CRMChat网页版客服系统 UNIAPP 全方位在线客服系统源码与管理体系平台

CRMChat客服系统:基于Swoole4Tp6RedisVueMysql构建的高效沟通桥梁 CRMChat是一款独立且高性能的在线客服系统,它结合了Swoole4、Tp6、Redis、Vue以及Mysql等先进技术栈,为用户提供了卓越的在线沟通体验。该系统不仅支持在Pc端、移动端、小程…

软考 软件设计师 场景分析题 速成篇

文章目录 试题一:数据流图💖 基本图形元素1. 外部实体2. 数据存储3. 加工4. 数据流 📚 例题(1)实体名称(2)数据存储名称(3)数据流① 父子图平衡② 加工有输入有输出④ 数…

visual studio snippet常用注释片段

Visual Studio 2022 添加自定义代码片段_vs2022 代码片段-CSDN博客 dclass.snippet: <?xml version"1.0" encoding"utf-8"?> <CodeSnippets xmlns"http://schemas.microsoft.com/VisualStudio/2005/CodeSnippet"> …

Pip,whl,源码编译安装Python库

pip安装 pip 是 Python 包管理工具&#xff0c;用于安装和管理 Python 包。pip 是 Python 开发中不可或缺的工具&#xff0c;能够帮助开发者轻松地管理项目所需的各种库和依赖。无论是安装新包、升级现有包还是卸载不需要的包&#xff0c;pip 都提供了简单而强大的命令来完成这…

YOLOv5改进 | 主干网络 | 用repvgg模块替换Conv【教程+代码 】

&#x1f4a1;&#x1f4a1;&#x1f4a1;本专栏所有程序均经过测试&#xff0c;可成功执行&#x1f4a1;&#x1f4a1;&#x1f4a1; 尽管Ultralytics 推出了最新版本的 YOLOv8 模型。但YOLOv5作为一个anchor base的目标检测的算法&#xff0c;YOLOv5可能比YOLOv8的效果更好。…

【02】GeoScene Enterprise(Windows)许可更新

如果在Windows环境下部署了GeoScene Enterprise基础环境&#xff0c;也就是部署了server、portal、datastore、web adaptor四大组件&#xff0c;当试用许可到期后&#xff0c;拿到新的许可想要更新许可&#xff0c;从而使得软件能够正常工作&#xff0c;下述步骤是更新GeoScene…

WebRTC 音频抗弱网技术

实时音视频通话一直是我们通信行业必不可少的一门技术&#xff0c;并且近今年音视频边缘设备产品涌现出很多设备&#xff0c;然而&#xff0c;在当今网络环境中&#xff0c;网络传输质量确常常无法得到有效的保障&#xff0c;那么&#xff0c;在当今弱网环境下&#xff0c;如何…

DeepRec Extension 打造稳定高效的分布式训练

DeepRec Extension 即 DeepRec 扩展&#xff0c;在 DeepRec 训练推理框架之上&#xff0c;围绕大规模稀疏模型分布式训练&#xff0c;我们从训练任务的视角提出了自动弹性训练&#xff0c;分布式容错等功能&#xff0c;进一步提升稀疏模型训练的整体效率&#xff0c;助力 DeepR…

Vue3:动态路由+子页面(新增、详情页)动态路由配置(代码全注释)

文章目录 实现思路调用后端接口获取用户权限获取页面权限动态绑定到路由对象中动态添加子页面路由 实现思路 emm&#xff0c;项目中使用动态路由实现根据后端返回的用户详情信息&#xff0c;动态将该用户能够访问的页面信息&#xff0c;动态生成并且绑定到路由对象中。但是后…

【leetcode面试经典150题】-80. 删除有序数组中的重复项 II

【leetcode面试经典150题】-80. 删除有序数组中的重复项 II 1 题目介绍2 个人解题思路2.1 代码2.2 思路 3 官方题解 1 题目介绍 给你一个有序数组 nums &#xff0c;请你 原地 删除重复出现的元素&#xff0c;使得出现次数超过两次的元素只出现两次 &#xff0c;返回删除后数组…

MongoDB基础入门到深入(七)建模、调优

文章目录 系列文章索引十一、MongoDB开发规范十二、MongoDB调优1、三大导致MongoDB性能不佳的原因2、影响MongoDB性能的因素3、MongoDB性能监控工具&#xff08;1&#xff09;mongostat&#xff08;2&#xff09;mongotop&#xff08;3&#xff09;Profiler模块&#xff08;4&a…