使用 Huggingface Trainer 对自定义数据集进行文本分类

news2025/1/11 13:00:08

文本分类是一项常见的 NLP 任务,它根据文本的内容定义文本的类型、流派或主题。Huggingface🤗 Transformers 提供 API 和工具来轻松下载和训练最先进的预训练模型。Huggingface Transformers 支持 PyTorch、TensorFlow 和 JAX 之间的框架互操作性。模型还可以导出为 ONNX 和 TorchScript 等格式,以便在生产环境中部署。

在这里插入图片描述
此博客将指导您使用 Huggingface Transformers 对自定义数据集的 Distillbert 进行微调。

DistilBERT 是一种小型、快速、便宜且轻便的 Transformer 模型,通过提取 BERT 基础进行训练。它比bert-base-uncased少了 40% 的参数,运行速度提高了 60%,同时保留了 BERT 在 GLUE 语言理解基准测试中超过 95% 的性能

本博客源代码:Google Colab

1.安装transformers库

按照以下链接中的这些说明安装 huggingface 库:https 😕/huggingface.co/docs/datasets/v1.11.0/installation.html

!pip install transformers
!pip install sentencepiece

我们将在本教程中使用 IMDb 数据集。IMDB 数据集是一个大型电影评论数据集。这是一个用于二元情感分类的数据集,包含比以前的基准数据集多得多的数据。他们提供了一组 25,000 条高度极端的电影评论用于训练,25,000 条用于测试。还有其他未标记的数据可供使用。我们可以通过以下步骤下载数据集。

2.下载数据集

!wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz 
!tar -xf aclImdb_v1.tar.gz

此数据被组织到每个示例一个文本文件的文件夹中pos。read_imdb_split函数将帮助我们读取和处理数据集。neg

from pathlib import Path

def read_imdb_split(split_dir):
    split_dir = Path(split_dir)
    texts = []
    labels = []
    for label_dir in ["pos", "neg"]:
        for text_file in (split_dir/label_dir).iterdir():
            texts.append(text_file.read_text())
            labels.append(0 if label_dir is "neg" else 1)

    return texts, labels

train_texts, train_labels = read_imdb_split('aclImdb/train')
test_texts, test_labels = read_imdb_split('aclImdb/test')

让我们使用 Scikit-learn 中的train_test_split实用程序从训练数据集创建训练和验证集。我们将使用验证数据集使用自定义度量函数来评估和调整我们的模型。

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=.2)

3.预处理输入数据

下一步是对模型的输入数据进行预处理。与清理数据一起,标记化是任何 NLP 管道中的第一步。Tokenizer 将非结构化字符串转换为适合机器学习的数字数据结构。由于我们将使用预训练的 DistilBert,因此我们将使用 DistilBert 分词器进行分词。

from transformers import DistilBertTokenizerFast
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

4.数据集对象

torch.utils.data.Dataset 是 Pytorch 中表示数据集的抽象类。我们的自定义数据集将继承Dataset并覆盖以下方法:

__len__以便len(dataset)返回数据集的大小。
__getitem__支持索引,这样dataset[i]可以用来获取第i个样本。

import torch

class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_dataset = IMDbDataset(train_encodings, train_labels)
val_dataset = IMDbDataset(val_encodings, val_labels)
test_dataset = IMDbDataset(test_encodings, test_labels)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/45755.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JAVA学习-java基础讲义01

java基础讲义一 java语言1.1 java语言介绍1.1.1 什么是java1.1.2 java之父1.1.3 java语言发展史1.2 java语言的特点二 java环境搭建相关2.1 Java环境介绍2.1.1 虚拟机介绍2.1.2 JVM介绍2.2 Java跨平台2.2.1 跨平台2.2.2 跨平台原理2.3 java运行过程2.4 JDK、JRE、JVM关系图2.4.…

JaVers:自动化数据审计

在开发应用程序时,我们经常需要存储有关数据如何随时间变化的信息。此信息可用于更轻松地调试应用程序并满足设计要求。在本文中,我们将讨论 JaVers 工具,该工具允许您通过记录数据库实体状态的更改来自动执行此过程。 Javers如何工作&#x…

RT-thread lts-v3.1.x版本,GD32F450以太网,上电之后有一定概率ping不通问题处理。

先给结论 官方驱动没有按照GD32F4XX手册要求,等待ENET_DMA_CTL第20bit清0后再写 synopsys_emac.c 文件,void EMAC_FlushTransmitFIFO(struct rt_synopsys_eth * ETHERNET_MAC)函数,增加一句判断即可解决。 /*** Clears the ETHERNET transm…

Kotlin高仿微信-第4篇-主页-消息

Kotlin高仿微信-项目实践58篇详细讲解了各个功能点,包括:注册、登录、主页、单聊(文本、表情、语音、图片、小视频、视频通话、语音通话、红包、转账)、群聊、个人信息、朋友圈、支付服务、扫一扫、搜索好友、添加好友、开通VIP等众多功能。 Kotlin高仿…

Android Studio / IDEA 调试金手指:live template自动打印方法名以及所有变量

ctrl alt s 搜设置,template,结果是在 live template 区域设置代码模板的,不知这功能和直播有何关系,live stream? live template 就是自动完成一段代码。比如输入 fori,然后ctrl空格补全循环体&#xf…

Apache-DButils以及Druid(德鲁伊) 多表连接查询的解决方案:两种

Apache-DButils以及Druid(德鲁伊) 多表连接查询的问题 每博一文案 张爱玲说,于千万人之中,遇到你所要遇到的人,于千万年之中,时间的无涯的荒野里,没有 早一步,也没有晚一步,刚巧赶上了。 人生海…

iPhone开机密码什么时候会用到?忘记了怎么办?

iPhone的开机密码也是屏幕解锁密码,它的作用还是很重要的。一般用在: 解锁手机手机重启后解锁手机系统更新后第一次解锁手机手机连接电脑需要信任设备Face ID或指纹解锁失败三次后连接Apple Watch后第一次解锁手机 虽然我们现在经常使用其他的解锁方式&…

马斯克特斯拉内部邮件火了:痛恨开会,少说黑话

金磊 羿阁 发自 凹非寺量子位 | 公众号 QbitAI马斯克给员工的一封内部邮件火了。鼓励员工拒绝开会、公司规定不合理可以不遵守……俨然一个为员工着想的好老板。一开始人们还奇怪马斯克的画风怎么变这么快,后来才发现原来这是他6年前写的。对象也不是推特员工&#…

BCN点击试剂:1516551-46-4,BCN-succinimidylester,BCN NHS

●中文名:丙烷环辛炔-活性酯,BCN-琥珀酰亚胺酯 ●英文名:BCN-NHS, BCN-NHS 酯,BCN-活性酯,BCN-succinimidylester 【产品理化指标】: CAS号: 1516551-46-4 分子式:C15H17…

58 - 类模板的概念和意义

---- 整理自狄泰软件唐佐林老师课程 1. 思考 在C中是否能够将泛型的思想应用于类? 1.1 类模板 一些类主要用于存储和组织数据元素类中数据组织的方式和数据元素的具体类型无关 如:数组类、链表类、Stack类、Queue类,等 C中模板的思想应用于…

【LeetCode】No.103. Binary Tree Zigzag Level Order Traversal -- Java Version

题目链接:https://leetcode.com/problems/binary-tree-zigzag-level-order-traversal/ 1. 题目介绍(Binary Tree Zigzag Level Order Traversal) Given the root of a binary tree, return the zigzag level order traversal of its nodes’…

【网络编程】第二章 网络套接字(socket+UDP协议程序)

🏆个人主页:企鹅不叫的博客 ​ 🌈专栏 C语言初阶和进阶C项目Leetcode刷题初阶数据结构与算法C初阶和进阶《深入理解计算机操作系统》《高质量C/C编程》Linux ⭐️ 博主码云gitee链接:代码仓库地址 ⚡若有帮助可以【关注点赞收藏】…

html实训大作业《基于HTML+CSS+JavaScript红色文化传媒网站(20页)》

🎉精彩专栏推荐 💭文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 💂 作者主页: 【主页——🚀获取更多优质源码】 🎓 web前端期末大作业: 【📚毕设项目精品实战案例 (10…

群晖修改默认端口为80、443

写之前哔哔两句 我这个人是个有强迫症的人,本来群晖用的好好的,然后觉得为什么还要输入5000、5001端口呢? 然后我就尝试着去修改端口,想修改为40、443的时候,结果提示端口被保留,这我哪能忍,ss…

springboot整合canal

该篇博客是基于前两篇的基础上来实现的,如果没有看过可以看一下前面的步骤 使用docker搭建 MYSQL主从_极速小乌龟的博客-CSDN博客docker 上面搭建mysql主从服务器https://blog.csdn.net/qq_35771266/article/details/128101019?spm1001.2014.3001.5501 ShardingS…

Matlab optimtool优化阵列天线的幅相激励

摘要: 阵列天线的激励幅度和相位控制着其方向图形状。例如锥削分布的幅度可实现低副瓣、递变相位激励可改变波束指向,采用幅相综合控制则可实现平顶波束、余割平方等波束赋形。下面介绍利用Matlab optimtool优化阵列天线的幅相激励实现上述需求。 推文…

超市结算系统|Springboot+Vue通用超市结算收银系统

作者主页:编程千纸鹤 作者简介:Java、前端、Pythone开发多年,做过高程,项目经理,架构师 主要内容:Java项目开发、毕业设计开发、面试技术整理、最新技术分享 收藏点赞不迷路 关注作者有好处 项目编号&…

JavaScript基础语法(变量)

JavaScript基础语法(变量) 学习路线:JavaScript基础语法(输出语句)->JavaScript基础语法(变量)->JavaScript基础语法(数据类型)->JavaScript基础语法&#xff…

(一)DepthAI-python相关接口:OAK Device

消息快播:OpenCV众筹了一款ROS2机器人rae,开源、功能强、上手简单。来瞅瞅~ 编辑:OAK中国 首发:oakchina.cn 喜欢的话,请多多👍⭐️✍ 内容可能会不定期更新,官网内容都是最新的,请查…

SuperMap iPortal 与独立代理服务的 session 共享通过redis配置实现

作者:yx 文章目录前言一、支持的Tomcat系列二、使用步骤1.将 /lib 中所有的 jar 拷贝到 tomcat/lib 目录2.给 tomcat 添加一个系统环境变量 "catalina.base",变量取值为 tomcat 的根目录3、修改 redis 的相关配置4、在 【SuperMap iPortal / i…