心法利器[89] | 实用文本生成中的解码方法

news2025/1/11 2:46:13

心法利器

本栏目主要和大家一起讨论近期自己学习的心得和体会,与大家一起成长。具体介绍:仓颉专项:飞机大炮我都会,利器心法我还有。

2022年新一版的文章合集已经发布,累计已经60w字了,获取方式看这里:CS的陋室60w字原创算法经验分享-2022版。(2023在路上了!)

往期回顾

  • 心法利器[84] | 最近面试小结

  • 心法利器[85] | 算法技术和职业规划

  • 心法利器[86] | 毕业4年的算法工程师:进步再进步

  • 心法利器[87] | 填志愿:AI算法方向过来人的建议

  • 心法利器[88] | 有关大模型幻觉问题的思考

最近大模型挺火的,在学习的过程中,偶然间发现在解码上,似乎能有不少花样,而且通过调整似乎也能得到很不一样的回复内容,而且这也是文本生成中很关键的一块,所以最近趁机就把这块内容学习了一下。

本文主要参考了这篇的内容:

  • 英文版:https://huggingface.co/blog/how-to-generate

  • 中文版:https://huggingface.co/blog/zh/how-to-generate

为什么需要解码

在文章中,所有的解码都是指代的自回归式的生成任务,简单的可以理解为每个词的预测其实都是基于上文的词的概率分布对这个位置进行的预测,说白了就是一个很简单的条件概率。然而,模型预测出来的,其实是每个位置的概率分布,即这个位置下每个词在这个位置的可能性,而所谓的解码,就是根据这一系列的概率分布,在每一步选择最优的词汇,从而最终输出一个句子。

假设模型的初始化如下:

import tensorflow as tf
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

# add the EOS token as PAD token to avoid warnings
model = TFGPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)

贪心解码(greedy search)

顾名思义,就是在每一步都选择概率最大的词,这也是速度最快的解码方式了,这个图直接就能用看出来:

1b93fab6e2f80182bf09f2a51efa6965.png
greedy search

从图里可以看到,每一步其实选择的都是概率最高的那个分支。然后是代码:

# encode context the generation is conditioned on
input_ids = tokenizer.encode('I enjoy walking with my cute dog', return_tensors='tf')

# generate text until the output length (which includes the context length) reaches 50
greedy_output = model.generate(input_ids, max_length=50)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(greedy_output[0], skip_special_tokens=True))

理论上这个似乎是合理的,但实际上的输出是这样的:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll ever be able to walk with my dog.

I'm not sure if I'll

大家可以发现,生成的过程中开始重复了,主要原因是贪心搜索只关注眼前最大值,而忽略了后续可能有整体最大的选择,因此我们吸纳更多的选择来综合打分,尽可能选择一个全局最优解。

束搜素(beam search)

很显然,因为词汇过多,所以我们是无法再每一步都遍历所有的可能,这将会是的复杂度,因此我们倒是可以选TOPN的词汇来往后搜索即可。

640e510d5b607a6b8b98208d095206ef.png
beam_search

假设这次每次的选择都是TOP2,其实会发现,总结下来似乎确实是能找到更好的解码结果。

# activate beam search and early_stopping
beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I'm not sure if I'll ever be able to walk with him again. I'm not sure if I'll

但我们也看到,内容依旧会有重复,从模型的输出角度,我们知道模型总会选出一定范围内最大的选择,有些话确实可能会因为上文而重复循环出现,因此我们需要对这些容易重复的内容进行一定的惩罚。

# set no_repeat_ngram_size to 2
beam_output = model.generate(
    input_ids, 
    max_length=50, 
    num_beams=5, 
    no_repeat_ngram_size=2, 
    early_stopping=True
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(beam_output[0], skip_special_tokens=True))
Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with him again.

I've been thinking about this for a while now, and I think it's time for me to take a break

这样重复就得到缓解甚至避免了。当然了,这种避免重复的方式还是需要避免,尤其是某些带有关键话题的内容,如果约束了,某些主题词可能出现的次数就太少了。

另外,我们还可以用num_return_sequences这个字段来控制输出句子的个数,有更多选择也会在一些场景比较方便,注意num_return_sequences<num_beams`。

采样(sampling)

采样,就是在对每一个位置预测的时候,以该位置的概率分布随机选择输出词,这种方式最大的特点就是增加了随机性(注意,是特点,有的时候这样做可能是负效果的,除非固定了random_seed)。

设置采样的开关在do_sample这个字段,为True的时候,就启动了do_sample。

# set seed to reproduce results. Feel free to change the seed though to get different results
tf.random.set_seed(0)

# activate sampling and deactivate top_k by setting top_k sampling to 0
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

这里仔细看看输出:

Output:
----------------------------------------------------------------------------------------------------
I enjoy walking with my cute dog. He just gave me a whole new hand sense."

But it seems that the dogs have learned a lot from teasing at the local batte harness once they take on the outside.

"I take

从这里看,输出似乎流畅,但会出现一定的不合理性,核心原因是因为这个采样,运气不好会找到一些不合适的单词的,为了缓解这个问题,可以通过设置温度来进行调整,这个温度实际上是加载softmax中的,用于锐化或拉平这个概率分布,一般温度越小差异越大,此时,概率高的词汇概率会变得更高,从而更容易被选择,从而缓解选出不太可能的词汇的问题。

# set seed to reproduce results. Feel free to change the seed though to get different results
tf.random.set_seed(0)

# use temperature to decrease the sensitivity to low probability candidates
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=0, 
    temperature=0.7
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

Top-K采样和Top-p采样

Top-K采样是指在选择的时候,最大的K个词会被选择出来,选出来后重新归一化,再来进行采样,这种方式能更大限度避免选出不太可能的词汇。

# set top_k to 50
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_k=50
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

但是在某些时候还是有问题,概率分布有时候倾向性可能很明显,有时候又会不那么明显,如果是按照强硬的个数条件进行选择,此时仍有可能选到后面的词汇概率仍旧非常低,此时又有了top-p采样,即按照累计概率进行采样,当前N个词汇的累计概率大于我们预设的概率时,就会停止采样。

# deactivate top_k sampling and sample only from 92% most likely words
sample_output = model.generate(
    input_ids, 
    do_sample=True, 
    max_length=50, 
    top_p=0.92, 
    top_k=0
)

print("Output:\n" + 100 * '-')
print(tokenizer.decode(sample_output[0], skip_special_tokens=True))

小结

有关上面的多个生成方案,其实只是通过某个方式串起来而已,他们之间可能没有明显的上下位关系,而是一个优劣势互补的关系,很多时候可能我们要进经过一些筛选。

另外,generate里面,其实有很多可供控制的参数,具体的大家可以参考这几篇文章:

  • https://huggingface.co/docs/transformers/v4.30.0/en/generation_strategies#customize-text-generation

  • https://blog.csdn.net/muyao987/article/details/125917234

9e0f27af632c4e0061e45724bf322b48.png

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/736396.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

[CVPR‘23] PanoHead: Geometry-Aware 3D Full-Head Synthesis in 360 deg

论文&#xff5c;项目 总结&#xff1a; 任务&#xff1a;3D human head synthesis现有问题&#xff1a;GANs无法在「in-the-wild」「single-view」的图片情况下&#xff0c;生成360度人像解决方案&#xff1a;1&#xff09;提出了two-stage self-adaptive image alignment&am…

C++ 设计模式之策略模式

文章目录 一、简介二、场景三、举个栗子四、小结参考资料 一、简介 策略模式的定义很简单&#xff1a;即创建一系列的算法,把它们一个个封装起来 , 并且使它们可相互替换&#xff08;用扩展的方式来面对未来变化&#xff09;。在GoF一书中将其定位为一种“对象行为式模式”&…

vs code insiders 配置c语言

vs code insiders 配置c语言 1.下载插件 2.再配置代码 &#xff08;1&#xff09;launch.json {// Use IntelliSense to learn about possible attributes.// Hover to view descriptions of existing attributes.// For more information, visit: https://go.microsoft.com/…

操作系统的可扩展访问控制

访问控制是操作系统安全的基石&#xff0c;当前的操作系统已部署了很多访问控制的模型&#xff1a;Unix和Windows NT多用户安全&#xff1b;SELinux中的类型执行&#xff1b;反恶意软件产品&#xff1b;Apple OS X&#xff0c;Apple iOS和Google Android中的应用沙盒&#xff1…

RNN介绍

时间序列的表示 [seq_len, batch_size, vec ] seq_len表示一个句子通常有多少个单词或者一个序列有多少个时间段,batch_size表示同时多个样本,vec表示单词的编码长度 请问rnn和lstm中batchsize和timestep的区别是什么? - 知乎 (zhihu.com) import torch import torch.nn …

MYSQL05高级_查看修改存储引擎、InnoDB和MyISAM对比、其他存储引擎介绍

文章目录 ①. 查看、修改存储引擎②. InnoDB和MyISAM对比③. Archive引擎 - 归档④. Blackhole引擎丢数据⑤. CSV - 引擎⑥. Memory引擎 - 内存表⑦. Federated引擎 - 访问远程表⑧. Merge引擎 - 管理多个MyISAM⑨. NDB引擎 - 集群专用 ①. 查看、修改存储引擎 ①. 查看mysql提…

Spring Boot原理分析(一):项目启动(上)——@SpringBootApplication

文章目录 〇、准备工作一、SpringBootApplication.java源码解析1.源码2.自定义注解3.组合注解4.注解ComponentScan过滤器 5.注解SpringBootConfigurationConfiguration 6.注解EnableAutoConfiguration 本文章是Spring Boot源码解读与原理分析系列博客的第一篇&#xff0c;将会介…

Mac(M1Pro)下运行ChatGLM2

最近很多人都尝试在M1/M2芯片下跑chatglm/chatglm2&#xff0c;结果都不太理想&#xff0c;或者是说要32G内存才可以运行。本文使用cpu基于chatglm-cpp运行chatglm2的int4版本。开了多个网页及应用的情况下&#xff08;包括chatglm2)&#xff0c;总体内存占用9G左右。chatglm2可…

PYTHON+YOLOV5+OPENCV,实现数字仪表自动读数,并将读数结果进行输出显示和保存

最近完成了一个项目&#xff0c;利用pythonyolov5实现数字仪表的自动读数&#xff0c;并将读数结果进行输出和保存&#xff0c;现在完成的7788了&#xff0c;写个文档记录一下&#xff0c;若需要数据集和源代码可以私信。 最后实现的结果如下&#xff1a; 项目过程 首先查阅文…

从单体到SpringBoot/SpringCloud微服务架构无感升级的最佳实践

目录导读 从单体到SpringBoot/SpringCloud微服务架构无感升级的最佳实践1. 业务背景2. 当前问题3. 升级方案3.1 架构设计4. 详细设计4.1 迁移阻碍4.2 解决思路 5. 实现过程5.1 认证兼容改造5.2 抽象业务流程5.2.1 抽象业务的思路5.2.2 抽象业务的抽象编码5.2.3 抽象业务的具体实…

BFF网关模式开发指南

BFF是近些年新衍生出来的一种开发模式&#xff0c;或者说是一种适配模式的系统&#xff0c;BFF全称为Backend OF Front意为后端的前端&#xff0c;为了适配微服务模式下前端后端系统接口调用混乱而出现的。在如今微服务盛行的趋势下&#xff0c;大型系统中划分出了数十个服务模…

前端优化的一些方向

对于浏览器来说&#xff0c;加载网页的过程可以分为两部分&#xff0c;下载文档并响应&#xff08;5%左右&#xff09;&#xff0c;下载各种组件&#xff08;95%左右&#xff09;。 而对比大部分优秀网页来说下载文档&#xff08;10%~ 20%&#xff09;&#xff0c;下载组件&…

23_7第一周LeetCode刷题回顾

目录 1. 两数之和2. 两数相加3.无重复字符的最长子串4.寻找两个正序数组的中位数5.最长回文子串6.N 形变换7.整数反转8.字符串转整数&#xff08;atoi&#xff09;9.回文数10. 正则表达式匹配11. 盛最多水的容器12. 整数转罗马数字13. 罗马数字转整数14. 最长公共前缀15.三数之…

MyBatis中的动态SQL(sql标签、where标签、set标签、批量增加与批量删除)

目录 sql标签 ​编辑 where标签 set标签 foreach标签 批量增加 批量删除 将基础SQL语句中重复性高的增加它的复用性&#xff0c;使得sql语句的灵活性更强 sql标签<sql> <sql id"text">select * from user</sql><select id"selectA…

如何在苹果商店发布App?

一、介绍 众所周知&#xff0c;苹果对于自家产品的安全问题十分重视&#xff0c;他们有严格的一套审核标准和流程&#xff0c;当我们想要在苹果商店发布一款App的时候就需要经过重重艰难险阻&#xff0c;克服不少繁杂的问题去完成这项工作。 另外有一点需要注意的是&#xff…

C语言库函数strcpy学习

strcpy是C语言的一个标准库函数&#xff1b; strcpy把含有\0结束符的字符串复制到另一个地址空间&#xff0c;返回值的类型为char*。 原型声明&#xff1a;char *strcpy(char* dest, const char *src); 头文件&#xff1a;#include <string.h> 和 #include <stdio.h&g…

领域驱动设计(三) - 快速开始 - 【3/3】事件风暴

使用DDD的最终目的是深入学习业务如何运作。然后基于学习试验、质疑、再学习和重建模的过程。过程中面临的最大挑战是如何快速学习&#xff0c;并且在保证学习质量的前提下压缩学习时间&#xff08;你的学习是需要公司付工资的&#xff09;。 事件风暴就是一种相对高效的分析工…

【电子学会】2023年05月图形化二级 -- 接水果

接水果 天上掉落各种水果下来&#xff0c;有草莓、苹果、香蕉&#xff0c;快拿大碗去接住水果吧。 1. 准备工作 &#xff08;1&#xff09;导入背景Blue Sky&#xff1b; &#xff08;2&#xff09;删除小猫角色&#xff0c;导入角色Bowl、Apple、Strawberry、Bananas。 2.…

【技能实训】DMS数据挖掘项目-Day03

文章目录 任务5【任务5.1】基础信息实体类【任务5.2.1】继承DataBase类&#xff0c;重构日志类【任务5.2.2】继承DataBase类&#xff0c;重构物流实体类【任务5.2.3】创建物流、日志测试类&#xff0c;测试任务5.2中的程序&#xff0c;演示物流信息、日志信息的采集及打印输出 …

【Redis】Transaction(事务)

&#x1f3af;前言 Redis事务是一个组有多个Redis命令的集合&#xff0c;这些命令可以作为一个原子操作来执行。 Redis事务通常用于以下两种情况&#xff1a; 保证操作的原子性&#xff1a;在多个命令的执行过程中&#xff0c;如果有一个命令执行失败&#xff0c;整个事务都需…