理解 logits_to_keep = logits_to_keep + 1 在 _get_per_token_logps 中的作用

news2025/2/24 14:21:21

理解 logits_to_keep = logits_to_keep + 1_get_per_token_logps 中的作用

source: anaconda3/envs/xxx/lib/python3.10/site-packages/trl/trainer/grpo_trainer.py

 def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
        # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
        logits = model(
            input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
        ).logits  # (B, L, V)
        logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred

        input_ids = input_ids[:, -logits_to_keep:]
        # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
        # See https://github.com/huggingface/trl/issues/2770
        logits = logits[:, -logits_to_keep:]

        # Compute the log probabilities for the input tokens.
        token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
        # use a loop to reduce memory peak
        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
        token_log_probs = token_logits - logsumexp_values  # log_softmax = logits - log(sum(exp(logits)))
        return token_log_probs

_get_per_token_logps 这个函数中,logits_to_keep 控制了要保留的 logits 数量,用于计算每个 token 的对数概率。
但这里有一个关键点:

logits_to_keep = logits_to_keep + 1

为什么需要加 1?
因为在 Transformer 语言模型(如 GPT)中,模型的 logits 预测的是下一个 token,所以如果我们只保留 logits_to_keeplogits数量是不够的
为了确保对齐,我们先多取一个 logits,然后再手动丢弃最后一个 logits,这样 logitsinput_ids 就能正确对齐。


1. 为什么需要 logits_to_keep + 1

1.1 自回归模型的 logits 预测的是下一个 token

在 Transformer 语言模型中,模型的 logits 形状通常是:

logits.shape = (B, L, V)

其中:

  • B:batch_size
  • L:序列长度
  • V:词表大小(vocab size)

模型在生成 logits 时,每个 logits[i] 实际上是用于预测下一个 token,而不是当前 token:

logits[:, 0, :]  ->  用于预测 input_ids[:, 1]
logits[:, 1, :]  ->  用于预测 input_ids[:, 2]
...
logits[:, L-1, :]  ->  用于预测 input_ids[:, L](即下一个 token)

input_ids 只包含当前 token,并不包含 “下一个 token” 的真实值,因此我们需要手动去掉最后一个 logits,让它和 input_ids 对齐。


2. 代码执行步骤

2.1 假设 input_ids.shape = (1, 5)

假设 logits_to_keep = 3,那么:

  • logits_to_keep + 1 = 4,即多取一个 logits
  • 模型返回的 logits.shape = (1, 6, V),因为 logits_to_keep+1=4,再加上可能的 padding,会得到 6 个 logits

2.2 关键代码

步骤 1:调用模型
logits = model(
    input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits

此时 logits.shape = (B, L, Vocab),其中 L = logits_to_keep + 1

步骤 2:删除最后一个 logits
logits = logits[:, :-1, :]

这样 logits 的形状就变成 (B, L-1, V),让它正确对应 input_ids[:, -logits_to_keep:]

步骤 3:对齐 input_ids
input_ids = input_ids[:, -logits_to_keep:]

这里 input_ids[:, -logits_to_keep:] 取的是最后 logits_to_keep 个 token,确保 logitsinput_ids 一一对应。


3. 示例代码

3.1 假设 input_ids = [5, 8, 2, 3, 9]logits_to_keep = 3

logits_to_keep + 1 让模型生成 4logits
logits.shape = (1, 5, V)  # 5 个 token,分别预测下一个 token
Token真实 input_idslogits 预测
15用于预测 8
28用于预测 2
32用于预测 3
43用于预测 9
59(无用,预测下一个 token)
② 手动删除最后一个 logits
logits = logits[:, :-1, :]  # 丢弃最后一个预测

最终 logits 形状:

logits.shape = (1, 4, V)  # 只保留前 4 个 logits

这样 logitsinput_ids[:, -logits_to_keep:] 对齐:

logits  对应 input_ids = [8, 2, 3]

4. 如果不加 +1 会发生什么?

如果 logits_to_keep 不加 1,那么:

  • logits 数量input_ids 少 1 个,导致维度对不上。
  • 计算 log_probslogits.gather(dim=-1, index=input_ids.unsqueeze(-1)) 会报错,或者索引到错误的 logits。

5. 结论

步骤目的
logits_to_keep + 1获取一个额外的 logits,避免数据对不齐
logits[:, :-1, :]删除最后一个 logits,确保与 input_ids 对齐
input_ids[:, -logits_to_keep:]选取最后 logits_to_keep 个 token 计算 log_probs

核心逻辑

因为 logits 预测的是下一个 token,所以要多取 1 个,然后手动删除最后一个
这样 logitsinput_ids 维度对齐,确保计算正确的 log_probs

🚀 理解这个逻辑对于实现 Transformer 语言模型的 loss 计算至关重要! 🚀

如果 logits_to_keep 不加 +1 会发生什么?

假设:

  • input_ids = [5, 8, 2, 3, 9]
  • logits_to_keep = 3
  • logits.shape = (B, L, V), 其中 L=5,表示 5 个 token,每个 token 的 logits 是一个 Vocab 大小的概率分布。

1. 正确做法(logits_to_keep + 1

如果 logits_to_keep + 1

  • logits_to_keep = 3 + 1 = 4
  • 让模型输出 4 个 logits,即:
    logits.shape = (1, 4, V)
    
  • 然后 删除最后一个 logitslogits[:, :-1, :]),得到:
    logits.shape = (1, 3, V)  # 3 个 logits,对应 input_ids 的最后 3 个 token
    
  • 此时 logitsinput_ids[:, -3:] = [8, 2, 3] 维度匹配,可以正确计算 log_probs

2. 错误示例(如果不加 +1

如果不加 +1,直接 logits_to_keep = 3,那么:

  • 模型只会返回 3logits
    logits.shape = (1, 3, V)  # 只保留 3 个 logits
    
  • 然后 logits[:, :-1, :] 会让 logits 变成:
    logits.shape = (1, 2, V)  # 只有 2 个 logits
    
  • input_ids[:, -logits_to_keep:] 仍然是:
    input_ids[:, -3:] = [8, 2, 3]  # 3 个 token
    
  • 这样,gather 操作:
    token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
    
    将会报错,因为 logits.shape = (1, 2, V),但 input_ids.shape = (1, 3),维度不匹配!
错误示例代码
import torch

# 模拟 logits (batch_size=1, sequence_length=2, vocab_size=5)
logits = torch.tensor([
    [[2.0, 1.0, 0.5, -1.0, 0.2],  # logit for token 8
     [0.1, -0.5, 2.2, 1.5, 0.0]]  # logit for token 2
])  # shape = (1, 2, 5)

# input_ids 仍然有 3 个 token
input_ids = torch.tensor([[8, 2, 3]])  # shape = (1, 3)

# 错误的 gather 操作
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

错误信息:

RuntimeError: Expected index with dimension 3, but got dimension 4 for input tensor.

这个错误表明 logits 只有 2 个 token,而 input_ids 仍然有 3 个 token,导致 gather 操作失败!


3. 错误情况总结

情况logits.shape (B, L, V)input_ids.shape (B, L)是否匹配?
正确:logits_to_keep + 1 后删掉最后一个 logits(1, 3, V)(1, 3)匹配
错误:不加 +1(1, 2, V)(1, 3)不匹配,报错

🔴 结论:
如果不加 +1,最终 logits 会比 input_ids 少 1 个 token,导致 gather 操作失败,无法正确计算 log_probs


4. 关键结论

  1. logits_to_keep + 1 确保 logits 先比 input_ids 多一个,然后删掉最后一个 logits,使两者对齐。
  2. 不加 +1,最终 logitsinput_ids 少 1 个,导致 gather 维度错误,代码会报错。
  3. 在自回归模型中,logits 预测的是下一个 token,所以要手动调整,以确保 logitsinput_ids 一一对应。

🚀 正确理解 logits_to_keep + 1 是构建 Transformer 语言模型损失计算的关键! 🚀

如果不加 +1,可以不执行 logits = logits[:, :-1, :] 吗?

不可以!如果不加 +1,并且 不执行 logits[:, :-1, :] 这个操作,最终 logitsinput_ids 的对齐仍然会出问题,导致错误的 token 对数概率计算。


1. 代码逻辑分析

1.1 logits_to_keep + 1 的作用

logits = model(
    input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1
).logits  # (B, L, V)
  • logits_to_keep + 1 让模型输出logits_to_keep 多 1 个 logits
  • 这样,logits.shape = (B, L+1, V)(多 1 个 token 预测的 logits)。

1.2 logits[:, :-1, :] 的作用

logits = logits[:, :-1, :]  # (B, L-1, V)
  • 这一步 删除最后一个 logits,确保 logits 只用于计算 input_ids 对应 token 的概率。
  • 如果不执行这一步,则 logits.shape = (B, L, V),这就会导致 logitsinput_ids 多 1 个 token,维度不匹配。

1.3 input_ids[:, -logits_to_keep:] 作用

input_ids = input_ids[:, -logits_to_keep:]
  • 这一步 只保留 logits_to_keep 个 token 的 input_ids,确保 input_idslogits 维度匹配。

2. 如果不加 +1,但仍然执行 logits[:, :-1, :],会发生什么?

如果 logits_to_keep 没有 +1,但仍然执行:

logits = logits[:, :-1, :]
  • logits 数量会比 input_ids 少 1 个。
  • logits.shape = (B, logits_to_keep - 1, V)
  • input_ids.shape = (B, logits_to_keep)

这会导致:

token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

报错,因为 logits.shape[1]input_ids.shape[1] 不匹配!


3. 如果不加 +1,并且不执行 logits[:, :-1, :],会发生什么?

假设 logits_to_keep = 3,并且不加 +1,那么:

logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep).logits
  • logits.shape = (B, logits_to_keep, V)
  • input_ids[:, -logits_to_keep:] 仍然是 (B, logits_to_keep)
  • logitsinput_ids 维度看似匹配,但实际上错位了!

错位的原因:

  • logits[:, i, :] 对应的是 input_ids[:, i+1](预测的是下一个 token),而不是 input_ids[:, i]
  • 这会导致 gather 取到错误的 logits,计算的 log_probs 也是错的。
示例

假设:

input_ids = [[5, 8, 2, 3, 9]]  # 长度 5
logits_to_keep = 3

如果 不加 +1,且不 logits[:, :-1, :]

logits[:, 0, :]  # 实际预测 input_ids[:, 1] (8)
logits[:, 1, :]  # 实际预测 input_ids[:, 2] (2)
logits[:, 2, :]  # 实际预测 input_ids[:, 3] (3)  ❌ 但被错误匹配到 input_ids[:, 2]

最终 gather 取到的是错位的 logits!


4. 结论

情况logits.shapeinput_ids.shape结果
正确:加 +1 并执行 logits[:, :-1, :](B, logits_to_keep, V)(B, logits_to_keep)匹配正确
错误:不加 +1,但仍然执行 logits[:, :-1, :](B, logits_to_keep - 1, V)(B, logits_to_keep)维度不匹配,gather 报错
错误:不加 +1,且不执行 logits[:, :-1, :](B, logits_to_keep, V)(B, logits_to_keep)错位,计算错误的 log_probs

核心总结

  1. logits_to_keep + 1logits 先多 1 个,再删掉最后 1 个,以正确对齐 input_ids
  2. 如果不 +1,但仍然 logits[:, :-1, :],最终 logitsinput_ids 少 1 个,导致 gather 失败。
  3. 如果不 +1,且不 logits[:, :-1, :],最终 logitsinput_ids 看似匹配,但会错位,计算错误的 log_probs

🚀 正确理解 logits_to_keep + 1 是确保 Transformer 语言模型 log_prob 计算正确的关键! 🚀

后记

2025年2月21日19点32分于上海。在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2304431.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

论文笔记-WSDM2025-ColdLLM

论文笔记-WSDM2025-Large Language Model Simulator for Cold-Start Recommendation ColdLLM&#xff1a;用于冷启动推荐的大语言模型模拟器摘要1.引言2.前言3.方法3.1整体框架3.1.1行为模拟3.1.2嵌入优化 3.2耦合漏斗ColdLLM3.2.1过滤模拟3.2.2精炼模拟 3.3模拟器训练3.3.1LLM…

基于 Python Django 的校园互助平台(附源码,文档)

博主介绍&#xff1a;✌Java徐师兄、7年大厂程序员经历。全网粉丝13w、csdn博客专家、掘金/华为云等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;&#x1f3fb; 不…

智慧废品回收小程序php+uniapp

废品回收小程序&#xff1a;数字化赋能环保&#xff0c;开启资源循环新时代 城市垃圾治理难题&#xff0c;废品回收小程序成破局关键 随着城市化进程加速与消费水平提升&#xff0c;我国生活垃圾总量逐年攀升&#xff0c;年均增速达5%-8%&#xff0c;其中超30%为可回收物。然…

网页版的俄罗斯方块

1、新建一个txt文件 2、打开后将代码复制进去保存 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>俄…

创建虚拟环境以及配置对应的项目依赖

文章目录 首先创建一个虚拟环境&#xff0c;创建一个名字为myenv,并且版本为xxx的虚拟环境 conda create --name myenv pythonxxx激活虚拟环境 conda activate myenv下载所需的依赖&#xff0c;如果有requirements.txt文件 pip install -r requirements.txt容易出现的错误&a…

网络安全第三次练习

一、实验拓扑 二、实验要求 配置真实DNS服务信息&#xff0c;创建虚拟服务&#xff0c;配置DNS透明代理功能 三、需求分析 1.创建用户并配置认证策略 2.安全策略划分接口 3.ip与策略配置 四、实验步骤 1.划分安全策略接口 2.创建用户并进行策略认证 3.配置安全策略 4.NAT配…

写大论文的word版本格式整理,实现自动生成目录、参考文献序号、公式序号、图表序号

前情提要&#xff1a;最近开始写大论文&#xff0c;发现由于内容很多导致用老方法一个一个改的话超级麻烦&#xff0c;需要批量自动化处理&#xff0c;尤其是序号&#xff0c;在不断有增添删减的情况时序号手动调整很慢也容易出错&#xff0c;所以搞一个格式总结&#xff0c;记…

STM32——HAL库开发笔记22(定时器3—呼吸灯实验)(参考来源:b站铁头山羊)

本文利用前几节所学知识来实现一个呼吸灯实验&#xff1a;两颗led灯交替呼吸。 一、STM32CubeMX配置 step1&#xff1a;配置调试接口 step2&#xff1a;配置定时器 定时器1位于APB2总线上&#xff0c;如上图所示。 step3&#xff1a;配置时基单元 按照下图配置 时钟来源配置…

玩转 Java 与 Python 交互,JEP 库来助力

文章目录 玩转 Java 与 Python 交互&#xff0c;JEP 库来助力一、背景介绍二、JEP 库是什么&#xff1f;三、如何安装 JEP 库&#xff1f;四、JEP 库的简单使用方法五、JEP 库的实际应用场景场景 1&#xff1a;数据处理场景 2&#xff1a;机器学习场景 3&#xff1a;科学计算场…

【单片机毕业设计14-基于stm32c8t6的智能宠物养护舱系统设计】

【单片机毕业设计14-基于stm32c8t6的智能宠物养护舱系统设计】 前言一、功能介绍二、硬件部分三、软件部分总结 前言 &#x1f525;这里是小殷学长&#xff0c;单片机毕业设计篇14-基于stm32c8t6的智能宠物养护舱系统设计 &#x1f9ff;创作不易&#xff0c;拒绝白嫖可私 一、功…

DevEco Studio常用快捷键以及如何跟AndroidStudio的保持同步

DevEco Studio快捷键 DevEco Studio是华为推出的用于开发HarmonyOS应用的集成开发环境&#xff0c;它提供了丰富的快捷键以提高开发效率&#xff0c;以下为你详细介绍不同操作场景下的常用快捷键&#xff1a; 通用操作快捷键 操作描述Windows/Linux 快捷键Mac 快捷键打开设置窗…

[Windows] 全国油价实时查询,可具体到城市

[Windows] 全国油价实时查询&#xff0c;可具体到城市 链接&#xff1a;https://pan.xunlei.com/s/VOJnS3aOPeBwGaSvS0O0E1hwA1?pwdx83j# 出于代码练习的目的&#xff0c;调用公共免费api做的py程序&#xff0c;已经一键打包&#xff0c;双击启动即可 使用&#xff1a;选择…

【CSS】---- CSS 变量,实现样式和动画函数复用

1. 前言 本文介绍 CSS 的自定义属性(变量)来实现样式、动画等 CSS 的复用。都是知道在 CSS 和 JS 复用一个很重要的事情,比如 JS 的函数封装,各个设计模式的使用等等,CSS 中样式的复用,同样重要。MDN 使用 CSS 自定义属性(变量):自定义属性(有时候也被称作CSS 变量或…

装修流程图: 装修前准备 → 设计阶段 → 施工阶段 → 安装阶段 → 收尾阶段 → 入住

文章目录 引言I 毛坯房装修的全流程**1. 装修前准备****1.1 确定装修预算****1.2 选择装修方式****1.3 选择装修公司****1.4 办理装修手续****2. 设计阶段****2.1 量房****2.2 设计方案****2.3 确认方案****3. 施工阶段****3.1 主体拆改****3.2 水电改造****3.3 防水工程****3.…

【论文解读】《Training Large Language Models to Reason in a Continuous Latent Space》

论文链接 1. 背景与动机 语言空间与推理的矛盾 目前大多数大语言模型&#xff08;LLMs&#xff09;在解决复杂问题时采用链式思维&#xff08;Chain-of-Thought, CoT&#xff09;方法&#xff0c;即利用自然语言逐步推导出答案。然而&#xff0c;论文指出&#xff1a; 自然语言…

深度剖析 C 语言函数递归:原理、应用与优化

在 C 语言的函数世界里&#xff0c;递归是一个独特且强大的概念。它不仅仅是函数调用自身这么简单&#xff0c;背后还蕴含着丰富的思想和广泛的应用。今天&#xff0c;让我们跟随这份课件&#xff0c;深入探索函数递归的奥秘。 一、递归基础&#xff1a;概念与思想 递归是一种…

goredis常见基础命令

基本操作 //删除键 exists,err: rdb.Exists(ctx,"key").Result() if err!nil{panic(err) } if exists>0{err rdb.Del(ctx,"key").Err()if err!nil{panic(err)} }string类型 //设置一个键值对 //0表示没有过期时间 err:rdb.Set(ctx,"key1",…

【Linux网络】序列化、守护进程、应用层协议HTTP、Cookie和Session

⭐️个人主页&#xff1a;小羊 ⭐️所属专栏&#xff1a;Linux 很荣幸您能阅读我的文章&#xff0c;诚请评论指点&#xff0c;欢迎欢迎 ~ 目录 1、序列化和反序列化2、守护进程2.1 什么是进程组&#xff1f;2.2 什么是会话&#xff1f; 3、应用层协议HTTP3.1 HTTP协议3.2 HT…

system verilog的流操作符

流操作符&#xff0c;有分为操作对象是一整个数组和单独的数据两种&#xff0c;例如bit [7:0] a[4]和bit [31:0] b&#xff0c;前者操作对象是数组&#xff0c;后者是单独一个较大位宽的数。 流操作符有<<和>>&#xff0c;代表从右向左打包和从左向右打包。 打包的…

LLM2CLIP论文学习笔记:强大的语言模型解锁更丰富的视觉表征

1. 写在前面 今天分享的一篇论文《LLM2CLIP: P OWERFUL L ANGUAGE M ODEL U NLOCKS R ICHER V ISUAL R EPRESENTATION》&#xff0c; 2024年9月微软和同济大学的一篇paper&#xff0c; 是多模态领域的一篇工作&#xff0c;主要探索了如何将大模型融合到Clip模型里面来进一步提…