问题:在使用transformers调用huggingface生成式模型时,我们可以通过model.generate_ids
来获取模型生成的若干id,那么我应该如何查看这些生成id对应的logits大小(即生成概率)
在Hugging Face的Transformers库中,当你使用生成模型(如GPT-2、BERT、XLNet等)的.generate()
方法时,默认情况下不会直接返回生成每个token时的logits。但是,你可以通过一些方法来访问这些logits。
以下是你可以采取的步骤来获取生成每个token对应的logits:
- 自定义生成函数:使用
.generate()
方法时,可以传递一个自定义的output_scores=True
参数来获取每个生成的token的logits。
下面是一个例子,展示如何在使用.generate()
时获取logits:
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# 加载预训练模型和分词器
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# 输入文本
input_text = "The answer to the ultimate question of life, the universe, and everything is"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# 生成响应
gen_output = model.generate(input_ids, max_length=50, output_scores=True)
# 获取生成的ids和对应的logits
generated_ids = gen_output.sequences
logits = gen_output.scores
# logits是一个列表,其中每个元素是一个形状为[batch_size, vocab_size]的张量,对应于每个生成步骤的logits
# 如果你想查看第一个生成的token的logits:
first_token_logits = logits[0][0]
# 可以将logits转换为概率,使用softmax函数
import torch
probabilities = torch.nn.functional.softmax(first_token_logits, dim=-1)
# 现在你可以查看每个token的概率
print(probabilities)
这里设置max_length=50
,因此实际上输出的长度一定只有50,你可以设置地更长,让模型输出更多的tokens
在上面的代码中,logits[0][0]
代表第一个生成的token的logits。由于模型可能生成了多个token,logits
列表包含了每个生成步骤的logits。
2. 手动进行解码:如果你想要更细致地控制生成过程,可以手动实现解码逻辑,在每一步使用模型的.forward()
方法来获取logits,然后根据需要选择下一个token。
这是一个手动解码的简化例子:
input_text = "The answer to the ultimate question of life, the universe, and everything is"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# 初始化
past = None
for _ in range(50): # 生成50个token
outputs = model(input_ids, past_key_values=past)
next_token_logits = outputs.logits[:, -1, :]
past = outputs.past_key_values
# 可以在这里应用softmax获取概率
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
# 选择下一个token(例如,选择概率最高的token)
next_token = torch.argmax(next_token_logits, dim=-1)
# 将选择的token添加到输入序列
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
# 如果需要输出logits,可以保存起来
# ...
# 最后,解码生成的ids
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
print(generated_text)
请注意,在实际使用中,你可能需要考虑更多的因素,比如避免重复生成相同的token、应用温度调节(temperature)、top-k或top-p采样等。
这些步骤将帮助你获取模型在生成每个token时的logits,从而能够进一步分析生成的概率分布。
另外一个细节,通过:
outputs = model.generate(inputs.cuda(), max_length=64, output_scores=True, return_dict_in_generate=True)
即在model.generate
方法中添加output_scores
和return_dict_in_generate
两个参数,可以得到
generated_ids = gen_output.sequences
logits = gen_output.scores
两个重要的生成信息
此外,一些重要的信息:
关于past
中包含的kv cache,它是一个transformers.cache_utils.DynamicCache
对象,然后重要的是past.key_cache
和past.value_cache
,这两个是两个list,长度等于模型config.json
中的num_hidden_layers
大小(deepseek的qwen蒸馏系列,1.5B的是28,32B的是64)
key cache和value cache的大小都是一样的:
torch.Size([1, 2, 1643, 128]) torch.Size([1, 2, 1643, 128])
config.json
中并没有找到和1643及128相匹配的数字。
服务器启动jupyter
jupyter notebook --ip=0.0.0.0 --port=8888 --no-browser --allow-root
然后映射到本地即可(服务器端拿token):
勘误,以手动解码为例:
测试脚本:
# Manually decode and see the probability
def demo_4(mid=0):
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import json
import torch
model_paths = [
"/data/nishome/wangyinglin/caoyang/DeepSeek-R1-Distill-Qwen-1.5B",
"/data/nishome/wangyinglin/yangyitong/DeepSeek-R1-Distill-Qwen-32B",
]
model_path = model_paths[mid]
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).cuda()
prompt = "很久很久以前,"
inputs = tokenizer.encode(prompt, return_tensors="pt").cuda()
past = None
for i in range(64):
outputs = model(inputs, past_key_values=past)
next_token_logits = outputs.logits[:, -1, :]
past = outputs.past_key_values
next_token_probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
# next_token_id = torch.argmax(next_token_logits, dim=-1)
# inputs = torch.cat([inputs, next_token_id.unsqueeze(-1)], dim=-1)
next_token_ids = torch.topk(next_token_logits, k=2, dim=-1).indices
# print(inputs.size())
# print(next_token_ids.size())
with open(f"logging-{mid}.txt", 'a', encoding="utf8") as f:
f.write("Round: {}\n".format(i))
f.write("{}\t{}\n".format(outputs.logits.size(), next_token_logits.size()))
f.write("{}\t{}\n".format(past.key_cache[0].size(), past.key_cache[-1].size()))
f.write("{}\t{}\n".format(past.value_cache[0].size(), past.value_cache[-1].size()))
f.write("{}\t{}\n".format(inputs.size(), next_token_ids.size())) # [1, 5], [1, 2]
f.write("--------------------------------------------------------\n")
inputs = torch.cat([inputs, next_token_ids[:, -1].unsqueeze(-1)], dim=-1)
generated_text = tokenizer.decode(inputs[0], skip_special_tokens=True)
with open("demo4-generate.txt", 'w', encoding="utf8") as f:
f.write(generated_text)
输出:
Round: 0
torch.Size([1, 5, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 5, 128]) torch.Size([1, 2, 5, 128])
torch.Size([1, 2, 5, 128]) torch.Size([1, 2, 5, 128])
torch.Size([1, 5]) torch.Size([1, 2])
--------------------------------------------------------
Round: 1
torch.Size([1, 6, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 11, 128]) torch.Size([1, 2, 11, 128])
torch.Size([1, 2, 11, 128]) torch.Size([1, 2, 11, 128])
torch.Size([1, 6]) torch.Size([1, 2])
--------------------------------------------------------
Round: 2
torch.Size([1, 7, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 18, 128]) torch.Size([1, 2, 18, 128])
torch.Size([1, 2, 18, 128]) torch.Size([1, 2, 18, 128])
torch.Size([1, 7]) torch.Size([1, 2])
--------------------------------------------------------
Round: 3
torch.Size([1, 8, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 26, 128]) torch.Size([1, 2, 26, 128])
torch.Size([1, 2, 26, 128]) torch.Size([1, 2, 26, 128])
torch.Size([1, 8]) torch.Size([1, 2])
--------------------------------------------------------
Round: 4
torch.Size([1, 9, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 35, 128]) torch.Size([1, 2, 35, 128])
torch.Size([1, 2, 35, 128]) torch.Size([1, 2, 35, 128])
torch.Size([1, 9]) torch.Size([1, 2])
--------------------------------------------------------
Round: 5
torch.Size([1, 10, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 45, 128]) torch.Size([1, 2, 45, 128])
torch.Size([1, 2, 45, 128]) torch.Size([1, 2, 45, 128])
torch.Size([1, 10]) torch.Size([1, 2])
--------------------------------------------------------
Round: 6
torch.Size([1, 11, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 56, 128]) torch.Size([1, 2, 56, 128])
torch.Size([1, 2, 56, 128]) torch.Size([1, 2, 56, 128])
torch.Size([1, 11]) torch.Size([1, 2])
--------------------------------------------------------
Round: 7
torch.Size([1, 12, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 68, 128]) torch.Size([1, 2, 68, 128])
torch.Size([1, 2, 68, 128]) torch.Size([1, 2, 68, 128])
torch.Size([1, 12]) torch.Size([1, 2])
--------------------------------------------------------
Round: 8
torch.Size([1, 13, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 81, 128]) torch.Size([1, 2, 81, 128])
torch.Size([1, 2, 81, 128]) torch.Size([1, 2, 81, 128])
torch.Size([1, 13]) torch.Size([1, 2])
--------------------------------------------------------
Round: 9
torch.Size([1, 14, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 95, 128]) torch.Size([1, 2, 95, 128])
torch.Size([1, 2, 95, 128]) torch.Size([1, 2, 95, 128])
torch.Size([1, 14]) torch.Size([1, 2])
--------------------------------------------------------
Round: 10
torch.Size([1, 15, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 110, 128]) torch.Size([1, 2, 110, 128])
torch.Size([1, 2, 110, 128]) torch.Size([1, 2, 110, 128])
torch.Size([1, 15]) torch.Size([1, 2])
--------------------------------------------------------
Round: 11
torch.Size([1, 16, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 126, 128]) torch.Size([1, 2, 126, 128])
torch.Size([1, 2, 126, 128]) torch.Size([1, 2, 126, 128])
torch.Size([1, 16]) torch.Size([1, 2])
--------------------------------------------------------
Round: 12
torch.Size([1, 17, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 143, 128]) torch.Size([1, 2, 143, 128])
torch.Size([1, 2, 143, 128]) torch.Size([1, 2, 143, 128])
torch.Size([1, 17]) torch.Size([1, 2])
--------------------------------------------------------
Round: 13
torch.Size([1, 18, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 161, 128]) torch.Size([1, 2, 161, 128])
torch.Size([1, 2, 161, 128]) torch.Size([1, 2, 161, 128])
torch.Size([1, 18]) torch.Size([1, 2])
--------------------------------------------------------
Round: 14
torch.Size([1, 19, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 180, 128]) torch.Size([1, 2, 180, 128])
torch.Size([1, 2, 180, 128]) torch.Size([1, 2, 180, 128])
torch.Size([1, 19]) torch.Size([1, 2])
--------------------------------------------------------
Round: 15
torch.Size([1, 20, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 200, 128]) torch.Size([1, 2, 200, 128])
torch.Size([1, 2, 200, 128]) torch.Size([1, 2, 200, 128])
torch.Size([1, 20]) torch.Size([1, 2])
--------------------------------------------------------
Round: 16
torch.Size([1, 21, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 221, 128]) torch.Size([1, 2, 221, 128])
torch.Size([1, 2, 221, 128]) torch.Size([1, 2, 221, 128])
torch.Size([1, 21]) torch.Size([1, 2])
--------------------------------------------------------
Round: 17
torch.Size([1, 22, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 243, 128]) torch.Size([1, 2, 243, 128])
torch.Size([1, 2, 243, 128]) torch.Size([1, 2, 243, 128])
torch.Size([1, 22]) torch.Size([1, 2])
--------------------------------------------------------
Round: 18
torch.Size([1, 23, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 266, 128]) torch.Size([1, 2, 266, 128])
torch.Size([1, 2, 266, 128]) torch.Size([1, 2, 266, 128])
torch.Size([1, 23]) torch.Size([1, 2])
--------------------------------------------------------
Round: 19
torch.Size([1, 24, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 290, 128]) torch.Size([1, 2, 290, 128])
torch.Size([1, 2, 290, 128]) torch.Size([1, 2, 290, 128])
torch.Size([1, 24]) torch.Size([1, 2])
--------------------------------------------------------
Round: 20
torch.Size([1, 25, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 315, 128]) torch.Size([1, 2, 315, 128])
torch.Size([1, 2, 315, 128]) torch.Size([1, 2, 315, 128])
torch.Size([1, 25]) torch.Size([1, 2])
--------------------------------------------------------
Round: 21
torch.Size([1, 26, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 341, 128]) torch.Size([1, 2, 341, 128])
torch.Size([1, 2, 341, 128]) torch.Size([1, 2, 341, 128])
torch.Size([1, 26]) torch.Size([1, 2])
--------------------------------------------------------
Round: 22
torch.Size([1, 27, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 368, 128]) torch.Size([1, 2, 368, 128])
torch.Size([1, 2, 368, 128]) torch.Size([1, 2, 368, 128])
torch.Size([1, 27]) torch.Size([1, 2])
--------------------------------------------------------
Round: 23
torch.Size([1, 28, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 396, 128]) torch.Size([1, 2, 396, 128])
torch.Size([1, 2, 396, 128]) torch.Size([1, 2, 396, 128])
torch.Size([1, 28]) torch.Size([1, 2])
--------------------------------------------------------
Round: 24
torch.Size([1, 29, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 425, 128]) torch.Size([1, 2, 425, 128])
torch.Size([1, 2, 425, 128]) torch.Size([1, 2, 425, 128])
torch.Size([1, 29]) torch.Size([1, 2])
--------------------------------------------------------
Round: 25
torch.Size([1, 30, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 455, 128]) torch.Size([1, 2, 455, 128])
torch.Size([1, 2, 455, 128]) torch.Size([1, 2, 455, 128])
torch.Size([1, 30]) torch.Size([1, 2])
--------------------------------------------------------
Round: 26
torch.Size([1, 31, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 486, 128]) torch.Size([1, 2, 486, 128])
torch.Size([1, 2, 486, 128]) torch.Size([1, 2, 486, 128])
torch.Size([1, 31]) torch.Size([1, 2])
--------------------------------------------------------
Round: 27
torch.Size([1, 32, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 518, 128]) torch.Size([1, 2, 518, 128])
torch.Size([1, 2, 518, 128]) torch.Size([1, 2, 518, 128])
torch.Size([1, 32]) torch.Size([1, 2])
--------------------------------------------------------
Round: 28
torch.Size([1, 33, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 551, 128]) torch.Size([1, 2, 551, 128])
torch.Size([1, 2, 551, 128]) torch.Size([1, 2, 551, 128])
torch.Size([1, 33]) torch.Size([1, 2])
--------------------------------------------------------
Round: 29
torch.Size([1, 34, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 585, 128]) torch.Size([1, 2, 585, 128])
torch.Size([1, 2, 585, 128]) torch.Size([1, 2, 585, 128])
torch.Size([1, 34]) torch.Size([1, 2])
--------------------------------------------------------
Round: 30
torch.Size([1, 35, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 620, 128]) torch.Size([1, 2, 620, 128])
torch.Size([1, 2, 620, 128]) torch.Size([1, 2, 620, 128])
torch.Size([1, 35]) torch.Size([1, 2])
--------------------------------------------------------
Round: 31
torch.Size([1, 36, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 656, 128]) torch.Size([1, 2, 656, 128])
torch.Size([1, 2, 656, 128]) torch.Size([1, 2, 656, 128])
torch.Size([1, 36]) torch.Size([1, 2])
--------------------------------------------------------
Round: 32
torch.Size([1, 37, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 693, 128]) torch.Size([1, 2, 693, 128])
torch.Size([1, 2, 693, 128]) torch.Size([1, 2, 693, 128])
torch.Size([1, 37]) torch.Size([1, 2])
--------------------------------------------------------
Round: 33
torch.Size([1, 38, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 731, 128]) torch.Size([1, 2, 731, 128])
torch.Size([1, 2, 731, 128]) torch.Size([1, 2, 731, 128])
torch.Size([1, 38]) torch.Size([1, 2])
--------------------------------------------------------
Round: 34
torch.Size([1, 39, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 770, 128]) torch.Size([1, 2, 770, 128])
torch.Size([1, 2, 770, 128]) torch.Size([1, 2, 770, 128])
torch.Size([1, 39]) torch.Size([1, 2])
--------------------------------------------------------
Round: 35
torch.Size([1, 40, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 810, 128]) torch.Size([1, 2, 810, 128])
torch.Size([1, 2, 810, 128]) torch.Size([1, 2, 810, 128])
torch.Size([1, 40]) torch.Size([1, 2])
--------------------------------------------------------
Round: 36
torch.Size([1, 41, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 851, 128]) torch.Size([1, 2, 851, 128])
torch.Size([1, 2, 851, 128]) torch.Size([1, 2, 851, 128])
torch.Size([1, 41]) torch.Size([1, 2])
--------------------------------------------------------
Round: 37
torch.Size([1, 42, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 893, 128]) torch.Size([1, 2, 893, 128])
torch.Size([1, 2, 893, 128]) torch.Size([1, 2, 893, 128])
torch.Size([1, 42]) torch.Size([1, 2])
--------------------------------------------------------
Round: 38
torch.Size([1, 43, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 936, 128]) torch.Size([1, 2, 936, 128])
torch.Size([1, 2, 936, 128]) torch.Size([1, 2, 936, 128])
torch.Size([1, 43]) torch.Size([1, 2])
--------------------------------------------------------
Round: 39
torch.Size([1, 44, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 980, 128]) torch.Size([1, 2, 980, 128])
torch.Size([1, 2, 980, 128]) torch.Size([1, 2, 980, 128])
torch.Size([1, 44]) torch.Size([1, 2])
--------------------------------------------------------
Round: 40
torch.Size([1, 45, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1025, 128]) torch.Size([1, 2, 1025, 128])
torch.Size([1, 2, 1025, 128]) torch.Size([1, 2, 1025, 128])
torch.Size([1, 45]) torch.Size([1, 2])
--------------------------------------------------------
Round: 41
torch.Size([1, 46, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1071, 128]) torch.Size([1, 2, 1071, 128])
torch.Size([1, 2, 1071, 128]) torch.Size([1, 2, 1071, 128])
torch.Size([1, 46]) torch.Size([1, 2])
--------------------------------------------------------
Round: 42
torch.Size([1, 47, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1118, 128]) torch.Size([1, 2, 1118, 128])
torch.Size([1, 2, 1118, 128]) torch.Size([1, 2, 1118, 128])
torch.Size([1, 47]) torch.Size([1, 2])
--------------------------------------------------------
Round: 43
torch.Size([1, 48, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1166, 128]) torch.Size([1, 2, 1166, 128])
torch.Size([1, 2, 1166, 128]) torch.Size([1, 2, 1166, 128])
torch.Size([1, 48]) torch.Size([1, 2])
--------------------------------------------------------
Round: 44
torch.Size([1, 49, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1215, 128]) torch.Size([1, 2, 1215, 128])
torch.Size([1, 2, 1215, 128]) torch.Size([1, 2, 1215, 128])
torch.Size([1, 49]) torch.Size([1, 2])
--------------------------------------------------------
Round: 45
torch.Size([1, 50, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1265, 128]) torch.Size([1, 2, 1265, 128])
torch.Size([1, 2, 1265, 128]) torch.Size([1, 2, 1265, 128])
torch.Size([1, 50]) torch.Size([1, 2])
--------------------------------------------------------
Round: 46
torch.Size([1, 51, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1316, 128]) torch.Size([1, 2, 1316, 128])
torch.Size([1, 2, 1316, 128]) torch.Size([1, 2, 1316, 128])
torch.Size([1, 51]) torch.Size([1, 2])
--------------------------------------------------------
Round: 47
torch.Size([1, 52, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1368, 128]) torch.Size([1, 2, 1368, 128])
torch.Size([1, 2, 1368, 128]) torch.Size([1, 2, 1368, 128])
torch.Size([1, 52]) torch.Size([1, 2])
--------------------------------------------------------
Round: 48
torch.Size([1, 53, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1421, 128]) torch.Size([1, 2, 1421, 128])
torch.Size([1, 2, 1421, 128]) torch.Size([1, 2, 1421, 128])
torch.Size([1, 53]) torch.Size([1, 2])
--------------------------------------------------------
Round: 49
torch.Size([1, 54, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1475, 128]) torch.Size([1, 2, 1475, 128])
torch.Size([1, 2, 1475, 128]) torch.Size([1, 2, 1475, 128])
torch.Size([1, 54]) torch.Size([1, 2])
--------------------------------------------------------
Round: 50
torch.Size([1, 55, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1530, 128]) torch.Size([1, 2, 1530, 128])
torch.Size([1, 2, 1530, 128]) torch.Size([1, 2, 1530, 128])
torch.Size([1, 55]) torch.Size([1, 2])
--------------------------------------------------------
Round: 51
torch.Size([1, 56, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1586, 128]) torch.Size([1, 2, 1586, 128])
torch.Size([1, 2, 1586, 128]) torch.Size([1, 2, 1586, 128])
torch.Size([1, 56]) torch.Size([1, 2])
--------------------------------------------------------
Round: 52
torch.Size([1, 57, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1643, 128]) torch.Size([1, 2, 1643, 128])
torch.Size([1, 2, 1643, 128]) torch.Size([1, 2, 1643, 128])
torch.Size([1, 57]) torch.Size([1, 2])
--------------------------------------------------------
Round: 53
torch.Size([1, 58, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1701, 128]) torch.Size([1, 2, 1701, 128])
torch.Size([1, 2, 1701, 128]) torch.Size([1, 2, 1701, 128])
torch.Size([1, 58]) torch.Size([1, 2])
--------------------------------------------------------
Round: 54
torch.Size([1, 59, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1760, 128]) torch.Size([1, 2, 1760, 128])
torch.Size([1, 2, 1760, 128]) torch.Size([1, 2, 1760, 128])
torch.Size([1, 59]) torch.Size([1, 2])
--------------------------------------------------------
Round: 55
torch.Size([1, 60, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1820, 128]) torch.Size([1, 2, 1820, 128])
torch.Size([1, 2, 1820, 128]) torch.Size([1, 2, 1820, 128])
torch.Size([1, 60]) torch.Size([1, 2])
--------------------------------------------------------
Round: 56
torch.Size([1, 61, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1881, 128]) torch.Size([1, 2, 1881, 128])
torch.Size([1, 2, 1881, 128]) torch.Size([1, 2, 1881, 128])
torch.Size([1, 61]) torch.Size([1, 2])
--------------------------------------------------------
Round: 57
torch.Size([1, 62, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 1943, 128]) torch.Size([1, 2, 1943, 128])
torch.Size([1, 2, 1943, 128]) torch.Size([1, 2, 1943, 128])
torch.Size([1, 62]) torch.Size([1, 2])
--------------------------------------------------------
Round: 58
torch.Size([1, 63, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2006, 128]) torch.Size([1, 2, 2006, 128])
torch.Size([1, 2, 2006, 128]) torch.Size([1, 2, 2006, 128])
torch.Size([1, 63]) torch.Size([1, 2])
--------------------------------------------------------
Round: 59
torch.Size([1, 64, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2070, 128]) torch.Size([1, 2, 2070, 128])
torch.Size([1, 2, 2070, 128]) torch.Size([1, 2, 2070, 128])
torch.Size([1, 64]) torch.Size([1, 2])
--------------------------------------------------------
Round: 60
torch.Size([1, 65, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2135, 128]) torch.Size([1, 2, 2135, 128])
torch.Size([1, 2, 2135, 128]) torch.Size([1, 2, 2135, 128])
torch.Size([1, 65]) torch.Size([1, 2])
--------------------------------------------------------
Round: 61
torch.Size([1, 66, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2201, 128]) torch.Size([1, 2, 2201, 128])
torch.Size([1, 2, 2201, 128]) torch.Size([1, 2, 2201, 128])
torch.Size([1, 66]) torch.Size([1, 2])
--------------------------------------------------------
Round: 62
torch.Size([1, 67, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2268, 128]) torch.Size([1, 2, 2268, 128])
torch.Size([1, 2, 2268, 128]) torch.Size([1, 2, 2268, 128])
torch.Size([1, 67]) torch.Size([1, 2])
--------------------------------------------------------
Round: 63
torch.Size([1, 68, 151936]) torch.Size([1, 151936])
torch.Size([1, 2, 2336, 128]) torch.Size([1, 2, 2336, 128])
torch.Size([1, 2, 2336, 128]) torch.Size([1, 2, 2336, 128])
torch.Size([1, 68]) torch.Size([1, 2])
--------------------------------------------------------
这个outputs.logits
是逐步升高的,然后kvcache的第3个维度的数值的差值是一个等差数列,从5=>11=>18=>26,以此类推,当次增加的数值恰好是当次的n_token
数量,即恰好是按照一个下三角矩阵的累计数量一步步增加的,符合kvcache的逻辑,这样看,最后一个128就是某一层的hidden_size了
至于第k轮的输出,kvcache的第3维就是n + (n + 1) + (n + 2) + … + (n + k)=(k+1)n + k(k+1)/2,这里n=5,k=63,结果刚好就是2336,符合预期。