大模型入门3:理解LLAMA

news2024/12/22 17:55:20

Model

  • a stack of DecoderBlocks(SelfAttention, FeedForward, and RMSNorm)
    在这里插入图片描述
    decoder block 整体结构:最大的区别在pre-norm

x -> norm(x) -> attention() -> residual connect -> norm() -> ffn -> residual connect

class DecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config['n_heads']
        self.dim = config['embed_dim']
        self.head_dim = self.dim // self.n_heads

        self.attention = SelfAttention(config)
        self.feed_forward = FeedForward(config)

        # rms before attention block
        self.attention_norm = RMSNorm(self.dim, eps=config['norm_eps'])

        # rms before  feed forward block
        self.ffn_norm = RMSNorm(self.dim, eps=config['norm_eps'])

    def forward(self, x, start_pos, freqs_complex):

        # (m, seq_len, dim)
        h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex)
        # (m, seq_len, dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vocab_size = config['vocab_size']
        self.n_layers = config['n_layers']
        self.tok_embeddings = nn.Embedding(self.vocab_size, config['embed_dim'])
        self.head_dim = config['embed_dim'] // config['n_heads']

        self.layers = nn.ModuleList()
        for layer_id in range(config['n_layers']):
            self.layers.append(DecoderBlock(config))

        self.norm = RMSNorm(config['embed_dim'], eps=config['norm_eps'])
        self.output = nn.Linear(config['embed_dim'], self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(
            self.head_dim, config['max_seq_len'] * 2, device=(config['device']))

    def forward(self, tokens, start_pos):
        # (m, seq_len)
        batch_size, seq_len = tokens.shape

        # (m, seq_len) -> (m, seq_len, embed_dim)
        h = self.tok_embeddings(tokens)

        # (seq_len, (embed_dim/n_heads)/2]
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

        # Consecutively apply all the encoder layers
        # (m, seq_len, dim)
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)

        # (m, seq_len, vocab_size)
        output = self.output(h).float()
        return output

model = Transformer(config).to(config['device'])
res = model.forward(test_set['input_ids'].to(config['device']), 0)
print(res.size())

RoPE

在这里插入图片描述

def precompute_theta_pos_frequencies(head_dim, seq_len, device, theta=10000.0):

    # theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2]
    # (head_dim / 2)
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)

    # (seq_len)
    m = torch.arange(seq_len, device=device)

    # (seq_len, head_dim / 2)
    freqs = torch.outer(m, theta).float()

    # complex numbers in polar, c = R * exp(m * theta), where R = 1:
    # (seq_len, head_dim/2)
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x, freqs_complex, device):

    # last dimension pairs of two values represent real and imaginary
    # two consecutive values will become a single complex number

    # (m, seq_len, num_heads, head_dim/2, 2)
    x = x.float().reshape(*x.shape[:-1], -1, 2)

    # (m, seq_len, num_heads, head_dim/2)
    x_complex = torch.view_as_complex(x)

    # (seq_len, head_dim/2) --> (1, seq_len, 1, head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)

    # multiply each complex number
    # (m, seq_len, n_heads, head_dim/2)
    x_rotated = x_complex * freqs_complex

    # convert back to the real number
    # (m, seq_len, n_heads, head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)

    # (m, seq_len, n_heads, head_dim)
    x_out = x_out.reshape(*x.shape)

    return x_out.type_as(x).to(device)

RMS norm

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor):
        # (m, seq_len, dim) * (m, seq_len, 1) = (m, seq_len, dim)
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor):
        # weight is a gain parameter used to re-scale the standardized summed inputs
        # (dim) * (m, seq_len, dim) = (m, seq_Len, dim)
        return self.weight * self._norm(x.float()).type_as(x)

KV Caching

在这里插入图片描述
在这里插入图片描述

class KVCache:
    def __init__(self, max_batch_size, max_seq_len, n_kv_heads, head_dim, device):
        self.cache_k = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)
        self.cache_v = torch.zeros((max_batch_size, max_seq_len, n_kv_heads, head_dim)).to(device)

    def update(self, batch_size, start_pos, xk, xv):
        self.cache_k[:batch_size, start_pos :start_pos + xk.size(1)] = xk
        self.cache_v[:batch_size, start_pos :start_pos + xv.size(1)] = xv

    def get(self, batch_size, start_pos, seq_len):
        keys = self.cache_k[:batch_size,  :start_pos + seq_len]
        values = self.cache_v[:batch_size, :start_pos + seq_len]
        return keys, values

Grouped Query Attention

在这里插入图片描述

def repeat_kv(x, n_rep):

    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    else:
        # (m, seq_len, n_kv_heads, 1, head_dim)
        # --> (m, seq_len, n_kv_heads, n_rep, head_dim)
        # --> (m, seq_len, n_kv_heads * n_rep, head_dim)
        return (
            x[:, :, :, None, :]
            .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
            .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
        )

class SelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.n_heads = config['n_heads']
        self.n_kv_heads = config['n_kv_heads']
        self.dim = config['embed_dim']
        self.n_kv_heads = self.n_heads if self.n_kv_heads is None else self.n_kv_heads
        self.n_heads_q = self.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = self.dim // self.n_heads

        self.wq = nn.Linear(self.dim, self.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(self.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)

        self.cache = KVCache(
            max_batch_size=config['max_batch_size'],
            max_seq_len=config['max_seq_len'],
            n_kv_heads=self.n_kv_heads,
            head_dim=self.head_dim,
            device=config['device']
        )

    def forward(self, x, start_pos, freqs_complex):

        # seq_len is always 1 during inference
        batch_size, seq_len, _ = x.shape

        # (m, seq_len, dim)
        xq = self.wq(x)

        # (m, seq_len, h_kv * head_dim)
        xk = self.wk(x)
        xv = self.wv(x)

        # (m, seq_len, n_heads, head_dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)

        # (m, seq_len, h_kv, head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # (m, seq_len, num_head, head_dim)
        xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)

        # (m, seq_len, h_kv, head_dim)
        xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)

        # replace the entry in the cache
        self.cache.update(batch_size, start_pos, xk, xv)

        # (m, seq_len, h_kv, head_dim)
        keys, values = self.cache.get(batch_size, start_pos, seq_len)

        # (m, seq_len, h_kv, head_dim) --> (m, seq_len, n_heads, head_dim)
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        # (m, n_heads, seq_len, head_dim)
        # seq_len is 1 for xq during inference
        xq = xq.transpose(1, 2)

        # (m, n_heads, seq_len, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (m, n_heads, seq_len_q, head_dim) @ (m, n_heads, head_dim, seq_len) -> (m, n_heads, seq_len_q, seq_len)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

        # (m, n_heads, seq_len_q, seq_len)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # (m, n_heads, seq_len_q, seq_len) @ (m, n_heads, seq_len, head_dim) -> (m, n_heads, seq_len_q, head_dim)
        output = torch.matmul(scores, values)

        # ((m, n_heads, seq_len_q, head_dim) -> (m, seq_len_q, dim)
        output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))

        # (m, seq_len_q, dim)
        return self.wo(output)

SwiGlu

def sigmoid(x, beta=1):
    return 1 / (1 + torch.exp(-x * beta))

def swiglu(x, beta=1):
    return x * sigmoid(x, beta)
class FeedForward(nn.Module):
    def __init__(self, config):

        super().__init__()

        hidden_dim = 4 * config['embed_dim']
        hidden_dim = int(2 * hidden_dim / 3)

        if config['ffn_dim_multiplier'] is not None:
            hidden_dim = int(config['ffn_dim_multiplier'] * hidden_dim)

        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = config['multiple_of'] * ((hidden_dim + config['multiple_of'] - 1) // config['multiple_of'])

        self.w1 = nn.Linear(config['embed_dim'], hidden_dim, bias=False)
        self.w2 = nn.Linear(config['embed_dim'], hidden_dim, bias=False)
        self.w3 = nn.Linear(hidden_dim, config['embed_dim'], bias=False)

    def forward(self, x: torch.Tensor):
        # (m, seq_len, dim) --> (m, seq_len, hidden_dim)
        swish = swiglu(self.w1(x))
        # (m, seq_len, dim) --> (m, seq_len, hidden_dim)
        x_V = self.w2(x)

        # (m, seq_len, hidden_dim)
        x = swish * x_V

        # (m, seq_len, hidden_dim) --> (m, seq_len, dim)
        return self.w3(x)

小结

  • padding 方式

reference

  • llama tech report
  • 源码:transformers
  • 参数量计算: https://zhuanlan.zhihu.com/p/676113501
  • 基于 MLX 的 LLAMA2-13B 的详细分析 - 亚东的文章 - 知乎 https://zhuanlan.zhihu.com/p/677125915
  • 2023年你最喜欢的MLSys相关的工作是什么? - Lin Zhang的回答 - 知乎
  • https://ai.plainenglish.io/understanding-llama2-kv-cache-grouped-query-attention-rotary-embedding-and-more-c17e5f49a6d7
  • https://github.com/wdndev/llama3-from-scratch-zh/blob/main/llama3/model.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2132465.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

从零到一:构建你的第一个AI项目(实战教程)

引言 欢迎来到AI世界的初学者指南!在这个实战教程中,我们将一步步构建一个基础的AI项目,让你从零开始,亲手体验人工智能的魅力。我们的目标是让即使没有任何编程或AI背景的你,也能通过本教程完成一个小型的AI应用。今天…

《程序猿之设计模式实战 · 装饰者模式》

📢 大家好,我是 【战神刘玉栋】,有10多年的研发经验,致力于前后端技术栈的知识沉淀和传播。 💗 🌻 CSDN入驻不久,希望大家多多支持,后续会继续提升文章质量,绝不滥竽充数…

Python 求亲和数

亲和数(Amicable Numbers)是指两个不同的正整数,它们的真因数(即除去本身的所有因数)之和与对方的数相等。 def sum_of_proper_divisors(n):"""计算一个数的真因子之和"""divisors_su…

SpringBoot闲一品交易平台

SpringBoot闲一品交易平台 #vue项目实战 #计算机项目 #java项目 SpringBoot闲一品交易平台通过运用软件工程原理和开发方法,借助Spring Boot框架,旨在实现零食交易信息的高效管理,提升用户的购物体验和满意度。 技术栈 开发语言:…

用于安全研究的 Elastic Container Project

作者:来自 Elastic Andrew Pease•Colson Wilhoit•Derek Ditch 使用 Docker 启动 Elastic Stack 序言 Elastic Stack 是一个模块化数据分析生态系统。虽然这允许工程灵活性,但建立开发实例进行测试可能很麻烦。建立 Elastic Stack 的最简单方法是使用…

Day09-StatefuleSet控制器

Day09-StatefuleSet控制器 0、昨日内容回顾1、StatefulSets控制器1.1 StatefulSet概述1.2 StatefulSets控制器-网络唯一标识之headless1.3 StatefulSets控制器-独享存储 2、metric-server2.1 metric-server概述2.2 部署metric-server:2.3 hpa案例 3、helm概述3.1 安装helm3.2 h…

RabbitMQ 高级特性——持久化

文章目录 前言持久化交换机持久化队列持久化消息持久化 前言 前面我们学习了 RabbitMQ 的高级特性——消息确认,消息确认可以保证消息传输过程的稳定性,但是在保证了消息传输过程的稳定性之后,还存在着其他的问题,我们都知道消息…

【rpg像素角色】俯视角-行走动画

制作像素角色的俯视角行走动画并不像看上去那么复杂,尤其是在你已经完成了角色的4个方向站立姿势之后(其中左右方向可以通过水平翻转实现)。接下来,我会一步步为你讲解如何制作行走动画。 1. 理解行走规律 在制作行走动画之前&am…

Spring Boot集成Akka Stream快速入门Demo

1.什么是Akka Stream? Akka Streams是一个用于处理和传输元素序列的库。它建立在Akka Actors之上,使流的摄入和处理变得简单。由于它是建立在Akka Actors之上的,它为Akka现有的actor模型提供了一个更高层次的抽象。Akka流由3个主要部分组成-…

从0开始学习RocketMQ:快速部署启动

快速部署 快速部署一个单节点单副本 RocketMQ 服务,并完成简单的消息收发。 安装Apache RocketMQ 下载地址:RocketMQ官网下载 这里我们下载二进制包:rocketmq-all-5.3.0-bin-release.zip 直接解压即可:tar -zxvf rocketmq-all…

光伏开发:工商业光伏的流程管理全面解析

一、项目准备阶段 1、资源寻觅与沟通 首要任务是寻找适合的工商业屋顶或空地资源,并与业主初步交流,了解其意向、屋顶条件及用电情况。这一阶段的关键在于建立信任关系,为后续工作奠定基础。 2、资料收集与核查 全面收集业主资料&#xff…

算法练习题26——多项式输出(模拟)

输入格式 输入共有 2 行 第一行 1 个整数,n,表示一元多项式的次数。 第二行有 n1 个整数,其中第 i 个整数表示第 n−i1 次项的系数,每两个整数之间用空格隔开。 输出格式 输出共 1 行,按题目所述格式输出多项式。…

Navicat BI 中创建自定义字段:计算字段

在数据库设计和开发中,避免存储任何可以从其他字段计算或重建的数据是一种惯例。因此,在 Navicat BI 中构建图表时,你可能会缺少一些数据。但这不是问题,因为 Navicat BI 提供了专门用于此目的的计算字段。在今天的博客中&#xf…

网站按钮检测系统源码分享

网站按钮检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

浅谈MVC设计模式

1 前言 1.1 内容概要 熟悉使用JSON工具,完成Java对象(Map)和Json字符串之间的相互转换(注意提供构造器和getter/setter方法) 注意事项:不管使用的是什么JSON工具,都要提供类的无参构造方法和…

基于SpringBoot的宠物寄领养网站管理系统

作者:计算机学姐 开发技术:SpringBoot、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:前后端分离项目源码、SpringBoot项目源码、SSM项目源码 系统展示 【前后端分离】基于JavaSpringBootVueMySQL的宠物寄领养网站…

北斗卫星系统信号介绍

覆盖范围亚太区域全球范围 卫星数量35颗区域服务卫星30颗全球服务卫星 信号频段B1I, B2IB1C, B2a, B3, 兼容GPS/Galileo 定位精度区域内10米全球2.5~5米,中国内更高 新增功能区域短报文通信全球短报文通信、星基增强、精密定位 抗干扰能力相对有限更强 互操作…

无人机 PX4 飞控 | 如何检测状态估计EKF性能

无人机 PX4 飞控 | 如何检测状态估计EKF性能 前言检查EKF性能缺少pyulog问题解决脚本崩溃没有输出文件生成对应文件 结语 前言 ECL (Estimation and Control Library,估计和控制库),其中的状态估计使用扩展卡尔曼滤波算法&#x…

图像检测【YOLOv5】——深度学习

Anaconda的安装配置:(Anaconda是一个开源的Python发行版本,包括Conda、Python以及很多安装好的工具包,比如:numpy,pandas等,其中conda是一个开源包和环境管理器,可以用于在同一个电脑…

计算机网络基本概述

欢迎大家订阅【计算机网络】学习专栏,开启你的计算机网络学习之旅! 文章目录 前言一、网络的基本概念二、集线器、交换机和路由器三、互连网与互联网四、网络的类型五、互连网的组成1. 边缘部分2. 核心部分 六、网络协议 前言 计算机网络是现代信息社会…