VLM(视觉语言模型)与DeepSeek R1(奖励机制)如何结合

news2025/2/25 22:08:40

VLM(视觉语言模型)与DeepSeek R1(奖励机制)如何结合

flyfish

VLM的传统训练依赖于监督学习(直接拟合问答对),而规则奖励函数通常用于强化学习(通过试错和奖励反馈优化策略)。这两种方式如何结合?

源码来自
VLM-R1/src/open-r1-multimodal/src/open_r1/grpo_rec.py

# 导入 debugpy 库,用于调试,当前代码中被注释掉,若需要调试可取消注释
# import debugpy
# try:
#     # 5678 是 VS Code 调试配置中的默认附加端口。除非指定主机和端口,否则主机默认为 127.0.0.1
#     debugpy.listen(("localhost", 9501))
#     print("Waiting for debugger attach")
#     debugpy.wait_for_client()
# except Exception as e:
#     pass

# 导入操作系统相关功能的库
import os
# 导入正则表达式库,用于字符串匹配和处理
import re
# 导入日期时间处理库
from datetime import datetime
# 导入数据类装饰器和字段定义类,用于定义数据类
from dataclasses import dataclass, field
# 导入可选类型注解,用于表示某个参数可以为 None
from typing import Optional

# 导入 Pillow 库中的 Image 类,用于处理图像
from PIL import Image
# 导入 PyTorch 中的数据集基类
from torch.utils.data import Dataset
# 导入 Qwen2VL 条件生成模型
from transformers import Qwen2VLForConditionalGeneration

# 导入自定义的数学验证模块中的解析和验证函数
from math_verify import parse, verify
# 导入自定义的 Qwen2VLGRPOTrainer 类
from open_r1.trainer import Qwen2VLGRPOTrainer
# 导入 TRL 库中的 GRPO 配置、训练器、模型配置、脚本参数、解析器和 PEFT 配置获取函数
from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
# 导入 Transformers 库中的训练参数类
from transformers import TrainingArguments
# 导入 YAML 文件处理库
import yaml
# 导入 JSON 文件处理库
import json
# 导入随机数生成库
import random
# 导入数学计算库
import math

# ----------------------- 修复当前版本 transformers 中的 flash attention 错误 -----------------------
# 导入 Qwen2_5_VL 模型中的相关类和函数
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLVisionFlashAttention2, apply_rotary_pos_emb_flashatt, flash_attn_varlen_func
# 导入 PyTorch 库
import torch
# 导入元组类型注解
from typing import Tuple

# 自定义 Qwen2_5_VLVisionFlashAttention2 类的前向传播函数
def custom_forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
    # 获取隐藏状态的序列长度
    seq_length = hidden_states.shape[0]
    # 通过 qkv 层得到查询、键、值张量,并进行形状调整和维度置换
    q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
    # 如果没有提供位置嵌入,则根据旋转位置嵌入计算余弦和正弦值
    if position_embeddings is None:
        # 打印一次警告信息,提示 RoPE 嵌入计算方式的变化
        logger.warning_once(
            "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
            "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
            "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
            "removed and `position_embeddings` will be mandatory."
        )
        # 拼接旋转位置嵌入
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        # 计算余弦值
        cos = emb.cos().float()
        # 计算正弦值
        sin = emb.sin().float()
    else:
        # 从位置嵌入中获取余弦和正弦值
        cos, sin = position_embeddings
        # 将余弦值转换为浮点类型
        cos = cos.to(torch.float)
        # 将正弦值转换为浮点类型
        sin = sin.to(torch.float)
    # 应用旋转位置嵌入到查询和键张量
    q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
    # 去除查询张量的额外维度
    q = q.squeeze(0)
    # 去除键张量的额外维度
    k = k.squeeze(0)

    # 计算最大序列长度
    max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
    # 调用 flash 注意力函数计算注意力输出
    attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
        seq_length, -1
    )
    # 通过投影层得到最终的注意力输出
    attn_output = self.proj(attn_output)
    return attn_output

# 将自定义的前向传播函数赋值给 Qwen2_5_VLVisionFlashAttention2 类的 forward 方法
Qwen2_5_VLVisionFlashAttention2.forward = custom_forward


# ----------------------- 主脚本 -----------------------
# 定义 GRPOScriptArguments 数据类,继承自 ScriptArguments
@dataclass
class GRPOScriptArguments(ScriptArguments):
    """
    用于 GRPO 训练脚本的脚本参数。

    参数:
        reward_funcs (`list[str]`):
            奖励函数列表。可能的值: 'accuracy', 'format'。
    """

    # 奖励函数列表,默认包含 'accuracy' 和 'format'
    reward_funcs: list[str] = field(
        default_factory=lambda: ["accuracy", "format"],
        metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format'"}
    )
    # 图像的最大像素数,默认为 12845056
    max_pixels: Optional[int] = field(
        default=12845056,
        metadata={"help": "Maximum number of pixels for the image"}
    )
    # 图像的最小像素数,默认为 3136
    min_pixels: Optional[int] = field(
        default=3136,
        metadata={"help": "Minimum number of pixels for the image"}
    )
    # 图像的根目录,默认为 None
    image_root: Optional[str] = field(
        default=None,
        metadata={"help": "Root directory of the image"}
    )

# 定义系统提示信息,用于指导模型的对话生成
SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
    "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
    "process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
    "<think> reasoning process here </think><answer> answer here </answer>"
)

# 定义 LazySupervisedDataset 类,继承自 Dataset
class LazySupervisedDataset(Dataset):
    def __init__(self, data_path: str, script_args: GRPOScriptArguments):
        # 调用父类的构造函数
        super(LazySupervisedDataset, self).__init__()
        # 保存脚本参数
        self.script_args = script_args
        # 初始化数据字典列表
        self.list_data_dict = []

        # 如果数据文件是 YAML 格式
        if data_path.endswith(".yaml"):
            # 打开 YAML 文件
            with open(data_path, "r") as file:
                # 加载 YAML 数据
                yaml_data = yaml.safe_load(file)
                # 获取数据集列表
                datasets = yaml_data.get("datasets")
                # 文件格式应为:
                # datasets:
                #   - json_path: xxxx1.json
                #     sampling_strategy: first:1000
                #   - json_path: xxxx2.json
                #     sampling_strategy: end:3000
                #   - json_path: xxxx3.json
                #     sampling_strategy: random:999

                # 遍历每个数据集
                for data in datasets:
                    # 获取 JSON 文件路径
                    json_path = data.get("json_path")
                    # 获取采样策略,默认为 'all'
                    sampling_strategy = data.get("sampling_strategy", "all")
                    # 初始化采样数量为 None
                    sampling_number = None

                    # 如果 JSON 文件是 JSONL 格式
                    if json_path.endswith(".jsonl"):
                        # 初始化当前数据字典列表
                        cur_data_dict = []
                        # 打开 JSONL 文件
                        with open(json_path, "r") as json_file:
                            # 逐行读取文件
                            for line in json_file:
                                # 解析每行 JSON 数据并添加到当前数据字典列表
                                cur_data_dict.append(json.loads(line.strip()))
                    # 如果 JSON 文件是 JSON 格式
                    elif json_path.endswith(".json"):
                        # 打开 JSON 文件
                        with open(json_path, "r") as json_file:
                            # 加载 JSON 数据到当前数据字典列表
                            cur_data_dict = json.load(json_file)
                    else:
                        # 如果文件类型不支持,抛出异常
                        raise ValueError(f"Unsupported file type: {json_path}")

                    # 如果采样策略包含冒号
                    if ":" in sampling_strategy:
                        # 分割采样策略和采样数量
                        sampling_strategy, sampling_number = sampling_strategy.split(":")
                        # 如果采样数量包含百分比符号
                        if "%" in sampling_number:
                            # 计算采样数量
                            sampling_number = math.ceil(int(sampling_number.split("%")[0]) * len(cur_data_dict) / 100)
                        else:
                            # 将采样数量转换为整数
                            sampling_number = int(sampling_number)

                    # 应用采样策略
                    if sampling_strategy == "first" and sampling_number is not None:
                        # 取前 sampling_number 个样本
                        cur_data_dict = cur_data_dict[:sampling_number]
                    elif sampling_strategy == "end" and sampling_number is not None:
                        # 取后 sampling_number 个样本
                        cur_data_dict = cur_data_dict[-sampling_number:]
                    elif sampling_strategy == "random" and sampling_number is not None:
                        # 随机打乱样本
                        random.shuffle(cur_data_dict)
                        # 取前 sampling_number 个样本
                        cur_data_dict = cur_data_dict[:sampling_number]
                    # 打印从当前 JSON 文件加载的样本数量
                    print(f"Loaded {len(cur_data_dict)} samples from {json_path}")
                    # 将当前数据字典列表添加到总数据字典列表
                    self.list_data_dict.extend(cur_data_dict)
        else:
            # 如果文件类型不支持,抛出异常
            raise ValueError(f"Unsupported file type: {data_path}")

    def __len__(self):
        # 返回数据字典列表的长度
        return len(self.list_data_dict)

    def __getitem__(self, i):
        # 定义将示例转换为对话格式的函数
        def make_conversation(example):
            return {
                "prompt": [
                    {"role": "system", "content": SYSTEM_PROMPT},
                    {"role": "user", "content": example["problem"]}
                ]
            }

        # 问题模板,用于包含图像的对话
        QUESTION_TEMPLATE = "{Question} First output the thinking process in <think> </think> tags and then output the final answer in <answer> </answer> tags. Output the final answer in JSON format."

        # 定义将包含图像的示例转换为对话格式的函数
        def make_conversation_image(example):
            return {
                "prompt": [
                    # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                    {
                        "role": "user",
                        "content": [
                            {"type": "image"},
                            {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])}
                        ]
                    }
                ]
            }

        # 获取指定索引的示例
        example = self.list_data_dict[i]
        # 获取图像根目录
        image_root = self.script_args.image_root
        # 如果示例中包含图像信息
        if 'image' in example:
            # 构建图像路径
            image_path = os.path.join(image_root, example['image'])
            # 如果图像文件不存在
            while not os.path.exists(image_path):
                # 打印警告信息
                print(f"Warning: Image {image_path} not found, randomly selecting another image")
                # 随机选择一个新的索引
                new_index = random.randint(0, len(self.list_data_dict)-1)
                # 获取新的示例
                example = self.list_data_dict[new_index]
                # 构建新的图像路径
                image_path = os.path.join(image_root, example['image'])
            # 打开图像并转换为 RGB 格式
            image = Image.open(image_path).convert("RGB")
        else:
            # 如果示例中不包含图像信息,图像为 None
            image = None

        return {
            'image': image,
            'problem': example['problem'],
            'solution': example['solution'],
            'prompt': make_conversation_image(example)['prompt'] if 'image' in example else make_conversation(example)['prompt']
        }

'''
    如果模型预测的边界框与真实边界框的交并比(IoU)大于 0.5,则奖励为 1.0,否则为 0.0。
    这是一种硬奖励,未来可能使用软奖励会更好。
'''
def iou_reward(completions, solution, **kwargs):
    # 定义计算交并比的函数
    def iou(box1, box2):
        # 计算交集的左上角坐标
        inter_x1 = max(box1[0], box2[0])
        inter_y1 = max(box1[1], box2[1])
        # 计算交集的右下角坐标
        inter_x2 = min(box1[2]-1, box2[2]-1)
        inter_y2 = min(box1[3]-1, box2[3]-1)
        # 如果交集存在
        if inter_x1 < inter_x2 and inter_y1 < inter_y2:
            # 计算交集面积
            inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
        else:
            # 交集面积为 0
            inter = 0
        # 计算并集面积
        union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
        # 返回交并比
        return float(inter)/union

    # 获取完成内容列表
    contents = [completion[0]["content"] for completion in completions]
    # 初始化奖励列表
    rewards = []
    # 获取当前时间并格式化
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    # 定义答案标签的正则表达式模式
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    # 定义边界框的正则表达式模式
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    # 遍历完成内容和真实解决方案
    for content, sol in zip(contents, solution):
        # 初始化奖励为 0.0
        reward = 0.0
        # 尝试进行符号验证
        try:
            # 在完成内容中查找答案标签
            content_answer_match = re.search(answer_tag_pattern, content)
            if content_answer_match:
                # 获取答案内容
                content_answer = content_answer_match.group(1).strip()
                # 在答案内容中查找边界框
                bbox_match = re.search(bbox_pattern, content_answer)
                if bbox_match:
                    # 获取边界框坐标
                    bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                    # 如果交并比大于 0.5
                    if iou(bbox, sol) > 0.5:
                        # 奖励为 1.0
                        reward = 1.0
        except Exception:
            # 如果验证失败,继续下一个验证方法
            pass

        # 将奖励添加到奖励列表
        rewards.append(reward)
        # 如果处于调试模式
        if os.getenv("DEBUG_MODE") == "true":
            # 获取日志路径
            log_path = os.getenv("LOG_PATH")
            # 打开日志文件并追加记录
            with open(log_path, "a") as f:
                # 记录当前时间和奖励信息
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                # 记录完成内容
                f.write(f"Content: {content}\n")
                # 记录真实解决方案
                f.write(f"Solution: {sol}\n")
    return rewards


def format_reward(completions, **kwargs):
    """奖励函数,用于检查完成内容是否符合特定格式。"""
    # 定义格式的正则表达式模式
    # pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
    # 获取完成内容列表
    completion_contents = [completion[0]["content"] for completion in completions]
    # 检查每个完成内容是否符合格式
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    # 根据匹配结果生成奖励列表
    return [1.0 if match else 0.0 for match in matches]


# 奖励函数注册表,将奖励函数名称映射到对应的函数
reward_funcs_registry = {
    "accuracy": iou_reward,
    "format": format_reward,
}


def main(script_args, training_args, model_args):
    # 根据脚本参数中的奖励函数名称,从注册表中获取对应的奖励函数
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
    # 打印奖励函数列表
    print("reward_funcs:", reward_funcs)

    # 加载数据集
    dataset = LazySupervisedDataset(script_args.dataset_name, script_args)

    # 选择训练器类,这里使用自定义的 Qwen2VLGRPOTrainer
    trainer_cls = Qwen2VLGRPOTrainer
    # 初始化 GRPO 训练器
    trainer = trainer_cls(
        model=model_args.model_name_or_path,  # 模型名称或路径
        reward_funcs=reward_funcs,  # 奖励函数列表
        args=training_args,  # 训练参数
        train_dataset=dataset,  # 训练数据集
        eval_dataset=None,  # 评估数据集,这里设为 None
        peft_config=get_peft_config(model_args),  # PEFT 配置
        attn_implementation=model_args.attn_implementation,  # 注意力实现方式
        max_pixels=script_args.max_pixels,  # 图像最大像素数
        min_pixels=script_args.min_pixels,  # 图像最小像素数
        torch_dtype=model_args.torch_dtype,  # PyTorch 数据类型
    )

    # 开始训练模型
    trainer.train()

    # 保存模型到指定的输出目录
    trainer.save_model(training_args.output_dir)
    # 如果设置了将模型推送到 Hub
    if training_args.push_to_hub:
        # 将模型推送到 Hub,并指定数据集名称
        trainer.push_to_hub(dataset_name=script_args.dataset_name)


if __name__ == "__main__":
    # 创建 TrlParser 对象,用于解析脚本参数、训练配置和模型配置
    parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
    # 解析命令行参数和配置
    script_args, training_args, model_args = parser.parse_args_and_config()
    # 调用主函数开始训练
    main(script_args, training_args, model_args)

代码中的两个关键奖励函数 format_rewardiou_reward

1. 格式奖励函数 format_reward

函数定义和功能
def format_reward(completions, **kwargs):
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?\{.*\[\d+,\s*\d+,\s*\d+,\s*\d+\].*\}.*?</answer>"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

此函数的主要功能是检查模型生成的完成内容是否符合特定的格式要求。具体来说,它期望模型的输出满足以下格式:

  • 包含 <think></think> 标签,用于包裹思考过程。
  • 包含 <answer></answer> 标签,用于包裹答案。
  • 答案部分需要是一个 JSON 格式,并且其中包含一个由四个整数组成的列表,通常可以理解为表示边界框的坐标。
实现步骤
  1. 定义正则表达式模式pattern 是一个正则表达式,用于描述期望的输出格式。
  2. 提取完成内容completion_contentscompletions 中提取出每个完成内容的文本部分。
  3. 检查格式匹配matches 使用 re.fullmatch 函数检查每个完成内容是否完全匹配正则表达式模式。
  4. 生成奖励列表:根据匹配结果,为每个完成内容生成一个奖励值,如果匹配则为 1.0,否则为 0.0。
作用

通过这个奖励函数,模型在训练过程中会被激励去生成符合特定格式的输出,有助于规范模型的回答结构,使得输出更易于解析和使用。

2. 交并比(IoU)奖励函数 iou_reward

函数定义和功能
def iou_reward(completions, solution, **kwargs):
    def iou(box1, box2):
        inter_x1 = max(box1[0], box2[0])
        inter_y1 = max(box1[1], box2[1])
        inter_x2 = min(box1[2]-1, box2[2]-1)
        inter_y2 = min(box1[3]-1, box2[3]-1)
        if inter_x1 < inter_x2 and inter_y1 < inter_y2:
            inter = (inter_x2-inter_x1+1)*(inter_y2-inter_y1+1)
        else:
            inter = 0
        union = (box1[2]-box1[0])*(box1[3]-box1[1]) + (box2[2]-box2[0])*(box2[3]-box2[1]) - inter
        return float(inter)/union
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
    answer_tag_pattern = r'<answer>(.*?)</answer>'
    bbox_pattern = r'\[(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*),\s*(\s*-?\d*\.?\d+\s*)\]'
    for content, sol in zip(contents, solution):
        reward = 0.0
        try:
            content_answer_match = re.search(answer_tag_pattern, content)
            if content_answer_match:
                content_answer = content_answer_match.group(1).strip()
                bbox_match = re.search(bbox_pattern, content_answer)
                if bbox_match:
                    bbox = [int(bbox_match.group(1)), int(bbox_match.group(2)), int(bbox_match.group(3)), int(bbox_match.group(4))]
                    if iou(bbox, sol) > 0.5:
                        reward = 1.0
        except Exception:
            pass
        rewards.append(reward)
        if os.getenv("DEBUG_MODE") == "true":
            log_path = os.getenv("LOG_PATH")
            with open(log_path, "a") as f:
                f.write(f"------------- {current_time} Accuracy reward: {reward} -------------\n")
                f.write(f"Content: {content}\n")
                f.write(f"Solution: {sol}\n")
    return rewards

此函数的主要功能是评估模型预测的边界框与真实边界框之间的重叠程度,并根据交并比(IoU)值给予奖励。

实现步骤
  1. 定义 IoU 计算函数iou 函数用于计算两个边界框的交并比。它首先计算两个边界框的交集面积和并集面积,然后将交集面积除以并集面积得到 IoU 值。
  2. 提取完成内容contentscompletions 中提取出每个完成内容的文本部分。
  3. 查找答案和边界框:使用正则表达式 answer_tag_pattern 查找完成内容中的答案部分,再使用 bbox_pattern 查找答案中的边界框坐标。
  4. 计算 IoU 并给予奖励:对于每个完成内容,提取预测的边界框坐标,与真实边界框计算 IoU 值。如果 IoU 值大于 0.5,则给予 1.0 的奖励,否则给予 0.0 的奖励。
  5. 日志记录(可选):如果设置了调试模式(DEBUG_MODEtrue),则将每个完成内容的奖励信息记录到日志文件中。
作用

通过这个奖励函数,模型在训练过程中会被激励去预测更准确的边界框,提高目标检测的精度。同时,结合格式奖励函数,可以让模型不仅准确预测边界框,还能以规定的格式输出结果。

监督学习与规则奖励函数强化学习的结合方式

1. 数据层面的结合
  • 利用监督数据初始化模型:在开始强化学习训练之前,使用监督学习的方式对视觉语言模型(VLM)进行预训练。通过直接拟合问答对数据,让模型学习到基本的语言和视觉特征表示以及问题回答的模式。例如,在代码中使用 LazySupervisedDataset 类加载数据集,这些数据可以作为监督学习阶段的训练数据,让模型初步学习到如何根据问题和图像生成答案。
  • 监督数据作为强化学习的参考:在强化学习的过程中,监督学习的数据可以作为参考来评估模型的输出。例如,在 iou_reward 函数中,通过比较模型预测的边界框与真实边界框的交并比(IoU)来给予奖励,这里的真实边界框就是监督学习中的标签信息。
2. 训练过程的结合
  • 分阶段训练:先进行监督学习训练,让模型收敛到一个较好的初始状态。然后再切换到强化学习阶段,使用规则奖励函数来进一步优化模型的策略。在代码中,虽然没有明确体现分阶段训练的逻辑,但可以在实际应用中先使用监督学习的方法对 Qwen2VLForConditionalGeneration 模型进行训练,然后再使用 Qwen2VLGRPOTrainer 进行强化学习训练。
  • 混合训练:在每个训练步骤中,既使用监督学习的损失函数,又使用强化学习的奖励函数。例如,可以将监督学习的交叉熵损失和强化学习的奖励损失加权求和,作为总的损失函数来更新模型参数。这样可以让模型在学习过程中既考虑到直接拟合标签的准确性,又考虑到长期的奖励优化。
3. 奖励函数设计结合监督信息
  • 准确性奖励:如 iou_reward 函数,将模型输出与监督学习中的标签进行比较,根据比较结果给予奖励。这种奖励函数可以促使模型在强化学习过程中输出更接近真实标签的结果,从而结合了监督学习的信息。
  • 格式奖励format_reward 函数可以确保模型输出的格式符合特定要求,这可以看作是一种规则约束。同时,这种格式要求也可以是在监督学习阶段就定义好的,从而将监督学习中的格式规范融入到强化学习的奖励机制中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2306041.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

FFMPEG编码容错处理解决办法之途径----升级库文件

在qt开发环境下接收网络数据&#xff0c;调用ffmpeg解码播放视频&#xff0c;出现闪屏现象&#xff0c;具体现象可以使用操作系统自带的ffplay播放器播放原始视频流可复现&#xff1b;而使用操作系统自带的mpv播放器播放视频则不会出现闪屏&#xff1b;闪屏时会报Could not fin…

uniapp h5端和app端 使用 turn.js

前提:添加页后,添加页与当前页会重叠在一起,不知道为什么,没有找到解决办法 1.h5端 <template><view class"container"><view id"flipbook"><view class"page page1">Page 1</view><view class"page pag…

【入门音视频】音视频基础知识

&#x1f308;前言&#x1f308; 这个系列在我学习过程中&#xff0c;对音视频知识归纳总结的笔记。因为音视频相关讲解非常稀少&#xff0c;所以我希望通过这个音视频系列&#xff0c;跟大家一起学习音视频&#xff0c;希望减少初学者在学习上的压力。同时希望也欢迎指出文章的…

数据结构☞泛型

一.基础定义与应用方向 1.定义&#xff1a; 一般的类和方法&#xff0c;只能使用具体的类型 : 要么是基本类型&#xff0c;要么是自定义的类。如果要编写可以 应用于多种类型 的代码&#xff0c;这种刻板的限制对代码的束缚就会很大。----- 来源《 Java 编程思想》对泛型的介…

hot100-二叉树

二叉树 二叉树递归 相当于这个的顺序来回调换 class Solution {private List<Integer> res new ArrayList<>();public List<Integer> inorderTraversal(TreeNode root) {if(root null)return res;inorderTraversal(root.left);res.add(root.val);inorde…

嵌入式项目:STM32刷卡指纹智能门禁系统

本文详细介绍基于STM32的刷卡指纹智能门禁系统。 获取资料/指导答疑/技术交流/选题/帮助&#xff0c;请点链接&#xff1a; https://gitee.com/zengzhaorong/share_contact/blob/master/stm32.txt 1 系统功能 1.1 功能概述 本系统由STM32硬件端&#xff08;下位机&#xff09;…

短剧小程序系统源码

短剧小程序系统源码 今天我要向大家介绍的是最新作品——短剧小程序系统源码。这不仅仅是一款简单的播放工具&#xff0c;它背后蕴含的强大功能能够帮助你的短剧业务实现质的飞跃&#xff01; 为什么说这款源码很厉害&#xff1f; 首先&#xff0c;在当今竞争激烈的市场环境…

C#中级教程(2)——走进 C# 面向对象编程:从基础到进阶的深度探索

一、为什么选择面向对象编程 在软件开发的演进过程中&#xff0c;随着程序规模和复杂度的不断增加&#xff0c;传统的编程方式逐渐暴露出局限性。面向对象编程应运而生&#xff0c;它就像是一位智慧的组织者&#xff0c;将程序中的功能进行模块化划分。每个模块各司其职&#x…

基于SpringBoot的“流浪动物救助系统”的设计与实现(源码+数据库+文档+PPT)

基于SpringBoot的“流浪动物救助系统”的设计与实现&#xff08;源码数据库文档PPT) 开发语言&#xff1a;Java 数据库&#xff1a;MySQL 技术&#xff1a;SpringBoot 工具&#xff1a;IDEA/Ecilpse、Navicat、Maven 系统展示 系统功能结构图 局部E-R图 系统首页界面 系统…

基于WebRTC与AI大模型接入EasyRTC:打造轻量级、高实时、强互动的嵌入式音视频解决方案

随着物联网和嵌入式技术的快速发展&#xff0c;嵌入式设备对实时音视频通信的需求日益增长。然而&#xff0c;传统的音视频解决方案往往存在体积庞大、实时性差、互动体验不佳等问题&#xff0c;难以满足嵌入式设备的资源限制和应用场景需求。 针对以上痛点&#xff0c;本文将介…

Windows - 通过ssh打开带有图形界面的程序 - 一种通过计划任务的曲折实现方式

Windows(奇思妙想) - 通过ssh打开带有图形界面的程序 - 一种通过计划任务的曲折实现方式 前言 Windows启用OpenSSH客户端后就可以通过SSH的方式访问Windows了。但是通过SSH启动的程序&#xff1a; 无法显示图形界面会随着SSH进程的结束而结束 于是想到了一种通过执行“计划…

RT-Thread+STM32L475VET6——USB鼠标模拟

文章目录 前言一、板载资源二、具体步骤1.配置icm20608传感器2.打开CubeMX进行USB配置3. 配置USB3.1 打开USB驱动3.2 声明USB3.3 剪切stm32xxxx_hal_msp.c中的void HAL_PCD_MspInit(PCD_HandleTypeDef* hpcd)和void HAL_PCD_MspDeInit(PCD_HandleTypeDef* hpcd)函数至board.c3.…

计算机毕业设计SpringBoot+Vue.js母婴商城(源码+LW文档+PPT+讲解+开题报告)

温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 温馨提示&#xff1a;文末有 CSDN 平台官方提供的学长联系方式的名片&#xff01; 作者简介&#xff1a;Java领…

Teigha(ODA<Open Design Alliance>_开放设计联盟)——cad c# 二次开发

需将dll库文件与exe文件放同一路径下&#xff0c;运行exe即可执行。 using System; using System.Collections.Generic; using System.ComponentModel; using System.Data; using System.Drawing; using System.IO; using System.Linq; using System.Text; using System.Thread…

idea 部署 AJ-Report 启动的注意事项

AJ-Report 入门参考&#xff1a; AJ-Report 初学(入门教程) gitee 下载&#xff1a;https://gitee.com/anji-plus/report/releases 根据上面提供的 gitee 下载链接&#xff0c;点击直接下载 最上面的就是最新版本的&#xff0c;旧版本往下拉就可以找到&#xff0c;有三个下载…

智能化客户行为轨迹分析:AI视频监控在大型商场的技术方案

项目背景&#xff1a;为了提升顾客体验并支持精准营销&#xff0c;卖场或商场需要通过智能化手段分析客户在商场内的行为路线。 一、具体需求 1、行为路径分析&#xff1a;跟踪顾客在商场内的移动轨迹&#xff0c;了解顾客的购物习惯和偏好。 2、高频活动区域识别&#xff1a…

Denoising Diffusion Restoration Models论文解读

论文要点 恢复的线性逆问题可以使用预训练的DDPM完成&#xff1a;1. 将降质矩阵使用SVD&#xff0c;得到分解矩阵&#xff1b;2. 使用分解矩阵将图像投影到降质类型间共享的谱空间&#xff1b;3. 谱空间中执行DDPM。 评价 同Track的方法同样很多&#xff0c;比如后续的DDNM、…

基于SpringBoot的校园消费点评管理系统

作者&#xff1a;计算机学姐 开发技术&#xff1a;SpringBoot、SSM、Vue、MySQL、JSP、ElementUI、Python、小程序等&#xff0c;“文末源码”。 专栏推荐&#xff1a;前后端分离项目源码、SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏&#xff1a;…

【NLP 38、激活函数 ④ GELU激活函数】

别盲目&#xff0c;别着急&#xff0c;慢慢走&#xff0c;没事的 —— 25.2.24 一、定义与数学表达式 GELU&#xff08;Gaussian Error Linear Unit&#xff0c;高斯误差线性单元&#xff09;是一种结合概率分布的非线性激活函数&#xff0c;其核心思想是通过输入值服从标准正…

QT:paintEvent、QPainter、QPaintDevice

paintEvent 介绍 在 Qt 编程中&#xff0c;paintEvent 是 QWidget 类中的一个非常重要的虚函数&#xff0c;用于处理绘图事件。当一个 QWidget 或其派生类的实例需要进行重绘操作时&#xff0c;Qt 会自动调用该控件的 paintEvent 函数。 触发时机 窗口首次显示&#xff1a;当…