使用预训练模型自动续写文本的四种方法

news2025/1/16 21:18:10

作者:皮皮雷 来源:投稿
编辑:学姐

这篇文章以中文通用领域文本生成为例,介绍四种常用的模型调用方法。在中文文本生成领域,huggingface上主要有以下比较热门的pytorch-based预训练模型:

本文用到了其中的uer/gpt2-chinese-cluecorpussmall和hfl/chinese-xlnet-base,它们都是在通用领域文本上训练的。

但是要注意有些模型(如CPM-Generate共有26亿参数)模型文件较大,GPU有限的情况下可能会OOM。

依赖包:transformers 4

本文使用的例句来源于豆瓣爬下的部分书评。

方法1:transformers.pipline

简介:

直接调用transformers里面的pipline。

源码及参数选择参考:

https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.pipeline

缺点:不能以batch形式生成句子,不能并行,大规模调用的时候时间复杂度较高。

from transformers import pipeline

#this pipline can only generate text one by one
generator = pipeline(
    'text-generation', 
    model="uer/gpt2-chinese-cluecorpussmall",  #可以直接写huggingface上的模型名,也可以写本地的模型地址
    device = 1)

text_inputs = ["客观、严谨、浓缩",
                "地摊文学……",
                "什么鬼玩意,",
                "豆瓣水军果然没骗我。",
                "这是一本社会新闻合集",
                "风格是有点学古龙嘛?但是不好看。"]

sent_gen = generator(text_inputs, 
                        max_length=100, 
                        num_return_sequences=2,
                        repetition_penalty=1.3, 
                        top_k = 20) 
#返回的sent_gen 形如#[[{'generated_text':"..."},{}],[{},{}]]

for i in sent_gen:
    print(i)

方法2:transformers中的TextGenerationPipeline类

源码及参数选择参考:

https://huggingface.co/docs/transformers/v4.17.0/en/main_classes/pipelines#transformers.TextGenerationPipeline

优点:相较方法1,可以设置batch size。

from transformers import BertTokenizer, GPT2LMHeadModel, TextGenerationPipeline

tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
model = GPT2LMHeadModel.from_pretrained("uer/gpt2-chinese-cluecorpussmall")

text_generator = TextGenerationPipeline(model, tokenizer, batch_size=3, device=1)
text_inputs = ["客观、严谨、浓缩",
                "地摊文学……",
                "什么鬼玩意,",
                "豆瓣水军果然没骗我。",
                "这是一本社会新闻合集",
                "风格是有点学古龙嘛?但是不好看。"]

gen = text_generator(text_inputs, 
                    max_length=100, 
                    repetition_penalty=10.0, 
                    do_sample=True, 
                    num_beams=5,
                    top_k=10)

for sent in gen:
    gen_seq = sent[0]["generated_text"]
    print("")
    print(gen_seq.replace(" ",""))

方法3:transformers通用方法,直接加载模型

源码及参数选择参考:

https://github.com/huggingface/transformers/blob/c4d4e8bdbd25d9463d41de6398940329c89b7fb6/src/transformers/generation_utils.py#L101

缺点:封装度较差,代码较为冗长。

优点:由于是transformers调用模型的通用写法,和其他模型(如bert)的调用方式相似,(如tokenizer的使用),可以举一反三。

from transformers import AutoTokenizer, AutoModelWithLMHead
import torch, os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
tokenizer = AutoTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
model = AutoModelWithLMHead.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
config=model.config

print(config)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
texts = ["客观、严谨、浓缩",
                "地摊文学……",
                "什么鬼玩意,",
                "豆瓣水军果然没骗我。",
                "这是一本社会新闻合集",
                "风格是有点学古龙嘛?但是不好看。"]
#用batch输入的时候一定要设置padding
encoding = tokenizer(texts, return_tensors='pt', padding=True).to(device)

with torch.no_grad():
    generated_ids = model.generate(**encoding, 
    max_length=200, 
    do_sample=True, #default = False
    top_k=20, #default = 50
    repetition_penalty=3.0 #default = 1.0, use float
    ) 
generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

for l in generated_texts:
    print(l)

方法4:Simple Transformers

简介:Simple Transformers基于HuggingFace的Transformers,对特定的NLP经典任务做了高度的封装。在参数的设置上也较为灵活,可以通过词典传入参数。模型的定义和训练过程非常直观,方便理解整个AI模型的流程,很适合NLP新手使用。

simple transformers 指南:

https://simpletransformers.ai/docs/language-generation-model/

优点:这个包集成了微调的代码,不仅可以直接做生成,进一步微调也非常方便。

缺点:有些中文模型不能直接输入huggingface上的模型名称进行自动下载,会报错找不到tokenizer文件,需要手动下载到本地。

$ pip install simpletransformers

下载中文生成模型到本地文件夹 models/chinese-xlnet-base

from simpletransformers.language_generation import LanguageGenerationModel
# import logging

# logging.basicConfig(level=logging.INFO)
# transformers_logger = logging.getLogger("transformers")
# transformers_logger.setLevel(logging.WARNING)

model = LanguageGenerationModel("xlnet", #model type
"models/chinese-xlnet-base", #包含 .bin file的文件路径
args={"max_length": 50, "repetition_penalty": 1.3,"top_k":100})
prompts =["客观、严谨、浓缩",
                "地摊文学……",
                "什么鬼玩意,",
                "豆瓣水军果然没骗我。",
                "这是一本社会新闻合集",
                "风格是有点学古龙嘛?但是不好看。"]
for prompt in prompts:
    # Generate text using the model. Verbose set to False to prevent logging generated sequences.
    generated = model.generate(prompt, verbose=False)
    print(generated)

观察:用gpt2-chinese-cluecorpussmall生成的文本

参数设置:

max_length=100
repetition_penalty=10.0
do_sample=True
top_k=10

注:每一段文字的开头(标蓝)是预先给定的prompt

PS:乍一看生成语句的流利度和自然度都较好,还挺像人话的;而且有些句子能够按照“书评”的方向写。但仔细看就会发现噪音较多,而且容易“自由发挥”而跑题。这就是自由文本生成的常见问题:因为过于自由而不可控。

那么如何将生成的文本限定在想要的格式或领域中呢?这就是可控文本生成的研究范围了。一个较为常见的做法是对GPT-2作增量训练,让模型熟悉当前的语境。

总结

本文列举和比较了四种使用pytorch调用生成式模型做文本生成的方式。分别是:

① transformers自带的pipline

② transformers中的TextGenerationPipeline类

③ transformers通用方法,直接加载模型

④ Simple Transformers

这些方法各有优缺点。如果需要后续微调,建议使用③或④。如果只是简单地体验生成效果,建议使用①和②,但是方法①不能以batch形式输入,速度较慢。

关注下方《学姐带你玩AI》🚀🚀🚀

回复“ACL”免费获取NLP顶会必读高分论文

包含文本生成等20个细分方向

码字不易,欢迎大家点赞评论收藏!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/402091.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

RFID在技术在工业产线上的应用

RFID在技术在工业产线上的应用一工业产线需求制造业生产线几乎每月都要损耗大量物料,并且生产结果与预期因为有误差而影响交货的情况时有发生,生产线也往往因人为原因造成种种误差。将RFID标签贴在生产物料或产品上,可自动记录产品的数量、规…

学完Java只能在互联网公司任职吗?

当然不是只有互联网公司需要软件,需要开发技术人员,传统行业、新经济领域都有软件项目需求;Java也不是只能做网站、企业应用,还可以用于嵌入式、游戏…… 互联网时代的手机、智能电视、家具、机械设备等各种有形产品都将会嵌入智…

二、Neo4j源码研究系列 - 单步调试

二、Neo4j源码研究系列 - 单步调试 一、背景介绍 上一篇我们已经把了neo4j的源码准备以及打包流程完成了,本篇将讲解如何对neo4j进行单步调试。对于不了解如何编译打包neo4j的读者,请阅读《一、Neo4j源码研究系列 - 源代码准备》。 大纲: …

【改机教程】iOS系统去除小黑条,改拍照声、拨号音、键盘音,不用越狱,支持所有机型

大家好,上次给大家分享了几个iOS系统免越狱改机教程 今天带来最新的教程,这次修改利用的是同一个漏洞,由外网大神 tamago 开发,国内大神冷风 进行汉化和优化 可以修改的地方包括 去除底部小黑条 dock栏透明 桌面文件夹透明 桌面…

golang 占位符还傻傻分不清?

xdm ,写 C/C 语言的时候有格式控制符,例如 %s , %d , %c , %p 等等 在写 golang 的时候,也是有对应的格式控制符,也叫做占位符,写这个占位符,需要有对应的数据与之对应,不能瞎搞 基本常见常用…

Cobalt Strike---(2)

数据管理 Cobalt Strike 的团队服务器是行动期间Cobalt Strike 收集的所有信息的中间商。Cobalt Strike 解析来 自它的 Beacon payload 的输出,提取出目标、服务和凭据。 如果你想导出 Cobalt Strike 的数据,通过 Reporting → Export Data 。Cobalt Str…

CentOS7自签SSL证书并配置nginx

一、生成SSL证书 1、安装依赖包 yum install -y openssl openssl-devel 2、生成私钥,会让你输入一个 4~2048 位的密码,你需要暂时记住这个密码 openssl genrsa -des3 -out server.key 2048 输入两遍相同的密码 3、生成CSR(Certificate Signing Request …

Postgresql-12.5 visual studio-2022 windows 添加pg工程并调试

pg内核学习,记录一下 文章目录安装包编译安装VS添加Postgresql工程调试源码安装包 (1)perl下载 https://www.perl.org/get.html (2)diff下载 http://gnuwin32.sourceforge.net/packages/diffutils.htm (…

23届非科班选手秋招转码指南

1.秋招情况介绍 1.1自我介绍 我是一名23届非科班转码选手,本硕均就读于某211院校机械专业,秋招共计拿下12份offer,包括大疆创新、海康威视、联发科技、理想汽车、中电28、阳光电源等各行业、各种性质企业的意向。主要的投递岗位为嵌入式软件…

若依微服务版在定时任务里面跨模块调用服务

第一步 在被调用的模块中添加代理 RemoteTaskFallbackFactory.java: package com.ruoyi.rpa.api.factory;import com.ruoyi.common.core.domain.R; import com.ruoyi.rpa.api.RemoteTaskService; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springf…

【springmvc】执行流程

SpringMVC执行流程 原理图 1、SpringMVC常用组件 DispatcherServlet:前端控制器,不需要工程师开发,由框架提供 作用:统一处理请求和响应,整个流程控制的中心,由它调用其它组件处理用户的请求 HandlerMa…

Windows7,10使用:Vagrant+VirtualBox 安装 centos7

一、Vagrant,VirtualBox 是什么二、版本说明1、win7下建议安装版本2、win10下建议安装版本三、Windows7下安装1、安装Vagrant2、安装VirtualBox3、打开VirtualBox,配置虚拟机默认安装地址四、windows7下载.box文件,安装centos 71、下载一个.b…

拐点!新能源车交付均价首次「低于」燃油车,智能电动成新爆点

2023年开局,随着特斯拉打响新能源汽车市场的「价格战」首炮,除部分燃油车品牌(仍依赖自身多年的用户和品牌积累的溢价能力)没有跟进之外,几乎所有的新能源车型都在进行车型价格的下调。 而数据也在反映市场的拐点即将来…

深入理解Zookeeper的ZAB协议

ZAB是什么ZAB(Zookeeper Atomic Broadcast):Zookeeper原子广播ZAB是为了保证Zookeeper数据一致性而产生的算法(指的是Zookeeper集群模式)。它不仅能解决正常情况下的数据一致性问题,还可以保证主节点发生宕…

最全的论文写作技巧(建议收藏)

近10年来,笔者有幸多次参与教学论文的评审工作,在此,特将教学论文写作的步骤及相关问题整理汇总如下: 一、选定论题 (一)论题在文中的地位与作用 严格地讲,论文写作是从选定论题开始的。选题…

Android源码分析 - Parcel 与 Parcelable

0. 相关分享 Android-全面理解Binder原理 Android特别的数据结构(二)ArrayMap源码解析 1. 序列化 - Parcelable和Serializable的关系 如果我们需要传递一个Java对象,通常需要对其进行序列化,通过内核进行数据转发,…

这几个群,程序员可千万不要进!

震惊!某摸鱼网站惊现肾结石俱乐部! (图源V2EX) 无关地域、无关性别,各位程序员们在肾结石这个病上面有着出奇一致的反应。诸如此类的各种职业病在我们的生活中更是十分常见。 也可能是到年纪了,在办公室…

ATTCK v12版本战术介绍——提权(一)

一、引言在前几期文章中我们介绍了ATT&CK中侦察、资源开发、初始访问、执行、持久化战术理论知识及实战研究,通过实战场景验证行之有效的检测规则、防御措施,本期我们为大家介绍ATT&CK 14项战术中提权战术(一)&#xff0c…

计算机图形学09:二维观察之点的裁剪

作者:非妃是公主 专栏:《计算机图形学》 博客地址:https://blog.csdn.net/myf_666 个性签:顺境不惰,逆境不馁,以心制境,万事可成。——曾国藩 文章目录专栏推荐专栏系列文章序一、二维观察基本…

设计模式4——行为型模式

行为型模式用于描述程序在运行时复杂的流程控制,即描述多个类或对象之间怎样相互协作共同完成单个对象都无法单独完成的任务,它设计算法与对象间职责的分配。 行为型模式分为类行为模式和对象行为模式,前者采用继承机制来在类间分派行为&…