Transformers 安装与基本使用

news2024/11/25 22:51:03

文章目录

  • Github
  • 文档
  • 推荐文章
  • 简介
  • 安装
  • 官方示例
  • 中文情感分析模型
    • 分词器 Tokenizer
    • 填充 Padding
    • 截断 Truncation
  • google-t5/t5-small
  • 使用脚本进行训练
    • Pytorch
  • 机器翻译
    • 数据集下载
    • 数据集格式转换

Github

  • https://github.com/huggingface/transformers

文档

  • https://huggingface.co/docs/transformers/index
  • https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hans.md

推荐文章

  • http://jalammar.github.io/illustrated-transformer/

简介

Transformers是一种基于注意力机制(Attention Mechanism)的神经网络模型,广泛应用于自然语言处理(Natural Language Processing)任务中,如机器翻译、文本生成和文本分类等。

传统的序列模型(如循环神经网络)在处理长距离依赖时可能遇到困难,而Transformers通过引入注意力机制来解决这个问题。注意力机制使得模型能够在序列中对不同位置的信息进行加权关注,从而捕捉到全局的上下文信息。

在Transformers中,输入序列首先被分别编码为查询(Query)、键(Key)和值(Value)向量。通过计算查询与键的相似度,得到注意力分数,再将注意力分数与值相乘并加权求和,即可得到最终的上下文表示。这种自注意力机制允许模型在编码器和解码器中自由交换信息,从而更好地处理长距离依赖关系。

Transformer模型的核心组件是多层的自注意力机制和前馈神经网络。它的架构被广泛应用于许多重要的NLP任务,其中最著名的是BERT(Bidirectional Encoder Representations from Transformers),它在多项NLP任务上取得了突破性的性能。

除了NLP领域,Transformers模型也被应用于计算机视觉和其他领域,用于处理序列建模和生成任务。它已经成为深度学习中非常重要和有影响力的模型架构之一。

安装

pip install transformers
# PyTorch(推荐)
pip install 'transformers[torch]'
# TensorFlow 2.0
pip install 'transformers[tf-cpu]'
  • M1 / ARM 用户在安装 TensorFLow 2.0 之前,需要安装以下内容
brew install cmake
brew install pkg-config
  • 验证是否安装成功
python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"

在这里插入图片描述

注意: 以上验证操作需要“连网”,否则因无法下载文件而出现报错。

官方示例

from transformers import pipeline

# 使用情绪分析流水线
classifier = pipeline('sentiment-analysis')
classifier('We are very happy to introduce pipeline to the transformers repository.')
  • 输出结果
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]

在这里插入图片描述

中文情感分析模型

  • https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

中文的RoBERTa-wwm-ext-base在数个情感分析任务微调后的版本

git clone https://huggingface.co/IDEA-CCNL/Erlangshen-Roberta-110M-Sentiment
from transformers import BertForSequenceClassification, BertTokenizer
import torch

# 加载预训练模型和分词器
tokenizer = BertTokenizer.from_pretrained('Erlangshen-Roberta-110M-Sentiment')
model = BertForSequenceClassification.from_pretrained('Erlangshen-Roberta-110M-Sentiment')

# 待分类的文本
text = '今天心情不好'

# 对文本进行编码并转换为张量,然后输入模型中
input_ids = torch.tensor([tokenizer.encode(text)])
output = model(input_ids)

# 对输出的logits进行softmax处理,得到分类概率
probabilities = torch.nn.functional.softmax(output.logits, dim=-1)

# 打印输出分类概率
print(probabilities)
  • 输出
tensor([[0.9551, 0.0449]], grad_fn=<SoftmaxBackward0>)
from transformers import pipeline

# 使用pipeline函数加载预训练的情感分析模型,并进行情感分析
classifier = pipeline("sentiment-analysis", model="Erlangshen-Roberta-110M-Sentiment")

# 对输入文本进行情感分析
result = classifier("今天心情很好")

# 打印输出结果
print(result)
  • 输出
[{'label': 'Positive', 'score': 0.9374911785125732}]
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline

# 加载预训练模型和分词器
model_path = "Erlangshen-Roberta-110M-Sentiment"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 创建情感分析的pipeline
classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer)

# 对文本进行情感分析
result = classifier("今天心情很好")
print(result)
  • 输出
[{'label': 'Positive', 'score': 0.9374911785125732}]

分词器 Tokenizer

from transformers import AutoTokenizer

# 加载预训练模型的分词器
tokenizer = AutoTokenizer.from_pretrained("Erlangshen-Roberta-110M-Sentiment")

# 对文本进行编码
encoded_input = tokenizer("今天心情很好")
print(encoded_input)

# 解码已编码的输入,还原原始文本
decoded_input = tokenizer.decode(encoded_input["input_ids"])
print(decoded_input)
  • 输出
{'input_ids': [101, 791, 1921, 2552, 2658, 2523, 1962, 102],
'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0],
'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}
[CLS] 今 天 心 情 很 好 [SEP]

填充 Padding

模型的输入需要具有统一的形状(shape)。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Erlangshen-Roberta-110M-Sentiment")

batch_sentences = ["今天天气真好", "今天天气真好,适合出游"]
encoded_inputs = tokenizer(batch_sentences, padding=True)
print(encoded_inputs)
  • 输出
{'input_ids': [
[101, 791, 1921, 1921, 3698, 4696, 1962, 102, 0, 0, 0, 0, 0], 
[101, 791, 1921, 1921, 3698, 4696, 1962, 8024, 6844, 1394, 1139, 3952, 102]], 
'token_type_ids': [
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
], 
'attention_mask': [
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], 
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
]}

截断 Truncation

句子模型无法处理,可以将句子进行截断。

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Erlangshen-Roberta-110M-Sentiment")

batch_sentences = ["今天天气真好", "今天天气真好,适合出游"]
# return_tensors pt(PyTorch模型) tf(TensorFlow模型)
encoded_inputs = tokenizer(batch_sentences, padding=True, truncation=True, return_tensors="pt")
print(encoded_inputs)
  • 输出
{'input_ids': tensor([
   [ 101,  791, 1921, 1921, 3698, 4696, 1962,  102,    0,    0,   0,    0,    0],
   [ 101,  791, 1921, 1921, 3698, 4696, 1962, 8024, 6844, 1394, 1139, 3952, 102]
  ]), 
 'token_type_ids': tensor([
   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
  ]), 
  'attention_mask': tensor([
   [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
   [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
  ])
}

google-t5/t5-small

  • https://huggingface.co/google-t5/t5-small

在这里插入图片描述

Google的T5(Text-To-Text Transfer Transformer)是由Google Research开发的一种多功能的基于Transformer的模型。T5-small是T5模型的一个较小的变体,专为涉及自然语言理解和生成任务而设计。

  1. Transformer架构:与其它模型类似,T5-small采用了Transformer架构,该架构在各种自然语言处理(NLP)任务中表现出色。

  2. 多功能性:T5-small的设计理念是将所有的NLP任务都看作文本到文本的转换问题,使得模型可以通过简单地调整输入和输出来适应不同的任务。

  3. 预训练和微调:T5-small通常通过大规模的无监督预训练来学习通用的语言表示,然后通过有监督的微调来适应特定任务,如问答、摘要生成等。

  4. 应用广泛:由于其灵活性和性能,在各种NLP应用中都有广泛的应用,包括机器翻译、文本生成、情感分析等。

  • 下载 google-t5/t5-small 模型
# 模型大小 4.49G
git clone https://huggingface.co/google-t5/t5-small
  • 安装依赖库
pip install 'transformers[torch]'
pip install sentencepiece
  • 文本生成示例
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Step 1: 加载预训练的T5 tokenizer和模型
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

while True:
    # Step 2: 接收用户输入
    input_text = input("请输入要生成摘要的文本 (输入 'exit' 结束): ")
    
    if input_text.lower() == 'exit':
        print("程序结束。")
        break
    
    # 使用tokenizer对输入文本进行编码
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Step 3: 进行生成
    # 使用model.generate来生成文本
    output = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)

    # Step 4: 解码输出
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)

    # 打印输入和输出结果
    print("输入:", input_text)
    print("输出:", output_text)
    print("=" * 50)  # 分隔符,用来区分不同输入的输出结果

在这里插入图片描述

使用脚本进行训练

  • https://huggingface.co/docs/transformers/run_scripts

  • 从源代码安装 Transformers

git clone https://github.com/huggingface/transformers
cd transformers
pip install .
  • 将当前的 Transformers 克隆切换到特定版本
# 本地分支
git branch
# 远程分支
git branch -a
# 切换分支 v4.41.2,因为当前安装的版本是 v4.41.2
git checkout tags/v4.41.2
  • 安装依赖库
# 安装用于处理人类语言数据的工具集库
pip install nltk
# 安装用于计算ROUGE评估指标库
pip install rouge_score

Pytorch

示例脚本从 🤗 Datasets库下载并预处理数据集。然后,该脚本使用Trainer在支持摘要的架构上微调数据集。以下示例展示了如何在CNN/DailyMail数据集上微调T5-small。由于训练方式的原因,T5 模型需要额外的参数。此提示让 T5 知道这是一项摘要任务。

cd transformers/examples/pytorch/summarization
pip install -r requirements.txt
python run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

注意: 家用机上训练非常耗时,建议租用GPU服务器进行测试。

  • 数据缓存目录
# Linux/macOS
cd ~/.cache/huggingface
# Windows
C:\Users\{your_username}\.cache\huggingface
  • datasets
2.6G	cnn_dailymail
798M	downloads

机器翻译

数据集下载

  • https://huggingface.co/datasets/wmt/wmt16

在这里插入图片描述

数据集格式转换

pip install pandas
import pandas as pd
import jsonlines

# 输入和输出文件路径
input_parquet_file = './input_file.parquet'
output_jsonl_file = './output_file.jsonl'

# 加载 Parquet 文件
df = pd.read_parquet(input_parquet_file)

# 将数据写入 JSONLines 文件
with jsonlines.open(output_jsonl_file, 'w') as writer:
    for index, row in df.iterrows():
        json_record = {
            "source_text": row['source_column'],  # 替换成实际的源语言列名
            "target_text": row['target_column']   # 替换成实际的目标语言列名
        }
        writer.write(json_record)
  • train.jsonl
{ "cs": "Následný postup na základě usnesení Parlamentu: viz zápis", "en": "Action taken on Parliament's resolutions: see Minutes" }
  • validation.jsonl
{ "en": "UN Chief Says There Is No Military Solution in Syria", "ro": "Șeful ONU declară că nu există soluții militare în Siria" }
cd examples/pytorch/translation
pip install -r requirements.txt
python run_translation.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --source_lang en \
    --target_lang ro \
    --source_prefix "translate English to Romanian: " \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --train_file ./train.jsonl \
    --validation_file ./validation.jsonl \
    --output_dir /tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

注意: 家用机上训练非常耗时,建议租用GPU服务器进行测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1872152.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

上海亚商投顾:沪指震荡下跌 多只银行股创年内新高

上海亚商投顾前言&#xff1a;无惧大盘涨跌&#xff0c;解密龙虎榜资金&#xff0c;跟踪一线游资和机构资金动向&#xff0c;识别短期热点和强势个股。 一.市场情绪 三大指数昨日震荡调整&#xff0c;沪指尾盘跌近1%&#xff0c;深成指、创业板指均跌超1.5%。 板块概念方面&a…

六西格玛项目实战:数据驱动,手机PCM率直线下降

在当前智能手机市场日益竞争激烈的背景下&#xff0c;消费者对手机质量的要求达到了前所未有的高度。PCM&#xff08;可能指生产过程中的某种不良率或缺陷率&#xff09;作为影响手机质量的关键因素&#xff0c;直接关联到消费者满意度和品牌形象。为了应对这一挑战&#xff0c…

ICCV2023知识蒸馏相关论文速览

Paper1 Spatial Self-Distillation for Object Detection with Inaccurate Bounding Boxes 摘要原文: Object detection via inaccurate bounding box supervision has boosted a broad interest due to the expensive high-quality annotation data or the occasional inevit…

Windows CMD:快速入门

文字目录 一、概述二、常用命令2.1 切换盘符2.2 查看当前盘符下的所有文件2.3 进入单级目录2.4 返回上一级目录2.5 进入多级目录2.6 回到盘符目录2.7 清屏2.8 退出 三、练习 一、概述 CMD 是 Command的缩写&#xff0c;即命令的意思&#xff0c;它的作用是利用命令的方式来操作…

Linux-笔记 使用SCP命令传输文件报错 :IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY!

前言 使用scp命令向开发板传输文件发生报错&#xff0c;报错见下图; 解决 rm -rf /home/<用户名>/.ssh/known_hosts 此方法同样适用于使用ssh命令连接开发板报错的情况。 参考 https://blog.csdn.net/westsource/article/details/6636096

大模型赋能全链路可观测性:运维效能的革新之旅

目录 全链路可观测工程与大模型结合---提升运维效能 可观测性&#xff08;Observability&#xff09;在IT系统中的应用及其重要性 统一建设可观测数据 统一建设可观测数据的策略与流程 全链路的构成和监控形态 云上的全链路可视方案 为什么一定是Copilot 大模型的Copilo…

基于iview.viewUI实现行合并(无限制/有限制合并)【已验证可正常运行】

1.基于iview.viewUI实现行合并&#xff08;列之间没有所属对应关系&#xff0c;正常合并&#xff09; 注&#xff1a;以下代码来自于GPT4o&#xff1a;国内直连GPT4o 只需要修改以下要合并的列字段&#xff0c;就可以方便使用啦 mergeFields: [majorNo, devNam, overhaulAdvic…

【EXCEL技巧】Excel如何将数字前面的0去掉

Excel文件中经常会遇到数据是0001345这种&#xff0c;那么&#xff0c;如何将数字前面的0去掉呢&#xff1f;今天和大家分享方法。 首先&#xff0c;选中一列空的单元格&#xff0c;然后在单元格中输入公式TEXT(D3,0)&#xff0c;这里的D3指的是前面带有0的数据的位置 回车之后…

Linux基础- 使用 Apache 服务部署静态网站

目录 零. 简介 一. linux安装Apache 二. 创建网页 三. window访问 修改了一下默认端口 到 8080 零. 简介 Apache 是世界使用排名第一的 Web 服务器软件。 它具有以下一些显著特点和优势&#xff1a; 开源免费&#xff1a;可以免费使用和修改&#xff0c;拥有庞大的社区支…

小程序备案小程序认证双系统

​打造安全合规的线上平台 &#x1f50d; 一、引言&#xff1a;为何需要小程序备案与认证&#xff1f; 在数字化快速发展的今天&#xff0c;小程序已成为企业、个人展示自身、提供服务的重要窗口。然而&#xff0c;随着小程序数量的快速增长&#xff0c;安全、合规等问题也逐渐…

jenkins设置定时构建语法

一、设置定时 定时构建的语法是*** * * * ***。 第一个*表示分钟&#xff0c;取值范围是0~59。例如&#xff0c;5 * * * *表示每个小时的第5分钟会构建一次&#xff1b;H/15 * * * 或/15 * * * 表示每隔15分钟构建一次&#xff1b; 第2个表示小时&#xff0c;取值范围是0~23。…

深度解析RocketMq源码-IndexFile

1.绪论 在工作中&#xff0c;我们经常需要根据msgKey查询到某条日志。但是&#xff0c;通过前面对commitLog分析&#xff0c;producer将消息推送到broker过后&#xff0c;其实broker是直接消息到达broker的先后顺序写入到commitLog中的。我们如果想根据msgKey检索一条消息无疑…

Embedding 、词嵌入、向量模型说的是一回事么?AI是如何理解世界?AI人不能不看的Embedding白话科普!

在AI理解世界的过程中&#xff0c;向量模型扮演着一个至关重要的角色&#xff0c;甚至可以说它是AI大模型用以构建和理解复杂数据的基础&#xff0c;也是对不同形态数据的一种标准化的“浓缩”。它能够将语言、图像、声音等多样化的信息&#xff0c;转化为一种通用的、数学化的…

知乎正通过乱码来干扰必应/谷歌等爬虫,从而限制中文数据集被用于AI训练

有用户反馈称使用微软必应搜索和谷歌搜索发现存在不少知乎乱码内容&#xff0c;即搜索结果里知乎内容的标题和正文内容都可能是乱码的&#xff0c;但抓取的正文前面一些段落内容可以正常查看。考虑到此前知乎已经屏蔽除百度和搜狗以外的所有搜索引擎爬虫 (蜘蛛 / 机器人)&#…

《数字图像处理与机器视觉》案例二(基于边缘检测和数学形态学焊缝图像处理)

一、前言 焊缝是评价焊接质量的重要标志&#xff0c;人工检测方法存在检测标准不统一&#xff0c;检测精度低&#xff0c;焊缝视觉检测技术作为一种重要的质量检测方法&#xff0c;正逐渐在各行各业中崭露头角。把焊缝准确的从焊接工件中准确分割出来是焊缝评价的关键一步&…

使用模板方法设计模式封装 socket 套接字并实现Tcp服务器和客户端 简单工厂模式设计

文章目录 使用模板方法设计模式封装套接字使用封装后的套接字实现Tcp服务器和客户端实现Tcp服务器实现Tcp客户端 工厂模式 使用模板方法设计模式封装套接字 可以使用模块方法设计模式来设计套接字 socket 的封装 模板方法&#xff08;Template Method&#xff09;设计模式是一…

百度ueditor如何修改图片的保存位置

背景 编辑器的保存图片是设置有默认规则的&#xff0c;但是服务器上一般会把图片路径设置为软连接&#xff0c;所以我就需要更改编辑器保存图片的路径&#xff0c;要不然&#xff0c;每次有新的部署&#xff0c;上一次上传的图片就会失效。先来看看编辑器默认的保存路径吧&…

目标检测算法之RT-DETR

RT-DETR算法理解 BackgroundModel ArchitectureEfficient Hybrid EncoderUncertainty-minimal Query Selection 总结 Background Real-time Detection Transformer&#xff08;RT-DETR&#xff09;是一个基于tranformer的实时推理目标检测模型。RT-DETR是2023年百度发布的一个…

七天速通javaSE:第五天 数组进阶

文章目录 前言一、二维数组二、Arrays类1.toString打印数组内各元素1.1 示例1.2 自己实现内部逻辑 2. sort升序排列3. fill数组填充&#xff08;重新赋值&#xff09;4.equals比较数组元素是否相等 三、冒泡排序 前言 本文将学习二维数组、arrays类以及冒泡排序 一、二维数组 …

重生奇迹MU新手攻略:如何一步步往大佬发展

装备强化攻略&#xff1a; 提纯装备&#xff1a;通过提纯装备可以提升基础属性&#xff0c;选择合适的装备进行提纯可以获得更好的效果。 镶嵌宝石&#xff1a;使用宝石进行装备镶嵌可以增加装备的属性&#xff0c;根据需要选择适合的宝石进行镶嵌。 洗练装备&#xff1a;通…