【gpt生成文本的回复的原理和代码,通俗思路清晰】

news2024/9/20 22:16:18

首先介绍了贪婪解码
其次为增家多样性,用温度系数和TopK增加采样
真实的采样步骤 1、topk备选tokens 2、用维度系数大于1让概率平衡一下,3.再用softmax,4.根据概率分布采样

1、贪婪解码

# 之前,我们总是使用torch.argmax采样最大概率的标记作为下一个标记。
import torch
vocab = { 
    "closer": 0,
    "every": 1, 
    "effort": 2, 
    "forward": 3,
    "inches": 4,
    "moves": 5, 
    "pizza": 6,
    "toward": 7,
    "you": 8,
} 

inverse_vocab = {v: k for k, v in vocab.items()}

# 假设input是 "every effort moves you", 模型返回的logits值为下面tensor中的数值:
next_token_logits = torch.tensor(
    [4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]
)

probas = torch.softmax(next_token_logits, dim=0)
next_token_id = torch.argmax(probas).item()

# 下一个标记:
print(inverse_vocab[next_token_id])
#

2、增加多样性

为了增加多样性,我们可以使用torch.multinomial(probs, num_samples=1)从概率分布中采样下一个标记。

# 是根据概率probs抽样tokens
torch.manual_seed(123)
sample = [torch.multinomial(probas, num_samples=1).item() for i in range(1_0)]
print(sample)
set(sample)

3、温度系数

“温度缩放”只是将logits除以一个大于0的数字的高级说法。
大于1的温度值:softmax后导致更均匀分布。
小于1的温度值: softmax(更尖锐或更高峰)的分布。

def softmax_with_temperature(logits, temperature):
    scaled_logits = logits / temperature
    return torch.softmax(scaled_logits, dim=0)

# Temperature values
temperatures = [1, 0.1, 5]  # Original, higher confidence, and lower confidence

# Calculate scaled probabilities
scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]

# Plotting
x = torch.arange(len(vocab))
bar_width = 0.15

fig, ax = plt.subplots()
for i, T in enumerate(temperatures):
    # 条形图的绘制,ax.bar()函数里面的参数分别为条形的x轴位置、高度、宽度、图例标签
    rects = ax.bar(x + i * bar_width, scaled_probas[i], bar_width, label=f'Temperature = {T}')

ax.set_ylabel('Probability')
ax.set_xticks(x)
ax.set_xticklabels(vocab.keys(), rotation=90)
ax.legend()

plt.tight_layout()
# plt.savefig("temperature-plot.pdf")
plt.show()

在这里插入图片描述

4、TopK备选

为了能够使用更高的温度来增加输出的多样性,并降低无意义句子出现的概率,我们可以将采样的标记限制在最可能的前k个标记中:
也就是在采样之前,只选topK备选的tokens,代码如下:

top_k = 3
top_logits, top_pos = torch.topk(next_token_logits, top_k)

print("Top logits:", top_logits)
print("Top positions:", top_pos)
# Top logits: tensor([6.7500, 6.2800, 4.5100])
# Top positions: tensor([3, 7, 0])

# 通过这步,余下的token 的概率为-inf
new_logits = torch.where(
    condition=next_token_logits < top_logits[-1],
    input=torch.tensor(float('-inf')), 
    other=next_token_logits
)

print(new_logits)
# tensor([4.5100,   -inf,   -inf, 6.7500,   -inf,   -inf,   -inf, 6.2800,   -inf])

# 3 然后softmax
topk_probas = torch.softmax(new_logits, dim=0)
print(topk_probas)

4 、归结为文本生成函数

def generate(model, idx, max_new_tokens, context_size, temperature, top_k=None):

    # 循环与之前相同:获取logits,并仅关注最后一步。
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        # 使用top_k采样对logits值进行过滤
        if top_k is not None:
            # 仅保留top_k的值
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)

        # 使用温度缩放
        if temperature > 0.0:
            logits = logits / temperature

            # 使用softmax函数得到概率
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)

            # 从概率分布中采样
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)

        # 否则和之前的generate_simple函数中的处理相同,使用argmax函数取得概率最大的token
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        # 和之前相同的序列拼接处理
        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2034696.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

年薪30万+,TOP大厂月薪10万+....网络安全工程师凭什么?

时代飞速发展&#xff0c;我们的工作、生活乃至整个社会的运转都越来越依赖于网络。也因此&#xff0c;网络的无处不在带来了前所未有的安全风险。 从个人隐私泄露到企业机密被盗&#xff0c;再到国家关键基础设施遭受攻击&#xff0c;网络安全问题无处不在&#xff0c;威胁着…

SQL之使用存储过程循环插入数据

1、已经创建了任务日志表 CREATE TABLE t_task_log (id bigint NOT NULL AUTO_INCREMENT,task_id bigint NOT NULL COMMENT 任务ID,read_time bigint NOT NULL COMMENT 单位秒&#xff0c;读取耗时,write_time bigint NOT NULL COMMENT 单位秒&#xff0c;写入耗时,read_size …

8月13日学习笔记 LVS

一.描述以及工作原理 1. 什么是LVS linux virtural server的简称&#xff0c;也就是linxu虚拟机服务器&#xff0c;这是一个 由章文嵩博士发起的开源项目&#xff0c;官网是 http://www.linuxvirtualserver.org,现在lvs已经是linux内核标 准的一部分&#xff0c;使用lvs可以达…

网络剪枝——network-slimming 项目复现

目录 文章目录 目录网络剪枝——network-slimming 项目复现clone 存储库Baselinevgg训练结果 resnet训练结果 densenet训练结果 Sparsityvgg训练结果 resnet训练结果 densenet训练结果 Prunevgg命令结果 resnet命令结果 densenet命令结果 Fine-tunevgg训练结果 resnet训练结果 …

5个小众宝藏软件看看有没有你喜欢的

冷门APP分享来啦&#xff0c;这5个小众宝藏软件看看有没有你喜欢的吧&#xff01; 1.space登月计划 从地球到月球的大概距离是3.84亿米&#xff0c;而登月得消耗掉大约3.2亿千卡的能量。一个人想单飞登月得花上万年。 但在space上&#xff0c;可以和小伙伴一起合作玩登月游戏…

记录Java使用websocket

实现场景&#xff1a;每在小程序中添加一条数据时&#xff0c;后台将主动推送一个标记给PC端&#xff0c;PC端接收到标记将进行自动播放音频。 import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import or…

GitHub 2FA中国认证教程

1. 问题描述 在github上有过代码贡献的账号在登录时需要进行2FA双重身份验证。 这是github官方给出的关于2FA的解释&#xff1a; 官方文章地址&#xff1a;点击进入 这是登录时2FA的验证界面&#xff1a; 我们需要使用扩展程序解析这个二维码拿到2FA验证码&#xff0c;填入二维…

python爬虫滑块验证及各种加密函数(基于ddddocr进行的一层封装)

git链接: https://github.com/JOUUUSKA/spider_toolsbox 这里写目录标题 一.识别验证码1、识别英文&#xff0b;数字验证码2、识别滑块验证码3、识别点选验证码 一.识别验证码 git链接: https://github.com/JOUUUSKA/spider_toolsbox 创作不易记得stars 1、识别英文&#xf…

Arduino控制带编码器的直流电机速度

Arduino DC Motor Speed Control with Encoder, Arduino DC Motor Encoder 作者 How to control dc motor with encoder:DC Motor with Encoder Arduino, Circuit Diagram:Driving the Motor with Encoder and Arduino:Control DC motor using Encoder feedback loop: How …

一文读懂Xinstall专属链接推广,轻松解决App运营痛点!

随着互联网的飞速发展&#xff0c;App推广和运营面临着前所未有的挑战。传统的营销方式已经难以适应多变的市场环境&#xff0c;而Xinstall专属链接推广应运而生&#xff0c;成为解决App获客难题的新利器。本文将深入探讨Xinstall专属链接推广如何帮助推广者触达更多用户&#…

MacOS vue-cli为2.9.6 无法升级的解决方案

背景 今天需要验证plop工具做前端工程化实践&#xff0c;打算使用vue3方式&#xff0c;结果发现vue-cli 2.9.6一直无法升级成功&#xff0c;也无法通过vue-cli生成vue3模板工程&#xff0c;测试了几把后&#xff0c;最终升级vue-cli成功&#xff0c;为了能给出现同样问题的小伙…

上瘾模型与产品激励系统

​产品要增加客户粘性&#xff0c;使产品深入人心就需要让用户对产品上瘾。如何使用户对产品上瘾&#xff1f;对于产品来说&#xff0c;就需要建立产品的激励系统。 产品的激励系统要做的事就是对用户进行激励&#xff0c;就是让用户主动完成产品或服务想要他们做的事情。 那么…

重启人生计划-勇敢者先行

&#x1f973;&#x1f973;&#x1f973; 茫茫人海千千万万&#xff0c;感谢这一刻你看到了我的文章&#xff0c;感谢观赏&#xff0c;大家好呀&#xff0c;我是最爱吃鱼罐头&#xff0c;大家可以叫鱼罐头呦~&#x1f973;&#x1f973;&#x1f973; 如果你觉得这个【重启人生…

分布式与微服务详解

1. 单机架构 只有一台机器&#xff0c;这个机器负责所有的工作 &#xff08;这里假定一个电商网站&#xff09; 现在大部分公司的产品都是单机架构 。 2. 分布式架构 一台机器的硬件资源是有限的&#xff0c;服务器处理请求是需要占用硬件资源的&#xff0c;如果业务增长&a…

前端学习笔记-JS篇-01

JS基础Day1-01-必看-基本软件以及准备工作_哔哩哔哩_bilibili JavaScript介绍 是什么 1.JavaScript (是什么?) 是一种运行在客户端(浏览器)的编程语言&#xff0c;实现人机交互效果2.作用(做什么?) 网页特效(监听用户的一些行为让网页作出对应的反馈)表单验证(针对表单…

streampark-使用记录-备忘

1、重新部署的任务会读历史配置&#xff08;包括错误配置&#xff09;&#xff0c;即使点击确认了也无效 解决&#xff1a;复制新的任务&#xff0c;修改ckeckpoint 路径&#xff08;重要&#xff09; 2、任务启动报错&#xff0c;即使后续把脚本改正确或者复制其他脚本过来执…

什么是 Java?

探索 Java&#xff0c;一种多功能且功能强大的编程语言。释放其构建强大应用程序的潜力。 前言 简单来说&#xff0c;Java 是一种用于开发软件应用程序的面向对象设计的编程语言。截至 2019 年&#xff0c;它是世界上最受欢迎的编程语言&#xff0c;尤其是因为它是开源的&#…

MySQL 的 InnoDB 缓冲池里有什么?--InnoDB存储梳理(二)

文章目录 缓冲池的配置介绍一张表 INNODB_BUFFER_POOL_PAGES字段解释 缓冲池的配置 以下配置的意思&#xff0c;缓冲池在内存中的大小为20M&#xff1b;只有1个缓冲池实例&#xff1b;每一块的大小&#xff0c;插入缓冲占的百分比 # InnoDB 缓存池配置 innodb_buffer_pool_si…

Spring Boot 3.x Web单元测试最佳实践

上一篇&#xff1a;Spring Boot 3.x Rest API统一异常处理最佳实践 下一篇&#xff1a;Spring Boot 3.x Filter实战&#xff1a;记录请求日志 Spring Boot为我们提供了非常便捷的web层Rest API单元测试的API&#xff0c;这种开发能力也是小伙伴必须要掌握的。如何对数据库、中…

【简历】扬州某一本大学:前端秋招简历指导,面试通过率低

注&#xff1a;为保证用户信息安全&#xff0c;姓名和学校等信息已经进行同层次变更&#xff0c;内容部分细节也进行了部分隐藏 简历说明 这是25届一本前端同学的简历。这是一个老牌一本学校&#xff0c;老牌一本定位求职层次&#xff0c;可以从传统的中厂上升到大厂。学历可以…