Hugging Face Transformer 的APIs应用实例

news2024/12/23 1:54:05

拥抱面变压器 API 简要摘要

一、说明 

        Hugging Face 的变压器库提供了广泛的 API,可用于处理各种 NLP 任务的预训练变压器模型。在本教程中,我们将探讨主要 API 并提供示例来帮助你了解它们的用法。

二、导入模型 

1. 分词器接口:

        分词器 API 用于预处理文本数据,并将其标记为转换器模型的输入特征。

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

text = "Hello, how are you?"
encoded_input = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
print(encoded_input)

2. 模型接口:

        模型 API 允许您为各种 NLP 任务加载预先训练的转换器模型。

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
outputs = model(**encoded_input)
print(outputs)

3. 管道接口:

        管道 API 简化了将预训练模型用于特定任务(如文本生成、问答、情绪分析等)的过程。

from transformers import pipeline
text_generation = pipeline("text-generation", model="gpt2")
generated_text = text_generation("Once upon a time")
print(generated_text)

三、训练过程

4. 培训师接口:

        训练器 API 可帮助您针对特定任务在自定义数据集上微调预训练的模型。

from transformers import Trainer, TrainingArguments

training_args = TrainingArguments("fine_tuned_model")

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)
trainer.train()

5. 数据集接口:

        数据集 API 允许您为 NLP 任务加载和预处理各种数据集。

from datasets import load_dataset
dataset = load_dataset("glue", "sst2")
print(dataset)

6. 评估接口:

        评估 API 可帮助您计算模型预测的评估指标。

from datasets import load_metric

metric = load_metric("accuracy")
predictions = model.predict(test_dataset)
accuracy = metric.compute(predictions=predictions["predictions"], references=predictions["labels"])
print(accuracy)

7. 使用自定义指标进行微调:

        您可以为培训师 API 定义自己的自定义指标。

from datasets import load_metric

metric = load_metric("accuracy")
predictions = model.predict(test_dataset)
accuracy = metric.compute(predictions=predictions["predictions"], references=predictions["labels"])
print(accuracy)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_custom_metrics,
)

trainer.train()

四、训练模型存取

8. 从检查点加载模型:

        您可以加载预先训练的模型检查点并恢复训练或将其用于推理。

model_checkpoint = "fine_tuned_model/checkpoint-1000"
loaded_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint)

9. 模型推理:

        加载模型后,您可以使用它对新数据进行推理。

new_text = "This is a new text to classify"
encoded_input = tokenizer(new_text, padding=True, truncation=True, return_tensors="pt")
outputs = loaded_model(**encoded_input)
print(outputs)

五、训练模型序列化

10. 模型序列化:

        可以使用 and 方法保存和加载预先训练的模型。.save_pretrained().from_pretrained()

loaded_model.save_pretrained("saved_model")
reloaded_model = AutoModelForSequenceClassification.from_pretrained("saved_model")

        通过探索这些 API,您可以利用拥抱面变压器的强大功能来预处理数据、微调模型、执行推理并评估各种 NLP 任务的性能。该图书馆的灵活性和广泛的文档使其成为NLP从业者和研究人员的宝贵资源  拉克什·拉杰普罗希特

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/932937.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

浅谈分布式共识算法概念与演进

分布式共识是指在分布式系统中,多个节点之间达成共识的过程。 分布式共识的意义在于确保分布式系统中各个节点之间的数据一致性。通过分布式共识算法,可以使得多个节点针对某个状态达成一致,从而保证系统中各个节点之间的数据一致性。这对于…

应知道的16个Python基础知识

列表推导式 # 列表推导式,用一行代码生成一个有规律的列表 # 列表推导式,用一行代码生成一个有规律的列表 import randomlist_comprehension =[i for i in range(10)] print(list_comprehension)list_comprehension2 =[(x,y)for x in range(4) for y in range(5,10)] print(…

手写Spring源码——实现一个简单的spring framework

这篇文章主要带大家实现一个简单的Spring框架,包含单例、多例bean的获取,依赖注入、懒加载等功能。 一、创建Java项目 首先,需要创建一个Java工程,名字就叫spring。 创建完成后,如下图,再依次创建三级包 二…

Linux系统编程系列之进程基础

一、什么是进程 关于进程的定义很多,这里讲一种比较直接的,进程就是程序中的代码和数据被加载到内存中运行的过程,就是程序的执行过程。进程是动态的,而程序是静态的。程序存储在硬盘里,进程只有在程序被执行后&#x…

生信分析Python实战练习 1 | 视频18

开源生信 Python教程 生信专用简明 Python 文字和视频教程 源码在:https://github.com/Tong-Chen/Bioinfo_course_python 目录 背景介绍 编程开篇为什么学习Python如何安装Python如何运行Python命令和脚本使用什么编辑器写Python脚本Python程序事例Python基本语法 数…

STM32 Cubemx配置串口收发

文章目录 前言注意事项Cubemx配置printf重定向修改工程属性修改源码 测试函数 前言 最近学到了串口收发,简单记录一下注意事项。 注意事项 Cubemx配置 以使用USART1为例。 USART1需配置成异步工作模式Asynchronous。 并且需要使能NVIC。 printf重定向 我偏向…

使用cgroup工具对服务器某些/全部用户进行计算资源限制

使用cgroup工具对服务器某些/全部用户进行计算资源限制 主要介绍,如何对指定/所有用户进行资源限定(这里主要介绍cpu和内存占用限制),防止某些用户大量占用服务器计算资源,影响和挤占他人正常使用服务器。 安装cgrou…

Transformer代码计算过程全解

条件设置 batch_size1 src_len 8 # 源句子的最大长度 根据这个进行padding的填充 tgt_len 7 # 目标输入句子的最大长度 根据这个进行padding的填充 d_model512 # embedding的维度 d_ff2048 # 全连接层的维度 h_head8 # Multi-Head Attention 的…

【C++】—— C++11之可变参数模板

前言: 在C语言中,我们谈论了有关可变参数的相关知识。在C11中引入了一个新特性---即可变参数模板。本期,我们将要介绍的就是有关可变参数模板的相关知识!!! 目录 序言 (一)可变参…

深度学习10:Attention 机制

目录 Attention 的本质是什么 Attention 的3大优点 Attention 的原理 Attention 的 N 种类型 Attention 的本质是什么 Attention(注意力)机制如果浅层的理解,跟他的名字非常匹配。他的核心逻辑就是「从关注全部到关注重点」。 Attention…

ServiceManager接收APP的跨进程Binder通信流程分析

现在一起来分析Server端接收(来自APP端)Binder数据的整个过程,还是以ServiceManager这个Server为例进行分析,这是一个至下而上的分析过程。 在分析之前先思考ServiceManager是什么?它其实是一个独立的进程,由init解析i…

windows11不允许安装winpcap4.1.3

问题:下载安装包后在安装时显示与电脑系统不兼容,不能安装。 原因:winpcap是一个用于Windows操作系统的网络抓包库,有一些安全漏洞,存在被黑客攻击的风险。Windows11为了加强系统安全而禁用了这个库,因此不…

java.8 - java -overrideoverload 重写和重载

重写(Override) 重写是子类对父类的允许访问的方法的实现过程进行重新编写, 返回值和形参都不能改变。即外壳不变,核心重写! 重写的好处在于子类可以根据需要,定义特定于自己的行为。 也就是说子类能够根据需要实现父类的方法。 重写方法不…

【GAMES202】Real-Time Environment Mapping1—实时环境光照1

一、Distance field soft shadows Inigo Quilez :: computer graphics, mathematics, shaders, fractals, demoscene and more (iquilezles.org) 在开始我们的实时环境光照之前,我们再说一种现在的实现实时软阴影的方式,也就是Distance field soft shado…

SpringBoot实现文件上传和下载笔记分享(提供Gitee源码)

前言:这边汇总了一下目前SpringBoot项目当中常见文件上传和下载的功能,一共三种常见的下载方式和一种上传方式,特此做一个笔记分享。 目录 一、pom依赖 二、yml配置文件 三、文件下载 3.1、使用Spring框架提供的下载方式 3.2、通过IOUti…

分布式 - 服务器Nginx:一小时入门系列之 return 指令

文章目录 1. return 指令语法2. return code URL 示例3. return code text 示例4. return URL 示例 1. return 指令语法 return指令用于立即停止当前请求的处理,并返回指定的HTTP状态码和响应头信息,它可以用于在Nginx中生成自定义错误页面,…

分布式事务-seata框架

文章目录 分布式事务0.学习目标1.分布式事务问题1.1.本地事务1.2.分布式事务1.3.演示分布式事务问题 2.理论基础2.1.CAP定理2.1.1.一致性2.1.2.可用性2.1.3.分区容错2.1.4.矛盾 2.2.BASE理论2.3.解决分布式事务的思路 3.初识Seata3.1.Seata的架构3.2.部署TC服务3.3.微服务集成S…

CAPL - Panel和TestModule结合实现测试项可选

目录 一、定义脚本编号和脚本组编号 1、测试组定义 2、测试脚本编号定义

【C++】初步认识模板

🏖️作者:malloc不出对象 ⛺专栏:C的学习之路 👦个人简介:一名双非本科院校大二在读的科班编程菜鸟,努力编程只为赶上各位大佬的步伐🙈🙈 目录 前言一、泛型编程二、函数模板2.1 函…

Java10(异常处理)

0.复习面向对象 1.异常的体系结构 异常:在Java语言中,将程序执行中发生的不正常情况.(开发中的语法错误和逻辑错误不是异常) 异常事件分两类(它们上一级为java.lang.Throwable): Error Java虚拟机无法解决的严重问…