Transformers 安装及 google-t5/t5-small 机器翻译示例

news2024/12/23 12:12:33

文章目录

  • Github
  • 文档
  • 推荐文章
  • 简介
  • 安装
  • 官方示例
  • google-t5/t5-small
  • 使用脚本进行训练
    • Pytorch
  • 机器翻译
    • 数据集下载
    • 数据集格式转换

Github

  • https://github.com/huggingface/transformers

文档

  • https://huggingface.co/docs/transformers/index
  • https://github.com/huggingface/transformers/blob/main/i18n/README_zh-hans.md

推荐文章

  • http://jalammar.github.io/illustrated-transformer/

简介

Transformers是一种基于注意力机制(Attention Mechanism)的神经网络模型,广泛应用于自然语言处理(Natural Language Processing)任务中,如机器翻译、文本生成和文本分类等。

传统的序列模型(如循环神经网络)在处理长距离依赖时可能遇到困难,而Transformers通过引入注意力机制来解决这个问题。注意力机制使得模型能够在序列中对不同位置的信息进行加权关注,从而捕捉到全局的上下文信息。

在Transformers中,输入序列首先被分别编码为查询(Query)、键(Key)和值(Value)向量。通过计算查询与键的相似度,得到注意力分数,再将注意力分数与值相乘并加权求和,即可得到最终的上下文表示。这种自注意力机制允许模型在编码器和解码器中自由交换信息,从而更好地处理长距离依赖关系。

Transformer模型的核心组件是多层的自注意力机制和前馈神经网络。它的架构被广泛应用于许多重要的NLP任务,其中最著名的是BERT(Bidirectional Encoder Representations from Transformers),它在多项NLP任务上取得了突破性的性能。

除了NLP领域,Transformers模型也被应用于计算机视觉和其他领域,用于处理序列建模和生成任务。它已经成为深度学习中非常重要和有影响力的模型架构之一。

安装

pip install transformers
# PyTorch(推荐)
pip install 'transformers[torch]'
# TensorFlow 2.0
pip install 'transformers[tf-cpu]'
  • M1 / ARM 用户在安装 TensorFLow 2.0 之前,需要安装以下内容
brew install cmake
brew install pkg-config
  • 验证是否安装成功
python -c "from transformers import pipeline; print(pipeline('sentiment-analysis')('we love you'))"

在这里插入图片描述

注意: 以上验证操作需要“连网”,否则因无法下载文件而出现报错。

官方示例

from transformers import pipeline

# 使用情绪分析流水线
classifier = pipeline('sentiment-analysis')
classifier('We are very happy to introduce pipeline to the transformers repository.')
  • 输出结果
[{'label': 'POSITIVE', 'score': 0.9996980428695679}]

在这里插入图片描述

google-t5/t5-small

  • https://huggingface.co/google-t5/t5-small

在这里插入图片描述

Google的T5(Text-To-Text Transfer Transformer)是由Google Research开发的一种多功能的基于Transformer的模型。T5-small是T5模型的一个较小的变体,专为涉及自然语言理解和生成任务而设计。

  1. Transformer架构:与其它模型类似,T5-small采用了Transformer架构,该架构在各种自然语言处理(NLP)任务中表现出色。

  2. 多功能性:T5-small的设计理念是将所有的NLP任务都看作文本到文本的转换问题,使得模型可以通过简单地调整输入和输出来适应不同的任务。

  3. 预训练和微调:T5-small通常通过大规模的无监督预训练来学习通用的语言表示,然后通过有监督的微调来适应特定任务,如问答、摘要生成等。

  4. 应用广泛:由于其灵活性和性能,在各种NLP应用中都有广泛的应用,包括机器翻译、文本生成、情感分析等。

  • 下载 google-t5/t5-small 模型
# 模型大小 4.49G
git clone https://huggingface.co/google-t5/t5-small
  • 安装依赖库
pip install 'transformers[torch]'
pip install sentencepiece
  • 文本生成示例
from transformers import T5Tokenizer, T5ForConditionalGeneration

# Step 1: 加载预训练的T5 tokenizer和模型
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

while True:
    # Step 2: 接收用户输入
    input_text = input("请输入要生成摘要的文本 (输入 'exit' 结束): ")
    
    if input_text.lower() == 'exit':
        print("程序结束。")
        break
    
    # 使用tokenizer对输入文本进行编码
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids

    # Step 3: 进行生成
    # 使用model.generate来生成文本
    output = model.generate(input_ids, max_length=50, num_beams=4, early_stopping=True)

    # Step 4: 解码输出
    output_text = tokenizer.decode(output[0], skip_special_tokens=True)

    # 打印输入和输出结果
    print("输入:", input_text)
    print("输出:", output_text)
    print("=" * 50)  # 分隔符,用来区分不同输入的输出结果

在这里插入图片描述

使用脚本进行训练

  • https://huggingface.co/docs/transformers/run_scripts

  • 从源代码安装 Transformers

git clone https://github.com/huggingface/transformers
cd transformers
pip install .
  • 将当前的 Transformers 克隆切换到特定版本
# 本地分支
git branch
# 远程分支
git branch -a
# 切换分支 v4.41.2,因为当前安装的版本是 v4.41.2
git checkout tags/v4.41.2
  • 安装依赖库
# 安装用于处理人类语言数据的工具集库
pip install nltk
# 安装用于计算ROUGE评估指标库
pip install rouge_score

Pytorch

示例脚本从 🤗 Datasets库下载并预处理数据集。然后,该脚本使用Trainer在支持摘要的架构上微调数据集。以下示例展示了如何在CNN/DailyMail数据集上微调T5-small。由于训练方式的原因,T5 模型需要额外的参数。此提示让 T5 知道这是一项摘要任务。

cd transformers/examples/pytorch/summarization
pip install -r requirements.txt
python run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

注意: 家用机上训练非常耗时,建议租用GPU服务器进行测试。

  • 数据缓存目录
# Linux/macOS
cd ~/.cache/huggingface
# Windows
C:\Users\{your_username}\.cache\huggingface
  • datasets
2.6G	cnn_dailymail
798M	downloads

机器翻译

数据集下载

  • https://huggingface.co/datasets/wmt/wmt16

在这里插入图片描述

数据集格式转换

pip install pandas
import pandas as pd
import jsonlines

# 输入和输出文件路径
input_parquet_file = './input_file.parquet'
output_jsonl_file = './output_file.jsonl'

# 加载 Parquet 文件
df = pd.read_parquet(input_parquet_file)

# 将数据写入 JSONLines 文件
with jsonlines.open(output_jsonl_file, 'w') as writer:
    for index, row in df.iterrows():
        json_record = {
            "source_text": row['source_column'],  # 替换成实际的源语言列名
            "target_text": row['target_column']   # 替换成实际的目标语言列名
        }
        writer.write(json_record)
  • train.jsonl
{ "cs": "Následný postup na základě usnesení Parlamentu: viz zápis", "en": "Action taken on Parliament's resolutions: see Minutes" }
  • validation.jsonl
{ "en": "UN Chief Says There Is No Military Solution in Syria", "ro": "Șeful ONU declară că nu există soluții militare în Siria" }
cd examples/pytorch/translation
pip install -r requirements.txt
python run_translation.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --source_lang en \
    --target_lang ro \
    --source_prefix "translate English to Romanian: " \
    --dataset_name wmt16 \
    --dataset_config_name ro-en \
    --train_file ./train.jsonl \
    --validation_file ./validation.jsonl \
    --output_dir /tmp/tst-translation \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --overwrite_output_dir \
    --predict_with_generate

注意: 家用机上训练非常耗时,建议租用GPU服务器进行测试。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1871528.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

《昇思25天学习打卡营第1天|基本介绍》

文章目录 前言:今日所学: 前言: 今天非常荣幸的收到了昇思25天学习打卡营的邀请。昇思MindSpore作为华为昇腾AI全栈的重要一员,他支持端、边、云独立的和协同的统一训练和推理框架,有着易于开发、执行效率高、全场景框…

以Bert训练为例,测试torch不同的运行方式,并用torch.profile+HolisticTraceAnalysis分析性能瓶颈

以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTraceAnalysis分析性能瓶颈 1.参考链接:2.性能对比3.相关依赖或命令4.测试代码5.HolisticTraceAnalysis代码6.可视化A.优化前B.优化后 以Bert训练为例,测试torch不同的运行方式,并用torch.profileHolisticTra…

深入剖析 Android 网络开源库 Retrofit 的源码详解

文章目录 概述一、Retrofit 简介Android主流网络请求库 二、Retrofit 源码剖析1. Retrofit 网络请求过程2. Retrofit 实例构建2.1 Retrofit.java2.2 Retrofit.Builder()2.2.1 Platform.get()2.2.2 Android 平台 2.3 Retrofit.Builder().baseUrl()2.4 Retrofit.Builder.client()…

Windows的内核对象

内核对象句柄特定于进程。 也就是说,进程必须创建 对象或打开现有对象以获取内核对象句柄。 内核句柄上的每个进程限制为 2^24。 但是,句柄存储在分页池中,因此可以创建的实际句柄数取决于可用内存。 可以在 32 位 Windows 上创建的句柄数明显低于 2^24。 任何进程都可以为…

Golang | Leetcode Golang题解之第201题数字范围按位与

题目&#xff1a; 题解&#xff1a; func rangeBitwiseAnd(m int, n int) int {for m < n {n & (n - 1)}return n }

Linux技能篇-恢复lvm物理卷

项目场景&#xff1a; 今天遇到一个很有意思的故障&#xff0c;我用虚拟机来还原了当前的故障场景。 首先来看&#xff0c;系统中只有一个lvn卷组 我们给系统中添加一块磁盘&#xff0c;使用pvcreate创建物理卷 pvcreate /dev/sdb并将容量添加到当前的卷组中 创建一个lvm逻辑…

基于Spring Boot医护人员排班系统

设计技术&#xff1a; 开发语言&#xff1a;Java数据库&#xff1a;MySQL技术&#xff1a;SpringbootMybatisvue 工具&#xff1a;IDEA、Maven、Navicat 主要功能&#xff1a; 医护类型管理 医护人员排班系统的系统管理员可以对医护类型添加修改删除以及查询操作。具体界面…

Opencv学习项目7——face_recognition

前面两篇博客解决了安装dlib库的问题和numpy和dlib不兼容的问题&#xff0c;今天开始做人脸识别第一个项目 我们可以从网上下载一张带有人脸的图片或者自己电脑有的也可以&#xff0c;我这里使用lyf的图片进行演示 加载图像文件 img1 face_recognition.load_image_file(lyf1.…

mac菜单栏应用管理软件:Bartender 4 for Mac 中文激活版

Bartender 4 是一款由Bearded Men Games开发的适用于Mac操作系统的应用程序&#xff0c;它被设计用来优化和美化Mac菜单栏的功能。自从macOS Big Sur开始&#xff0c;Mac的菜单栏可以自定义&#xff0c;用户可以添加和移除各种图标。Bartender 4就是在这个背景下应运而生&#…

论文阅读Vlogger: Make Your Dream A Vlog

摘要 论文介绍了一个名为“Vlogger”的通用人工智能系统&#xff0c;它能够根据用户的描述生成分钟级的视频博客&#xff08;vlog&#xff09;。与通常只有几秒钟的短视频不同&#xff0c;vlog通常包含复杂的故事情节和多样化的场景&#xff0c;这对现有的视频生成方法来说是一…

CPPTest设计分析

目录 1 概述2 设计3 扩展Output3.1 扩展实例 1 概述 CppTest是一个可移植、功能强大但简单的单元测试框架&#xff0c;用于处理C中的自动化测试。重点在于可用性和可扩展性。支持多种输出格式&#xff0c;并且可以轻松添加新的输出格式。 CppTest下载地址Sourceforge Github地…

django学习入门系列之第三点《伪类简单了解》

文章目录 hover&#xff08;伪类&#xff09;after&#xff08;伪类&#xff09;往期回顾 hover&#xff08;伪类&#xff09; 伪类指的是用冒号加的 hover样式指的是&#xff0c;当用户光标移动到设定区域后&#xff0c;所执行的用法 如&#xff1a; <!DOCTYPE html>…

视频网站系统

摘 要 随着互联网的快速发展和人们对视频内容的需求增加&#xff0c;视频网站成为了人们获取信息和娱乐的重要平台。本论文基于SpringBoot框架&#xff0c;设计与实现了一个视频网站系统。首先&#xff0c;通过对国内外视频网站发展现状的调研&#xff0c;分析了视频网站的背景…

静态资源服务器

上一章【认识 MIME 和 HTTP】。 我们认识和了解了 MIME 的概念和作用&#xff0c;也简单地学习了通过浏览器控制台查看请求和返回的用法。 通过对不同的 HTML、CSS、JS 文件进行判断&#xff0c;设置不同的 MIME 值&#xff0c;得以让我们的浏览器正正确地接收和显示不同的文…

2-18 基于matlab的关于联合对角化盲源分离算法的二阶盲识别(SOBI)算法

基于matlab的关于联合对角化盲源分离算法的二阶盲识别&#xff08;SOBI&#xff09;算法。通过联合对角化逼近解混矩阵。构建的四组信号&#xff0c;并通过认为设置添加噪声比例&#xff0c;掩盖信号信息。通过SOBI算法实现了解混。程序已调通&#xff0c;可直接运行。 2-18联合…

技术速递|Visual Studio Code 的 .NET MAUI 扩展现已正式发布

作者&#xff1a;Maddy Montaquila 排版&#xff1a;Alan Wang 今天&#xff0c;我们非常高兴地宣布 .NET MAUI VS Code 扩展插件结束了预览阶段&#xff0c;并将包含一些期待已久的新功能 - 包括 XAML IntelliSense 和 Hot Reload&#xff01; 什么是 .NET MAUI 扩展插件&…

成功解决​​​​​​​TypeError: __call__() got an unexpected keyword argument ‘first_int‘

成功解决TypeError: __call__() got an unexpected keyword argument first_int 目录 解决问题 解决思路 解决方法 T1、直接调用原始函数 T2、检查装饰器实现 T3、使用不同的调用方式 解决问题 result = multiply(**arguments) File "D:\ProgramData\Anaconda3\Li…

BFS:队列+树的宽搜

一、二叉树的层序遍历 . - 力扣&#xff08;LeetCode&#xff09; 该题的层序遍历和以往不同的是需要一层一层去遍历&#xff0c;每一次while循环都要知道在队列中节点的个数&#xff0c;然后用一个for循环将该层节点走完了再走下一层 class Solution { public:vector<vec…

【教程】简介nccl-test工具

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ GitHub - NVIDIA/nccl-tests: NCCL TestsNCCL Tests. Contribute to NVIDIA/nccl-tests development by creating an account on GitHub.https://githu…