浅谈LLAMA2核心函数generate源码

news2024/11/26 6:19:45

在学习LLAMA2的generate源码之前,先介绍Temperature超参数及sample_top_p的原理。

Temperature

Temperature 是一个超参数,可用于控制生成语言模型中生成文本的随机性和创造性。用于调整模型的softmax输出层中预测词的概率。

softmax函数:
p ( x i ) = e x i ∑ j = 1 V e x j p\left(x_i\right)=\frac{e^{x_i}}{\sum_{j=1}^V e^{x_j}} p(xi)=j=1Vexjexi

Temperature 参数(T)添加到softmax函数:
p ( x i ) = e x i T ∑ j = 1 V e x j T p\left(x_i\right)=\frac{e^{\frac{x_i}{T}}}{\sum_{j=1}^V e^{\frac{x_j}{T}}} p(xi)=j=1VeTxjeTxi
Temperature参数通常设置为 0.1 到 1.0 之间(T=1时形变为标准的Softmax函数),下图分别显示了 x i / T x_i/T xi/T在5:0.5和5:0.1时的图像(紫线为softmax,黑线为添加T参数的softmax),可以看到:

  • 当T值更大时,函数图像会变的更加的平缓,预测词的概率被拉平,这意味着所有词被选择的可能性更大。 这会产生更有创意和多样化的文本,因为模型更有可能生成不寻常或意想不到的词。

  • 当T值更小时,函数图像会变的更加的陡峭,预测词的概率会变尖锐,这意味着选择最有可能的词的概率更高。 这会产生更保守和可预测的文本,因为模型不太可能生成意想不到或不寻常的词。

在这里插入图片描述

x i / T x_i/T xi/T=5:0.5

在这里插入图片描述

x i / T 5 = 0.1 x_i/T5=0.1 xi/T5=0.1

小结:Temperature 参数是文本生成模型中用于控制生成文本的随机性和创造性的一个重要的超参数。

sample_top_p

在这里插入图片描述

平缓和陡峭的概率分布图-文献【2】

采样意味着根据当前条件概率分布随机选择输出词 ,使用采样方法时文本生成本身不再是确定性的。对单词序列进行采样时的大问题: 模型通常会产生不连贯的乱码。在LLAMA2中,缓解这一问题的方式是通过top_p(也称:nucleus sampling)

def sample_top_p(probs, p):
    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
    probs_sum = torch.cumsum(probs_sort, dim=-1)
    mask = probs_sum - probs_sort > p
    probs_sort[mask] = 0.0
    # 归一化
    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
    # multinomial为多项式抽样函数
    next_token = torch.multinomial(probs_sort, num_samples=1)
    next_token = torch.gather(probs_idx, -1, next_token)
    return next_token_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

sample_top_p函数的作用:每个时间步,按照字出现的概率由高到底排序,当概率之和大于top-p的时候,就不取后面的样本了。然后对取到的这些字的概率重新归一化后,进行采样。这样做的好处是,既保证了质量,又增加了适当的随机性。

核心函数generate()

这一块直接在代码中进行注释:

def generate(
        self,
        prompt_tokens: List[List[int]],  # 输入的提示
        max_gen_len: int,  # 最大生成长度
        temperature: float = 0.6,  # 影响生成文本的随机性
        top_p: float = 0.9,  # 用于决定采样过程中保留的 token 集合的概率阈值
        logprobs: bool = False,  # 是否返回每个 token 的对数概率
        echo: bool = False,  # 是否返回输入的提示
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
    # ---------------------------初始化长度为 total_len tokens张量,并填充 pad_id----------------------------------
    params = self.model.params
    bsz = len(prompt_tokens)
    assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)

    min_prompt_len = min(len(t) for t in prompt_tokens)
    max_prompt_len = max(len(t) for t in prompt_tokens)
    assert max_prompt_len <= params.max_seq_len
    total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

    pad_id = self.tokenizer.pad_id
    tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
    # 将prompt_tokens中的token复制到tokens张量中。
    for k, t in enumerate(prompt_tokens):
        tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
    if logprobs:
        # 创建一个与tokens相同形状的token_logprobs张量,并用0填充
        token_logprobs = torch.zeros_like(tokens, dtype=torch.float)

    prev_pos = 0
    eos_reached = torch.tensor([False] * bsz, device="cuda")
    input_text_mask = tokens != pad_id
    # -------------------------------------------------------------

    for cur_pos in range(min_prompt_len, total_len):
        # 调用模型的forward方法获取logits
        logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if logprobs:
            # 计算token level的logprobs
            token_logprobs[:, prev_pos + 1: cur_pos + 1] = -F.cross_entropy(
                input=logits.transpose(1, 2),
                target=tokens[:, prev_pos + 1: cur_pos + 1],
                reduction="none",
                ignore_index=pad_id,
            )
        # 根据温度参数和top_p参数对logits进行softmax和采样,得到下一个token
        if temperature > 0:
            # sample_top_p函数对probs进行采样
            probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
            next_token = sample_top_p(probs, top_p)
        else:
            # 将logits中概率最大的token作为下一个token。
            next_token = torch.argmax(logits[:, -1], dim=-1)

        next_token = next_token.reshape(-1)
        # only replace token if prompt has already been generated
        next_token = torch.where(
            input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
        )
        # tokens张量更新
        tokens[:, cur_pos] = next_token
        eos_reached |= (~input_text_mask[:, cur_pos]) & (
                next_token == self.tokenizer.eos_id
        )
        prev_pos = cur_pos
        # 检查是否已经生成了所有的eos token,如果是则停止生成
        if all(eos_reached):
            break

    if logprobs:
        # token_logprobs列表化
        token_logprobs = token_logprobs.tolist()
    out_tokens, out_logprobs = [], []
    for i, toks in enumerate(tokens.tolist()):
        # cut to max gen len
        # 对于 tokens 张量中的每一行(即每一个生成的序列),如果 echo 参数为假,则去掉提示部分
        start = 0 if echo else len(prompt_tokens[i])
        toks = toks[start: len(prompt_tokens[i]) + max_gen_len]
        probs = None
        if logprobs:
            probs = token_logprobs[i][start: len(prompt_tokens[i]) + max_gen_len]
        # cut to eos tok if any
        # 存在结束标记,则去掉结束标记之后的部分
        if self.tokenizer.eos_id in toks:
            eos_idx = toks.index(self.tokenizer.eos_id)
            toks = toks[:eos_idx]
            probs = probs[:eos_idx] if logprobs else None
        out_tokens.append(toks)
        out_logprobs.append(probs)
    # 返回生成的tokens和对数概率(如果logprobs参数为真)
    return (out_tokens, out_logprobs if logprobs else None)

总结

本文介绍了Temperature以及sample_top_p的原理,并且阅读了LLAMA2的核心生成函数的源码。关于更多细节实现,请关注llama源码。

参考文献

【1】https://github.com/facebookresearch/llama/blob/main/llama/generation.py

【2】The Curious Case of Neural Text Degeneration

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/867743.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

JDK内置SPI机制、服务提供发现机制

SPI的全称是Service Provider Interface服务提供接口&#xff0c;是JDK内置的一种 服务提供发现机制&#xff0c;例如我们常用的数据库驱动Driver&#xff0c;就是基于SPI来做的。 运行机制&#xff1a; 服务的调用方需要调用服务提供方的服务&#xff0c;如果在调用方中直接…

插入、希尔、归并、快速排序(java实现)

目录 插入排序 希尔排序 归并排序 快速排序 插入排序 排序原理&#xff1a; 1.把所有元素分为两组&#xff0c;第一组是有序已经排好的&#xff0c;第二组是乱序未排序。 2.将未排序一组的第一个元素作为插入元素&#xff0c;倒序与有序组比较。 3.在有序组中找到比插入…

大语言模型之一 Attention is all you need ---Transformer

大语言模型已经在很多领域大显身手&#xff0c;其应用包括只能写作、音乐创作、知识问答、聊天、客服、广告文案、论文、新闻、小说创作、润色、会议/文章摘要等等领域。在商业上模型即产品、服务即产品、插件即产品&#xff0c;任何形态的用户可触及的都可以是产品&#xff0c…

面试题:ArrayList扩容时扩容多少?

大家好&#xff0c;我是你们的小米&#xff01;今天要和大家一起来探讨一个在Java面试中经常被问到的问题&#xff1a;“ArrayList扩容时扩容多少&#xff1f;”相信很多小伙伴都在面试中遇到过这个问题&#xff0c;那么接下来&#xff0c;我就为大家详细解析一下这个问题&…

Vue3+Ts+Vite项目全局配置Element-Plus主题色

概述 我找了很多博客&#xff0c;想全局配置Elmenet-Plus组件主题色&#xff0c;但都没有效果。所以有了这篇博客&#xff0c;希望能对你有所帮助&#xff01;&#xff01;&#xff01; 文章目录 概述一、先看效果二、创建全局颜色文件2.1 /src/styles 下新建 element-plus.sc…

C#应用处理传入参数 - 开源研究系列文章

今天介绍关于C#的程序传入参数的处理例子。 程序的传入参数应用比较普遍&#xff0c;特别是一个随操作系统启动的程序&#xff0c;需要设置程序启动的时候不显示主窗体&#xff0c;而是在后台运行&#xff0c;于是就有了传入参数问题&#xff0c;比如传入/h或者/min等等。所以此…

湘大 XTU OJ 1260 Completed String 题解(非常详细):建立数组下标和数组元素之间的映射关系 ~scanf

一、链接 1260 Completed String 二、题目 题目描述 给一个字符串&#xff0c;请判断字符串是否出现了所有的英文字母&#xff08;不区分大小写&#xff09;。 输入 每行一个只含英文字母的字符串&#xff0c;长度不超过1000。 输出 每行输出一个样例的结果&#xff0c…

SpringBoot案例-部门管理-新增

根据页面原型&#xff0c;明确需求 页面原型 需求 阅读接口文档 接口文档链接如下&#xff1a; 【腾讯文档】SpringBoot案例所需文档 https://docs.qq.com/doc/DUkRiTWVaUmFVck9N 思路分析 前端在输入要新增的部门名称后&#xff0c;会以JSON格式将数据传入至后端&#xf…

基于Python实现的有限元方程求解程序附源码

问题描述 根据已知下列非齐次两点边值问题(1.2.28) { L u − d d x ( p d u d x ) q u f , a < x < b , u ( a ) α , u ′ ( b ) β , \begin{cases} \boldsymbol{L} u-\frac{\mathrm{d}}{\mathrm{d} x}\left(p \frac{\mathrm{d} u}{\mathrm{~d} x}\right)q uf, a…

markdown命令模板

markdown快速入门(typora) 1、代码块 //代码块语 public static void main(String[] args){}//linux下spring项目的启动命令 # java -jar blog start ## 2、标题&#xff1a;java # 一级标题 ## 二级标题 ### 三级标题 #### 四级标题 ##### 五级标题 ###### 六级标题3、字体 …

STM32 LL库+STM32CubeMX--点亮板载LED

一、前期准备 硬件&#xff1a;STM32F103C8T6开发板调试工具&#xff1a;DAPLink(本次使用)或USB-TTL开发环境&#xff1a;STM32CubeMX、Keil、Vscode(可选)板载LED&#xff1a;PC13(低电平点亮) 二、STM32CubeMX配置 1.选择芯片型号&#xff1a; 2.配置外设时钟&#xff1a…

步入React正殿 - 事件处理

目录 扩展学习资料 React事件和DOM事件 和传统DOM事件处理异同 this关键字的处理 this关键字 在JSX中使用bind方法 在构造函数中使用bind方法 使用箭头函数【推荐】 向事件处理程序传递参数【不跨组件】 向父组件传递参数 /src/App.js /src/components/listItem.jsx…

【MySQL--->数据库基础】

文章目录 [TOC](文章目录) 一、基本概念二、实际应用中的数据库三、mysql的架构四、mysql语句分类五、存储引擎查看 一、基本概念 mysql本质是一个CS模式的网络服务,mysql是客户端,mysqld是服务端,提供高效的数据存取方案.数据库系统简单来说是一个数据集合加上管理这个数据集…

Java旋转数组中的最小数字(图文详解版)

目录 1.题目描述 2.题解 分析 具体实现 方法一&#xff08;遍历&#xff09;&#xff1a; 方法二&#xff08;排序&#xff09;&#xff1a; 方法三&#xff08;二分查找&#xff09;&#xff1a; 1.题目描述 有一个长度为 n 的非降序数组&#xff0c;比如[1,2,3,4,5]&a…

【LeetCode 75】第二十六题(394)字符串解码

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码运行结果&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 给我们字符串&#xff0c;让我们解码&#xff0c;那么该怎么解码呢&#xff0c;被括号【】包裹起来的字符串需要扩展成括号左边第…

一百五十一、Kettle——Linux上安装的kettle8.2开启carte服务以及配置子服务器

一、目的 kettle8.2在Linux上安装好可以启动界面、并且可以连接MySQL、Hive、ClickHouse等数据库后&#xff0c;准备在Linux上启动kettle的carte服务 二、实施步骤 &#xff08;一&#xff09;carte服务文件路径 kettle的Linux运行的carte服务文件是carte.sh &#xff08;二…

grafana部署

一、前言 grafana是一款用于将prometheus收集的数据通过ui展示出来的组件&#xff0c;可以直观的看到每个数据的情况和指标&#xff0c;grafana有很多的ui展示模板可以使用 二、部署 这里我使用docker部署 先查找一下镜像 docker search grafana 创建存放grafana数据的目录…

C++初阶之一篇文章教会你list(理解和使用)

list&#xff08;理解和使用&#xff09; 什么是list特点和优势基本操作示例用法与其他序列式容器&#xff08;如 std::vector 和 std::deque&#xff09;相比&#xff0c;std::list 显著的区别和优势成员类型 list构造函数1. default (1)2. fill (2)3.range (3)4. copy (4) li…

无涯教程-Perl - opendir函数

描述 此函数使用readdir函数打开目录EXPR,并将其与DIRHANDLE关联以进行处理。 语法 以下是此函数的简单语法- opendir DIRHANDLE, EXPR返回值 如果成功,此函数将返回true。 例 以下是显示其基本用法的示例代码- #!/usr/bin/perl -w$dirname"/tmp";opendir ( …

MySQL~事务的四大特性和隔离级别

事务的四大特性 1.原子性&#xff1a;一个事务&#xff08;transaction&#xff09;中的所有操作&#xff0c;要么全部完成&#xff0c;要么全部不完成。事务在执行过程中发生错误&#xff0c;会被回滚&#xff08;Rollback&#xff09;到事务开始前的状态&#xff0c;就像这个…