huggingface笔记:LLama 2

news2024/11/24 13:30:37

1 前提tip

1.1 使用什么数据类型训练模型?

  • Llama2模型是使用bfloat16训练的
    • 上传到Hub的检查点使用torch_dtype = 'float16',这将通过AutoModel API将检查点从torch.float32转换为torch.float16。
    • 在线权重的数据类型通常无关紧要,这是因为模型将首先下载(使用在线检查点的数据类型),然后转换为torch的默认数据类型(变为torch.float32),最后,如果配置中提供了torch_dtype,则会使用它。
  • 不建议在float16中训练模型,因为已知会产生nan;因此,应该在bfloat16中训练模型

1.2 Llama2 的tokenizer

  • LlaMA的tokenizer是基于sentencepiece的BPE模型。
  • sentencepiece的一个特点是,在解码序列时,如果第一个令牌是词的开头(例如“Banana”),令牌器不会在字符串前添加前缀空格。 

2  transformers.LlamaConfig

根据指定的参数实例化LLaMA模型,定义模型架构。使用默认值实例化配置将产生与LLaMA-7B相似的配置

2.1 参数介绍

vocab_size

(int, 可选,默认为32000) — LLaMA模型的词汇量大小。

定义 通过调用LlamaModel时传递的inputs_ids表示的不同令牌的数量。

hidden_size(int, 可选,默认为4096) — 隐藏表示的维度
intermediate_size(int, 可选,默认为11008) — MLP表示的维度
num_hidden_layers (int, 可选,默认为32) — 解码器中的隐藏层数量
num_attention_heads(int, 可选,默认为32) — 解码器中每个注意力层的注意力头数。
hidden_act(str或函数, 可选,默认为"silu") — 解码器中的非线性激活函数
max_position_embeddings

(int, 可选,默认为2048) — 该模型可能使用的最大序列长度。

Llama 1支持最多2048个令牌,Llama 2支持最多4096个,CodeLlama支持最多16384个。

initializer_range(float, 可选,默认为0.02) — 用于初始化所有权重矩阵的截断正态初始化器的标准差
rms_norm_eps(float, 可选,默认为1e-06) — rms归一化层使用的epsilon
use_cache(bool, 可选,默认为True) — 模型是否应返回最后的键/值注意力
pad_token_id(int, 可选) — 填充令牌id
bos_token_id(int, 可选,默认为1) — 开始流令牌id
eos_token_idint, 可选,默认为2) — 结束流令牌id
attention_bias(bool, 可选,默认为False) — 在自注意力过程中的查询、键、值和输出投影层中是否使用偏置
attention_dropout(float, 可选,默认为0.0) — 注意力概率的丢弃率
mlp_bias(bool, 可选,默认为False) — 在MLP层中的up_proj、down_proj和gate_proj层中是否使用偏置

2.2 举例

from transformers import LlamaModel, LlamaConfig


configuration = LlamaConfig()
# 默认是Llama-7B的配置
configuration
'''
LlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 11008,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 32,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.41.0",
  "use_cache": true,
  "vocab_size": 32000
}
'''

3 transformers.LlamaTokenizer 

构建一个Llama令牌器。基于字节级Byte-Pair-Encoding。

默认的填充令牌未设置,因为原始模型中没有填充令牌。

3.1 参数介绍

vocab_file(str) — 词汇文件的路径
unk_token

(str或tokenizers.AddedToken, 可选, 默认为"<unk>") — 未知令牌。

不在词汇表中的令牌将被设置为此令牌。

bos_token

(str或tokenizers.AddedToken, 可选, 默认为"<s>") — 预训练期间使用的序列开始令牌。

可以用作序列分类器令牌

eos_token

(str或tokenizers.AddedToken, 可选, 默认为"</s>") — 序列结束令牌

pad_token

(str或tokenizers.AddedToken, 可选) — 用于使令牌数组大小相同以便批处理的特殊令牌。

在注意力机制或损失计算中将其忽略。

add_bos_token(bool, 可选, 默认为True) — 是否在序列开始处添加bos_token。
add_eos_token(bool, 可选, 默认为False) — 是否在序列结束处添加eos_token。
use_default_system_prompt(bool, 可选, 默认为False) — 是否使用Llama的默认系统提示。

4  transformers.LlamaTokenizerFast

4.1 参数介绍

vocab_file(str) —SentencePiece文件(通常具有.model扩展名),包含实例化分词器所需的词汇表。
tokenizer_file(str, 可选) — 分词器文件(通常具有.json扩展名),包含加载分词器所需的所有内容。
clean_up_tokenization_spaces(bool, 可选, 默认为False) — 解码后是否清理空格,清理包括移除潜在的如额外空格等人工痕迹。
unk_token

(str或tokenizers.AddedToken, 可选, 默认为"<unk>") — 未知令牌。

不在词汇表中的令牌将被设置为此令牌。

bos_token

(str或tokenizers.AddedToken, 可选, 默认为"<s>") — 预训练期间使用的序列开始令牌。

可以用作序列分类器令牌

eos_token

(str或tokenizers.AddedToken, 可选, 默认为"</s>") — 序列结束令牌

pad_token

(str或tokenizers.AddedToken, 可选) — 用于使令牌数组大小相同以便批处理的特殊令牌。

在注意力机制或损失计算中将其忽略。

add_bos_token(bool, 可选, 默认为True) — 是否在序列开始处添加bos_token。
add_eos_token(bool, 可选, 默认为False) — 是否在序列结束处添加eos_token。
use_default_system_prompt(bool, 可选, 默认为False) — 是否使用Llama的默认系统提示。

 4.2 和 LlamaTokenizer的对比

调用from_pretrained从huggingface获取已有的tokenizer时,可以使用AutoTokenizer和LlamaTokenizerFast,不能使用LlamaTokenizer

from transformers import AutoTokenizer

tokenizer1=AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
tokenizer1.encode('Hello world! This is a test file for Llama tokenizer')
#[128000, 9906, 1917, 0, 1115, 374, 264, 1296, 1052, 369, 445, 81101, 47058]


from transformers import LlamaTokenizerFast

tokenizer2=LlamaTokenizerFast.from_pretrained('meta-llama/Meta-Llama-3-8B')
tokenizer2.encode('Hello world! This is a test file for Llama tokenizer')
#[128000, 9906, 1917, 0, 1115, 374, 264, 1296, 1052, 369, 445, 81101, 47058]
from transformers import LlamaTokenizer

tokenizer3=LlamaTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B')
tokenizer3.encode('Hello world! This is a test file for Llama tokenizer')

5 LlamaModel

5.1 参数介绍

  • config
    • 模型配置类,包含模型的所有参数
    • 使用配置文件初始化不会加载与模型相关的权重,只加载配置
    • 使用from_pretrained() 方法以加载模型权重

5.2 介绍

  • LLaMA 模型的基本形式,输出原始隐藏状态,顶部没有任何特定的头部。
    • 这个模型继承自 PreTrainedModel
    • 此模型也是 PyTorch torch.nn.Module 的子类
  • 解码器包括 config.num_hidden_layers 层。每层是一个 LlamaDecoderLayer

使用前面的LlamaConfig:

from transformers import LlamaModel, LlamaConfig


configuration = LlamaConfig()
# 默认是Llama-7B的配置

model=LlamaModel(configuration)
model
'''
LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)
'''
model.from_pretrained('meta-llama/Meta-Llama-3-8B')
model
'''
LlamaModel(
  (embed_tokens): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x LlamaDecoderLayer(
      (self_attn): LlamaSdpaAttention(
        (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        (rotary_emb): LlamaRotaryEmbedding()
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
        (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm()
      (post_attention_layernorm): LlamaRMSNorm()
    )
  )
  (norm): LlamaRMSNorm()
)
'''

5.3 forward方法

参数:

input_ids

torch.LongTensor,形状为 (batch_size, sequence_length))

输入序列token在词汇表中的索引

索引可以通过 AutoTokenizer 获取【PreTrainedTokenizer.encode()】

attention_mask

torch.Tensor,形状为 (batch_size, sequence_length),可选)

避免对填充标记索引执行注意力操作的掩码。

掩码值在 [0, 1] 中选择:

  • 1 表示未被掩盖的标记,
  • 0 表示被掩盖的标记。
inputs_embeds

(torch.FloatTensor,形状为 (batch_size, sequence_length, hidden_size)

选择性地,可以直接传递嵌入表示,而不是传递 input_ids

output_attentions(布尔值,可选)— 是否返回所有注意力层的注意力张量。
output_hidden_states(布尔值,可选)— 是否返回所有层的隐藏状态

5.3.1 举例

prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")


a=m.forward(inputs.input_ids)

a.keys()
#odict_keys(['last_hidden_state', 'past_key_values'])


a.last_hidden_state
'''
tensor([[[ 4.0064, -0.4994, -1.9927,  ..., -3.7454,  0.8413,  2.6989],
         [-1.5624,  0.5211,  0.1731,  ..., -1.5174, -2.2977, -0.3990],
         [-0.7521, -0.4335,  1.0871,  ..., -0.7031, -1.8011,  2.0173],
         ...,
         [-3.5611, -0.2674,  1.7693,  ..., -1.3848, -0.4413, -1.6342],
         [-1.2451,  1.5639,  1.5049,  ...,  0.5092, -1.2059, -2.3104],
         [-3.2812, -2.2462,  1.8884,  ...,  3.7066,  1.2010,  0.2117]]],
       grad_fn=<MulBackward0>)
'''

a.last_hidden_state.shape
#torch.Size([1, 13, 4096])


len(a.past_key_values)
#32


a.past_key_values[1][0].shape
#torch.Size([1, 8, 13, 128])

6 LlamaForCausalLM 

用于对话系统

6.1 forward方法

参数:

input_ids

torch.LongTensor,形状为 (batch_size, sequence_length))

输入序列token在词汇表中的索引

索引可以通过 AutoTokenizer 获取【PreTrainedTokenizer.encode()】

attention_mask

torch.Tensor,形状为 (batch_size, sequence_length),可选)

避免对填充标记索引执行注意力操作的掩码。

掩码值在 [0, 1] 中选择:

  • 1 表示未被掩盖的标记,
  • 0 表示被掩盖的标记。
inputs_embeds

(torch.FloatTensor,形状为 (batch_size, sequence_length, hidden_size)

选择性地,可以直接传递嵌入表示,而不是传递 input_ids

output_attentions(布尔值,可选)— 是否返回所有注意力层的注意力张量。
output_hidden_states(布尔值,可选)— 是否返回所有层的隐藏状态

6.2  举例

from transformers import LlamaForCausalLM

m1=LlamaForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B')
m1
'''
LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)
'''

结构是一样的

prompt = "Hey, are you conscious? Can you talk to me?"
inputs = tokenizer(prompt, return_tensors="pt")

tokenizer.batch_decode(m1.generate(inputs.input_ids, max_length=30))
'''
['<|begin_of_text|>Hey, are you conscious? Can you talk to me? Can you hear me? Are you sure you can hear me? Can you understand what']
'''


m1.generate(inputs.input_ids, max_length=30)
'''
tensor([[128000,  19182,     11,    527,    499,  17371,     30,   3053,    499,
           3137,    311,    757,     30,   3053,    499,   6865,    757,     30,
           8886,    499,   2103,  27027,     30,   3053,    499,   1518,    757,
             30,   3053,    499]])
'''

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1689923.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于springboot+vue+Mysql的校园台球厅人员与设备管理系统

开发语言&#xff1a;Java框架&#xff1a;springbootJDK版本&#xff1a;JDK1.8服务器&#xff1a;tomcat7数据库&#xff1a;mysql 5.7&#xff08;一定要5.7版本&#xff09;数据库工具&#xff1a;Navicat11开发软件&#xff1a;eclipse/myeclipse/ideaMaven包&#xff1a;…

在docker中运行SLAM十四讲程序

《十四讲》的示例程序依赖比较多&#xff0c;而且系统有点旧。可以在容器中运行。 拉取镜像 docker pull ddhogan/slambook:v0.1这个docker对应的github&#xff1a;HomeLH/slambook2-docker 拉下来之后&#xff0c;假如是Windows系统&#xff0c;需要使用XLaunch用于提供X11…

Playwright 隐藏浏览器指纹特征:注入stealth.min.js

引言 浏览器指纹技术通过分析用户的浏览器和操作系统信息来识别用户&#xff0c;这包括浏览器类型、版本、插件、屏幕分辨率等。在自动化测试和爬虫操作中&#xff0c;这些信息可能会暴露脚本的身份&#xff0c;导致被目标网站阻止。Playwright是一个跨浏览器的自动化库&#…

Spring Security整合Gitee第三方登录

文章目录 学习链接环境准备1. 搭建基本web应用引入依赖ThirdApp启动类创建index页面application.yml配置访问测试 2. 引入security引入依赖ProjectConfig访问测试 第三方认证简介注册gitee客户端实现1引入依赖application.yml配置文件创建index.html页面启动类InfoControllerPr…

html中table表格的行、列怎么进行合并

在HTML中&#xff0c;使用 <table> 元素来创建表格&#xff0c;而行&#xff08;tr&#xff09;和列&#xff08;td或th&#xff09;的合并可以通过 colspan和 rowspan 属性来实现。这两个属性允许单个表格单元格&#xff08;td或th&#xff09;跨越多个列或行。 colspa…

【FixBug】超级大Json转POJO失败

今天遇到了一个问题&#xff1a;使用Jackson将一个超级大的JSON字符串转换POJO失败&#xff0c;debug看没问题&#xff0c;将JSON字符串粘贴到main方法中测试&#xff0c;提示错误信息如下&#xff1a; 自己猜测是因为字符串超长导致转换时先截断字符串导致JSON格式不正确&…

QT5.15.2及以上版本安装

更新时间&#xff1a;2024-05-20 安装qt5.15以上版本 系统&#xff1a;ubuntu20.04.06 本文安装&#xff1a;linux-5.15.2 下载安装 # 安装编译套件g sudo apt-get install build-essential #安装OpenGL sudo apt-get install libgl1-mesa-dev# 下载qt安装器 https://downl…

安卓 逆向高级-人均瑞数

引言&#xff1a; JS 爬虫&#xff0c;绕不过瑞数这道坎&#xff0c;卡的死死的。一般网上的教程就是补环境什么的&#xff0c;我尝试了&#xff0c;可以但是比较麻烦。 今天说一种&#xff0c;秒过的方式&#xff0c;抗并发。那就是牛逼的RPC&#xff0c;hook JS 技术。 前期…

RedisTemplateAPI:List

文章目录 ⛄介绍⛄List的常见命令有⛄RedisTemplate API❄️❄️添加缓存❄️❄️将List放入缓存❄️❄️设置过期时间(单独设置)❄️❄️获取List缓存全部内容&#xff08;起始索引&#xff0c;结束索引&#xff09;❄️❄️从左或从右弹出一个元素❄️❄️根据索引查询元素❄…

【日记】跟奇安信斗智斗勇,败下阵来(416 字)

正文 今天一个客户都没有&#xff0c;让我快怀疑我们银行是不是要倒闭了…… 因为内外网 u 盘不知所踪&#xff0c;所以重新制了一个。深刻体会到了奇安信有多烂。有两个 u 盘&#xff0c;奇安信似乎把主控写坏了&#xff0c;插上电脑有反应&#xff0c;但是看不见盘符&#xf…

游戏陪玩/在线租号/任务系统网站源码

源码介绍 游戏陪玩系统/在线租号系统/小姐姐陪玩任务系统/网游主播任务威客平台源码/绝地吃鸡LOL在线下单/带手机端/声优线上游戏任务系统网站源码 界面美观,功能齐全,已对接支付,安装教程放源码压缩包里了! 界面截图 源码下载 https://download.csdn.net/download/huayula…

【Fiddler抓包工具】第四节.断点设置和弱网测试

文章目录 前言一、断点设置 1.1 全局断点 1.2 局部断点 1.3 打断点的几种常用命令 1.4 篡改响应报文二、弱网测试 2.1 网络限速 2.2 精准限速总结 前言 一、断点设置 1.1 全局断点 特点&#xff1a; 中断Fiddler捕获的所有请求&#xff0c;包括…

系统架构师考试(九)

TCP/IP协议族 SMTP是简单邮件传输协议 DNS 域名解析协议 URL - IP&#xff0c;通过URL解析ip是哪一台电脑 DHCP 动态IP地址分配的协议 SNMP 简单网络管理协议 TFTP 简单文件管理协议 ICMP 是网络中差错校验&#xff0c;差错报错的协议 IGMP G是组&#xff0c;组…

cuda11.2安装哪个版本的tensorflow-gpu

在官网上找到这个表格&#xff0c;因为自己的电脑一直配置的11.2的cuda&#xff0c;所以也不想换&#xff0c;最好就是安装一般能适应该版本的tensorflow&#xff0c;我配置了python3.8的环境&#xff0c;然后进行 pip install tensorflow-gpu2.6 回车就会自动从清华镜像上进…

基于vue3速学angular

因为工作原因&#xff0c;需要接手新的项目&#xff0c;新的项目是angular框架的&#xff0c;自学下和vue3的区别&#xff0c;写篇博客记录下&#xff1a; 参考&#xff1a;https://zhuanlan.zhihu.com/p/546843290?utm_id0 1.结构上&#xff1a; vue3:一个vue文件&#xff…

金职优学:分析央国企面试如何通关?

在当今竞争激烈的就业市场中&#xff0c;中央和国有企业&#xff08;以下简称“央国企”&#xff09;的面试机会对求职者来说是非常有吸引力的。这些企业通常拥有稳定的发展前景、良好的薪酬福利和广阔的职业发展空间。但是&#xff0c;要想成功通过央国企的面试&#xff0c;求…

力扣HOT100 - 31. 下一个排列

解题思路&#xff1a; 数字是逐步增大的 步骤如下&#xff1a; class Solution {public void nextPermutation(int[] nums) {int i nums.length - 2;while (i > 0 && nums[i] > nums[i 1]) i--;if (i > 0) {int j nums.length - 1;while (j > 0 &&…

2024年二建准考证打印入口已开通!

24年二建将于6月1日、2日举行&#xff0c;目前西藏、陕西准考证打印入口已开通&#xff0c;各省也将陆续开始准考证打印工作。 2024二建考试时间安排 2024二建准考证打印时间 二建准考证打印须知 01 准考证打印信息显示空白怎么办? 1)使用电脑自带的浏览器重新试一下。 2)…

【全开源】填表统计预约打卡表单系统FastAdmin+ThinkPHP+UniApp

简化流程&#xff0c;提升效率 一、引言&#xff1a;传统表单处理的局限性 在日常工作和生活中&#xff0c;我们经常会遇到需要填写表单、统计数据和预约打卡等场景。然而&#xff0c;传统的处理方式往往效率低下、易出错&#xff0c;且不利于数据的统计和分析。为了解决这些…

Spring Web MVC介绍及详细教程

目录 1.什么是Spring Web MVC&#xff1f; 1.1 MVC定义 1.2 Spring MVC与MVC关系 2.为什么要学习Spring MVC 3.项目创建 4.Spring MVC连接 4.1 RequestMapping 4.2 PostMapping和GetMapping 5.Spring MVC参数获取 5.1 获取单个参数 5.2 获取多个参数 5.3 获取普通对…