【学习笔记】LLM+RL

news2025/2/26 2:23:48

文章目录


1 合成数据与模型坍缩(model collapse),

nature正刊:https://doi.org/10.1038/s41586-024-07566-y

import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
from datasets import load_dataset
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

from IPython.display import Image

1.1 递归生成数据与模型坍缩

  • nature cover: AI models collapse when trained on recursively generated data
    • https://www.nature.com/articles/s41586-024-07566-y
    • The Curse of Recursion: Training on Generated Data Makes Models Forget
      • https://arxiv.org/abs/2305.17493
  • Model Collapse refers to a degenerative learning process where models start forgetting improbable events over time, as the model becomes poisoned with its own projection of reality.
    • forgetting improbable events
    • ppl 的图:更长的尾部。后期代的模型开始生成原始模型永远不会生成的样本;
  • 关注下实验设计
    • 控制变量:no data preserved vs. 10% data preserved
    • metrics:PPL
  • 不自知地会去利用这样的数据,因为现实的互联网数据已大量地混入 aigc 的数据,真假难辨,尤其是2023年3月,GPT4 发布之后;

在这里插入图片描述

高概率的事件会被高估,低概率的事件会被低估,也就是数据不平衡带来的固有偏差。

从而模型遗忘低概率的事件,导致模型退化。

在这里插入图片描述

指标只用了一个PPL,实验设计上,分为完全不用真实数据 和 保留10%的真实数据 的对比。

你其实会不自知的使用了合成数据,因为互联网上已经出现了大量的合成数据(真实数据的分布被污染了)

1.2 三种错误

在这里插入图片描述

这三种误差随着模型的训练迭代会不断地加深。这三种误差加深的方式是不同的,functional expressive完全是线性的。

1.3 理论直觉

import numpy as np
import matplotlib.pyplot as plt

# 定义状态数量和每代的样本数
N = 4      # 状态数量
M = 50     # 每代的样本数
generations = 20  # 总共的代数

# 初始化为近似均匀分布
current_distribution = np.ones(N) / N

# 记录指定代数的分布
selected_generations = [0, 5, 10, 15]
distributions = {gen: None for gen in selected_generations}
distributions[0] = current_distribution.copy()

for gen in range(1, generations + 1):
    # 从当前分布中抽样
    samples = np.random.choice(N, size=M, p=current_distribution)
    
    # 计算新的分布(频率)
    new_distribution = np.zeros(N)
    unique, counts = np.unique(samples, return_counts=True)
    new_distribution[unique] = counts / M
    
    # 更新当前分布
    current_distribution = new_distribution
    
    # 如果是选定的代数,记录分布
    if gen in selected_generations:
        distributions[gen] = current_distribution.copy()
    
    # 检查是否只剩下一个状态(模型坍塌)
    if np.count_nonzero(current_distribution) == 1:
        print(f"Model collapsed at generation {gen}.")
        # 填充剩余代数的分布
        for future_gen in selected_generations:
            if future_gen > gen and distributions[future_gen] is None:
                distributions[future_gen] = current_distribution.copy()
        break

# 绘制指定代数的pmf
colors = ['blue', 'green', 'orange', 'red']
labels = [f"Generation {gen}" for gen in selected_generations]

x = np.arange(N)  # 状态的索引

plt.figure(figsize=(10, 6))

for idx, gen in enumerate(selected_generations):
    if distributions[gen] is not None:
        plt.bar(x + idx*0.2, distributions[gen], width=0.2, color=colors[idx], label=labels[idx])

plt.xlabel("State", fontsize=14)
plt.ylabel("Probability", fontsize=14)
plt.title("PMF Evolution Over Generations", fontsize=16)
plt.xticks(x + 0.3, [f"State {i}" for i in x])
plt.legend()
plt.show()

在这里插入图片描述

到state3时,红色的bar已经消失了,这是离散均匀分布的一个情况。

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse

# 设置随机种子以保证结果可重复
np.random.seed(42)

# 定义初始均值和协方差矩阵
mean = np.array([0, 0])
cov = np.array([[5, 2], [2, 3]])

# 定义每代的样本数和总代数
M = 100    # 每代样本数
generations = 100  # 总共的代数

# 记录每一代的均值和协方差矩阵
means = [mean]
covariances = [cov]

for gen in range(1, generations + 1):
    # 从当前分布中抽样
    samples = np.random.multivariate_normal(mean, cov, size=M)
    
    # 计算新的均值和协方差矩阵(无偏估计)
    new_mean = np.mean(samples, axis=0)
    new_cov = np.cov(samples, rowvar=False, bias=False)
    
    # 更新均值和协方差矩阵
    mean = new_mean
    cov = new_cov
    
    # 记录
    means.append(mean)
    covariances.append(cov)

# 选择要绘制的代数
selected_generations = [0, 25, 50, 75]

# 定义颜色
colors = ['blue', 'green', 'orange', 'red']
labels = [f"Generation {gen}" for gen in selected_generations]

# 绘制散点图
plt.figure(figsize=(8, 8))

for idx, gen in enumerate(selected_generations):
    if gen <= generations:
        # 从记录中获取均值和协方差矩阵
        mean = means[gen]
        cov = covariances[gen]
        
        # 从当前分布中抽样用于可视化
        samples = np.random.multivariate_normal(mean, cov, size=M)
        
        # 绘制散点图
        plt.scatter(samples[:, 0], samples[:, 1], alpha=0.5, color=colors[idx], label=labels[idx])
        
        # 绘制协方差椭圆
        eigenvalues, eigenvectors = np.linalg.eigh(cov)
        order = eigenvalues.argsort()[::-1]
        eigenvalues, eigenvectors = eigenvalues[order], eigenvectors[:, order]
        angle = np.degrees(np.arctan2(*eigenvectors[:,0][::-1]))
        width, height = 2 * np.sqrt(eigenvalues)
        ellip = Ellipse(xy=mean, width=width, height=height, angle=angle, edgecolor=colors[idx], fc='None', lw=2)
        plt.gca().add_artist(ellip)
        
plt.xlabel("X-axis", fontsize=14)
plt.ylabel("Y-axis", fontsize=14)
plt.title("Model Collapse in Multidimensional Gaussian", fontsize=16)
plt.legend()
plt.grid(True)
plt.axis('equal')
plt.show()

在这里插入图片描述
原始是一个接近圆的椭圆,到75代时,已经是非常细长的一个长条椭圆了。

1.4 PPL指标

  • Auto-regressive negative log likelihood loss
    L = − 1 N ∑ i N log ⁡ P ( y i ) P P L = exp ⁡ ( − 1 N ∑ i N log ⁡ P ( y i ) ) P P L = exp ⁡ ( L ) \begin{split} &L=-\frac1N\sum_i^N \log P(y_i)\\ &PPL=\exp\left(-\frac1N\sum_i^N \log P(y_i)\right)\\ &PPL=\exp(L) \end{split} L=N1iNlogP(yi)PPL=exp(N1iNlogP(yi))PPL=exp(L)
    • minimize L == minimize PPL
  • For example, if a language model has a PPL of 30, it means that on average the model is as uncertain as if it had to choose between 30 equally probable options for the next word.

PPL是对数似然的一个指数形式

# 检查是否有可用的 GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载数据集
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
texts = dataset['text']
len(texts) # 4358
# 定义困惑度计算函数
def calculate_perplexity(model, tokenizer, text):
    encodings = tokenizer(text, return_tensors='pt', truncation=True, max_length=512)
    input_ids = encodings.input_ids.to(device)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
    ppl = torch.exp(torch.tensor(loss)).item()
    return ppl
# 加载 GPT-2 模型和分词器
model_name_1 = 'gpt2'
tokenizer_1 = GPT2TokenizerFast.from_pretrained(model_name_1)
model_1 = GPT2LMHeadModel.from_pretrained(model_name_1).to(device)
model_1.eval()

# 加载 GPT-2-XL 模型和分词器
model_name_2 = 'gpt2-xl'
tokenizer_2 = GPT2TokenizerFast.from_pretrained(model_name_2)
model_2 = GPT2LMHeadModel.from_pretrained(model_name_2).to(device)
model_2.eval()

# 计算困惑度
num_samples = 500  # 样本数量,可以根据计算资源调整
sample_length = 100  # 每个样本的最大字符长度

ppl_values_1 = []
ppl_values_2 = []

for i in range(num_samples):
    text = texts[i]
    if len(text.strip()) == 0:
        continue
    text = text.strip()[:sample_length]

    # 计算 GPT-2 的困惑度
    ppl_1 = calculate_perplexity(model_1, tokenizer_1, text)
    # 计算 GPT-2-XL 的困惑度
    ppl_2 = calculate_perplexity(model_2, tokenizer_2, text)

    # 过滤异常值
    if 1 < ppl_1 < 1e4 and 1 < ppl_2 < 1e4:
        ppl_values_1.append(ppl_1)
        ppl_values_2.append(ppl_2)

# 对困惑度取对数
log_ppl_values_1 = np.log(ppl_values_1)
log_ppl_values_2 = np.log(ppl_values_2)

# 绘制困惑度分布的核密度估计图
plt.figure(figsize=(10, 6))

sns.kdeplot(log_ppl_values_1, fill=True, bw_adjust=0.5, color='blue', label='GPT-2')
sns.kdeplot(log_ppl_values_2, fill=True, bw_adjust=0.5, color='red', label='GPT-2-XL')

plt.xlabel('Log Perplexity')
plt.ylabel('Density')
plt.title('Comparison of Log Perplexity Distributions')
plt.legend()
plt.show()

在这里插入图片描述

上面这个图均值变低,尾部也变低,PPL越低越好吧。

PPL那个图里,在wiki数据集上算PPL,然后做迭代,初始Generation 0的分布是相对均匀了,然后一代一代生成之后,分布均值发生偏移,而且长尾现象很明显。但如果保留10%的数据,则坍缩的速度要慢很多。


2 基于开源 LLM 实现 O1-like step by step 慢思考(slow thinking),ollama,streamlit

参考资料:

  • https://github.com/bklieger-groq/g1(开源项目)
  • 2 classical query(容易犯错的问题)
    • Which is larger, 0.9 or 0.11?(这个gpt4o只有30%的准确率,但是用下面的prompt寄巧能达到70%)
    • How many Rs are in strawberry?

2.1 ollama(Structured outputs)

  • 下载最新版 ollama,然后 pip install -U ollama
    • https://ollama.com/blog/structured-outputs
  • 资源释放:
    • 还包括 ollama run llama3.1 对话结束之后输入 /bye 还是不会自动资源释放;
    • curl http://localhost:11434/api/generate -d '{"model": "qwen2.5", "keep_alive": 0}'

2.2 dynamic CoT(o1-like CoT)

"""You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.

Example of a valid JSON response:
```json
{
    "title": "Identifying Key Information",
    "content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
    "next_action": "continue"
}```
"""
  • messages 是对话式的
    • system
    • user (query)
    • assistant (植入的)
    • assistant (step by step)
    • assistant (step by step)
  • 不断地追加进 messages,实现 dynamic 的 reasoning process

2.3 streamlit run

  • 命令行执行如下命令:streamlit run struct_llama_reasoning_app.py

llama_reasoning_app.py

import streamlit as st
import ollama
import os
import json
import time

def make_api_call(messages, max_tokens, is_final_answer=False):
    for attempt in range(3):
        try:
            response = ollama.chat(
                model="llama3.1:latest",
                messages=messages,
                options={"temperature":0.2, "num_predict":max_tokens},
                format='json',
            )
            return json.loads(response['message']['content'])
        except Exception as e:
            if attempt == 2:
                if is_final_answer:
                    return {"title": "Error", "content": f"Failed to generate final answer after 3 attempts. Error: {str(e)}"}
                else:
                    return {"title": "Error", "content": f"Failed to generate step after 3 attempts. Error: {str(e)}", "next_action": "final_answer"}
            time.sleep(1)  # Wait for 1 second before retrying

def generate_response(prompt):
    messages = [
        {"role": "system", "content": """You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.

Example of a valid JSON response:
```json
{
    "title": "Identifying Key Information",
    "content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
    "next_action": "continue"
}```
"""},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
    ]
    
    steps = []
    step_count = 1
    total_thinking_time = 0
    
    while True:
        start_time = time.time()
        step_data = make_api_call(messages, 300)
        end_time = time.time()
        thinking_time = end_time - start_time
        total_thinking_time += thinking_time
        
        steps.append((f"Step {step_count}: {step_data['title']}", step_data['content'], thinking_time))
        
        messages.append({"role": "assistant", "content": json.dumps(step_data)})
        
        if step_data['next_action'] == 'final_answer' or step_count > 25: # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
            break
        
        step_count += 1

        # Yield after each step for Streamlit to update
        yield steps, None  # We're not yielding the total time until the end

    # Generate final answer
    messages.append({"role": "user", "content": "Please provide the final answer based on your reasoning above."})
    
    start_time = time.time()
    final_data = make_api_call(messages, 200, is_final_answer=True)
    end_time = time.time()
    thinking_time = end_time - start_time
    total_thinking_time += thinking_time
    
    steps.append(("Final Answer", final_data['content'], thinking_time))

    yield steps, total_thinking_time

def main():
    st.set_page_config(page_title="g1 prototype", page_icon="🧠", layout="wide")
    
    st.title("g1: Using Llama-3.1 8b on local to create o1-like reasoning chains")
    
    st.markdown("""
    This is an early prototype of using prompting to create o1-like reasoning chains to improve output accuracy. It is not perfect and accuracy has yet to be formally evaluated. It is powered by Ollama.
                
    Open source [repository here](https://github.com/bklieger-groq)
    """)
    
    # Text input for user query
    user_query = st.text_input("Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
    
    if user_query:
        st.write("Generating response...")
        
        # Create empty elements to hold the generated text and total time
        response_container = st.empty()
        time_container = st.empty()
        
        # Generate and display the response
        for steps, total_thinking_time in generate_response(user_query):
            with response_container.container():
                for i, (title, content, thinking_time) in enumerate(steps):
                    if title.startswith("Final Answer"):
                        st.markdown(f"### {title}")
                        st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
                    else:
                        with st.expander(title, expanded=True):
                            st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
            
            # Only show total time when it's available at the end
            if total_thinking_time is not None:
                time_container.markdown(f"**Total thinking time: {total_thinking_time:.2f} seconds**")

if __name__ == "__main__":
    main()

struct_llama_reasoning_app.py

import streamlit as st
import ollama
import os
import json
import time

from pydantic import BaseModel
from typing import Literal

class ReasoningStep(BaseModel):
    title: str
    content: str
    next_action: Literal["continue", "final_answer"]

class FinalAnswer(BaseModel):
    title: str
    content: str

def make_api_call(messages, max_tokens, is_final_answer=False):
    for attempt in range(3):
        try:
            format_schema = ReasoningStep if not is_final_answer else FinalAnswer
            response = ollama.chat(
                model="llama3.1:latest",
                messages=messages,
                options={"temperature":0.2, "num_predict":max_tokens},
                format=format_schema.model_json_schema(),
            )
            return format_schema.model_validate_json(response.message.content)
        except Exception as e:
            if attempt == 2:
                if is_final_answer:
                    return FinalAnswer(title="Error", content=f"Failed to generate final answer after 3 attempts. Error: {str(e)}")
                else:
                    return ReasoningStep(title="Error", 
                                         content=f"Failed to generate step after 3 attempts. Error: {str(e)}", next_action="final_answer")
            time.sleep(1)  # Wait for 1 second before retrying

def generate_response(prompt):
    messages = [
        {"role": "system", "content": """You are an expert AI assistant that explains your reasoning step by step. For each step, provide a title that describes what you're doing in that step, along with the content. Decide if you need another step or if you're ready to give the final answer. Respond in JSON format with 'title', 'content', and 'next_action' (either 'continue' or 'final_answer') keys. USE AS MANY REASONING STEPS AS POSSIBLE. AT LEAST 3. BE AWARE OF YOUR LIMITATIONS AS AN LLM AND WHAT YOU CAN AND CANNOT DO. IN YOUR REASONING, INCLUDE EXPLORATION OF ALTERNATIVE ANSWERS. CONSIDER YOU MAY BE WRONG, AND IF YOU ARE WRONG IN YOUR REASONING, WHERE IT WOULD BE. FULLY TEST ALL OTHER POSSIBILITIES. YOU CAN BE WRONG. WHEN YOU SAY YOU ARE RE-EXAMINING, ACTUALLY RE-EXAMINE, AND USE ANOTHER APPROACH TO DO SO. DO NOT JUST SAY YOU ARE RE-EXAMINING. USE AT LEAST 3 METHODS TO DERIVE THE ANSWER. USE BEST PRACTICES.

Example of a valid JSON response:
```json
{
    "title": "Identifying Key Information",
    "content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
    "next_action": "continue"
}```
"""},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
    ]
    
    steps = []
    step_count = 1
    total_thinking_time = 0
    
    while True:
        start_time = time.time()
        step_data = make_api_call(messages, 300)
        end_time = time.time()
        thinking_time = end_time - start_time
        total_thinking_time += thinking_time
        
        steps.append((f"Step {step_count}: {step_data.title}", step_data.content, thinking_time))
        
        messages.append({"role": "assistant", "content": step_data.model_dump_json()})
        
        if step_data.next_action == 'final_answer' or step_count > 25: # Maximum of 25 steps to prevent infinite thinking time. Can be adjusted.
            break
        
        step_count += 1

        # Yield after each step for Streamlit to update
        yield steps, None  # We're not yielding the total time until the end

    for msg in messages:
        print(msg['role'], msg['content'][:20])
        
    # Generate final answer
    messages.append({"role": "user", "content": "Please provide the final answer based on your reasoning above."})
    
    start_time = time.time()
    final_data = make_api_call(messages, 200, is_final_answer=True)
    end_time = time.time()
    thinking_time = end_time - start_time
    total_thinking_time += thinking_time
    
    steps.append(("Final Answer", final_data.content, thinking_time))

    yield steps, total_thinking_time

def main():
    st.set_page_config(page_title="g1 prototype", page_icon="🧠", layout="wide")
    
    st.title("g1: Using Llama-3.1 8b on local to create o1-like reasoning chains")
    
    st.markdown("""
    This is an early prototype of using prompting to create o1-like reasoning chains to improve output accuracy. It is not perfect and accuracy has yet to be formally evaluated. It is powered by Ollama.
                
    Open source [repository here](https://github.com/bklieger-groq)
    """)
    
    # Text input for user query
    user_query = st.text_input("Enter your query:", placeholder="e.g., How many 'R's are in the word strawberry?")
    
    if user_query:
        st.write("Generating response...")
        
        # Create empty elements to hold the generated text and total time
        response_container = st.empty()
        time_container = st.empty()
        
        # Generate and display the response
        for steps, total_thinking_time in generate_response(user_query):
            with response_container.container():
                for i, (title, content, thinking_time) in enumerate(steps):
                    if title.startswith("Final Answer"):
                        st.markdown(f"### {title}")
                        st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
                    else:
                        with st.expander(title, expanded=True):
                            st.markdown(content.replace('\n', '<br>'), unsafe_allow_html=True)
            
            # Only show total time when it's available at the end
            if total_thinking_time is not None:
                time_container.markdown(f"**Total thinking time: {total_thinking_time:.2f} seconds**")

if __name__ == "__main__":
    main()

注意到上面的代码里:

Example of a valid JSON response:
```json
{
    "title": "Identifying Key Information",
    "content": "To begin solving this problem, we need to carefully examine the given information and identify the crucial elements that will guide our solution process. This involves...",
    "next_action": "continue"
}```
"""},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": "Thank you! I will now think step by step following my instructions, starting at the beginning after decomposing the problem."}
    ]

植入了一个assistant,这是非常妙的一种prompt寄巧,包括上面用大写部分表示强调(强调输出推理步骤)

system + assistant + user


3 [LLM+RL] model.generate 之 beam search decoding strategy(束搜索)

参考资料:

  • https://huggingface.co/blog/how-to-generate

  • https://huggingface.co/blog/constrained-beam-search

  • https://huggingface.co/spaces/HuggingFaceH4/blogpost-scaling-test-time-compute

  • https://github.com/huggingface/search-and-learn

  • Generate multiple candidate solutions iteratively by maintaining a fixed number of “beams” or active paths N N N;

  • In the first iteration, sample N N N independent steps from the LLM with temperature T T T to introduce diversity in the responses. These steps are usually defined by a stopping criterion like terminating on a new line \n or double new line \n\n.

  • Score each step with the PRM and select the top N / M N/M N/M steps as candidates for the next round of generation. Here M M M denotes the “beam width” of a given active path. As in Best-of-N, we used the “last” reduction to score the partial solutions at each iteration.

  • Expand the steps selected in step (3) by sampling M M M next steps in the solution.

  • Repeat steps (3) and (4) until the EOS token is reached or the maximum search depth is exceeded.

  • greedy search => beam search

    • greedy search:只选择 top1 logit 的 token
      • [batch_size, seq_length inc]
    • beam search: 增加候选的数量,即束宽度:beam width
      • [batch_size * num_beams, seq_length inc]
  • model(input_ids):是一步;

  • model.generate(input_ids):是多步,autoregressive 的生成;

    • max_length: input + max_new_length
from transformers import AutoTokenizer, AutoModelForCausalLM

prefixes = ["Once upon a time", "Hi I am a"]
model_name = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

input_ids = tokenizer(prefixes, return_tensors="pt").input_ids
output_ids = model.generate(input_ids, num_beams=3, max_length=20)

output_text = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
for text in output_text:
    print(text)

这样输出解码结果,直接调num_beams参数即可,很简单。

greedy_output = model.generate(input_ids, max_length=20)
greedy_output

可以step by step

  • log ⁡ p 1 + log ⁡ p 2 = log ⁡ ( p 1 ⋅ p 2 ) \log p_1+\log p_2=\log (p_1\cdot p_2) logp1+logp2=log(p1p2)
import torch
import torch.nn.functional as F

def show_beam_search_steps(model, tokenizer, prefix, num_beams=3, max_steps=3):
    # 将输入文本转换为 token ids
    input_ids = tokenizer(prefix, return_tensors="pt").input_ids
    
    # 初始化 beam 状态
    current_beams = [(input_ids, 0)]  # (sequence, score)
    
    print(f"\n开始处理前缀: '{prefix}'")
    
    # 对每一步进行 beam search
    for step in range(max_steps):
        candidates = []
        print(f"\n第 {step + 1} 步:")
        
        # 对每个当前的 beam 进行扩展
        for beam_ids, beam_score in current_beams:
            # 获取模型输出
            with torch.no_grad():
                outputs = model(beam_ids)
                next_token_logits = outputs.logits[:, -1, :]
                next_token_probs = F.softmax(next_token_logits, dim=-1)
            
            # 获取前 num_beams 个最可能的下一个 token
            values, indices = torch.topk(next_token_probs, num_beams)
            
            # 为每个可能的下一个 token 创建新的候选项
            for value, index in zip(values[0], indices[0]):
                new_ids = torch.cat([beam_ids, index.unsqueeze(0).unsqueeze(0)], dim=1)
                new_score = beam_score + torch.log(value).item()
                candidates.append((new_ids, new_score))
                
                # 打印当前候选项
                new_text = tokenizer.decode(new_ids[0])
                print(f"候选项: {new_text}({new_ids[0].tolist()}) 分数: {new_score:.4f}")
        
        # 选择前 num_beams 个最佳候选项
        candidates.sort(key=lambda x: x[1], reverse=True)
        current_beams = candidates[:num_beams]
        print("\n选择的 beam:")
        for beam_ids, beam_score in current_beams:
            print(f"beam: {tokenizer.decode(beam_ids[0])}({beam_ids[0].tolist()}) 分数: {beam_score:.4f}")
            
show_beam_search_steps(model, tokenizer, prefixes[0])

输出结果:


开始处理前缀: 'Once upon a time'1:
候选项: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
候选项: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
候选项: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.2029

选择的 beam:
beam: Once upon a time,([7454, 2402, 257, 640, 11]) 分数: -0.8512
beam: Once upon a time the([7454, 2402, 257, 640, 262]) 分数: -2.7396
beam: Once upon a time I([7454, 2402, 257, 640, 314]) 分数: -3.20292:
候选项: Once upon a time, the([7454, 2402, 257, 640, 11, 262]) 分数: -3.0524
候选项: Once upon a time, I([7454, 2402, 257, 640, 11, 314]) 分数: -3.6055
候选项: Once upon a time, it([7454, 2402, 257, 640, 11, 340]) 分数: -4.0718
候选项: Once upon a time the world([7454, 2402, 257, 640, 262, 995]) 分数: -6.5612
候选项: Once upon a time the sun([7454, 2402, 257, 640, 262, 4252]) 分数: -7.6559
候选项: Once upon a time the people([7454, 2402, 257, 640, 262, 661]) 分数: -7.7589
候选项: Once upon a time I was([7454, 2402, 257, 640, 314, 373]) 分数: -4.8048
候选项: Once upon a time I had([7454, 2402, 257, 640, 314, 550]) 分数: -5.7436
候选项: Once upon a time I thought([7454, 2402, 257, 640, 314, 1807]) 分数: -6.5309

选择的 beam:
beam: Once upon a time, the([7454, 2402, 257, 640, 11, 262]) 分数: -3.0524
beam: Once upon a time, I([7454, 2402, 257, 640, 11, 314]) 分数: -3.6055
beam: Once upon a time, it([7454, 2402, 257, 640, 11, 340]) 分数: -4.07183:
候选项: Once upon a time, the world([7454, 2402, 257, 640, 11, 262, 995]) 分数: -7.0757
候选项: Once upon a time, the people([7454, 2402, 257, 640, 11, 262, 661]) 分数: -8.2539
候选项: Once upon a time, the two([7454, 2402, 257, 640, 11, 262, 734]) 分数: -8.3031
候选项: Once upon a time, I was([7454, 2402, 257, 640, 11, 314, 373]) 分数: -5.5660
候选项: Once upon a time, I had([7454, 2402, 257, 640, 11, 314, 550]) 分数: -6.2778
候选项: Once upon a time, I would([7454, 2402, 257, 640, 11, 314, 561]) 分数: -6.8437
候选项: Once upon a time, it was([7454, 2402, 257, 640, 11, 340, 373]) 分数: -5.1921
候选项: Once upon a time, it seemed([7454, 2402, 257, 640, 11, 340, 3947]) 分数: -6.7970
候选项: Once upon a time, it would([7454, 2402, 257, 640, 11, 340, 561]) 分数: -6.8182

选择的 beam:
beam: Once upon a time, it was([7454, 2402, 257, 640, 11, 340, 373]) 分数: -5.1921
beam: Once upon a time, I was([7454, 2402, 257, 640, 11, 314, 373]) 分数: -5.5660
beam: Once upon a time, I had([7454, 2402, 257, 640, 11, 314, 550]) 分数: -6.2778

可以多试几次:

show_beam_search_steps(model, tokenizer, prefixes[1])
"""

开始处理前缀: 'Hi I am a'

第 1 步:
候选项: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
候选项: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
候选项: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127

选择的 beam:
beam: Hi I am a big([17250, 314, 716, 257, 1263]) 分数: -3.8471
beam: Hi I am a very([17250, 314, 716, 257, 845]) 分数: -4.0766
beam: Hi I am a little([17250, 314, 716, 257, 1310]) 分数: -4.1127

第 2 步:
候选项: Hi I am a big fan([17250, 314, 716, 257, 1263, 4336]) 分数: -4.2283
候选项: Hi I am a big believer([17250, 314, 716, 257, 1263, 29546]) 分数: -7.1364
候选项: Hi I am a big supporter([17250, 314, 716, 257, 1263, 15525]) 分数: -8.3071
候选项: Hi I am a very good([17250, 314, 716, 257, 845, 922]) 分数: -6.7408
候选项: Hi I am a very nice([17250, 314, 716, 257, 845, 3621]) 分数: -7.1981
候选项: Hi I am a very happy([17250, 314, 716, 257, 845, 3772]) 分数: -7.3774
候选项: Hi I am a little bit([17250, 314, 716, 257, 1310, 1643]) 分数: -6.2787
候选项: Hi I am a little confused([17250, 314, 716, 257, 1310, 10416]) 分数: -7.0489
候选项: Hi I am a little disappointed([17250, 314, 716, 257, 1310, 11679]) 分数: -7.2741

选择的 beam:
beam: Hi I am a big fan([17250, 314, 716, 257, 1263, 4336]) 分数: -4.2283
beam: Hi I am a little bit([17250, 314, 716, 257, 1310, 1643]) 分数: -6.2787
beam: Hi I am a very good([17250, 314, 716, 257, 845, 922]) 分数: -6.7408

第 3 步:
候选项: Hi I am a big fan of([17250, 314, 716, 257, 1263, 4336, 286]) 分数: -4.3084
候选项: Hi I am a big fan and([17250, 314, 716, 257, 1263, 4336, 290]) 分数: -8.1861
候选项: Hi I am a big fan.([17250, 314, 716, 257, 1263, 4336, 13]) 分数: -8.3988
候选项: Hi I am a little bit of([17250, 314, 716, 257, 1310, 1643, 286]) 分数: -8.6324
候选项: Hi I am a little bit worried([17250, 314, 716, 257, 1310, 1643, 7960]) 分数: -9.4857
候选项: Hi I am a little bit older([17250, 314, 716, 257, 1310, 1643, 4697]) 分数: -9.5333
候选项: Hi I am a very good person([17250, 314, 716, 257, 845, 922, 1048]) 分数: -9.3998
候选项: Hi I am a very good friend([17250, 314, 716, 257, 845, 922, 1545]) 分数: -9.8805
候选项: Hi I am a very good student([17250, 314, 716, 257, 845, 922, 3710]) 分数: -10.3733

选择的 beam:
beam: Hi I am a big fan of([17250, 314, 716, 257, 1263, 4336, 286]) 分数: -4.3084
beam: Hi I am a big fan and([17250, 314, 716, 257, 1263, 4336, 290]) 分数: -8.1861
beam: Hi I am a big fan.([17250, 314, 716, 257, 1263, 4336, 13]) 分数: -8.3988
"""

4 [LLM + RL] kimi 1.5 论文导读与 highlights

https://github.com/chunhuizhang/llm_rl/blob/main/tutorials/r1-k1.5/k1.5.ipynb

  • k1.5 和 R1
    • 中文 reasoning models:更好的中文理解及运用,都是国产之光;
      • openai reasoning llms 路线的探索和复现者
      • 即如何利用 RL 更好地激化 LLMs 的长链推理能力;
    • k1.5 应该是第一个次 kimi 发布的 technical report
    • k1.5 和 R1 可以对比着看,互相补充,交叉验证一些内容
      • k1.5 的细节更为丰富,全面;
      • 都舍去了 mcts、value function、prm(process reward models),追求 simple & scaling;
        • Simplistic Framework.

(舍去的东西未必是不好的)

  • test cases
    • Using the numbers {1,3,5,37}, create an equation that equals {24}. You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.(无解的问题?其实像很多不能通过穷举解决的问题CoT就很难做,就是动态规划哪些,但是本身这是程序的事情,并不是人类可以用思考解决的)
      • Alright, ... Let me think ... Hmm, ... Alternatively, ... Wait,
    • 用忆秦娥的词牌,创作一手乡愁主题的词。
      • 不是简单的语言问题,还包括很多很难的约束,平仄、押韵、重复等等;
    • 学校组织出游,班长带了不超过150包的湿巾纸,如果40个人平均分则多7包,25个人平均分则多2包。问班长共带了多少包湿巾纸?
      • 127:127 = 40*3 + 7;127 = 25*5 + 2(这个已经很简单了,都能做出来)

1 reasoning models

  • think aloud:思考的具象化(这是个很好的概念)
    • thinking process
    • thinking steps
    • thinking tokens
  • long context, long cot;
  • emergence
    • planning, evaluation, reflection, exploration

2 data

数据从哪里来;
数据质量的定义;

  • data quality
    • Diverse Coverage;
    • Balanced Difficulty;
    • Accurate Evaluability; => reward design;
      • rule-based reward design (vs. RLHF 中的 learned/neural reward modeling)
      • r ( x , y , y ⋆ ) r(x,y,y^\star) r(x,y,y)
  • 数据从哪里来(非常非常 engineering)
    • data triplets
      • questions
        • we employ automatic filters to select questions that require rich reasoning and are straightforward to evaluate.
      • cot
      • answer

RL算法

  • PG (Policy gradient)
    • TRPO, PPO (GRPO)

K1.5用的还跟R1不太一样

改进的策略优化

pg => PPO;

E τ ∼ π θ [ ∇ θ log ⁡ π θ ( τ ) ⋅ R ( τ ) ] \mathbb E_{\tau\sim \pi_\theta}[\nabla_\theta \log\pi_\theta(\tau)\cdot R(\tau)] Eτπθ[θlogπθ(τ)R(τ)]

  • On-policy: the agent learned and the agent interacting with the env is same;
    • 边实践边学习;数据利用率低;
    • policy-gradient (on-policy)
  • Off-policy: the agent learned and the agent interacting with the env is different;
    • 观察他人学习;数据利用率高;
      E τ ∼ π θ ′ [ ∇ θ log ⁡ π θ ( τ ) ⋅ R ( τ ) π θ ( τ ) π θ ′ ( τ ) ] \mathbb E_{\tau\sim \pi_{\theta'}}\left[\nabla_\theta \log\pi_\theta(\tau)\cdot R(\tau)\frac{\pi_\theta(\tau)}{\pi_{\theta'}(\tau)}\right] Eτπθ[θlogπθ(τ)R(τ)πθ(τ)πθ(τ)]

5 [LLM+RL] R1 论文导读,SFT vs. RL,RL 基础以及 GRPO 细节,以及一系列复现工作讨论

提纲:

  • DeepSeekMath

    • Data Pipeline
    • GRPO
  • R1

    • R1-Zero (AlphaZero-style)
    • distillation
  • kimi-1.5: first tech report

    • 与R1交叉验证,互相补充
    • Long2short Distillation
  • S1

    • 类比DeepSeek蒸馏的过程
  • Reasoning Models (Think aloud)

    • reasoning tokens/steps/processes
      • Long-CoT
    • Expert CoT => Learned CoT
  • Data Curation

  • SFT v.s. RL

    • teach v.s. incentive(这个意思是说LLM不要去教它,而是去诱导和刺激它)
    • RL hightlights & Roadmap
  • Unified Paradigm

    • PG: SFT, TRPO, DPO, GRPO
  • 复现:数据,算法,infra(架构)

在这里插入图片描述
思考的时间越多,就会有更高的准确率(更多的思考时间,之前OpenAI也有这个论断),下图就是说明(这个是R1的论文,直接超越o1)

在这里插入图片描述
在这里插入图片描述
同理思考的越多,输出的长度越长,下面KIMI的论文有类似的结论:

在这里插入图片描述
总之:思考越久=>输出越长=>正确率越高

LLM在推理时,如果只是一锤子买卖,一条路推下去,很容易是错的,因此在强化学习中,context window要足够大

下面是一个猜字谜的例子:

在这里插入图片描述

现在就是不断验证尝试,直到达到一个reward model给到的一个score高的解

这里会有一个问题,就是找麦田里最大麦穗的问题,何时停止?是否回溯用之前的解?这可能是一个研究方向。

然后这种反思机制的prompt template如下:

在这里插入图片描述

然后第二轮对话输入的只是第一轮的输入和输出,隐去了CoT,因为CoT太长了。


然后到第二张PPT

在这里插入图片描述

这里是OpenAI o1发布会的观点👆

复杂问题 不应该 像 简单问题 消耗 同样的 计算量

有传言OpenAI雇佣了名校博士去标注Reasoning过程(时薪几百刀)

如果人类足够团结,比如找一百万个专家,每人贡献100条推理路径,这个数据量应该足以让模型拟人了。

但是让LLM用RL自己去打磨更好,OpenAI o1的研究员说LLM打磨的CoT比人类标注的还要好。


几篇与Data curation相关的paper
在这里插入图片描述
DeepSeekMath主要是两个重点:

  • 种子数据集的构建
  • GRPO

在这里插入图片描述

种子数据集可以是GSM8K

  1. FastText Model:判断问题是否是数学题
  2. 然后去爬数据集
  3. 标注,一旦标注一个网站是数学题,该URL下所有的resource都认为是数学题

追求一种domain的相关性,并且很快的构建起来

这是一个system engineering的工作,其实是有很强的参考价值的。

之后这期还有一些关于GRPO的讨论,GRPO目前实现上有一些争议,主要是在于它看起来很简单,但是需要三重循环,可能效率不见得很高


6 [LLM+RL] 理解 GRPO 公式原理及 TRL GrpoTrainer 代码实现(advantage 与 loss 计算)

https://github.com/chunhuizhang/llm_rl/blob/main/tutorials/r1-k1.5/trl_grpo.ipynb

RL roadmap

  • 基本概念: on-policy vs. off-policy
    • 理解复杂的公式( π θ , π θ o l d , π r e f \pi_\theta,\pi_{\theta_{old}}, \pi_{ref} πθ,πθold,πref),理解计算过程;
    • 对哪个概率分布进行采样获得数据;
  • GRPO <= PPO(CLIP) <= TRPO <= PG (policy gradient)
    • PPO: GAE (Generalized Advantage Estimation), TD-error

∇ θ J ( π θ ) = E τ ∼ ( π θ , T ) [ ∑ t = 0 T ∇ θ log ⁡ π θ ( a t ∣ s t ) R ( τ ) ] \nabla_\theta J(\pi_\theta)=\mathbb E_{\tau \sim (\pi_\theta, T)}\left[\sum_{t=0}^T\nabla_\theta \log\pi_\theta(a_t|s_t)R(\tau)\right] θJ(πθ)=Eτ(πθ,T)[t=0Tθlogπθ(atst)R(τ)]

grpo demo

dataset: 7473;

https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb#file-grpo_demo-py

  • 单 4090:
    export CUDA_VISIBLE_DEVICES=0
    python grpo_demo.py
    
    • 4/4 * 4 => 4,
      • 1868;
    • 13 小时;
  • 双 4090,ddp (accelerate map)
    • (4/4 * 4 * 2) => 8;
      • 934;
    • < 10小时;
  • 双 4090,(accelerate config)
    • deepspeed stage-2/3;
    • fsdp

GRPOConfig & GRPOTrainer

  • GRPOConfig
    • num_generations=8,
    • old
      • per_device_train_batch_size=1, * gradient_accumulation_steps=8,
        • per_device_train_batch_size * gradient_accumulation_steps * world_size ==> train_batch
    • new: https://github.com/huggingface/trl/pull/2776#issue-2833772774
      • per_device_train_batch_size:
        • it now represents the number of generations per device.
      • per_device_train_batch_size/num_generations * gradient_accumulation_steps
        • per_device_train_batch_size/num_generations: prompts per device
        • 也因此要求,per_device_train_batch_size 必须能被 num_generations 整除;

GRPO

  • GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training.
  • four main steps:
    • Generating completions,
      • At each training step, we sample a batch of prompts and generate a set of G G G completions(num_generations) for each prompt (denoted as o i o_i oi).
    • computing the advantage,
      • A ^ i , t = r i − μ ( r ) σ ( r ) \hat A_{i,t}=\frac{r_i-\mu(\mathbf r)}{\sigma(\mathbf r)} A^i,t=σ(r)riμ(r)
      • Outcome supervision provides the normalized reward at the end of each output o i o_i oi and sets the advantages A ^ i , t \hat A_{i,t} A^i,t of all tokens in the output as the normalized reward
    • estimating the KL divergence, (token-level see the figure)
      • https://huggingface.co/docs/trl/main/en/grpo_trainer#estimating-the-kl-divergence
      • per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
    • and computing the loss.
      • π r e f , ( π θ o l d , π θ ) \pi_{ref}, (\pi_{\theta_{old}}, \pi_\theta) πref,(πθold,πθ)
      • https://github.com/huggingface/trl/issues/2608
# x - x.detach() allows for preserving gradients from x
per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
per_token_loss = -(per_token_loss - self.beta * per_token_kl)
loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()

pipeline

  • grpo_trainer.py
    • _prepare_inputs()
    return {
        "prompt_ids": prompt_ids,
        "prompt_mask": prompt_mask,
        "completion_ids": completion_ids,
        "completion_mask": completion_mask,
        "ref_per_token_logps": ref_per_token_logps,
        "advantages": advantages,
    }
    
    • compute_loss()

qwen2.5 vocab size: 151936

  • prompts => tokenizer.apply_chat_template => model.generate

    • prompt_completion_ids = prompt_ids + prompt_completion_ids
    • rewards_per_func: (n_prompts, len(reward_funcs))
    • ref_per_token_logps( π r e f ( q , o ) \pi_{ref}(q, o) πref(q,o)), per_token_logps ( π θ ( q , o ) \pi_\theta(q,o) πθ(q,o))
      • completion token level
        • selective_log_softmax(logits, index)
          • logits.shape: (n_prompts, n_completion, n_vocab)
          • index.shape: (n_prompts, n_complection)
          • => (n_prompts, n_complection)
            exp ⁡ ( log ⁡ π ′ − log ⁡ π ) = π ′ π \exp(\log{\pi'}-\log{\pi})=\frac{\pi'}{\pi} exp(logπlogπ)=ππ
  • 目前的实现只有 π θ , π r e f \pi_{\theta}, \pi_{ref} πθ,πref,没有 π θ o l d \pi_{\theta_{old}} πθold

    • https://github.com/huggingface/trl/issues/2608
      • The policy model only has a single update following each exploration stage. (deepseekmath)
      • π θ \pi_\theta πθ 每次(rollout a group generations)只进行一次更新,而不是多次更新;
        • 对应 for GRPO iteration = 1, . . . , 𝜇 (𝜇 == 1)
    • π θ o l d = π θ \pi_{\theta_{old}}=\pi_\theta πθold=πθ
    • torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
      • per_token_logps.detach() 不参与计算图的梯度计算;
    • 没有用到 clip,只有 π θ π θ o l d A = 1 ⋅ A \frac{\pi_\theta}{\pi_{\theta_{old}}}A=1\cdot A πθoldπθA=1A(ratio * advantage)
      • ratio = 1,一定在 ( 1 − ϵ , 1 + ϵ ) (1-\epsilon, 1+\epsilon) (1ϵ,1+ϵ) 之间的;

training monitor

  • You should rely mostly on the reward. And keep an eye on the generations (risk of reward hacking)
    • https://github.com/huggingface/trl/issues/2703

7 [LLM+RL] GRPO 中的 KL div(散度),reverse vs. forward,以及无偏估计(Schulman)

https://www.bilibili.com/video/BV1rsAXe7EZs

毋庸置疑,就GRPO中的公式而言,KL散度在其中是reverse的

在这里插入图片描述

J G R P O ( θ ) = E q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ min ⁡ ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ) − β D K L ( π θ ∣ ∣ π r e f ) ] (3) \mathcal{J}_{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right) - \beta D_{KL} (\pi_\theta || \pi_{ref}) \right] \tag{3} JGRPO(θ)=EqP(Q),{oi}i=1Gπθold(Oq) G1i=1Goi1t=1oimin(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)A^i,t,clip(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t),1ε,1+ε)A^i,t)βDKL(πθ∣∣πref) (3)

https://timvieira.github.io/blog/post/2014/10/06/kl-divergence-as-an-objective-function

  • The reference GRPO implementation uses the reverse KL divergence, not the forward KL divergence.
    在这里插入图片描述
    源码里算loss的时候是做了一个clipped的advantage,然后加上β乘以一个KL散度
  • https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py(trl库最新版本已经把grpo给实现掉了)
    • L = 1 n ∑ β D K L ( q ∥ p ) + A L=\frac1n\sum\beta D_{KL}(q\|p)+A L=n1βDKL(qp)+A
    • q q q is the new trained model, and p p p is the original reference model.(q就是一个带参数的分布,需要迅雷)
    • p p p 待逼近的分布;
  • grpo
    • q ( x ) = π θ ( o i , t ∣ q , o i , < t ) q(x) = \pi_{\theta}(o_{i,t}|q,o_{i,<t}) q(x)=πθ(oi,tq,oi,<t)
    • p ( x ) = π r e f ( o i , t ∣ q , o i , < t ) p(x) = \pi_{ref}(o_{i,t}|q,o_{i,<t}) p(x)=πref(oi,tq,oi,<t)

J G R P O ( θ ) = E q ∼ P ( Q ) , { o i } i = 1 G ∼ π θ o l d ( O ∣ q ) [ 1 G ∑ i = 1 G 1 ∣ o i ∣ ∑ t = 1 ∣ o i ∣ min ⁡ ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) A ^ i , t , clip ( π θ ( o i , t ∣ q , o i , < t ) π θ o l d ( o i , t ∣ q , o i , < t ) , 1 − ε , 1 + ε ) A ^ i , t ) − β D K L ( π θ ∣ ∣ π r e f ) ] (3) \mathcal{J}_{GRPO}(\theta) = \mathbb{E}_{q \sim P(Q), \{o_i\}_{i=1}^G \sim \pi_{\theta_{old}}(O|q)} \left[ \frac{1}{G} \sum_{i=1}^G \frac{1}{|o_i|} \sum_{t=1}^{|o_i|} \min \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})} \hat{A}_{i,t}, \text{clip} \left( \frac{\pi_\theta(o_{i,t}|q, o_{i,<t})}{\pi_{\theta_{old}}(o_{i,t}|q, o_{i,<t})}, 1-\varepsilon, 1+\varepsilon \right) \hat{A}_{i,t} \right) - \beta D_{KL} (\pi_\theta || \pi_{ref}) \right] \tag{3} JGRPO(θ)=EqP(Q),{oi}i=1Gπθold(Oq) G1i=1Goi1t=1oimin(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t)A^i,t,clip(πθold(oi,tq,oi,<t)πθ(oi,tq,oi,<t),1ε,1+ε)A^i,t)βDKL(πθ∣∣πref) (3)

D K L [ π θ ∣ ∣ π r e f ] = π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − log ⁡ π r e f ( o i , t ∣ q , o i , < t ) π θ ( o i , t ∣ q , o i , < t ) − 1 \mathbb{D}_{KL}[\pi_{\theta}||\pi_{ref}] = \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta}(o_{i,t}|q, o_{i,<t})} - \log \frac{\pi_{ref}(o_{i,t}|q, o_{i,<t})}{\pi_{\theta}(o_{i,t}|q, o_{i,<t})} - 1 DKL[πθ∣∣πref]=πθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)logπθ(oi,tq,oi,<t)πref(oi,tq,oi,<t)1

这个KL散度是一个无偏、且低方差的估计👆

John Schulman’s blog

http://joschu.net/blog/kl-approx.html,博客

r = p ( x ) q ( x ) r=\frac{p(x)}{q(x)} r=q(x)p(x)

第一种估计: − log ⁡ r = log ⁡ q ( x ) p ( x ) -\log r=\log \frac{q(x)}{p(x)} logr=logp(x)q(x)

k l ( q ∥ p ) = ∑ x q ( x ) log ⁡ q ( x ) p ( x ) = E x ∼ q [ log ⁡ q ( x ) p ( x ) ] kl(q\|p)=\sum_x q(x)\log \frac{q(x)}{p(x)}=\mathbb E_{x\sim q}\left[\log \frac{q(x)}{p(x)}\right] kl(qp)=xq(x)logp(x)q(x)=Exq[logp(x)q(x)]

  • it has high-variance, as it’s negative for half of the samples, whereas KL is always positive.
    • 采样点落在 q ( x ) < p ( x ) q(x)\lt p(x) q(x)<p(x) 时,

为什么说 log ⁡ q ( x ) p ( x ) \log \frac{q(x)}{p(x)} logp(x)q(x)会有大约一半的样本得出来的结果是负的呢?下面是一个验证过程:

import torch.distributions as dis
import torch
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
torch.sum((q.log_prob(x) - p.log_prob(x)) < 0) # tensor(4799598)

虽然无偏,但方差过大

第二种估计: 1 2 ( log ⁡ r ) 2 \frac12(\log r)^2 21(logr)2

这个比较复杂,可以自行查看John Schulman的博客内容:http://joschu.net/blog/kl-approx.html

第三种估计: r − 1 − log ⁡ r r-1-\log r r1logr

这也是GRPO采取的一种估计方式

  • k l ( q ∥ p ) = r − 1 − log ⁡ r kl(q\|p)=r-1-\log r kl(qp)=r1logr (without q q q)
  • 保证非负 log ⁡ x ≤ x − 1 \log x\leq x-1 logxx1,这点很重要,因为KL散度总是非负的,下面是一个证明:

D K L ( q ∥ p ) = ∑ q log ⁡ q p = − ∑ q log ⁡ p q = 1 − ∑ q log ⁡ p q − 1 = ∑ p − ∑ q log ⁡ p q − ∑ q = ∑ q p q − ∑ q log ⁡ p q − ∑ q = ∑ q [ p q − log ⁡ p q − 1 ] \begin{split} D_{KL}(q\|p)&=\sum q\log \frac{q}p\\ &=-\sum q\log\frac pq\\ &=1-\sum q\log\frac pq-1\\ &=\sum p-\sum q\log\frac pq-\sum q\\ &=\sum q\frac{p}q-\sum q\log\frac pq-\sum q\\ &=\sum q\left[\frac{p}{q}-\log\frac pq-1\right] \end{split} DKL(qp)=qlogpq=qlogqp=1qlogqp1=pqlogqpq=qqpqlogqpq=q[qplogqp1]

import matplotlib.pyplot as plt
import numpy as np
xs = np.arange(0.01, 5, 0.01)
plt.plot(np.log(xs), label=r'$\log x$')
plt.plot(xs-1, label='$x-1$')
plt.legend()

log ⁡ x < = x − 1 \log x <= x - 1 logx<=x1的图示
在这里插入图片描述

import torch.distributions as dis
p = dis.Normal(loc=0, scale=1)
q = dis.Normal(loc=0.1, scale=1)
x = q.sample(sample_shape=(10_000_000,))
truekl = dis.kl_divergence(q, p)
print("true", truekl)
logr = p.log_prob(x) - q.log_prob(x)
k1 = -logr
k2 = logr ** 2 / 2
k3 = (logr.exp() - 1) - logr
for k in (k1, k2, k3):
    print(k.mean(), (k.mean() - truekl) / truekl, k.std() / truekl)

上面的代码中truekl = dis.kl_divergence(q, p)应该是写错了,它应该是一个reverse的,而非forward的,因此应该是倒过来(KL散度并不对称,qp和pq的值完全不一样)

输出结果:

true tensor(0.0050)
tensor(0.0050) tensor(0.0031) tensor(19.9963)
tensor(0.0050) tensor(0.0021) tensor(1.4167)
tensor(0.0050) tensor(-0.0004) tensor(1.4153)

Reverse vs Forward KL

  • Integrate along the given axis using the composite trapezoidal rule.
def approx_kl(gmm_1, gmm_2, xs):
    ys = gmm_1.pdf(xs) * (gmm_1.logpdf(xs) - gmm_2.logpdf(xs))
    return np.trapz(ys, xs)
  • https://www.tuananhle.co.uk/notes/reverse-forward-kl.html
    • p p p(目标分布)出发,因为它是确定的,优化的目标是参数化的 q ϕ q_\phi qϕ
    • Reverse KL: Zero-Forcing/Mode-Seeking
      • q log ⁡ q p q\log\frac qp qlogpq(p接近0的时候q也要接近0,不然就无穷大了,这就是zero-forcing的性质,mode-seeking则是zero-forcing的一个推断,但不必然)
        • forces q ϕ q_\phi qϕ to be zero where p p p is zero
          • zero-forcing => mode seeking (不必然)
        • not always mode-seeking (subplot 2/3)
    • Forward KL: Mass-Covering/Mean-Seeking
      • p log ⁡ p q p\log\frac pq plogqp(p有值时,q就不能是一个接近0的值,否则就无穷大了,这就是mass-covering,因此就要把所有p有值的地方都拟合出来,这就是mean-seeking)
        • there is some mass under q ϕ q_\phi qϕ wherever there is some mass under p p p
        • q q q zero avoiding:避免出现 0;

在这里插入图片描述
上图中,实线是p,即一个混合高斯分布,短的虚线是reverse的一个KL散度,长的虚线是forwar的一个KL散度

看最后两个图,典型的是一个mode-seeking,即拟合到双方的分布的模式上,但这不必然,比如看第3个图,前向和反向是基本相同的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2306121.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【论文精读】YOLO-World:实时开放词汇目标检测

论文地址&#xff1a; YOLO-World: Real-Time Open-Vocabulary Object Detection 源代码&#xff1a;YOLO-World 摘要 YOLO系列检测器因其高效性和实用性而被广泛认可。然而&#xff0c;它们依赖于预定义和训练过的物体类别&#xff0c;这限制了其在开放场景中的适用性。为了…

【AI时代】可视化训练模型工具LLaMA-Factory安装与使用

文章目录 安装训练使用 安装 官方地址&#xff1a;https://github.com/hiyouga/LLaMA-Factory 创建虚拟环境 conda create -n llama-factory conda activate llama-factory安装 git clone --depth 1 https://github.com/hiyouga/LLaMA-Factory.git cd LLaMA-Factory pip in…

将产品照片(form.productPhotos)转为 JSON 字符串发送给后端

文章目录 1. 前端 form.productPhotos 的当前处理a. 组件绑定b. 当前发送逻辑 2. 如何将 form.productPhotos 转为 JSON 字符串发送给后端a. 修改前端 save() 方法b. 确保 esave API 支持接收字符串 基于你提供的 identify-form.vue 代码&#xff0c;我将分析如何将产品照片&a…

【科研绘图系列】R语言绘制小提琴图、散点图和韦恩图(violin scatter plot Venn)

禁止商业或二改转载,仅供自学使用,侵权必究,如需截取部分内容请后台联系作者! 文章目录 介绍加载R包数据下载画图1画图2画图3画图4画图5画图6画图7参考介绍 【科研绘图系列】R语言绘制小提琴图、散点图和韦恩图(violin & scatter plot & Venn) 加载R包 library…

kotlin 知识点一 变量和函数

在Kotlin中定义变量的方式和Java 区别很大&#xff0c;在Java 中如果想要定义一个变 量&#xff0c;需要在变量前面声明这个变量的类型&#xff0c;比如说int a表示a是一个整型变量&#xff0c;String b表 示b是一个字符串变量。而Kotlin中定义一个变量&#xff0c;只允许在变量…

solidity之Foundry安装配置(一)

一门面向合约的高级编程语言&#xff0c;主要用来编写以太坊只能合约。 Solidity受C语言&#xff0c;Python和js影响&#xff0c;但为编译成为以太坊虚拟机字节码在EVM上执行&#xff0c;很多特性和限制都和EVM相关。 Solidity 是静态类型语言&#xff0c;支持继承、库、自定义…

PHP-create_function

[题目信息]&#xff1a; 题目名称题目难度PHP-create_function2 [题目考点]&#xff1a; create_function ( string args , string args , string code )[Flag格式]: SangFor{wWx5dEGHHhDUwmST4bpXwfjSzq43I6cz}[环境部署]&#xff1a; docker-compose.yml文件或者docker …

FFmpeg 是什么?为什么?怎么用?

摘要&#xff1a;本文介绍了 FFmpeg&#xff0c;一个功能强大的开源多媒体处理工具&#xff0c;广泛应用于视频和音频文件的处理。FFmpeg 支持多种多媒体格式&#xff0c;能够实现视频编码/解码、格式转换、裁剪、合并、音频提取、流媒体处理等功能。本文详细阐述了 FFmpeg 的主…

云计算及其他计算

云计算知识思维导图&#xff1a;https://kdocs.cn/l/cpl2Kizx7IyC 云计算的核心判断标准通常基于美国国家标准与技术研究院&#xff08;NIST&#xff09;的定义&#xff0c;并结合实际应用场景。以下是判断一个服务是否为云计算的关键标准&#xff0c;以及对应的服务类型&#…

前端Toast提示快速入门

White graces&#xff1a;个人主页 &#x1f649;专栏推荐:Java入门知识&#x1f649; &#x1f439;今日诗词:十年一觉扬州梦&#xff0c;赢得青楼薄幸名&#x1f439; ⛳️点赞 ☀️收藏⭐️关注&#x1f4ac;卑微小博主&#x1f64f; ⛳️点赞 ☀️收藏⭐️关注&#x1f4…

垂类大模型微调(一):认识LLaMA-Factory

LlamaFactory 是一个专注于 高效微调大型语言模型(LLMs) 的开源工具框架,尤其以支持 LLaMA(Meta 的大型语言模型系列)及其衍生模型(如 Chinese-LLaMA、Alpaca 等)而闻名。它的目标是简化模型微调流程,降低用户使用门槛; 官方文档 一、介绍 高效微调支持 支持多种微调…

Pytorch实现论文:基于多尺度融合生成对抗网络的水下图像增强

简介 简介:提出了一种新型的水下图像增强算法,基于多尺度融合生成对抗网络,名为UMSGAN,以解决低对比度和颜色失真的问题。首先经过亮度的处理,将处理后的图像输入设计的MFFEM模块和RM模块生成图像。该算法旨在适应各种水下场景,提供颜色校正和细节增强。 论文题目:Und…

从单片机的启动说起一个单片机到点灯发生了什么下——使用GPIO点一个灯

目录 前言 HAL库对GPIO的抽象 核心分析&#xff1a;HAL_GPIO_Init 前言 我们终于到达了熟悉的地方&#xff0c;对GPIO的初始化。经过漫长的铺垫&#xff0c;我们终于历经千辛万苦&#xff0c;来到了这里。关于GPIO的八种模式等更加详细的细节&#xff0c;由于只是点个灯&am…

基于大语言模型的推荐系统(1)

推荐系统&#xff08;recommendation system&#xff09;非常重要。事实上&#xff0c;搜索引擎&#xff0c;电子商务&#xff0c;视频&#xff0c;音乐平台&#xff0c;社交网络等等&#xff0c;几乎所有互联网应用的核心就是向用户推荐内容&#xff0c;商品&#xff0c;电影&…

计算机毕业设计SpringBoot+Vue.js新闻推荐系统(源码+文档+PPT+讲解)

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

Android 布局系列(一):LinearLayout 使用指南

引言 在 Android 开发中&#xff0c;布局是每个应用的基础&#xff0c;而 LinearLayout 无疑是最常见、最简单的布局之一。它允许我们将多个视图按顺序排列&#xff0c;可以选择水平方向&#xff08;horizontal&#xff09;或垂直方向&#xff08;vertical&#xff09;。 Line…

Android级联选择器,下拉菜单

近期android开发&#xff0c;遇到的需求&#xff0c;分享二个android可能用到的小组件 下拉选择器&#xff1a;它的实现&#xff0c;主要是需要监听它依附的组件当前距离屏幕顶端的位置。 在显示下拉菜单中&#xff0c;如果需要点击上面有响应。可通过activity拿到decorview(ac…

【每日八股】MySQL篇(一):概述

关系的三个范式是什么&#xff1f; 第一范式&#xff08;1NF&#xff09;&#xff1a;用来确保每列的原子性&#xff0c;要求每列都是不可再分的最小数据单元。 概括&#xff1a;表中的每一列都是不可分割的最小原子值&#xff0c;且每一行都是唯一的。 第二范式&#xff08…

Remainder Problem CF1207F

题目&#xff1a;题目链接 题目大意 题目描述 给你一个长度为 500000 的序列&#xff0c;初值为 0 &#xff0c;你要完成 q 次操作&#xff0c;操作有如下两种&#xff1a; 1 x y : 将下标为 x 的位置的值加上 y2 x y : 询问所有下标模 x 的结果为 y 的位置的值之和 输入格…

SpringBoot之自定义简单的注解和AOP

1.引入依赖 <!-- AOP依赖--> <dependency><groupId>org.aspectj</groupId><artifactId>aspectjweaver</artifactId><version>1.9.8</version> </dependency>2.自定义一个注解 package com.example.springbootdemo3.an…