文章目录
- 1 合成数据与模型坍缩(model collapse),
- 1.1 递归生成数据与模型坍缩
- 1.2 三种错误
- 1.3 理论直觉
- 1.4 PPL指标
- 2 基于开源 LLM 实现 O1-like step by step 慢思考(slow thinking),ollama,streamlit
- 2.1 ollama(Structured outputs)
- 2.2 dynamic CoT(o1-like CoT)
- 2.3 streamlit run
- 3 [LLM+RL] model.generate 之 beam search decoding strategy(束搜索)
- 4 [LLM + RL] kimi 1.5 论文导读与 highlights
- 1 reasoning models
- 2 data
- RL算法
- 5 [LLM+RL] R1 论文导读,SFT vs. RL,RL 基础以及 GRPO 细节,以及一系列复现工作讨论
- 6 [LLM+RL] 理解 GRPO 公式原理及 TRL GrpoTrainer 代码实现(advantage 与 loss 计算)
- RL roadmap
- grpo demo
- GRPO
- 7 [LLM+RL] GRPO 中的 KL div(散度),reverse vs. forward,以及无偏估计(Schulman)
1 合成数据与模型坍缩(model collapse),
nature正刊:https://doi.org/10.1038/s41586-024-07566-y
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'
from IPython.display import Image
1.1 递归生成数据与模型坍缩
- nature cover: AI models collapse when trained on recursively generated data
- https://www.nature.com/articles/s41586-024-07566-y
- The Curse of Recursion: Training on Generated Data Makes Models Forget
- https://arxiv.org/abs/2305.17493
- Model Collapse refers to a degenerative learning process where models start forgetting improbable events over time, as the model becomes poisoned with its own projection of reality.
- forgetting improbable events
- ppl 的图:更长的尾部。后期代的模型开始生成原始模型永远不会生成的样本;
- 关注下实验设计
- 控制变量:no data preserved vs. 10% data preserved
- metrics:PPL
- 不自知地会去利用这样的数据,因为现实的互联网数据已大量地混入 aigc 的数据,真假难辨,尤其是2023年3月,GPT4 发布之后;
高概率的事件会被高估,低概率的事件会被低估,也就是数据不平衡带来的固有偏差。
从而模型遗忘低概率的事件,导致模型退化。
指标只用了一个PPL,实验设计上,分为完全不用真实数据 和 保留10%的真实数据 的对比。
你其实会不自知的使用了合成数据,因为互联网上已经出现了大量的合成数据(真实数据的分布被污染了)
1.2 三种错误
这三种误差随着模型的训练迭代会不断地加深。这三种误差加深的方式是不同的,functional expressive完全是线性的。
1.3 理论直觉
import numpy as np
import matplotlib.pyplot as plt
# 定义状态数量和每代的样本数
N = 4 # 状态数量
M = 50 # 每代的样本数
generations = 20 # 总共的代数
# 初始化为近似均匀分布
current_distribution = np.ones(N) / N
# 记录指定代数的分布
selected_generations = [0, 5, 10, 15]
distributions = {gen: None for gen in selected_generations}
distributions[0] = current_distribution.copy()
for gen in range(1, generations + 1):
# 从当前分布中抽样
samples = np.random.choice(N, size=M, p=current_distribution)
# 计算新的分布(频率)
new_distribution = np.zeros(N)
unique, counts = np.unique(samples, return_counts=True)
new_distribution[unique] = counts / M
# 更新当前分布
current_distribution = new_distribution
# 如果是选定的代数,记录分布
if gen in selected_generations:
distributions[gen] = current_distribution.copy()
# 检查是否只剩下一个状态(模型坍塌)
if np.count_nonzero(current_distribution) == 1:
print(f"Model collapsed at generation {gen}.")
# 填充剩余代数的分布
for future_gen in selected_generations:
if future_gen > gen and distributions[future_gen] is None:
distributions[future_gen] = current_distribution.copy()
break
# 绘制指定代数的pmf
colors = ['blue', 'green', 'orange', 'red']
labels = [f"Generation {gen}" for gen in selected_generations]
x = np.arange(N) # 状态的索引
plt.figure(figsize=(10, 6))
for idx, gen in enumerate(selected_generations):
if distributions[gen] is not None:
plt.bar(x + idx*0.2, distributions[gen], width=0.2, color=colors[idx], label=labels[idx])
plt.xlabel("State", fontsize=14)
plt.ylabel("Probability", fontsize=14)
plt.title("PMF Evolution Over Generations", fontsize=16)
plt.xticks(x + 0.3, [f"State {i}" for i in x])
plt.legend()
plt.show()
到state3时,红色的bar已经消失了,这是离散均匀分布的一个情况。
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
# 设置随机种子以保证结果可重复
np.random.seed(42)
# 定义初始均值和协方差矩阵
mean = np.array([0, 0])
cov = np.array([[5, 2], [2, 3]])
# 定义每代的样本数和总代数
M = 100 # 每代样本数
generations = 100 # 总共的代数
# 记录每一代的均值和协方差矩阵
means = [mean]
covariances = [cov]
for gen in range(1, generations + 1):
# 从当前分布中抽样
samples = np.random.multivariate_normal(mean, cov, size=M)
# 计算新的均值和协方差矩阵(无偏估计)
new_mean = np.mean(samples, axis=0)
new_cov = np.cov(samples, rowvar=False, bias=False)
# 更新均值和协方差矩阵
mean = new_mean
cov = new_cov
# 记录
means.append(mean)
covariances.append(cov)
# 选择要绘制的代数
selected_generations = [0, 25, 50, 75]
# 定义颜色
colors = ['blue', 'green', 'orange', 'red']
labels = [f"Generation {gen}" for gen in selected_generations]
# 绘制散点图
plt.figure(figsize=(8, 8))
for idx, gen in enumerate(selected_generations):
if gen <= generations:
# 从记录中获取均值和协方差矩阵
mean = means[gen]
cov = covariances[gen]
# 从当前分布中抽样用于可视化
samples = np.random.multivariate_normal(mean, cov, size=M)
# 绘制散点图
plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, color=colors[idx], label=labels[idx])
# 绘制协方差椭圆
eigenvalues, eigenvectors = np.linalg.eigh(cov)
order = eigenvalues.argsort()[::-1]
eigenvalues, eigenvectors = eigenvalues[order], eigenvectors[:, order]
angle = np.degrees(np.arctan2(*eigenvectors[:,0][::-1]))
width, height = 2 * np.sqrt(eigenvalues)
ellip = Ellipse(xy=mean, width=width, height=height, angle=angle, edgecolor=colors[idx], fc='None', lw=2)
plt.gca().add_artist(ellip)
plt.xlabel("X-axis", fontsize=14)
plt.ylabel("Y-axis", fontsize=14)
plt.title("Model Collapse in Multidimensional Gaussian", fontsize=16)
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()
原始是一个接近圆的椭圆,到75代时,已经是非常细长的一个长条椭圆了。
1.4 PPL指标
- Auto-regressive negative log likelihood loss
L = − 1 N ∑ i N log P ( y i ) P P L = exp ( − 1 N ∑ i N log P ( y i ) ) P P L = exp ( L ) \begin{split} &L=-\frac1N\sum_i^N \log P(y_i)\\ &PPL=\exp\left(-\frac1N\sum_i^N \log P(y_i)\right)\\ &PPL=\exp(L) \end{split} L=−N1i∑NlogP(yi)PPL=exp(−N1i∑NlogP(yi))PPL=exp(L)- minimize L == minimize PPL
- For example, if a language model has a PPL of 30, it means that on average the model is as uncertain as if it had to choose between 30 equally probable options for the next word.
PPL是对数似然的一个指数形式
# 检查是否有可用的 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据集
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
texts = dataset['text']
len(texts) # 4358
# 定义困惑度计算函数
def calculate_perplexity(model, tokenizer, text):
encodings = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
input_ids = encodings.input_ids.to(device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss.item()
ppl = torch.exp(torch.tensor(loss)).item()
return ppl
# 加载 GPT-2 模型和分词器
model_name_1 = 'gpt2'
tokenizer_1 = GPT2TokenizerFast.from_pretrained(model_name_1)
model_1 = GPT2LMHeadModel.from_pretrained(model_name_1).to(device)
model_1.eval()
# 加载 GPT-2-XL 模型和分词器
model_name_2 = 'gpt2-xl'
tokenizer_2 = GPT2TokenizerFast.from_pretrained(model_name_2)
model_2 = GPT2LMHeadModel.from_pretrained(model_name_2).to(device)
model_2.eval()
# 计算困惑度
num_samples = 500 # 样本数量,可以根据计算资源调整
sample_length = 100 # 每个样本的最大字符长度
ppl_values_1 = []
ppl_values_2 = []
for i in range(num_samples):
text = texts[i]
if len(text.strip()) == 0:
continue
text = text.strip()[:sample_length]
# 计算 GPT-2 的困惑度
ppl_1 = calculate_perplexity(model_1, tokenizer_1, text)
# 计算 GPT-2-XL 的困惑度
ppl_2 = calculate_perplexity(model_2, tokenizer_2, text)
# 过滤异常值
if 1 < ppl_1 < 1e4 and 1 < ppl_2 < 1e4:
ppl_values_1.append(ppl_1)
ppl_values_2.append(ppl_2)
# 对困惑度取对数
log_ppl_values_1 = np.log(ppl_values_1)
log_ppl_values_2 = np.log(ppl_values_2)
# 绘制困惑度分布的核密度估计图
plt.figure(figsize=(10, 6))
sns.kdeplot(log_ppl_values_1, fill=True, bw_adjust=0.5, color='blue', label='GPT-2')
sns.kdeplot(log_ppl_values_2, fill=True, bw_adjust=0.5, color='red', label='GPT-2-XL')
plt.xlabel('Log Perplexity')
plt.ylabel('Density')
plt.title('Comparison of Log Perplexity Distributions')
plt.legend()
plt.show()
上面这个图均值变低,尾部也变低,PPL越低越好吧。
PPL那个图里,在wiki数据集上算PPL,然后做迭代,初始Generation 0的分布是相对均匀了,然后一代一代生成之后,分布均值发生偏移,而且长尾现象很明显。但如果保留10%的数据,则坍缩的速度要慢很多。
2 基于开源 LLM 实现 O1-like step by step 慢思考(slow thinking),ollama,streamlit
参考资料:
- https://github.com/bklieger-groq/g1(开源项目)
- 2 classical query(容易犯错的问题)
- Which is larger, 0.9 or 0.11?(这个gpt4o只有30%的准确率,但是用下面的prompt寄巧能达到70%)
- How many Rs are in strawberry?
2.1 ollama(Structured outputs)
- 下载最新版 ollama,然后
pip install -U ollama
- https://ollama.com/blog/structured-outputs
- 资源释放:
- 还包括
ollama run llama3.1
对话结束之后输入/bye
还是不会自动资源释放; curl http://localhost:11434/api/generate -d '{"model": "qwen2.5", "keep_alive": 0}'
- 还包括
2.2 dynamic CoT(o1-like CoT)
"""You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.
Example of a valid JSON response:
```json
{
"title": "Identifying Key Information",
"content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
"next_action": "continue"
}```
"""
- messages 是对话式的
- system
- user (query)
- assistant (植入的)
- assistant (step by step)
- assistant (step by step)
- …
- 不断地追加进 messages,实现 dynamic 的 reasoning process
2.3 streamlit run
- 命令行执行如下命令:
streamlit run struct_llama_reasoning_app.py
llama_reasoning_app.py
import streamlit as st
import ollama
import os
import json
import time
def make_api_call(messages, max_tokens, is_final_answer=False):
for attempt in range(3):
try:
response = ollama.chat(
model="llama3.1:latest",
messages=messages,
options={"temperature":0.2, "num_predict":max_tokens},
format='json',
)
return json.loads(response['message']['content'])
except Exception as e:
if attempt == 2:
if is_final_answer:
return {"title": "Error", "content": f"Failed to generate final answer after 3 attempts. Error: {str(e)}"}
else:
return {"title": "Error", "content": f"Failed to generate step after 3 attempts. Error: {str(e)}", "next_action": "final_answer"}
time.sleep(1) # Wait for 1 second before retrying
def generate_response(prompt):
messages = [
{"role": "system", "content": """You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.
Example of a valid JSON response:
```json
{
"title": "Identifying Key Information",
"content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
"next_action": "continue"
}```
"""},
{"role": "user", "content": prompt},
{"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
]
steps = []
step_count = 1
total_thinking_time = 0
while True:
start_time = time.time()
step_data = make_api_call(messages, 300)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time
steps.append((f"Step {step_count}: {step_data['title']}", step_data['content'], thinking_time))
messages.append({"role": "assistant", "content": json.dumps(step_data)})
if step_data['next_action'] == 'final_answer' or step_count > 25: # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
break
step_count += 1
# Yield after each step for Streamlit to update
yield steps, None # We're not yielding the total time until the end
# Generate final answer
messages.append({"role": "user", "content": "Please provide the final answer based on your reasoning above."})
start_time = time.time()
final_data = make_api_call(messages, 200, is_final_answer=True)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time
steps.append(("Final Answer", final_data['content'], thinking_time))
yield steps, total_thinking_time
def main():
st.set_page_config(page_title="g1 prototype", page_icon="🧠", layout="wide")
st.title("g1: Using Llama-3.1 8b on local to create o1-like reasoning chains")
st.markdown("""
This is an early prototype of using prompting to create o1-like reasoning chains to improve output accuracy. It is not perfect and accuracy has yet to be formally evaluated. It is powered by Ollama.
Open source [repository here](https://github.com/bklieger-groq)
""")
# Text input for user query
user_query = st.text_input("Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
if user_query:
st.write("Generating response...")
# Create empty elements to hold the generated text and total time
response_container = st.empty()
time_container = st.empty()
# Generate and display the response
for steps, total_thinking_time in generate_response(user_query):
with response_container.container():
for i, (title, content, thinking_time) in enumerate(steps):
if title.startswith("Final Answer"):
st.markdown(f"### {title}")
st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
else:
with st.expander(title, expanded=True):
st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
# Only show total time when it's available at the end
if total_thinking_time is not None:
time_container.markdown(f"**Total thinking time: {total_thinking_time:.2f} seconds**")
if __name__ == "__main__":
main()
struct_llama_reasoning_app.py
import streamlit as st
import ollama
import os
import json
import time
from pydantic import BaseModel
from typing import Literal
class ReasoningStep(BaseModel):
title: str
content: str
next_action: Literal["continue", "final_answer"]
class FinalAnswer(BaseModel):
title: str
content: str
def make_api_call(messages, max_tokens, is_final_answer=False):
for attempt in range(3):
try:
format_schema = ReasoningStep if not is_final_answer else FinalAnswer
response = ollama.chat(
model="llama3.1:latest",
messages=messages,
options={"temperature":0.2, "num_predict":max_tokens},
format=format_schema.model_json_schema(),
)
return format_schema.model_validate_json(response.message.content)
except Exception as e:
if attempt == 2:
if is_final_answer:
return FinalAnswer(title="Error", content=f"Failed to generate final answer after 3 attempts. Error: {str(e)}")
else:
return ReasoningStep(title="Error",
content=f"Failed to generate step after 3 attempts. Error: {str(e)}", next_action="final_answer")
time.sleep(1) # Wait for 1 second before retrying
def generate_response(prompt):
messages = [
{"role": "system", "content": """You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.
Example of a valid JSON response:
```json
{
"title": "Identifying Key Information",
"content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
"next_action": "continue"
}```
"""},
{"role": "user", "content": prompt},
{"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
]
steps = []
step_count = 1
total_thinking_time = 0
while True:
start_time = time.time()
step_data = make_api_call(messages, 300)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time
steps.append((f"Step {step_count}: {step_data.title}", step_data.content, thinking_time))
messages.append({"role": "assistant", "content": step_data.model_dump_json()})
if step_data.next_action == 'final_answer' or step_count > 25: # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
break
step_count += 1
# Yield after each step for Streamlit to update
yield steps, None # We're not yielding the total time until the end
for msg in messages:
print(msg['role'], msg['content'][:20])
# Generate final answer
messages.append({"role": "user", "content": "Please provide the final answer based on your reasoning above."})
start_time = time.time()
final_data = make_api_call(messages, 200, is_final_answer=True)
end_time = time.time()
thinking_time = end_time - start_time
total_thinking_time += thinking_time
steps.append(("Final Answer", final_data.content, thinking_time))
yield steps, total_thinking_time
def main():
st.set_page_config(page_title="g1 prototype", page_icon="🧠", layout="wide")
st.title("g1: Using Llama-3.1 8b on local to create o1-like reasoning chains")
st.markdown("""
This is an early prototype of using prompting to create o1-like reasoning chains to improve output accuracy. It is not perfect and accuracy has yet to be formally evaluated. It is powered by Ollama.
Open source [repository here](https://github.com/bklieger-groq)
""")
# Text input for user query
user_query = st.text_input("Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
if user_query:
st.write("Generating response...")
# Create empty elements to hold the generated text and total time
response_container = st.empty()
time_container = st.empty()
# Generate and display the response
for steps, total_thinking_time in generate_response(user_query):
with response_container.container():
for i, (title, content, thinking_time) in enumerate(steps):
if title.startswith("Final Answer"):
st.markdown(f"### {title}")
st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
else:
with st.expander(title, expanded=True):
st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
# Only show total time when it's available at the end
if total_thinking_time is not None:
time_container.markdown(f"**Total thinking time: {total_thinking_time:.2f} seconds**")
if __name__ == "__main__":
main()
注意到上面的代码里:
Example of a valid JSON response:
```json
{
"title": "Identifying Key Information",
"content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
"next_action": "continue"
}```
"""},
{"role": "user", "content": prompt},
{"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
]
植入了一个assistant,这是非常妙的一种prompt寄巧,包括上面用大写部分表示强调(强调输出推理步骤)
system + assistant + user
3 [LLM+RL] model.generate 之 beam search decoding strategy(束搜索)
参考资料:
-
https://huggingface.co/blog/how-to-generate
-
https://huggingface.co/blog/constrained-beam-search
-
https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute
-
https://github.com/huggingface/search-and-learn
-
Generate multiple candidate solutions iteratively by maintaining a fixed number of “beams” or active paths N N N;
-
In the first iteration, sample N N N independent steps from the LLM with temperature T T T to introduce diversity in the responses. These steps are usually defined by a stopping criterion like terminating on a new line
\n
or double new line\n\n
. -
Score each step with the PRM and select the top N / M N/M N/M steps as candidates for the next round of generation. Here M M M denotes the “beam width” of a given active path. As in Best-of-N, we used the “last” reduction to score the partial solutions at each iteration.
-
Expand the steps selected in step (3) by sampling M M M next steps in the solution.
-
Repeat steps (3) and (4) until the EOS token is reached or the maximum search depth is exceeded.
-
greedy search => beam search
- greedy search:只选择 top1 logit 的 token
[batch_size, seq_length inc]
- beam search: 增加候选的数量,即束宽度:beam width
[batch_size * num_beams, seq_length inc]
- greedy search:只选择 top1 logit 的 token
-
model(input_ids)
:是一步; -
model.generate(input_ids)
:是多步,autoregressive 的生成;- max_length: input + max_new_length
from transformers import AutoTokenizer, AutoModelForCausalLM
prefixes = ["Once upon a time", "Hi I am a"]
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
input_ids = tokenizer(prefixes, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, num_beams=3, max_length=20)
output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
for text in output_text:
print(text)
这样输出解码结果,直接调num_beams参数即可,很简单。
greedy_output = model.generate(input_ids, max_length=20)
greedy_output
可以step by step
- log p 1 + log p 2 = log ( p 1 ⋅ p 2 ) \log p_1+\log p_2=\log (p_1\cdot p_2) logp1+logp2=log(p1⋅p2)
import torch
import torch.nn.functional as F
def show_beam_search_steps(model, tokenizer, prefix, num_beams=3, max_steps=3):
# 将输入文本转换为 token ids
input_ids = tokenizer(prefix, return_tensors="pt").input_ids
# 初始化 beam 状态
current_beams = [(input_ids, 0)] # (sequence, score)
print(f"\n开始处理前缀: '{prefix}'")
# 对每一步进行 beam search
for step in range(max_steps):
candidates = []
print(f"\n第 {step + 1} 步:")
# 对每个当前的 beam 进行扩展
for beam_ids, beam_score in current_beams:
# 获取模型输出
with torch.no_grad():
outputs = model(beam_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_probs = F.softmax(next_token_logits, dim=-1)
# 获取前 num_beams 个最可能的下一个 token
values, indices = torch.topk(next_token_probs, num_beams)
# 为每个可能的下一个 token 创建新的候选项
for value, index in zip(values[0], indices[0]):
new_ids = torch.cat([beam_ids, index.unsqueeze(0).unsqueeze(0)], dim=1)
new_score = beam_score + torch.log(value).item()
candidates.append((new_ids, new_score))
# 打印当前候选项
new_text = tokenizer.decode(new_ids[0])
print(f"候选项: {new_text}({new_ids[0].tolist()}) 分数: {new_score:.4f}")
# 选择前 num_beams 个最佳候选项
candidates.sort(key=lambda x: x[1], reverse=True)
current_beams = candidates[:num_beams]
print("\n选择的 beam:")
for beam_ids, beam_score in current_beams:
print(f"beam: {tokenizer.decode(beam_ids[0])}({beam_ids[0].tolist()}) 分数: {beam_score:.4f}")
show_beam_search_steps(model, tokenizer, prefixes[0])
输出结果:
开始处理前缀: 'Once upon a time'
第 1 步:
候选项: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
候选项: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
候选项: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.2029
选择的 beam:
beam: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
beam: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
beam: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.2029
第 2 步:
候选项: Once upon a time, the([7454, 2402, 257, 640, 11, 262]) 分数: -3.0524
候选项: Once upon a time, I([7454, 2402, 257, 640, 11, 314]) 分数: -3.6055
候选项: Once upon a time, it([7454, 2402, 257, 640, 11, 340]) 分数: -4.0718
候选项: Once upon a time the world([7454, 2402, 257, 640, 262, 995]) 分数: -6.5612
候选项: Once upon a time the sun([7454, 2402, 257, 640, 262, 4252]) 分数: -7.6559
候选项: Once upon a time the people([7454, 2402, 257, 640, 262, 661]) 分数: -7.7589
候选项: Once upon a time I was([7454, 2402, 257, 640, 314, 373]) 分数: -4.8048
候选项: Once upon a time I had([7454, 2402, 257, 640, 314, 550]) 分数: -5.7436
候选项: Once upon a time I thought([7454, 2402, 257, 640, 314, 1807]) 分数: -6.5309
选择的 beam:
beam: Once upon a time, the([7454, 2402, 257, 640, 11, 262]) 分数: -3.0524
beam: Once upon a time, I([7454, 2402, 257, 640, 11, 314]) 分数: -3.6055
beam: Once upon a time, it([7454, 2402, 257, 640, 11, 340]) 分数: -4.0718
第 3 步:
候选项: Once upon a time, the world([7454, 2402, 257, 640, 11, 262, 995]) 分数: -7.0757
候选项: Once upon a time, the people([7454, 2402, 257, 640, 11, 262, 661]) 分数: -8.2539
候选项: Once upon a time, the two([7454, 2402, 257, 640, 11, 262, 734]) 分数: -8.3031
候选项: Once upon a time, I was([7454, 2402, 257, 640, 11, 314, 373]) 分数: -5.5660
候选项: Once upon a time, I had([7454, 2402, 257, 640, 11, 314, 550]) 分数: -6.2778
候选项: Once upon a time, I would([7454, 2402, 257, 640, 11, 314, 561]) 分数: -6.8437
候选项: Once upon a time, it was([7454, 2402, 257, 640, 11, 340, 373]) 分数: -5.1921
候选项: Once upon a time, it seemed([7454, 2402, 257, 640, 11, 340, 3947]) 分数: -6.7970
候选项: Once upon a time, it would([7454, 2402, 257, 640, 11, 340, 561]) 分数: -6.8182
选择的 beam:
beam: Once upon a time, it was([7454, 2402, 257, 640, 11, 340, 373]) 分数: -5.1921
beam: Once upon a time, I was([7454, 2402, 257, 640, 11, 314, 373]) 分数: -5.5660
beam: Once upon a time, I had([7454, 2402, 257, 640, 11, 314, 550]) 分数: -6.2778
可以多试几次:
show_beam_search_steps(model, tokenizer, prefixes[1])
"""
开始处理前缀: 'Hi I am a'
第 1 步:
候选项: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
候选项: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
候选项: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127
选择的 beam:
beam: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
beam: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
beam: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127
第 2 步:
候选项: Hi I am a big fan([17250, 314, 716, 257, 1263, 4336]) 分数: -4.2283
候选项: Hi I am a big believer([17250, 314, 716, 257, 1263, 29546]) 分数: -7.1364
候选项: Hi I am a big supporter([17250, 314, 716, 257, 1263, 15525]) 分数: -8.3071
候选项: Hi I am a very good([17250, 314, 716, 257, 845, 922]) 分数: -6.7408
候选项: Hi I am a very nice([17250, 314, 716, 257, 845, 3621]) 分数: -7.1981
候选项: Hi I am a very happy([17250, 314, 716, 257, 845, 3772]) 分数: -7.3774
候选项: Hi I am a little bit([17250, 314, 716, 257, 1310, 1643]) 分数: -6.2787
候选项: Hi I am a little confused([17250, 314, 716, 257, 1310, 10416]) 分数: -7.0489
候选项: Hi I am a little disappointed([17250, 314, 716, 257, 1310, 11679]) 分数: -7.2741
选择的 beam:
beam: Hi I am a big fan([17250, 314, 716, 257, 1263, 4336]) 分数: -4.2283
beam: Hi I am a little bit([17250, 314, 716, 257, 1310, 1643]) 分数: -6.2787
beam: Hi I am a very good([17250, 314, 716, 257, 845, 922]) 分数: -6.7408
第 3 步:
候选项: Hi I am a big fan of([17250, 314, 716, 257, 1263, 4336, 286]) 分数: -4.3084
候选项: Hi I am a big fan and([17250, 314, 716, 257, 1263, 4336, 290]) 分数: -8.1861
候选项: Hi I am a big fan.([17250, 314, 716, 257, 1263, 4336, 13]) 分数: -8.3988
候选项: Hi I am a little bit of([17250, 314, 716, 257, 1310, 1643, 286]) 分数: -8.6324
候选项: Hi I am a little bit worried([17250, 314, 716, 257, 1310, 1643, 7960]) 分数: -9.4857
候选项: Hi I am a little bit older([17250, 314, 716, 257, 1310, 1643, 4697]) 分数: -9.5333
候选项: Hi I am a very good person([17250, 314, 716, 257, 845, 922, 1048]) 分数: -9.3998
候选项: Hi I am a very good friend([17250, 314, 716, 257, 845, 922, 1545]) 分数: -9.8805
候选项: Hi I am a very good student([17250, 314, 716, 257, 845, 922, 3710]) 分数: -10.3733
选择的 beam:
beam: Hi I am a big fan of([17250, 314, 716, 257, 1263, 4336, 286]) 分数: -4.3084
beam: Hi I am a big fan and([17250, 314, 716, 257, 1263, 4336, 290]) 分数: -8.1861
beam: Hi I am a big fan.([17250, 314, 716, 257, 1263, 4336, 13]) 分数: -8.3988
"""
4 [LLM + RL] kimi 1.5 论文导读与 highlights
https://github.com/chunhuizhang/llm_rl/blob/main/tutorials/r1-k1.5/k1.5.ipynb
- k1.5 和 R1
- 中文 reasoning models:更好的中文理解及运用,都是国产之光;
- openai reasoning llms 路线的探索和复现者
- 即如何利用 RL 更好地激化 LLMs 的长链推理能力;
- k1.5 应该是第一个次 kimi 发布的 technical report
- k1.5 和 R1 可以对比着看,互相补充,交叉验证一些内容
- k1.5 的细节更为丰富,全面;
- 都舍去了 mcts、value function、prm(process reward models),追求 simple & scaling;
- Simplistic Framework.
- 中文 reasoning models:更好的中文理解及运用,都是国产之光;
(舍去的东西未必是不好的)
- test cases
- Using the numbers {1,3,5,37}, create an equation that equals {24}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.(无解的问题?其实像很多不能通过穷举解决的问题CoT就很难做,就是动态规划哪些,但是本身这是程序的事情,并不是人类可以用思考解决的)
Alright, ... Let me think ... Hmm, ... Alternatively, ... Wait,
- 用忆秦娥的词牌,创作一手乡愁主题的词。
- 不是简单的语言问题,还包括很多很难的约束,平仄、押韵、重复等等;
- 学校组织出游,班长带了不超过150包的湿巾纸,如果40个人平均分则多7包,25个人平均分则多2包。问班长共带了多少包湿巾纸?
- 127:127 = 40*3 + 7;127 = 25*5 + 2(这个已经很简单了,都能做出来)
- Using the numbers {1,3,5,37}, create an equation that equals {24}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.(无解的问题?其实像很多不能通过穷举解决的问题CoT就很难做,就是动态规划哪些,但是本身这是程序的事情,并不是人类可以用思考解决的)
1 reasoning models
- think aloud:思考的具象化(这是个很好的概念)
- thinking process
- thinking steps
- thinking tokens
- long context, long cot;
- emergence
- planning, evaluation, reflection, exploration
2 data
数据从哪里来;
数据质量的定义;
- data quality
- Diverse Coverage;
- Balanced Difficulty;
- Accurate Evaluability; => reward design;
- rule-based reward design (vs. RLHF 中的 learned/neural reward modeling)
- r ( x , y , y ⋆ ) r(x,y,y^\star) r(x,y,y⋆)
- 数据从哪里来(非常非常 engineering)
- data triplets
- questions
- we employ automatic filters to select questions that require rich reasoning and are straightforward to evaluate.
- cot
- answer
- questions
- data triplets
RL算法
- PG (Policy gradient)
- TRPO, PPO (GRPO)
K1.5用的还跟R1不太一样
改进的策略优化
pg => PPO;
E τ ∼ π θ [ ∇ θ log π θ ( τ ) ⋅ R ( τ ) ] \mathbb E_{\tau\sim \pi_\theta}[\nabla_\theta \log\pi_\theta(\tau)\cdot R(\tau)] Eτ∼πθ[∇θlogπθ(τ)⋅R(τ)]
- On-policy: the agent learned and the agent interacting with the env is same;
- 边实践边学习;数据利用率低;
- policy-gradient (on-policy)
- Off-policy: the agent learned and the agent interacting with the env is different;
- 观察他人学习;数据利用率高;
E τ ∼ π θ ′ [ ∇ θ log π θ ( τ ) ⋅ R ( τ ) π θ ( τ ) π θ ′ ( τ ) ] \mathbb E_{\tau\sim \pi_{\theta'}}\left[\nabla_\theta \log\pi_\theta(\tau)\cdot R(\tau)\frac{\pi_\theta(\tau)}{\pi_{\theta'}(\tau)}\right] Eτ∼πθ′[∇θlogπθ(τ)⋅R(τ)πθ′(τ)πθ(τ)]
- 观察他人学习;数据利用率高;
5 [LLM+RL] R1 论文导读,SFT vs. RL,RL 基础以及 GRPO 细节,以及一系列复现工作讨论
提纲:
-
DeepSeekMath
- Data Pipeline
- GRPO
-
R1
- R1-Zero (AlphaZero-style)
- distillation
-
kimi-1.5: first tech report
- 与R1交叉验证,互相补充
- Long2short Distillation
-
S1
- 类比DeepSeek蒸馏的过程
-
Reasoning Models (Think aloud)
- reasoning tokens/steps/processes
- Long-CoT
- Expert CoT => Learned CoT
- reasoning tokens/steps/processes
-
Data Curation
-
SFT v.s. RL
- teach v.s. incentive(这个意思是说LLM不要去教它,而是去诱导和刺激它)
- RL hightlights & Roadmap
-
Unified Paradigm
- PG: SFT, TRPO, DPO, GRPO
-
复现:数据,算法,infra(架构)
思考的时间越多,就会有更高的准确率(更多的思考时间,之前OpenAI也有这个论断),下图就是说明(这个是R1的论文,直接超越o1)
同理思考的越多,输出的长度越长,下面KIMI的论文有类似的结论:
总之:思考越久=>输出越长=>正确率越高
LLM在推理时,如果只是一锤子买卖,一条路推下去,很容易是错的,因此在强化学习中,context window要足够大
下面是一个猜字谜的例子:
现在就是不断验证尝试,直到达到一个reward model给到的一个score高的解
这里会有一个问题,就是找麦田里最大麦穗的问题,何时停止?是否回溯用之前的解?这可能是一个研究方向。
然后这种反思机制的prompt template如下:
然后第二轮对话输入的只是第一轮的输入和输出,隐去了CoT,因为CoT太长了。
然后到第二张PPT
这里是OpenAI o1发布会的观点👆
复杂问题 不应该 像 简单问题 消耗 同样的 计算量
有传言OpenAI雇佣了名校博士去标注Reasoning过程(时薪几百刀)
如果人类足够团结,比如找一百万个专家,每人贡献100条推理路径,这个数据量应该足以让模型拟人了。
但是让LLM用RL自己去打磨更好,OpenAI o1的研究员说LLM打磨的CoT比人类标注的还要好。
几篇与Data curation相关的paper
DeepSeekMath主要是两个重点:
- 种子数据集的构建
- GRPO
种子数据集可以是GSM8K
- FastText Model:判断问题是否是数学题
- 然后去爬数据集
- 标注,一旦标注一个网站是数学题,该URL下所有的resource都认为是数学题
追求一种domain的相关性,并且很快的构建起来
这是一个system engineering的工作,其实是有很强的参考价值的。
之后这期还有一些关于GRPO的讨论,GRPO目前实现上有一些争议,主要是在于它看起来很简单,但是需要三重循环,可能效率不见得很高
6 [LLM+RL] 理解 GRPO 公式原理及 TRL GrpoTrainer 代码实现(advantage 与 loss 计算)
https://github.com/chunhuizhang/llm_rl/blob/main/tutorials/r1-k1.5/trl_grpo.ipynb
RL roadmap
- 基本概念: on-policy vs. off-policy
- 理解复杂的公式( π θ , π θ o l d , π r e f \pi_\theta,\pi_{\theta_{old}}, \pi_{ref} πθ,πθold,πref),理解计算过程;
- 对哪个概率分布进行采样获得数据;
- GRPO <= PPO(CLIP) <= TRPO <= PG (policy gradient)
- PPO: GAE (Generalized Advantage Estimation), TD-error
∇ θ J ( π θ ) = E τ ∼ ( π θ , T ) [ ∑ t = 0 T ∇ θ log π θ ( a t ∣ s t ) R ( τ ) ] \nabla_\theta J(\pi_\theta)=\mathbb E_{\tau \sim (\pi_\theta, T)}\left[\sum_{t=0}^T\nabla_\theta \log\pi_\theta(a_t|s_t)R(\tau)\right] ∇θJ(πθ)=Eτ∼(πθ,T)[t=0∑T∇θlogπθ(at∣st)R(τ)]
grpo demo
dataset: 7473;
https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py
- 单 4090:
export CUDA_VISIBLE_DEVICES=0 python grpo_demo.py
- 4/4 * 4 => 4,
- 1868;
- 13 小时;
- 4/4 * 4 => 4,
- 双 4090,ddp (accelerate map)
- (4/4 * 4 * 2) => 8;
- 934;
- < 10小时;
- (4/4 * 4 * 2) => 8;
- 双 4090,(accelerate config)
- deepspeed stage-2/3;
- fsdp
GRPOConfig & GRPOTrainer
- GRPOConfig
- num_generations=8,
- old
- per_device_train_batch_size=1, * gradient_accumulation_steps=8,
- per_device_train_batch_size * gradient_accumulation_steps * world_size ==> train_batch
- per_device_train_batch_size=1, * gradient_accumulation_steps=8,
- new: https://github.com/huggingface/trl/pull/2776#issue-2833772774
- per_device_train_batch_size:
- it now represents the number of generations per device.
- per_device_train_batch_size/num_generations * gradient_accumulation_steps
- per_device_train_batch_size/num_generations: prompts per device
- 也因此要求,per_device_train_batch_size 必须能被 num_generations 整除;
- per_device_train_batch_size:
GRPO
- GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training.
- four main steps:
- Generating completions,
- At each training step, we sample a batch of prompts and generate a set of
G
G
G completions(
num_generations
) for each prompt (denoted as o i o_i oi).
- At each training step, we sample a batch of prompts and generate a set of
G
G
G completions(
- computing the advantage,
- A ^ i , t = r i − μ ( r ) σ ( r ) \hat A_{i,t}=\frac{r_i-\mu(\mathbf r)}{\sigma(\mathbf r)} A^i,t=σ(r)ri−μ(r)
- Outcome supervision provides the normalized reward at the end of each output o i o_i oi and sets the advantages A ^ i , t \hat A_{i,t} A^i,t of all tokens in the output as the normalized reward
- estimating the KL divergence, (token-level see the figure)
- https://huggingface.co/docs/trl/main/en/grpo_trainer#estimating-the-kl-divergence
per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
- and computing the loss.
- π r e f , ( π θ o l d , π θ ) \pi_{ref}, (\pi_{\theta_{old}}, \pi_\theta) πref,(πθold,πθ)
- https://github.com/huggingface/trl/issues/2608
- Generating completions,
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
pipeline
- grpo_trainer.py
_prepare_inputs()
return { "prompt_ids": prompt_ids, "prompt_mask": prompt_mask, "completion_ids": completion_ids, "completion_mask": completion_mask, "ref_per_token_logps": ref_per_token_logps, "advantages": advantages, }
compute_loss()
qwen2.5 vocab size: 151936
-
prompts => tokenizer.apply_chat_template => model.generate
- prompt_completion_ids = prompt_ids + prompt_completion_ids
- rewards_per_func: (n_prompts, len(reward_funcs))
- ref_per_token_logps(
π
r
e
f
(
q
,
o
)
\pi_{ref}(q, o)
πref(q,o)), per_token_logps (
π
θ
(
q
,
o
)
\pi_\theta(q,o)
πθ(q,o))
- completion token level
- selective_log_softmax(logits, index)
- logits.shape: (n_prompts, n_completion, n_vocab)
- index.shape: (n_prompts, n_complection)
- => (n_prompts, n_complection)
exp ( log π ′ − log π ) = π ′ π \exp(\log{\pi'}-\log{\pi})=\frac{\pi'}{\pi} exp(logπ′−logπ)=ππ′
- selective_log_softmax(logits, index)
- completion token level
-
目前的实现只有 π θ , π r e f \pi_{\theta}, \pi_{ref} πθ,πref,没有 π θ o l d \pi_{\theta_{old}} πθold
- https://github.com/huggingface/trl/issues/2608
- The policy model only has a single update following each exploration stage. (deepseekmath)
-
π
θ
\pi_\theta
πθ 每次(rollout a group generations)只进行一次更新,而不是多次更新;
- 对应
for GRPO iteration = 1, . . . , 𝜇
(𝜇 == 1)
- 对应
- π θ o l d = π θ \pi_{\theta_{old}}=\pi_\theta πθold=πθ
torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
- per_token_logps.detach() 不参与计算图的梯度计算;
- 没有用到 clip,只有
π
θ
π
θ
o
l
d
A
=
1
⋅
A
\frac{\pi_\theta}{\pi_{\theta_{old}}}A=1\cdot A
πθoldπθA=1⋅A(ratio * advantage)
- ratio = 1,一定在 ( 1 − ϵ , 1 + ϵ ) (1-\epsilon, 1+\epsilon) (1−ϵ,1+ϵ) 之间的;
- https://github.com/huggingface/trl/issues/2608
training monitor
- You should rely mostly on the reward. And keep an eye on the generations (risk of reward hacking)
- https://github.com/huggingface/trl/issues/2703
7 [LLM+RL] GRPO 中的 KL div(散度),reverse vs. forward,以及无偏估计(Schulman)
https://www.bilibili.com/video/BV1rsAXe7EZs
毋庸置疑,就GRPO中的公式而言,KL散度在其中是reverse的
J G R P O ( θ ) = E q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ min ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ) − β D K L ( π θ ∣ ∣ π r e f ) ] (3) \mathcal{J}_{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right) - \beta D_{KL} (\pi_\theta || \pi_{ref}) \right] \tag{3} JGRPO(θ)=Eq∼P(Q),{oi}i=1G∼πθold(O∣q) G1i=1∑G∣oi∣1t=1∑∣oi∣min(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)A^i,t,clip(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),1−ε,1+ε)A^i,t)−βDKL(πθ∣∣πref) (3)
https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function
- The reference GRPO implementation uses the reverse KL divergence, not the forward KL divergence.
源码里算loss的时候是做了一个clipped的advantage,然后加上β乘以一个KL散度 - https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py(trl库最新版本已经把grpo给实现掉了)
- L = 1 n ∑ β D K L ( q ∥ p ) + A L=\frac1n\sum\beta D_{KL}(q\|p)+A L=n1∑βDKL(q∥p)+A
- q q q is the new trained model, and p p p is the original reference model.(q就是一个带参数的分布,需要迅雷)
- p p p 待逼近的分布;
- grpo
- q ( x ) = π θ ( o i , t ∣ q , o i , < t ) q(x) = \pi_{\theta}(o_{i,t}|q,o_{i,<t}) q(x)=πθ(oi,t∣q,oi,<t)
- p ( x ) = π r e f ( o i , t ∣ q , o i , < t ) p(x) = \pi_{ref}(o_{i,t}|q,o_{i,<t}) p(x)=πref(oi,t∣q,oi,<t)
J G R P O ( θ ) = E q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ min ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ) − β D K L ( π θ ∣ ∣ π r e f ) ] (3) \mathcal{J}_{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right) - \beta D_{KL} (\pi_\theta || \pi_{ref}) \right] \tag{3} JGRPO(θ)=Eq∼P(Q),{oi}i=1G∼πθold(O∣q) G1i=1∑G∣oi∣1t=1∑∣oi∣min(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t)A^i,t,clip(πθold(oi,t∣q,oi,<t)πθ(oi,t∣q,oi,<t),1−ε,1+ε)A^i,t)−βDKL(πθ∣∣πref) (3)
D K L [ π θ ∣ ∣ π r e f ] = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 \mathbb{D}_{KL}[\pi_{\theta}||\pi_{ref}] = \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta}(o_{i,t}|q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta}(o_{i,t}|q, o_{i,<t})} - 1 DKL[πθ∣∣πref]=πθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−logπθ(oi,t∣q,oi,<t)πref(oi,t∣q,oi,<t)−1
这个KL散度是一个无偏、且低方差的估计👆
John Schulman’s blog
http://joschu.net/blog/kl-approx.html,博客
r = p ( x ) q ( x ) r=\frac{p(x)}{q(x)} r=q(x)p(x)
第一种估计: − log r = log q ( x ) p ( x ) -\log r=\log \frac{q(x)}{p(x)} −logr=logp(x)q(x)
k l ( q ∥ p ) = ∑ x q ( x ) log q ( x ) p ( x ) = E x ∼ q [ log q ( x ) p ( x ) ] kl(q\|p)=\sum_x q(x)\log \frac{q(x)}{p(x)}=\mathbb E_{x\sim q}\left[\log \frac{q(x)}{p(x)}\right] kl(q∥p)=x∑q(x)logp(x)q(x)=Ex∼q[logp(x)q(x)]
- it has high-variance, as it’s negative for half of the samples, whereas KL is always positive.
- 采样点落在 q ( x ) < p ( x ) q(x)\lt p(x) q(x)<p(x) 时,
为什么说 log q ( x ) p ( x ) \log \frac{q(x)}{p(x)} logp(x)q(x)会有大约一半的样本得出来的结果是负的呢?下面是一个验证过程:
import torch.distributions as dis
import torch
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
torch.sum((q.log_prob(x) - p.log_prob(x)) < 0) # tensor(4799598)
虽然无偏,但方差过大
第二种估计: 1 2 ( log r ) 2 \frac12(\log r)^2 21(logr)2
这个比较复杂,可以自行查看John Schulman的博客内容:http://joschu.net/blog/kl-approx.html
第三种估计: r − 1 − log r r-1-\log r r−1−logr
这也是GRPO采取的一种估计方式
- k l ( q ∥ p ) = r − 1 − log r kl(q\|p)=r-1-\log r kl(q∥p)=r−1−logr (without q q q)
- 保证非负 log x ≤ x − 1 \log x\leq x-1 logx≤x−1,这点很重要,因为KL散度总是非负的,下面是一个证明:
D K L ( q ∥ p ) = ∑ q log q p = − ∑ q log p q = 1 − ∑ q log p q − 1 = ∑ p − ∑ q log p q − ∑ q = ∑ q p q − ∑ q log p q − ∑ q = ∑ q [ p q − log p q − 1 ] \begin{split} D_{KL}(q\|p)&=\sum q\log \frac{q}p\\ &=-\sum q\log\frac pq\\ &=1-\sum q\log\frac pq-1\\ &=\sum p-\sum q\log\frac pq-\sum q\\ &=\sum q\frac{p}q-\sum q\log\frac pq-\sum q\\ &=\sum q\left[\frac{p}{q}-\log\frac pq-1\right] \end{split} DKL(q∥p)=∑qlogpq=−∑qlogqp=1−∑qlogqp−1=∑p−∑qlogqp−∑q=∑qqp−∑qlogqp−∑q=∑q[qp−logqp−1]
import matplotlib.pyplot as plt
import numpy as np
xs = np.arange(0.01, 5, 0.01)
plt.plot(np.log(xs), label=r'$\log x$')
plt.plot(xs-1, label='$x-1$')
plt.legend()
log
x
<
=
x
−
1
\log x <= x - 1
logx<=x−1的图示
import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(q, p)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
print(k.mean(), (k.mean() - truekl) / truekl, k.std() / truekl)
上面的代码中truekl = dis.kl_divergence(q, p)
应该是写错了,它应该是一个reverse的,而非forward的,因此应该是倒过来(KL散度并不对称,qp和pq的值完全不一样)
输出结果:
true tensor(0.0050)
tensor(0.0050) tensor(0.0031) tensor(19.9963)
tensor(0.0050) tensor(0.0021) tensor(1.4167)
tensor(0.0050) tensor(-0.0004) tensor(1.4153)
Reverse vs Forward KL
- Integrate along the given axis using the composite trapezoidal rule.
def approx_kl(gmm_1, gmm_2, xs):
ys = gmm_1.pdf(xs) * (gmm_1.logpdf(xs) - gmm_2.logpdf(xs))
return np.trapz(ys, xs)
- https://www.tuananhle.co.uk/notes/reverse-forward-kl.html
- 从 p p p(目标分布)出发,因为它是确定的,优化的目标是参数化的 q ϕ q_\phi qϕ
- Reverse KL: Zero-Forcing/Mode-Seeking
-
q
log
q
p
q\log\frac qp
qlogpq(p接近0的时候q也要接近0,不然就无穷大了,这就是zero-forcing的性质,mode-seeking则是zero-forcing的一个推断,但不必然)
- forces
q
ϕ
q_\phi
qϕ to be zero where
p
p
p is zero
- zero-forcing => mode seeking (不必然)
- not always mode-seeking (subplot 2/3)
- forces
q
ϕ
q_\phi
qϕ to be zero where
p
p
p is zero
-
q
log
q
p
q\log\frac qp
qlogpq(p接近0的时候q也要接近0,不然就无穷大了,这就是zero-forcing的性质,mode-seeking则是zero-forcing的一个推断,但不必然)
- Forward KL: Mass-Covering/Mean-Seeking
-
p
log
p
q
p\log\frac pq
plogqp(p有值时,q就不能是一个接近0的值,否则就无穷大了,这就是mass-covering,因此就要把所有p有值的地方都拟合出来,这就是mean-seeking)
- there is some mass under q ϕ q_\phi qϕ wherever there is some mass under p p p
- q q q zero avoiding:避免出现 0;
-
p
log
p
q
p\log\frac pq
plogqp(p有值时,q就不能是一个接近0的值,否则就无穷大了,这就是mass-covering,因此就要把所有p有值的地方都拟合出来,这就是mean-seeking)
上图中,实线是p,即一个混合高斯分布,短的虚线是reverse的一个KL散度,长的虚线是forwar的一个KL散度
看最后两个图,典型的是一个mode-seeking,即拟合到双方的分布的模式上,但这不必然,比如看第3个图,前向和反向是基本相同的。