【Transformers基础入门篇2】基础组件之Pipeline

news2024/11/23 15:22:47

文章目录

  • 一、什么是Pipeline
  • 二、查看PipeLine支持的任务类型
  • 三、Pipeline的创建和使用
    • 3.1 根据任务类型,直接创建Pipeline,默认是英文模型
    • 3.2 指定任务类型,再指定模型,创建基于指定模型的Pipeline
    • 3.3 预先加载模型,再创建Pipeline
    • 3.4 使用Gpu进行推理
    • 3.5 查看Device
    • 3.6 测试一下耗时
    • 3.7 确定的Pipeline的参数
  • 四、Pipeline的背后实现


本文为 https://space.bilibili.com/21060026/channel/collectiondetail?sid=1357748的视频学习笔记

项目地址为:https://github.com/zyds/transformers-code


一、什么是Pipeline

  • 将数据预处理、模型调用、结果后处理三部分组装成的流水线,如下流程图
  • 使我们能够直接输入文本便获得最终的答案,不需要我们关注细节
ToKenizer
Model
PostProcessing
Raw text
Input IDs
Logits
Predictions
我觉得不太行
101, 2769, 6230, 2533, 679, 1922, 6121, 8013, 102
0.9736, 0.0264
Positive:0.9736

二、查看PipeLine支持的任务类型

from transformers.pipelines import SUPPORTED_TASKS
from pprint import pprint
for k, v in SUPPORTED_TASKS.items():
    print(k, v)

输出但其概念PipeLine支持的任务类型以及可以调用的
举例输出:

audio-classification {'impl': <class 'transformers.pipelines.audio_classification.AudioClassificationPipeline'>, 'tf': (), 'pt': (<class 'transformers.models.auto.modeling_auto.AutoModelForAudioClassification'>,), 'default': {'model': {'pt': ('superb/wav2vec2-base-superb-ks', '372e048')}}, 'type': 'audio'}
  • key: 任务的名称,如音频分类
  • v:关于任务的实现,如具体哪个Pipeline,有没有TF模型,有没有pytorch模型, 模型具体是哪一个
    在这里插入图片描述

三、Pipeline的创建和使用

3.1 根据任务类型,直接创建Pipeline,默认是英文模型

from transformers import pipeline
pipe = pipeline("text-classification") # 根据pipeline直接创建一个任务类
pipe("very good") # 测试一个句子,输出结果

3.2 指定任务类型,再指定模型,创建基于指定模型的Pipeline

注,这里我已经将模型离线下载到本地了

# https://huggingface.co/models
pipe = pipeline("text-classification", model="./models/roberta-base-finetuned-dianping-chinese")

3.3 预先加载模型,再创建Pipeline

rom transformers import AutoModelForSequenceClassification, AutoTokenizer

# 这种方式,必须同时指定model和tokenizer
model = AutoModelForSequenceClassification.from_pretrained("./models_roberta-base-finetuned-dianping-chinese")
tokenizer = AutoTokenizer.from_pretrained("./models_roberta-base-finetuned-dianping-chinese")
pipe = pipeline("text-classification", model=model, tokenizer=tokenizer)

3.4 使用Gpu进行推理

pipe = pipeline("text-classification", model="./models_roberta-base-finetuned-dianping-chinese", device=0)

3.5 查看Device

pipe.model.device

3.6 测试一下耗时

import torch
import time
times = []
for i in range(100):
    torch.cuda.synchronize()
    start = time.time()
    pipe("我觉得不太行!")
    torch.cuda.synchronize()
    end = time.time()
    times.append(end - start)
print(sum(times) / 100)

3.7 确定的Pipeline的参数

# 先创建一个pipeline
qa_pipe = pipeline("question-answering", model="../../models/models")
qa_pipe

输出
在这里插入图片描述QuestionAnsweringPipeline
在这里插入图片描述
查看定义,会告诉我们这个pipeline该如何使用

class QuestionAnsweringPipeline(ChunkPipeline):
    """
    Question Answering pipeline using any `ModelForQuestionAnswering`. See the [question answering
    examples](../task_summary#question-answering) for more information.

    Example:

    ```python
    >>> from transformers import pipeline

    >>> oracle = pipeline(model="deepset/roberta-base-squad2")
    >>> oracle(question="Where do I live?", context="My name is Wolfgang and I live in Berlin")
    {'score': 0.9191, 'start': 34, 'end': 40, 'answer': 'Berlin'}
    ```

    Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)

    This question answering pipeline can currently be loaded from [`pipeline`] using the following task identifier:
    `"question-answering"`.

    The models that this pipeline can use are models that have been fine-tuned on a question answering task. See the
    up-to-date list of available models on
    [huggingface.co/models](https://huggingface.co/models?filter=question-answering).
    """

进入pipeline,看__call__,查看可以支持的更多的参数
列出了更多的参数

    def __call__(self, *args, **kwargs):
        """
        Answer the question(s) given as inputs by using the context(s).

        Args:
            args ([`SquadExample`] or a list of [`SquadExample`]):
                One or several [`SquadExample`] containing the question and context.
            X ([`SquadExample`] or a list of [`SquadExample`], *optional*):
                One or several [`SquadExample`] containing the question and context (will be treated the same way as if
                passed as the first positional argument).
            data ([`SquadExample`] or a list of [`SquadExample`], *optional*):
                One or several [`SquadExample`] containing the question and context (will be treated the same way as if
                passed as the first positional argument).
            question (`str` or `List[str]`):
                One or several question(s) (must be used in conjunction with the `context` argument).
            context (`str` or `List[str]`):
                One or several context(s) associated with the question(s) (must be used in conjunction with the
                `question` argument).
            topk (`int`, *optional*, defaults to 1):
                The number of answers to return (will be chosen by order of likelihood). Note that we return less than
                topk answers if there are not enough options available within the context.
            doc_stride (`int`, *optional*, defaults to 128):
                If the context is too long to fit with the question for the model, it will be split in several chunks
                with some overlap. This argument controls the size of that overlap.
            max_answer_len (`int`, *optional*, defaults to 15):
                The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
            max_seq_len (`int`, *optional*, defaults to 384):
                The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
                model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
            max_question_len (`int`, *optional*, defaults to 64):
                The maximum length of the question after tokenization. It will be truncated if needed.
            handle_impossible_answer (`bool`, *optional*, defaults to `False`):
                Whether or not we accept impossible as an answer.
            align_to_words (`bool`, *optional*, defaults to `True`):
                Attempts to align the answer to real words. Improves quality on space separated langages. Might hurt on
                non-space-separated languages (like Japanese or Chinese)

        Return:
            A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:

            - **score** (`float`) -- The probability associated to the answer.
            - **start** (`int`) -- The character start index of the answer (in the tokenized version of the input).
            - **end** (`int`) -- The character end index of the answer (in the tokenized version of the input).
            - **answer** (`str`) -- The answer to the question.
        """

如下面的例子

我们输出问题:中国的首都是哪里? 给的上下文是:中国的首都是北京

qa_pipe(question="中国的首都是哪里?", context="中国的首都是北京")

在这里插入图片描述

如果通过 max_answer_len参数来限定输出的最大长度,会进行强行截断

qa_pipe(question="中国的首都是哪里?", context="中国的首都是北京", max_answer_len=1)

在这里插入图片描述

四、Pipeline的背后实现

  • step1 初始化组件,Tokenizer,model
# step1 初始化tokenizer, model
tokenizer = AutoTokenizer.from_pretrained("../../models/models_roberta-base-finetuned-dianping-chinese")
model = AutoModelForSequenceClassification.from_pretrained("../../models/models_roberta-base-finetuned-dianping-chinese")
  • step2 预处理
# 预处理,返回pytorch的tensor,是一个dict
input_text = "我觉得不太行!"
inputs = tokenizer(input_text, return_tensors="pt")
inputs

在这里插入图片描述

  • step3 模型预测
res = model(**inputs)
res

在这里插入图片描述
预测的结果,包括的内容有点多,如loss,logits等

  • step4 结果后处理
logits = res.logits
logits = torch.softmax(logits, dim=-1)
pred = torch.argmax(logits).item()
result = model.config.id2label.get(pred)
result

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2161122.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

用二维码收集信息时,在后台可以查看、统计哪些数据?

大家都知道&#xff0c;在二维码上关联表单&#xff0c;就可以扫码填写信息了。那么&#xff0c;收集到的数据在哪里查看&#xff1f;具体可以查看到哪些数据呢&#xff1f; 如果是用草料二维码平台搭建的二维码&#xff0c;前往后台&#xff0c;在表单列表中找到对应的表单&a…

智能Ai语音机器人的应用价值有哪些?

随着时间的推移&#xff0c;人工智能的发展越来越成熟&#xff0c;智能时代也离人们越来越近&#xff0c;近几年人工智能越来越火爆&#xff0c;人工智能的应用已经开始渗透到各行各业&#xff0c;与生活交融&#xff0c;成为人们无法拒绝&#xff0c;无法失去的一个重要存在。…

【解密 Kotlin 扩展函数】命名参数和默认值(十三)

导读大纲 1.0.1 命名参数1.0.2 默认参数值 上一节讲述如何自定义 joinToString 函数来代替集合的默认字符串表示 文末遗留下几个待优化问题–传送门 1.0.1 命名参数 我们要解决的第一个问题涉及函数调用的可读性 例如,请看下面的joinToString调用: joinToString(collection,&…

MyBatis深度剖析:从入门到精通的实践指南

前言 什么是mybatis&#xff1f; MyBatis是一款优秀的持久层框架&#xff0c;用于简化Java应用程序与数据库之间的交互 什么是框架&#xff0c;为什么需要框架技术&#xff1f; 框架技术 是一个应用程序的半成品提供可重用的公共结构按一定规则组织的一组组件框架优势&#x…

【自动化测试】Appium 生态工具以及Appium Desktop如何安装和使用

引言 Appium 是一个开源的自动化测试框架&#xff0c;用于测试原生、移动 Web 和混合应用程序。它支持 iOS、Android 和 Windows 平台。Appium 生态系统包含多个工具和库&#xff0c;这些工具和库可以与 Appium 一起使用&#xff0c;以提高移动应用的自动化测试效率 文章目录 引…

Java面试指南(基础篇)

文章目录 前言01 Java语言的特点02 JVM、JRE及JDK的关系03 Java和C的区别04 基本数据类型05 类型转换06 自动装箱与拆箱07 String的不可变性08 字符常量和字符串常量的区别09 字符串常量池10 String 类的常用方法11 String和StringBuffer、StringBuilder的区别12 switch 是否能…

舒服了!学大模型必看的学习书籍来了

最近整理了日前市面上一大波大模型的书&#xff0c;已经打包成pdf了&#xff0c;大家有需要的&#xff0c;可以自行添加获取&#xff0c;纯福利&#xff0c;无套路&#xff0c;添加后说明是哪本书&#xff0c;会直接给大家&#xff01;&#xff08;文末获取&#xff09; 部分书…

IO 多路转接之 epoll

文章目录 IO 多路转接之 epoll1、IO 多路转接之 poll1.1、poll 函数1.2、poll 函数返回值1.3、Socket 就绪条件1.3.1、读就绪1.3.2、写就绪1.3.3、异常就绪 1.4、poll 的优点1.5、poll 的缺点1.6、poll 改写 select 2、IO 多路转接之 epoll2.1、epoll 函数2.2、epoll_create2.3…

Leetcode 反转链表

使用递归 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}* ListNode(int val) { this.val val; }* ListNode(int val, ListNode next) { this.val val; this.next next; }* }*/ class S…

超低排放燃气锅炉

在全球环保浪潮的推动下&#xff0c;超低排放燃气锅炉以其卓越的环保性能和高效能源利用&#xff0c;正逐渐成为现代热能供应的主力军。作为传统锅炉的升级版&#xff0c;超低排放燃气锅炉不仅在技术上实现了质的飞跃&#xff0c;更在环保和节能方面树立了新的标杆。朗观视觉小…

linux入门到实操-10 控制台显示和输出重定向、监控文件变化、软连接

教程来源&#xff1a;B站视频BV1WY4y1H7d3 3天搞定Linux&#xff0c;1天搞定Shell&#xff0c;清华学神带你通关_哔哩哔哩_bilibili 整理汇总的课程内容笔记和课程资料&#xff08;包含课程同版本linux系统文件等内容&#xff09;&#xff0c;供大家学习交流下载&#xff1a;…

【Delphi】扩展现有组件创建新的 FireMonkey 组件(步骤一)

本例中演示将TLabel控件扩展成TClockLabel新控件。具体如下&#xff1a; 步骤 1 - 使用新建组件向导创建组件 1. 菜单选择 Component -> New Component。 2. 在新建组件向导的第一页&#xff0c;选择 FireMonkey for Delphi &#xff1a; 3. 在 “Ancestor Component ”页…

【最新华为OD机试E卷-支持在线评测】爱吃蟠桃的孙悟空(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…

解决windows上VMware的ubuntu虚拟机不能拷贝和共享

困扰多时的VMware虚拟机不能复制拷贝和不能看到共享文件夹的问题&#xff0c;终于解决了~ 首先确定你已经开启了复制拷贝和共享文件夹&#xff0c;并且发现不好用。。。 按照下面方式解决这个问题。 1&#xff0c;删除当前的vmware tools。 sudo apt-get remove --purge ope…

【Redis技术进阶之路】「原理分析系列开篇」揭秘分析客户端和服务端网络通信交互实现(客户端篇)

揭秘高效存储模型与数据结构底层实现 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 客户端和服务器Redis服务器IO多路复用RedisClient结构 客户端属性分析套接字描述符客户端的分类伪客户端…

【二十五】【QT开发应用】无边窗窗口鼠标拖动窗口移动,重写mousePressEvent,mouseMoveEvent函数

在 Qt 中&#xff0c;可以通过在自定义的类中重载 mousePressEvent 和 mouseMoveEvent 函数来捕获鼠标按下和移动事件&#xff0c;以便实现例如拖动窗口等功能。 mousePressEvent 和 mouseMoveEvent分别是鼠标按下事件和鼠标移动事件。这两个函数是QT中本身就存在的函数&#…

prithvi WxC气象模型

NASA发布了prithvi WxC气象模型发布 Prithvi是NASA开源的模型&#xff0c;被誉为全球最大的开源地理空间大模型。昨天晚上逛X平台&#xff0c;我看到Prithvi模型又来了新成员&#xff1a;prithvi WxC。 NASA和IBM创建了一个基于MERRA-2数据的天气和气候AI基础模型—Prithvi Wx…

C++ :借助栈完成二叉树的非递归遍历

二叉树的传统访问分为&#xff1a;前序、中序、后序、层序。 其中前三者是递归访问&#xff0c;但是递归是有缺陷的&#xff0c;树太深就会栈溢出。 因此本文我们思考如何使用非递归的方法来完成遍历。 1. 前序遍历 要迭代⾮递归实现⼆叉树前序遍历&#xff0c;⾸先还是要借…

【计算机组成原理】实验一:运算器输入锁存器数据写实验

目录 实验要求 实验目的 主要集成电路芯片及其逻辑功能 实验原理 实验内容及步骤 实验内容 思考题 实验要求 利用CP226实验箱上的K16&#xff5e;K23二进制拨动开关作为DBUS数据输入端&#xff0c;其它开关作为控制信号的输入端&#xff0c;将通过K16&#xff5e;K23设定…

无人经济已经 next level 了吗?

01 从无人售货机开始… 晚上 11 点下班回到小区&#xff0c;顺便去驿站取个快递&#xff0c;走进驿站发现四周空无一人&#xff0c;把快递放在机器上滴一声就可以走人了。走的时候在旁边的无人超市里拿一袋方便面&#xff0c;当做加班的安慰……发现了吗&#xff0c;无人门店…