Llama开源代码详细解读(2)

news2024/9/9 5:06:31

FlashAttention

if is_flash_attn_available(): # 检查flashattention的可用性
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

FlashAttention是Tranformer模型中用于改进注意力机制的技术,主要目的是减少计算复杂度和内存占用。

  • flash_attn_func用于标准的flashattention计算。
  • flash_attn_varlen_func用于处理变长序列(长度未能确定)的flashattention计算。
  • index_first_axis用于处理第一个索引轴。
  • pad_input将数据进行填充处理,从而确定长度。
  • unpad_input将填充后的输入还原为原始形态。

Logging模块

logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LlamaConfig"

创建了名为logger的日志记录器对象,__name__用于保存模块的名称,确保每个模块都有自己的日志记录器。
_CONFIG_FOR_DOC前面带有下划线,因此可以看出其代表一个模块的内部变量。

get_unpad_data模块

def _get_unpad_data(padding_mask):
    seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32)
    indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten()
    max_seqlen_in_batch = seqlens_in_batch.max().item()
    cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
    return (
        indices,
        cu_seqlens,
        max_seqlen_in_batch,
    )

该模块的作用是padding_mask提取非填充的数据,分为以下几步:

  1. seqlens_in_batch计算量每个张量的有效长度,sum()函数计算每个张量的有效长度。dim等于-1意味着以最按照最后一个维度进行求和,如果是二维,就可以理解为对跨列操作,即计算了每一行非填充元素的个数。
  2. indices获取了非填充元素的索引。flatten()函数将张量展开成一维,as_tuple为flase意味着返回不是元组形式而是二维矩阵形式,由于返回的是二维矩阵,因此我们需要flatten()再次展平成一维。
    ——为什么不返回元组呢?
    如果返回元组,那么返回的格式是包含一个一维张量的元组,然后还需要从元组中取出这个一维张量,类似:
torch.nonzero(padding_mask.flatten(), as_tuple=True)[0]

这样比较麻烦,不如直接返回二维数组再展平。
3. max_seqlen_in_batch获取了在seqlens_in_batch中的最大值并返回(即长度最长的那一个),然后 item()函数的作用是将一个元素的张量转换为python对应的标量,即一个数。
4. cu_seqlens计算累计长度并进行填充。cumsum()函数用于计算指定维度的累计和,(1,0)意味着只在左边添加一个元素,右边不添加。F.pad()是为张量进行填充的函数。这对于处理变长序列非常有用,因为即获得了每个序列的开始索引,容易确定起始和结束位置。
5. 最终返回的包括:非零元素的索引,左边填充过了的累计长度,最长序列的长度。
从而,达到了提取非填充数据的目的。

_make_causal_mask模块

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

该模块用于生成因果掩码,通常用于双向自注意力机制。具体来说,该模块保证在计算注意力时,只能看到当前时间步之前的信息,而看不到未来的,来保持因果关系。

输入参数

input_ids_shape: torch.Size:输入张量的形状,通常为(batch_size, target_length)。
dtype: torch.dtype:用于生成掩码的张量类型。
device: torch.device:指定设备是GPU还是CPU。
past_key_values_length: int = 0:过去的键值对长度,用于增量计算。

步骤

  • 获取批大小(bsz)和目标长度(tgt_len)。
  • 创建初始掩码:形状为(tgt_len,tgt_len),所有的dtpye设置为最小,通常为负无穷。torch.full()函数创建一个均为min的矩阵
  • 设置掩码条件:mask_cond生成一个mask最后一个维度大小-1长度的序列,并放置在指定设备
  • mask.masked_fill:将下三角矩阵设置为0。这里用到了pytorch的广播机制,将一个行为1和列为1的向量扩充进行比较,从而将下三角都变为了0。
  • 将掩码转换为指定的数据类型。
  • 如果past_key_values_length>0,那么就在最后一个维度拼接上一个(tgt_len, past_key_values_length)的张量,这是为了在处理增量计算时,能够考虑过去的键值对。
    其中,zeros()函数创建了(tgt_len,past_key_value_length)的全零矩阵,用cat()在mask前添加了一个全0块。
  • 最终将遮罩的维度扩展为四维形状(bsz, 1, tgt_len, tgt_len + past_key_values_length),并返回。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1961486.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

盐分反演关键:批量计算常用的盐分指数反演变量

盐分反演关键&#xff1a;批量计算常用的盐分指数反演变量 一、引言 盐分指数反演是遥感应用中的一个重要方面&#xff0c;尤其在农业和环境监测中有着广泛的应用。通过遥感影像&#xff0c;研究人员可以高效地获取和分析地表盐分信息&#xff0c;为土地管理和作物生产提供重…

YOLOX+PyQt5交通路口智能监测平台设计与实现

1.概述 交通要道的路口上人车穿行&#xff0c;特别是上下班早高峰&#xff0c;且时常发生交通事故。因此对交通路口的车流量和人流量的监测必不可少。 2.检测模型 使用的检测模型为YOLOX模型&#xff0c;模型权重为训练VOC数据集得来&#xff0c;其中包括了二十个类别&#…

ONLYOFFICE 协作空间 2.6 已发布:表单填写房间、LDAP、优化房间和文件管理等

更新后的 ONLYOFFICE 协作空间带来了超过 20 项新功能和优化&#xff0c;让工作更加高效和舒适。阅读本文了解详情。 表单填写房间 这次更新增加了一种新的房间类型&#xff0c;可在 ONLYOFFICE 协作空间中组织简单的表单填写流程。 通过表单填写房间&#xff0c;目前可以完成…

仓库物品与装备物品位置更换

一、装备物品与选中的仓库物品位置交换 1、准备工作 2、Inventory Items 3、给Warehouse添加Grid Layout Group组件 4、复制Inventory Items&#xff0c;设置Grid Layout Group组件 5、创建文本ItemName和ItemDescription 6、设置物品数据 (1) 创建 ItemData.cs using Syst…

Spring boot tomcat 读写超时时间设置

yaml配置 connection-timeout: 20000 server:port: 9898servlet:context-path: /testtomcat:connection-timeout: 20000max-connections: 250accept-count: 300 spring源码设置自定义tomcat参数 customizeConnector(connector); Overridepublic WebServer getWebServer(Serv…

【MySQL】表的约束{ 常见约束 空属性 默认值 列描述comment zerofill 主键 复合主键 自增长 唯一键 外键 }

文章目录 常见约束空属性默认值列描述commentzerofill主键复合主键自增长唯一键外键 2.总结 真正约束字段的是数据类型&#xff0c;但是数据类型约束很单一&#xff0c;需要有一些额外的约束&#xff0c;更好的保证数据的合法性&#xff0c;从业务逻辑角度保证数据的正确性。比…

MySQL基础练习题12-使用唯一标识码替换员工ID

题目&#xff1a;展示每位用户的 唯一标识码&#xff08;unique ID &#xff09;&#xff1b;如果某位员工没有唯一标识码&#xff0c;使用 null 填充即可。 准备数据 分析数据 题目&#xff1a;展示每位用户的 唯一标识码&#xff08;unique ID &#xff09;&#xff1b;如果…

一, 创建工程,引入依赖

一&#xff0c; 创建工程&#xff0c;引入依赖 文章目录 一&#xff0c; 创建工程&#xff0c;引入依赖创建工程工程间的关系的建立配置各个工程当中的 pow 配置信息&#xff0c;相关的依赖父工程(也就是总项目工程)的 pow 配置demo-module06-generate 模块中pow 配置&#xff…

基于IEC61499标准的在线工业编程平台open61499

基于IEC61499标准的在线工业编程平台open61499是一个专为工业自动化领域设计的编程环境&#xff0c;它遵循IEC 61499标准&#xff0c;为开发者提供了一种高效、灵活的方式来创建、配置和管理分布式控制系统&#xff08;DCS&#xff09;的应用程序。以下是对open61499的详细解析…

LeetCode热题 翻转二叉树、二叉树最大深度、二叉树中序遍历

目录 一、翻转二叉树 1.1 题目链接 1.2 题目描述 1.3 解题思路 二、二叉树最大深度 2.1 题目链接 2.2 题目描述 2.3 解题思路 三、二叉树中序遍历 3.1 题目链接 3.2 题目描述 3.3 解题思路 一、翻转二叉树 1.1 题目链接 翻转二叉树 1.2 题目描述 1.3 解题思路 根…

【多模态大模型】 BLIP in ICML 2022

一、引言 论文&#xff1a; BLIP: Bootstrapping Language-Image Pre-training for Unified Vision-Language Understanding and Generation 作者&#xff1a; Salesforce Research 代码&#xff1a; BLIP 特点&#xff1a; 该方法分别使用ViT和BERT进行图像和文本特征提取&am…

【changchain-community安装失败】‘EntryPoints‘ object has no attribute ‘get‘报错解决

在安装changchain-community时报错信息如下&#xff1a; WARNING: Keyring is skipped due to an exception: EntryPoints object has no attribute get ERROR: Could not find a version that satisfies the requirement changchain-community ERROR: No matching distributio…

进程间通信与线程间通信的方法汇总

目录 一、进程间通信机制 管道(pipe)&#xff1a; 命名管道(FIFO)&#xff1a; 消息队列(MQ)&#xff1a; 信号量(semaphore)&#xff1a; 共享内存(shared memory)&#xff1a; 信号(signal)&#xff1a; 内存映射(mapped memory)&#xff1a; 内存映射和共享内存的区…

华杉研发九学习日记20 LinkedHashMap TreeMap Arrays 函数式接口 方法引用

华杉研发九学习日记20 一&#xff0c;LinkedHashMap 与HashMap相比&#xff0c;key是有序的 Map<Integer,String> map new LinkedHashMap<Integer,String>(); map.put(1, "one"); map.put(2, "two"); map.put(3, "three"); map.…

GitHub Desktop commit文件到repository

1. Clone a repository到本地 2. 在本地仓库修改/添加需要提交的文件或者文档 3. 添加comments并commit 4. 提交完成&#xff0c;点击Push origin提交代码到Github远程仓库 上传成功后&#xff0c;刷新Github网站页面就会出现上传的项目

鸿蒙应用框架开发【自绘编辑框】 输入法框架

自绘编辑框 介绍 本示例通过输入法框架实现自会编辑框&#xff0c;可以绑定输入法应用&#xff0c;从输入法应用输入内容&#xff0c;显示和隐藏输入法。 效果预览 使用说明 1.点击编辑框可以绑定并拉起输入法&#xff0c;可以从输入法键盘输入内容到编辑框。 2.可以点击a…

SSM老人服务管理系统小程序-计算机毕业设计源码91022

摘 要 21世纪的今天&#xff0c;随着社会的不断发展与进步&#xff0c;人们对于信息科学化的认识&#xff0c;已由低层次向高层次发展&#xff0c;由原来的感性认识向理性认识提高&#xff0c;管理工作的重要性已逐渐被人们所认识&#xff0c;科学化的管理&#xff0c;使信息存…

跨网段 IP 地址通信故障分析

现如今计算机网络的规模和复杂性不断增加&#xff0c;跨网段通信成为网络运行中的常见需求。但如果设备处于不同网段且路由设置出现偏差时就会导致通信故障&#xff0c;严重影响网络的正常运行和数据传输。 1.跨网段通信的基本原理 跨网段通信依赖于路由器的路由功能。路由器根…

vue3.0 入门基础知识汇总【1】 全面 精简 推荐

这篇博文主要对一些刚入门vue框架的同学&#xff0c;以及对vue基本知识进行巩固的&#xff0c;最后就是精简一下基本知识&#xff0c;以方便自己查看&#xff0c;感谢参考&#xff0c;有问题评论区交流&#xff0c;谢谢。 目录 1.component组件的基本结构和使用 2.method方法…

全网最适合入门的面向对象编程教程:28 类和对象的Python实现-Python编程原则、哲学和规范大汇总

全网最适合入门的面向对象编程教程&#xff1a;28 类和对象的 Python 实现-Python 编程原则、哲学和规范大汇总 摘要&#xff1a; 本文主要介绍了在使用 Python 进行面向对象编程时&#xff0c;Python 异常处理的原则-“请求谅解&#xff0c;而非许可”&#xff0c;以及软件设…