从FasterTransformer源码解读开始了解大模型(2.1)代码通读03

news2025/1/12 7:50:47

从FasterTransformer源码解读开始了解大模型(2.2)代码解读03-forward函数

写在前面的话

本篇的内容继续解读forward函数,从650行开始进行解读

零、输出Context_embeddings和context_cum_log_probs的参数和逻辑

从653行开始,会从输入的请求tensors中读取一个配置,如果请求中配置了is_return_context_embeddings参数并设置为true时,则会在返回参数中增加一个context_embeddings的tensor,这个tensor中包含的数据是输入经过了ContextDecoder过程的所有的层之后的logits,并对其进行求和。可能有些类似于强化学习(RLHF)之类的场景会用到这里的输出,所以在这里做了一层准备。

跳转到1093行,可以看见,这里调用了invokeSumLengthDimension这个kennels,来将context_decoder_output_buf这块buf中的显存数据拷贝到输出tensor的context_embeddings中,可以简单看一下这个kernel的实现,在src/fastertransformer/kernels/gpt_kernels.cu中

template<typename T>
void invokeSumLengthDimension(float*       out_buf,
                              const T*     in_buf,
                              const size_t batch_size,
                              const size_t input_length,
                              const size_t hidden_dim,
                              cudaStream_t stream)
{
    dim3 gridSize(batch_size);
    dim3 blockSize(256);

    sum_length_dimension<<<gridSize, blockSize, 0, stream>>>(out_buf, in_buf, batch_size, input_length, hidden_dim);
}


template<typename T>
__global__ void sum_length_dimension(
    float* out_buf, const T* in_buf, const size_t batch_size, const size_t input_length, const size_t hidden_dim)
{
    const int bidx = blockIdx.x;

    for (int hidx = threadIdx.x; hidx < hidden_dim; hidx += blockDim.x) {
        float accum = 0.0f;
        for (int step = 0; step < input_length; step++) {
            accum += static_cast<float>(in_buf[(bidx * input_length + step) * hidden_dim + hidx]);
        }
        out_buf[bidx * hidden_dim + hidx] = accum;
    }
}

从kernel中可以看出,这个kernel任务划分为按照batch_size维度进行了grid task分配,并为每个任务划分了256个线程。在kernel内部则是将每个batch内所有输入的logits按照输入length维度进行了累加,并拷贝到输出buf中。

类似的一个设置和处理环节,在783行还处理了is_return_context_cum_log_probs,这里会将contextDecoder完成之后的输出进行cum log计算。其中主要的处理逻辑函数是403行的computeContextCumLogProbs函数,进入这个函数后,在425到502行的处理逻辑是,首先将context_decoder的输出进行layernorm,然后使用cublas矩阵乘,将词表维度的logits计算出来。在此之后,在504行,则会使用invokeLogProbFromLogits这个kernel.

template<typename T>
void invokeLogProbFromLogits(float*       cum_log_probs,
                             const T*     logits,
                             const int*   input_ids,
                             const int*   input_lengths,
                             const size_t max_input_length,
                             const size_t batch_size,
                             const size_t vocab_size,
                             const size_t vocab_size_padded,
                             void*        workspace,
                             const size_t workspace_size,
                             cudaStream_t stream,
                             const bool   batch_first)
{
    // A batched version of log prob computation.
    //
    // cum_log_probs: [batch_size]
    // logits: [max_input_length, batch_size, vocab_size] or [batch_size, max_input_length, vocab_size]
    // input_ids: [max_input_length, batch_size] or [max_input_length, batch_size]
    // input_lengths: [batch_size]
    // workspace: workspace buffer of size at least sizeof(float) * max_input_length * batch_size.

    FT_LOG_DEBUG(__PRETTY_FUNCTION__);
    // block_size should be multiple of 32 to use warpReduceMax.
    const int block_size = vocab_size < 1024 ? (vocab_size + 31) / 32 * 32 : 1024;
    assert(block_size % 32 == 0);
    assert(workspace != nullptr && workspace_size >= sizeof(float) * max_input_length * batch_size);
    assert(vocab_size <= vocab_size_padded);

    float* log_probs = reinterpret_cast<float*>(workspace);
    int    gx        = batch_first ? batch_size : max_input_length - 1;
    int    gy        = batch_first ? max_input_length - 1 : batch_size;
    dim3   grid(gx, gy);
    log_probs_kernel<T><<<grid, block_size, 0, stream>>>(log_probs,
                                                         logits,
                                                         input_ids,
                                                         input_lengths,
                                                         max_input_length,
                                                         batch_size,
                                                         vocab_size,
                                                         vocab_size_padded,
                                                         batch_first);
    accumulate_log_probs<<<batch_size, block_size, 0, stream>>>(
        cum_log_probs, log_probs, input_lengths, max_input_length, batch_size, batch_first);
}

__global__ void accumulate_log_probs(float*       cum_log_probs,
                                     const float* log_probs,
                                     const int*   lengths,
                                     const size_t max_input_length,
                                     const size_t batch_size,
                                     const bool   batch_first)
{
    // Accumulate the log probability along with the sequence dimension.
    //   cum_log_probs[j] = sum_i log(softmax(logits))[ids[i,j]]
    //
    // cum_log_probs: [batch_size], cumulative log probability
    // log_probs: [max_length - 1, batch_size] or [batch_size, max_length - 1],
    //   log probability of each token
    // lengths: [batch_size], sequence lengths
    // batch_size: [1], batch_size. in case of beam > 1, batch x beam.

    int bidx = blockIdx.x;   // batch dim
    int tidx = threadIdx.x;  // step dim

    if (bidx < batch_size) {
        int length = lengths[bidx];
        // reposition logits to data for the current batch.
        log_probs += batch_first ? bidx * (max_input_length - 1) : bidx;
        int   stride      = batch_first ? 1 : batch_size;  // stride along with seq dim.
        float local_accum = 0.0f;
        for (int step = tidx; step < length - 1; step += blockDim.x) {
            local_accum += static_cast<float>(log_probs[step * stride]);
        }
        float accum = blockDim.x <= 32 ? warpReduceSum(local_accum) : blockReduceSum<float>(local_accum);
        if (tidx == 0) {
            cum_log_probs[bidx] = accum;
        }
    }
}

在任务划分阶段,grid的x维度是batch维度,而y维度则是输入长度,进入kernel后,则是按照batch维度对任务的log_probs进行了ReduceSum,并写入到cum_log_probs中。

一、跳过prompt和prefix阶段

回到653行,接下来是ft中处理其自身特殊的prompt和prefix的逻辑,但我们的源码解读可以先跳过这一段。之所以要跳过这一段是在当前主要的大模型处理逻辑中,并不会需要在推理引擎这一层过多地关注prompt和prefix的逻辑。在对话的大模型agent中,对于第n次对话的输入/输出,其真正要进行forward的输入,往往是由前n-1轮模型的输出和用户的输入共同组成的。从推理引擎的角度来看,prefill阶段在真正端到端时延的占比并不算高,所以专门设计一个处理prompt的逻辑(还需要常驻的显存空间)多少显得有些得不偿失了

二、正常的逻辑起始和输入展开

让我们直接来到865行,考虑以下一种场景来简化我们的源码解读的结构:beam_width=1,tp=1,pp=1,use_shared_contexts不启用,也就是说,没有beam search,在单机单卡上进行一次简单的,不共享contexts的推理服务的进行。

从866行开始,由于是计算起始,所以先对一些计算中的buff进行清空。在870行,由于不进行beam search,所以也不需要对beam search用到的buf进行清空。

在877到887行,如果词表大小不对齐的话,需要将词表计算的权重进行对齐拷贝。

在889到905行,不使用shared context的话,这一段的逻辑也不需要进行处理。在914行处理prompt的逻辑也可以跳过,直接进入955行的处理逻辑。在955行,由于我们的输入是带batch的,所以需要将batch进行展开(tiled),这里使用的是kernel invokeTileGptPromptInputs

void invokeTileGptPromptInputs(int*         tiled_input_ids,
                               int*         tiled_input_lengths,
                               int*         tiled_prompt_lengths,
                               const int*   input_ids,
                               const int*   input_lengths,
                               const int*   prefix_prompt_lengths,
                               const int    batch_size,
                               const int    beam_width,
                               const int    max_input_length,
                               cudaStream_t stream)
{
    dim3 grid(batch_size, beam_width);
    dim3 block(min(1024, max_input_length));
    if (prefix_prompt_lengths != nullptr) {
        tileGptPromptInputs<true><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                              tiled_input_lengths,
                                                              tiled_prompt_lengths,
                                                              input_ids,
                                                              input_lengths,
                                                              prefix_prompt_lengths,
                                                              max_input_length);
    }
    else {
        tileGptPromptInputs<false><<<grid, block, 0, stream>>>(tiled_input_ids,
                                                               tiled_input_lengths,
                                                               tiled_prompt_lengths,
                                                               input_ids,
                                                               input_lengths,
                                                               prefix_prompt_lengths,
                                                               max_input_length);
    }
}

template<bool PREFIX_PROMPT>
__global__ void tileGptPromptInputs(int*       tiled_input_ids,
                                    int*       tiled_input_lengths,
                                    int*       tiled_prompt_lengths,
                                    const int* input_ids,
                                    const int* input_lengths,
                                    const int* prefix_prompt_lengths,
                                    const int  max_input_length)
{
    if (threadIdx.x == 0) {
        tiled_input_lengths[blockIdx.x * gridDim.y + blockIdx.y] = input_lengths[blockIdx.x];
        if (PREFIX_PROMPT) {
            tiled_prompt_lengths[blockIdx.x * gridDim.y + blockIdx.y] = prefix_prompt_lengths[blockIdx.x];
        }
    }
    for (int index = threadIdx.x; index < max_input_length; index += blockDim.x) {
        tiled_input_ids[(blockIdx.x * gridDim.y + blockIdx.y) * max_input_length + index] =
            input_ids[blockIdx.x * max_input_length + index];
    }
}

在任务划分阶段,按照batch维度进行划分,每个任务起了至少1024个线程来进行拷贝。由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再处理

在kernel中,由于没有beam width,那么gridDim.y就会是1。在kernel中,首先将input_length拷贝到tiled_input_lengths,之后再按照batch维度进行处理,将一个batch*max_input_length的数据进行展开,并拷贝到对应的buf,这有利于我们进行接下来的后续embedding和MHA计算
在这里插入图片描述

下一回预告

下一回继续讲解forward函数中的处理逻辑,会简单讲解embedding, pre_layernorm,以及进入attention_layer之后的初步讲解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1904845.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Python实现ABC人工蜂群优化算法优化随机森林回归模型(RandomForestRegressor算法)项目实战

说明&#xff1a;这是一个机器学习实战项目&#xff08;附带数据代码文档视频讲解&#xff09;&#xff0c;如需数据代码文档视频讲解可以直接到文章最后获取。 1.项目背景 人工蜂群算法(Artificial Bee Colony, ABC)是由Karaboga于2005年提出的一种新颖的基于群智能的全局优化…

LeetCode Hard|124.二叉树中的最大路径和

力扣题目链接 题目解读&#xff1a; 二叉树路径的定义即从1.任意节点出发&#xff0c;到达任意节点&#xff1b;2.该路径至少包含一个节点&#xff0c;且不一定经过跟节点&#xff1b;3.求所有可能路径和的最大值。 也就是说路径途径一个节点只能选择来去两个方向 考虑一个二叉…

微信公众平台测试账号本地微信功能测试说明

使用场景 在本地测试微信登录功能时&#xff0c;因为微信需要可以互联网访问的域名接口&#xff0c;所以本地使用花生壳做内网穿透&#xff0c;将前端服务的端口和后端服务端口进行绑定&#xff0c;获得花生壳提供的两个外网域名。 微信测试账号入口 绑定回调接口 回调接口的…

2024年06月CCF-GESP编程能力等级认证Python编程二级真题解析

本文收录于专栏《Python等级认证CCF-GESP真题解析》&#xff0c;专栏总目录&#xff1a;点这里&#xff0c;订阅后可阅读专栏内所有文章。 一、单选题&#xff08;每题 2 分&#xff0c;共 30 分&#xff09; 第 1 题 小杨父母带他到某培训机构给他报名参加CCF组织的GESP认证…

声明队列和交换机 + 消息转换器

目录 1、声明队列和交换机 方法一&#xff1a;基于Bean的方式声明 方法二&#xff1a;基于Spring注解的方式声明 2、消息转换器 1、声明队列和交换机 方法一&#xff1a;基于Bean的方式声明 注&#xff1a;队列和交换机的声明是放在消费者这边的&#xff0c;这位发送的人他…

力扣206

题目 给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例 3&#xff1a; 输…

【排序算法】—— 快速排序

快速排序的原理是交换排序&#xff0c;其中qsort函数用的排序原理就是快速排序&#xff0c;它是一种效率较高的不稳定函数&#xff0c;时间复杂度为O(N*longN)&#xff0c;接下来就来学习一下快速排序。 一、快速排序思路 1.整体思路 以升序排序为例&#xff1a; (1)、首先随…

PTA甲级1005:Spell It Right

错误代码&#xff1a; #include<iostream> #include<vector> #include<unordered_map> using namespace std;int main() {unordered_map<int, string> map {{0, "zero"}, {1, "one"}, {2, "two"}, {3, "three&qu…

有一个日期(Date)类的对象和一个时间(Time)类的对象,均已指定了内容,要求一次输出其中的日期和时间

可以使用友元成员函数。在本例中除了介绍有关友元成员函数的简单应用外&#xff0c;还将用到类的提前引用声明&#xff0c;请读者注意。编写程序&#xff1a; 运行结果&#xff1a; 程序分析&#xff1a; 在一般情况下&#xff0c;两个不同的类是互不相干的。display函…

实验六 图像的傅立叶变换

一&#xff0e;实验目的 1了解图像变换的意义和手段&#xff1b; 2熟悉傅立叶变换的基本性质&#xff1b; 3熟练掌握FFT变换方法及应用&#xff1b; 4通过实验了解二维频谱的分布特点&#xff1b; 5通过本实验掌握利用MATLAB编程实现数字图像的傅立叶变换。 6评价人眼对图…

股票Level-2行情是什么,应该怎么使用,从哪里获取数据

行情接入方法 level2行情websocket接入方法-CSDN博客 相比传统的股票行情&#xff0c;Level-2行情为投资者打开了更广阔的视野&#xff0c;不仅限于买一卖一的表面数据&#xff0c;而是深入到市场的核心&#xff0c;提供了十档乃至千档的行情信息&#xff08;沪市十档&#…

关于MCU-Cortex M7的存储结构(flash与SRAM)

MCU并没有DDR&#xff0c;所以他把代码存储在flash上&#xff0c;临时变量和栈运行在SRAM上。之所以这么做是因为MCU的CPU频率很低&#xff0c;一般低于500MHZ&#xff0c;flash的读取速度能够满足CPU的取指需求&#xff0c;但flash 的写入速度很慢&#xff0c;所以引入了SRAM …

【数据结构与算法】快速排序挖坑法

&#x1f493; 博客主页&#xff1a;倔强的石头的CSDN主页 &#x1f4dd;Gitee主页&#xff1a;倔强的石头的gitee主页 ⏩ 文章专栏&#xff1a;《数据结构与算法》 期待您的关注 ​

go语言day11 错误 defer(),panic(),recover()

错误&#xff1a; 创建错误 1&#xff09;fmt包下提供的方法 fmt.Errorf(" 格式化字符串信息 " &#xff0c; 空接口类型对象 ) 2&#xff09;errors包下提供的方法 errors.New(" 字符串信息 ") 创建自定义错误 需要实现error接口&#xff0c;而error接口…

【植物大战僵尸杂交版】获取+存档插件

文章目录 一、还记得《植物大战僵尸》吗&#xff1f;二、在哪下载&#xff0c;怎么安装&#xff1f;三、杂交版如何进行存档功能概述 一、还记得《植物大战僵尸》吗&#xff1f; 最近&#xff0c;一款曾经在15年前风靡一时的经典游戏《植物大战僵尸》似乎迎来了它的"文艺复…

C++初学者指南-5.标准库(第一部分)--迭代器

C初学者指南-5.标准库(第一部分)–迭代器 Iterators 文章目录 C初学者指南-5.标准库(第一部分)--迭代器 Iterators1.默认正向迭代器2.反向迭代器3.基于迭代器的循环4.示例&#xff1a;交换相邻的一对元素5.迭代器范围6.迭代器范围中的元素数量7. 总结&#xff1a;迭代器 指向某…

Vue笔记11-Composition API的优势

Options API存在的问题 使用传统Options API中&#xff0c;新增或者修改一个需求&#xff0c;就需要分别在data&#xff0c;methods&#xff0c;computed里修改&#xff0c;而这些选项分布在代码的各个地方&#xff0c;中间还穿插着其他Optional API&#xff0c;如果代码量上来…

国产化新标杆:TiDB 助力广发银行新一代总账系统投产上线

随着全球金融市场的快速发展和数字化转型的深入推进&#xff0c;金融科技已成为推动银行业创新的核心力量。特别是在当前复杂多变的经济环境下&#xff0c;银行业务的高效运作和风险管理能力显得尤为重要。总账系统作为银行会计信息系统的核心&#xff0c;承载着记录、处理和汇…

运维锅总详解系统启动流程

本文详细介绍Linux及Windows系统启动流程&#xff0c;并分析了它们启动流程的异同以及造成这种异同的原因。希望本文对您理解系统的基本启动流程有所帮助&#xff01; 一、Linux系统启动流程 Linux 系统的启动流程可以分为几个主要阶段&#xff0c;从电源开启到用户登录。每个…

FPGA-UDP实验

1. 以太网简介 水晶头的规格就是RJ45千兆网一般指的就是UDP千兆网PHY芯片是用来协商用的 协商匹配最低的通信速率 1.1. OSI模型 7层 ![外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传](https://img-home.csdnimg.cn/images/20230724024159.png?origin…