基于torch.compile和gptfast代码风格实现ChatGLM模型推理加速

news2024/11/22 10:25:07

       

目录

一、ChatGLM模型代码重构迁移

二、推理的代码重构

三、效果分析对比

参考文章


         torch2.0发布以后模型训练和推理可以实现一行代码加速,试用之后发现效果并不明显。随后gptfast项目也发布,表明它确实是可以实现模型推理的加速,看来之前试用是打开方式不对。最近参考gptfast项目,实现了对ChatGLM模型推理的加速,主要的原理是借助torch.compile对模型推理过程中构建计算图,实现加速。本文的重点工作就是展示模型代码和推理逻辑的迁移实现,以及加速效果的对比,当然这个方案比VLLM和tensort-LLM肯定是差了点,这个不是本文的重点,后面有空了也把vllm和tensort-LLM也写写博客对比一下效率。

一、ChatGLM模型代码重构迁移

      这个工作是真的不是特别好做,需要对模型结构和模型输入输出非常熟悉,同时也要对gptfast项目迁移原则比较熟悉,才能比较快的迁移成功。核心原则是不能有tensor切片操作,同时kvcache这种也要写成固定的长度,计算过程中不断的去填充更新,同时还要放在模型的结构外侧作为一个参数传入,加速才有效果。还有一个点要注意注意力计算的实现,由于torch更新了scaled_dot_product_attention使得最大长度的定长的矩阵计算注意力,和之前动态逐步增加长度的值是一样的,这个是注意力计算中tensor切片改写的前提(验证过确实是一样的)。细节的地方需要注意kvcache的维度形状,解码过程中不同阶段(首次forward和kvcache存在后的)模型输入的full_attention_mask是不一样的。

整体结构

class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents

注意模型的输入新增的有input_pos,模型解码token的位置,kv_caches;模型基本模块上没有变化,精简其中的一下预处理逻辑和分支,主要就是要让torch.compile()能完成计算图的构建。

kvcache模块

class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out

模块中各个变量的维度信息都标注好了,作用就是kv缓存载体以及更新逻辑提供一个方法。

其他模块就不一一介绍了,注意selfattention中kvcache的更新

整个模型的代码如下:

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
import torch.nn as nn
from torch import Tensor
from torch.nn import functional as F
from math import gcd
from functools import reduce
import math


def find_multiple(n: int, *args: Tuple[int]) -> int:
    k = reduce(lambda x, y: x * y // gcd(x, y), args + (1,))
    if n % k == 0:
        return n
    return n + k - (n % k)


class CoreAttention(torch.nn.Module):
    def __init__(self, config, layer_number):
        super(CoreAttention, self).__init__()
        self.config = config
        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32

        self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        coeff = self.layer_number
        self.norm_factor *= coeff
        self.coeff = coeff
        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

    def forward(self, query_layer, key_layer, value_layer, attention_mask=None):
        query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
        if attention_mask is None:
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             is_causal=True)
        else:
            attention_mask = ~attention_mask
            context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer,
                                                                             attention_mask)

        context_layer = context_layer.permute(2, 0, 1, 3)
        new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
        context_layer = context_layer.reshape(*new_context_layer_shape)

        return context_layer


def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
    # x: [sq, b, np, hn]
    sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
    rot_dim = rope_cache.shape[-2] * 2
    x, x_pass = x[..., :rot_dim], x[..., rot_dim:]
    # truncate to support variable sizes
    rope_cache = rope_cache[:sq]
    xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2)
    rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2)
    x_out2 = torch.stack(
        [
            xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1],
            xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1],
        ],
        -1,
    )
    x_out2 = x_out2.flatten(3)
    return torch.cat((x_out2, x_pass), dim=-1)


class RotaryEmbedding(nn.Module):
    def __init__(self, dim, original_impl=False, device=None, dtype=None):
        super().__init__()
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, device=device).to(dtype=dtype) / dim))
        self.register_buffer("inv_freq", inv_freq)
        self.dim = dim
        self.original_impl = original_impl

    def forward_impl(
            self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
    ):
        """Enhanced Transformer with Rotary Position Embedding.

        Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
        transformers/rope/__init__.py. MIT License:
        https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
        """
        # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
        theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))

        # Create position indexes `[0, 1, ..., seq_len - 1]`
        seq_idx = torch.arange(seq_len, dtype=dtype, device=device)

        # Calculate the product of position index and $\theta_i$
        idx_theta = torch.outer(seq_idx, theta).float()

        cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)

        # this is to mimic the behaviour of complex32, else we will get different results
        if dtype in (torch.float16, torch.bfloat16, torch.int8):
            cache = cache.bfloat16() if dtype == torch.bfloat16 else cache.half()
        return cache

    def forward(self, max_seq_len, offset=0):
        return self.forward_impl(
            max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
        )


class KVCache(nn.Module):
    def __init__(self, max_batch_size, max_seq_length, dtype=torch.bfloat16):
        super().__init__()
        cache_shape = (2, max_batch_size, max_seq_length, 128)
        self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
        self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))

    def update(self, input_pos, k_val, v_val):
        # input_pos: S, k_val: [S, B, H, D]
        assert input_pos.shape[0] == k_val.shape[0]
        k_out = self.k_cache

        v_out = self.v_cache
        k_val = k_val.transpose(0, 2).contiguous()

        v_val = v_val.transpose(0, 2).contiguous()
        k_out[:, :, input_pos] = k_val.clone()
        v_out[:, :, input_pos] = v_val.clone()
        k_out = k_out.transpose(0, 2).contiguous()

        v_out = v_out.transpose(0, 2).contiguous()

        return k_out, v_out


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


class TransformerGLM(nn.Module):
    def __init__(self, config, device) -> None:
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
        rotary_dim = (
            128
        )
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        self.layers = nn.ModuleList(TransformerBlock(config, i, device) for i in range(config.num_layers))
        self.final_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)

        self.output_layer = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.seq_length = config.seq_length

    def forward(self, input_ids,
                position_ids: Optional[torch.Tensor] = None,
                attention_mask: Optional[torch.BoolTensor] = None,
                input_pos=None,
                is_input_mask=False,
                kv_caches=None
                ) -> Tensor:

        inputs_embeds = self.embedding(input_ids)
        inputs_embeds = inputs_embeds.transpose(0, 1).contiguous()
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        rotary_pos_emb = rotary_pos_emb[position_ids]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        presents = ()
        for i, layer in enumerate(self.layers):
            inputs_embeds, kv_cache = layer(inputs_embeds, rotary_pos_emb=rotary_pos_emb, input_pos=input_pos,
                                            attention_mask=attention_mask, kv_cache=kv_caches[i])
            presents = presents + (kv_cache,)
        hidden_states = self.final_layernorm(inputs_embeds)
        lm_logits = self.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()
        return lm_logits, presents


class MLP(torch.nn.Module):
    """MLP.
    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

        def swiglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return F.silu(x[0]) * x[1]

        self.activation_func = swiglu

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            # **_config_to_kwargs(config)
        )

    def forward(self, hidden_states):
        # [s, b, 4hp]
        intermediate_parallel = self.dense_h_to_4h(hidden_states)
        intermediate_parallel = self.activation_func(intermediate_parallel)
        # [s, b, h]
        output = self.dense_4h_to_h(intermediate_parallel)
        return output


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.config = config
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        # 32
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size

        self.num_multi_query_groups_per_partition = config.multi_query_group_num
        self.qkv_hidden_size = (
                self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
        )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         # device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear,
                               device=device,
                               # **_config_to_kwargs(config)
                               )

    def forward(
            self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        (query_layer, key_layer, value_layer) = mixed_x_layer.split(
            [
                self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
            ],
            dim=-1,
        )

        query_layer = query_layer.view(
            query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        key_layer = key_layer.view(
            key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.view(
            value_layer.size()[:-1]
            + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
        )

        # apply relative positional encoding (rotary embedding)
        query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
        key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)

        # 更新kvcache
        cache_k, cache_v = kv_cache
        cache_k[input_pos] = key_layer
        cache_v[input_pos] = value_layer
        key_layer = cache_k.clone()
        value_layer = cache_v.clone()
        kv_cache = (key_layer, value_layer)

        key_layer = key_layer.unsqueeze(-2)
        key_layer = key_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        key_layer = key_layer.contiguous().view(
            key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )
        value_layer = value_layer.unsqueeze(-2)
        value_layer = value_layer.expand(
            -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
        )
        value_layer = value_layer.contiguous().view(
            value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
        )

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask=attention_mask)

        # =================
        # Output. [sq, b, h]
        # =================

        output = self.dense(context_layer)

        return output, kv_cache


class TransformerBlock(nn.Module):
    def __init__(self, config, layer_number, device) -> None:
        super().__init__()
        self.hidden_dropout = config.hidden_dropout
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                       dtype=config.torch_dtype)
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layernorm_epsilon, device=device,
                                                dtype=config.torch_dtype)
        self.mlp = MLP(config, device=device)

    def forward(self, hidden_states, rotary_pos_emb, input_pos, attention_mask=None, kv_cache=None):
        layernorm_output = self.input_layernorm(hidden_states)
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            rotary_pos_emb,
            input_pos,
            attention_mask=attention_mask,
            kv_cache=kv_cache
        )
        residual = hidden_states
        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        layernorm_input = residual + layernorm_input
        layernorm_output = self.post_attention_layernorm(layernorm_input)
        mlp_output = self.mlp(layernorm_output)
        residual = layernorm_input
        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        output = residual + output
        return output, kv_cache


class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)


if __name__ == '__main__':
    import os

    os.environ['CUDA_VISIBLE_DEVICES'] = "1"
    from transformers import AutoConfig

    model_path = "./chatglm2-6b-merge"
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    model = TransformerGLM(config, device=None)
    for name, _ in model.named_parameters():
        print(name)

二、推理的代码重构

推理方法,也就是重写transformer模型中的generate这个方法,对于一次生成可以分为第一次解码forward阶段和余下的解码forward阶段。实现分别如下,只实现了greedy search 策略:

@torch.no_grad()
def first_decode_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids=input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=False,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches


@torch.no_grad()
def decode_one_token_batch(model, input_ids, position_ids, input_pos, attention_mask, kv_caches):
    logits, kv_caches = model(input_ids, position_ids=position_ids, input_pos=input_pos, is_input_mask=True,
                              attention_mask=attention_mask, kv_caches=kv_caches)
    logits = logits[:, -1:]
    next_tok = torch.argmax(logits, dim=-1)
    return next_tok, kv_caches

主要是得到解码过程中模型输出的token和kv_caches。特别要注意的是不能把这两个方法封装到一个类中,然后再进行torch.compile这样模型能正确输出结果,但是推理速度没有提升的,也就是torch.compile并没有生效。

整体的generate逻辑,包含停止符号,模型的初始输入、kvcaches初始化以及attention_mask输入的变化、position_ids的输入变化,batch推理是padding的加入。

def generate_own_batch(model,
                 inputs,
                 sampling_kwargs,
                 eos_token,
                       max_seq_length, max_batch_size):
    device = inputs['input_ids'].device
    cache_shape = (max_seq_length, max_batch_size, 2, 128)
    dtype = torch.bfloat16
    kv_caches = [(torch.zeros(cache_shape, dtype=dtype).to(device), torch.zeros(cache_shape, dtype=dtype).to(device))
                 for _ in range(model.config.num_layers)]

    input_ids = inputs['input_ids']

    ori_input_ids = input_ids.clone()
    position_ids = inputs['position_ids']

    input_pos = []
    for _ in range(max_batch_size):
        pos = list(range(0,input_ids.shape[1]))
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=input_ids.device)

    # input_pos = torch.arange(0, input_ids.shape[1], device=input_ids.device)
    next_token, kv_caches = first_decode_batch(model, input_ids, position_ids, input_pos, None, kv_caches)

    full_attention_mask = torch.ones(max_batch_size, 1, 1, max_seq_length).to(device).bool()
    full_attention_mask[:, :, :, input_pos] = False

    # pading部分为true
    for i in range(full_attention_mask.shape[0]):
        for j in range(input_ids.shape[1]):
            if input_ids[i, j] == 0:
                full_attention_mask[i, :, :, j] = True

    input_ids = torch.cat((input_ids, next_token.clone()), dim=1)
    num_new_tokens = sampling_kwargs["max_length"]
    T = input_ids.size()[1]

    position_ids = position_ids[:,-1:]

    input_pos = []
    for _ in range(max_batch_size):
        pos = [T]
        input_pos.append(pos)
    input_pos = torch.tensor(input_pos, device=next_token.device, dtype=torch.long)

    # position_ids = torch.tensor([[T - 1]], device=next_token.device, dtype=torch.long)
    # input_pos = torch.tensor([T], device=input_ids.device, dtype=torch.long)

    for i in range(num_new_tokens):
        input_pos += 1
        # Actually better for Inductor to codegen attention here
        with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
            full_attention_mask[:, :, :, input_pos] = False
            next_token, kv_caches = decode_one_token_batch(model, next_token, position_ids, input_pos,
                                                     full_attention_mask, kv_caches)

            input_ids = torch.cat((input_ids, next_token.clone()), dim=1)

            if (input_ids == eos_token).sum(dim=1).all():
                break

            position_ids += 1
            # token = next_token.tolist()
            # token = next_token.tolist()[0]
            # generated_tokens.append(token)
    return input_ids, ori_input_ids

推理核心逻辑

    model = TransformerGLM(config=config, device=None)
    checkpoint_dir = Path(model_path)
    model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
    converted_state_dict = model.state_dict()

    gen_kwargs = {"max_length": 200, "num_beams": 1,
                  "do_sample": False, "top_p": 0.8,
                  "temperature": 0.95
                  }
    device = "cuda:0"
    model_path = "./chatglm2-6b-merge"
    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    eos_token = tokenizer.eos_token_id
......
# 编译加速
    global first_decode_batch, decode_one_token_batch
    decode_one_token_batch = torch.compile(decode_one_token_batch, mode="reduce-overhead", fullgraph=True)
    first_decode_batch = torch.compile(first_decode_batch, dynamic=True, fullgraph=True)

......
generate_own_batch(model, inputs, gen_kwargs, eos_token, max_seq_length, max_batch_size)

核心点在于把解码两阶段的函数使用torch.compile函数包裹一下,真实解码过程中就会进行解码加速。

三、效果分析对比

展示一下glm模型ori、compile、和compile+int8,bs=1,max_seq_length= 1000的情况下的推理速度和效果的对比,7Bglm模型,模型输入prompt如下:

[
    "你好",
    "你是谁呀?",
    "你能做什么呀?",
    "你真厉害",
    "真棒呀",
    "再见了",
    "给我推荐一部电影",
    "你知道明天天气怎么样吗?"
]

ori原始transformer的推理效果如下:

使用compile后效果如下:

compile+int8效果如下:

可以看到相同的模型和相同的数据在bs=1下,原始模型推理速度31.7 tokens/s,compile的推理速度68.1 tokens/s,110.9 tokes/s;加速效果确实比较明显。

业务领域上的实验,这里也可以给一个结论,数据就不展示了,业务上生成的token数目每次推理大都在20 tokens以内,结果如下:

这次的分享就到这里为止了,这个迁移后的模型和推理在我们公司的服务端还有个问题,我们服务端采用的多进程异步来实现web服务的,这个gptfast的服务化的集成显示int8不生效,而且bs=1时候的推理加速并没有线下加速效果明显,具体原因一直没有弄明白,可能是其他进程占用服务器资源,导致torch.compile加速失效或者降低。

参考文章

gpt-fast实战

modeling_chatglm

gpt-fast项目源码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1532146.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

修改约束

目录 修改约束 创建数据库 添加约束 删除约束 Oracle从入门到总裁:​​​​​​https://blog.csdn.net/weixin_67859959/article/details/135209645 修改约束 如果说表结构的修改还在可以容忍的范畴之内,那么约束的修改是绝对 100% 禁止的 所有的约束一定要在…

微信小程序项目实战遇到的问题

我们以学生成绩平台来作为例子。这是我们想得到的效果。 以下是完整代码: index.js // index.js Page({//页面的初始数据data: {hello: 欢迎进入微信小程序的编程世界,score: 80,userArray: [{name: 张三,score: [66, 77, 86, 70, 90]},{name: 李四,score: [88, 7…

jwt以及加密完善博客系统

目录 一、背景 二、传统登陆功能&强制登陆功能 1、传统的实现方式 2、session存在的问题 三、jwt--令牌技术 1、实现过程 2、令牌内容 3、生成令牌 4、检验令牌 四、JWT登陆功能&强制登陆功能 1、JWT实现登陆功能 2、强制登陆功能 3、运行效果 五、加密/加…

全栈的自我修养 ———— 微信小程序开发电脑测试api请求正常,移动端请求异常!!

小编今天也是在电脑测试时候发送请求http到服务器是可以通的,但是到了手机端就不可以了,经过小编仔细钻研,终于发现了以下问题!!!!!!!!&#xff0…

Java面试题(Spring篇)

💟💟前言 ​ 友友们大家好,我是你们的小王同学😗😗 今天给大家打来的是 Java面试题(Spring篇) 希望能给大家带来有用的知识 觉得小王写的不错的话麻烦动动小手 点赞👍 收藏⭐ 评论📄 小王的主页…

Hive SQL必刷练习题:日期交叉问题(两种思路)

思路一: ​ 首先想到的是借助炸裂函数,一行变成多行,就可以进行去重操作,然后再统计日期。 用到炸裂函数,就首先需要可以拿到起始和终止日期差大小的数组,然后再炸裂​ 那这个指定长度数组怎么获取呢&…

STM32初识3

中断和事件 什么是中断? 中断是指计算机运行过程中,出现某些意外情况需主机干预时,机器能自动停止正在运行的 程序并转入处理新情况的程序,处理完毕后又返回原被暂停的程序继续运行。 什么是EXTI? …

外包干了1个月,技术明显进步。。。

我是一名大专生,自19年通过校招进入湖南某软件公司以来,便扎根于功能测试岗位,一晃便是近四年的光阴。今年8月,我如梦初醒,意识到长时间待在舒适的环境中,已让我变得不思进取,技术停滞不前。更令…

Docker实现在线代码运行平台-源码分析

一、背景 也许大家平时会有用到一些在线代码运行的网址, 方便我们做一些语法测试或者学习。 例如执行Python代码、Java代码、Shell代码等等, 有时候自己本地环境不具备的时候做一些简单的脚本测试还是蛮实用和方便的。 例如这个在线运行的代码站点: https://r.xjq.icu/ 但是你…

ORB-SLAM2安装和运行教程

前言 默认已经安装了cmake、git 、gcc 和g 更新apt库,更新软件列表 sudo apt-get update版本 以下库均可以通过apt-get方式安装,但是非常不建议这样做,版本不同,最后会报很多错,强烈建议源码安装!!&…

LeetCode 17 / 100

目录 普通数组最大子数组和合并区间轮转数组除自身以外数组的乘积缺失的第一个正数 LeetCode 53. 最大子数组和 LeetCode 56. 合并区间 LeetCode 189. 轮转数组 LeetCode 238. 除自身以外数组的乘积 LeetCode 41. 缺失的第一个正数 普通数组 最大子数组和 给你一个整数数组 …

练习 9 Web [SUCTF 2019]CheckIn (未拿到flag)

上传图片格式的木马文件&#xff1a; 返回 <? in contents!,存在PHP代码检测 上传非图片格式文件&#xff1a; 返回 不允许非image 修改木马PHP代码规避检测 <? ?> 改为 < script language“php”>< /script ><?php eval($_POST[shell]);?>…

多模态大语言模型的 (R) 演变:调查

目录 1. Introduction2. 赋予LLMs多模态能力2.1 大型语言模型2.2 视觉编码器2.3 视觉到语言适配器2.4 多模式训练 3. 使用 MLLM 处理视觉任务 连接文本和视觉模式在生成智能中起着至关重要的作用。因此&#xff0c;受大型语言模型成功的启发&#xff0c;大量研究工作致力于多模…

206.翻转链表

给你单链表的头节点 head &#xff0c;请你反转链表&#xff0c;并返回反转后的链表。 输入&#xff1a;head [1,2,3,4,5] 输出&#xff1a;[5,4,3,2,1]示例 2&#xff1a; 输入&#xff1a;head [1,2] 输出&#xff1a;[2,1]示例 3&#xff1a; 输入&#xff1a;head [] 输…

仿牛客网开发笔记

用到Spring的 一些 核心技术 Spring Framework Spring Core IOC 、AOP > 管理对象的一种思想 IOC > 面向对象的管理思想 AOP > 面向切面的管理思想Spring Data Access 》访问数据库的功能 Transaction、Spring MyBatis Transaction 》管理事务Spring MyBat…

Offer必备算法16_字符串_四道力扣题详解(由易到难)

目录 ①力扣14. 最长公共前缀 解析代码1&#xff08;两两比较&#xff09; 解析代码2&#xff08;统一比较&#xff09; ②力扣5. 最长回文子串 解析代码&#xff08;中心拓展&#xff09; ③力扣67. 二进制求和 解析代码 ④力扣43. 字符串相乘 解析代码&#xff08;无…

Kotlin runBlocking CoroutineScope synchronized简单死锁场景

Kotlin runBlocking CoroutineScope synchronized简单死锁场景 import kotlinx.coroutines.*fun main(args: Array<String>) {runBlocking {val lock1 Any()val lock2 Any()CoroutineScope(Dispatchers.IO).launch {repeat(10) {println("A-$it 申请 lock1...&quo…

【linux】环境基础|开发工具|gcc|yum|vim|gdb|make|git

目录 ​编辑 Linux 软件包管理器 yum 软件包: 操作&#xff1a; 拓展&#xff1a;lrzsz简介 Linux开发工具 Linux编辑器-vim使用 vim 的基本概念 命令模式 插入模式 底行模式 vim 命令模式的操作指令 vim 底行模式的操作命令 Linux编译器-gcc/g使用 功能 格…

Java学习六—面向对象

一、关于面向对象 1.1简介 Java 是一种面向对象编程语言&#xff0c;其核心思想是面向对象编程&#xff08;Object-Oriented Programming&#xff0c;OOP&#xff09;。 面向对象编程是一种程序设计范式&#xff0c;它将数据与操作数据的方法&#xff08;函数&#xff09;捆…

有了std::thread,为什么还需要引入std::jthread?

C进阶专栏&#xff1a;http://t.csdnimg.cn/HGkeZ 目录 1.前言 2.std::is_invocable_v 3.std::jthread 3.1.构造函数 3.2.std::jthread无需join/detach使用实例 3.3.std::jthread处理外部请求中断实 3.4.处理中断请求示例代码 4.特性 5.总结 1.前言 C11以来提供了C原…