【huggingface系列学习】Using Transformers

news2025/1/7 18:45:05

文章目录

  • 前言
  • Using Transformers
    • 使用tokenizer预处理
      • Tokenizer详解
      • Loading and saving
        • 加载
        • 保存
      • Encoding
      • Decoding
    • Model
      • 创建一个Transformer
      • 不同的加载方法
      • 模型保存
      • 使用模型进行推理

前言

  • 因实验中遇到很多 huggingface-transformers 模型和操作,因此打算随着 course 从头理一下
  • 这个系列将会持续更新
  • 后续应该也会学习一下fairseq框架

Using Transformers

我们以一个完整的样例开始,看看在处理的过程中到底发生了什么

from transformers import pipeline

classifier = pipeline("sentiment-analysis")
classifier(
    [
        "I've been waiting for a HuggingFace course my whole life.",
        "I hate this so much!",
    ]
)

[{‘label’: ‘POSITIVE’, ‘score’: 0.9598047137260437},
{‘label’: ‘NEGATIVE’, ‘score’: 0.9994558095932007}]

这个pipeline包括三个部分:预处理,将输入输入模型中,后处理

在这里插入图片描述

使用tokenizer预处理

和其他模型一样,transformer不能直接处理原始文本,我们首先用tokenizer将文本转换成模型可以理解的 numbers。Tokenizer 有以下几个任务

  • 将输入分成words, subwords 或者 symbols 等 token
  • 将每个 token 映射成一个数字
  • 添加额外的可能对模型有用的输入

我们使用预训练的Tokenizer,通过 AutoTokenizer class 和其 from_pretrained() 方法来加载

from transformers import AutoTokenizer

checkpoint = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

现在,我们可以向Tokenizer 中输入文本

raw_inputs = [
    "I've been waiting for a HuggingFace course my whole life.",
    "I hate this so much!",
]
inputs = tokenizer(raw_inputs, padding=True, truncation=True, return_tensors="pt")
print(inputs)

>>{
    'input_ids': tensor([
        [  101,  1045,  1005,  2310,  2042,  3403,  2005,  1037, 17662, 12172, 2607,  2026,  2878,  2166,  1012,   102],
        [  101,  1045,  5223,  2023,  2061,  2172,   999,   102,     0,     0,     0,     0,     0,     0,     0,     0]
    ]), 
    'attention_mask': tensor([
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
    ])
}

我们可以输入一个句子或者一个句子列表,同时指定想要得到的 tensor 的类型

Transformers model 只接受 tensor 作为输入

输出的结果是包含两个 key 的字典:

  • input_ids
  • attention_mask

Tokenizer详解

上面说过,Tokenizer的功能就是将原始文本转换成模型可以理解的形式(numbers)。

分离文本的方式有很多,比如python 中的 .split() 函数,按照空格来将文本分离成 words。我们还可以用标点符号来分隔,使用这种tokenizer,最后会得到一个很大的“词典”,a vocabulary is defined by the total number of independent tokens that we have in our corpus。每个词都会被分配一个 id(从0开始),模型利用 id 来区分词。

不同分词方式详见:Tokenizers - Hugging Face Course

Loading and saving

基于两个方法: from_oretrained()save_pretrained()。这些方法会保存 tokenizer 使用的算法(类似模型结构)和使用的词典(类似模型权重)

加载

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") # 根据checkpoint名字自动找到对应的class

#还可以直接加载特定的 tokennizer
tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

保存

tokenizer.save_pretrained("directory_on_my_computer")

Encoding

我们来看看 input_ids 是如何生成的(encode),encode分成两步:

  • tokenization(split text into tokens)

    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    sequence = "Using a Transformer network is simple"
    tokens = tokenizer.tokenize(sequence)
    
  • 转换到 input IDs:当我们用.from_pretrained()实例化一个 tokenizer 时,会下载一个词典。我们通过词典来完成映射

    ids = tokenizer.convert_tokens_to_ids(tokens)
    

Decoding

Decoding做的是,当我们提供 ids 时(其实就是词汇表中token的索引),能得到ids对应的token。这时我们可以使用

decoded_string = tokenizer.decode([7993, 170, 11303, 1200, 2443, 1110, 3014])
print(decoded_string)
  • decode()函数不仅将索引恢复成token,还能将属于同一个单词的 token 组合在一起,生成一个可读的句子

Model

创建一个Transformer

我们以BERT为例,实例化BERT的第一件事就是加载一个 configuration 对象

from transformers import BertConfig, BertModel

# Building the config
config = BertConfig()

# Building the model from the config
model = BertModel(config)

不同的加载方法

上面展示的是模型随机初始化的加载方式,同样,我们可以加载预训练模型

from transformers import BertModel

model = BertModel.from_pretrained("bert-base-cased")

权重会被下载并保存到cache(默认路径是*~/.cache/huggingface/transformers*),通过设定HF_HOME环境变量可以定制cache folder

模型保存

model.save_pretrained("directory_on_my_computer")

这会保存两个文件:

  • config.json:包括构建模型结构必要的属性,还包括一些 metadata(上次保存使用的transformer版本等)

  • pytorch_model.bin:包括所有的模型权重

  • 这两个文件是相辅相成的,一个可以知道模型架构,一个可以提供模型参数

使用模型进行推理

import torch
sequences = ["Hello!", "Cool.", "Nice!"]
encoded_sequences = [
    [101, 7592, 999, 102],
    [101, 4658, 1012, 102],
    [101, 3835, 999, 102],
]
model_inputs = torch.tensor(encoded_sequences)
output = model(model_inputs)
  • model可以接受很多不同的参数,但是只有 input_ids 是必须的

Tensors只接受矩形的数据,如果每一条数据的长度不同,转换成tensor会报错

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/336821.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

剖析字节案例,火山引擎 A/B 测试 DataTester 如何“嵌入”技术研发流程

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 日前,在 WOT 全球创新技术大会上,火山引擎 DataTester 技术负责人韩云飞做了关于字节跳动 A/B 测试平台的分享。DataTester 是字节跳动内部应…

Roboguide与TIA V16通讯

软件需求:1. roboguide;2. TIA V16;3. KEPServer; 在之前的文章中介绍过KEPServer与TIA V16的通讯,此处不再介绍。接下来,介绍roboguide与KEPServer的仿真通讯。 创建一个roboguide项目。选择【外部设备】➡【添加外部设备】 选择【OPC Server】➡【OK】 OPC服务器名称命…

linux安装并配置nginx

菜鸟教程 一 . Nginx安装和部署 1.输入指令,下载相关的依赖包 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-develYUM(Yellow dog Updater, Modified)为多个Linux发行版的前端软件包管理器 -y 是参数,默认不要确认, rp…

对话 ChatGPT:现象级 AI 应用,将如何阐释「研发效能管理」?

ChatGPT 已然是 2023 开年至今,互联网上最热的话题没有之一。从去年的 AI 图片生成,到 ChatGPT,再到现在各种基于大模型的应用如雨后春笋般出现……在人们探讨技术无限可能的同时,另一个更深刻的命题也不可回避地浮现出来&#xf…

汽摩仪表快检盒

不怕失业 ​ ​最近大火的ChatGPT说要取代程序员,老婆子惊慌失措,跟着糟老头憋屈,咸鱼想靠软件翻身,这下白瞎了。 ​温州寄来了汽车燃油预热控制板,绍兴又寄来了发动机仪表,昆山的尾门在路上,都…

如何成为java架构师?2023版Java架构师学习路线总结完成,真实系统有效,一切尽在其中

导读 从初级Java工程师成长为Java架构师,你需要走很长的路,很多有计划的人在学习之初就在做准备。你知道Java架构师学习路线该怎么走吗?成为一个优秀的Java架构师究竟需要学什么?接下来就跟小编一起揭晓答案。 架构师是一个充满挑战的职业&#xff0…

Python自定义模块

到目前为止,读者已经掌握了导入 Python 标准库并使用其成员(主要是函数)的方法,接下来要解决的问题是,怎样自定义一个模块呢?Python 模块就是 Python 程序,换句话说,只要是 Python 程…

Swagger自动生成api文档

Swagger自动生成api文档Swagger是什么Swagger底层原理使用方式1修改pom文件2启动类中加入注解EnableSwagger23加入SpringFoxConfig.java4加入WebMvcConfig.java文件5 给Web 服务的接口加注解访问可视化页面Swagger是什么 Swagger 是一个规范和完整的框架,用于生成、…

C经典小游戏之扫雷

编译环境:VS022 目录 1.算法思路 2.代码模块 2.1 game.h 2.2 game.cpp 2.3 test.cpp 3.重点分析 4.金句省身 1.算法思路 主要采用二维数组进行实现,设置两个二维数组,一个打印结果,即为游戏界面显示的效果,一个用…

值类型和引用类型

一、值类型和引用类型示例: 值类型:基本数据类型系列,如:int,float,bool,string,数组和结构体等。 引用类型:如:指针,slice切片,map&a…

windows wireshark抓到未加入组的组播消息

现象 在Windows上开启wireshark,抓到了大量地址为239.255.255.251的组播包。 同时,根据组播相关命令,调用netsh interface ipv4 show joins,显示当前并没加入 239.255.255.251 组播组。 解决 根据IGMP Snooping,I…

《机器学习》学习笔记

第 2 章 模型评估与选择 2.1 经验误差与过拟合 精度:精度1-错误率。如果在 mmm 个样本中有 aaa 个样本分类错误,则错误率 Ea/mEa/mEa/m,精度 1−a/m1-a/m1−a/m。误差:一般我们把学习器的实际预测输出与样本的真实输出之间的差…

MySQL---单表查询、多表查询

一、单表查询 素材: 表名:worker-- 表中字段均为中文,比如 部门号 工资 职工号 参加工作 等 CREATE TABLE worker ( 部门号 int(11) NOT NULL, 职工号 int(11) NOT NULL, 工作时间 date NOT NULL, 工资 float(8,2) NOT NULL, 政治面貌 v…

STM32驱动RC522

STM32驱动RC522开发环境:STM32CUBEMXKeil5使用平台:STM32F401CCU6该内容由网上内容改编,若不合适,请联系删除。一、使用STM32CUBEMX配置SPI二、驱动部分三、主函数调用四、移值攻略开发环境:STM32CUBEMXKeil5 使用平台&#xff1…

力扣:珠玑妙算(详解)

前言:内容包括四大模块:题目,代码实现,大致思路,代码解读 题目: 珠玑妙算游戏(the game of master mind)的玩法如下。 计算机有4个槽,每个槽放一个球,颜色可…

电力系统网架规划MATLAB程序分享

网架数据展示:完整程序:close all;clear all;clc;warning off; % 去除警告 tic; % tic用来保存当前时间,而后使用toc来记录程序完成时间%% 基本参数T12; % 典型日 8-19h % 8-19h 负荷各时段负荷总量total_P_LOAD[828,1001,1105,1105,994,1105…

STM32CubeMX+SPI+FATFS读写SD卡

一、软件硬件说明软件:STM32CubeMX V6.6.1 /KEIL5 V5.29硬件:正点原子mini开发板,SD卡,通过SPI方式驱动SD卡,用的是SPI1接口以上内容来源于正点原子mini开发板手册,SD卡的详细介绍也可以去查阅这个手册。二、STM32Cube…

Ethercat系列(3)TWCat3下抓包实例分析

简介研究Ethercat协议,必须知道数据包格式,以及其真实含义。以一个真实的数据包来学习是最有效的。Twcat3下用wireshark抓包,需要设置一下混杂模式,否则不能直接抓到Ethercat数据包。Twcat抓包设置在正确加载驱动器配置文件后&…

可深度二次开发的智能插座 工业化物联网多场景的高定系统服务商

物联网时代,各类物联网需求越来越迫切。物联网设备呈现出爆发式增长。同时近年来国家不断出台相关的法规政策,为物联网行业发展创造机遇,三大运营商积极部署NB-IOT网络建设,建成90万NB-IoT基站。据统计2012-2022年期间&#xff0c…

缺省参数+函数重载+构造函数

目录 一、缺省参数 (一)缺省参数概念 (二)缺省参数分类 1. 全缺省参数 2. 半缺省参数(缺省部分参数) 3. 注意 二、函数重载 (一)基本概念 (二)举例 …