[NeurlPS 2022] STaR 开源代码实现解读

news2024/12/13 3:22:56
  • STaR 方法代码开源,这里给出一个中文代码解读地址:repo
  • 入口点:iteration_train.py
  • 关键代码:device_train.py, device_inference.py, and create_finetune_tfrecords.py
  • 基于 JAX、RAY,在 Google TPU 上实现;

入口点:iteration_train.py

if __name__ == "__main__":
    args = parse_args()
    print(args)
    task = args.task                                                                    # 选择数据集/任务:论文中有 CommonsenseQA、GSM8K
    experiment_name = "_".join(sys.argv[1:])                                            # 实验参数以_分割,拼接在一起命名
    experiment_name = ''.join(ch for ch in experiment_name if ch.isalnum() or ch == "_")# 确保 name 只有字母、数字、下划线(符合文件命名格式)
    if args.no_prompt:
        eval_seq = 128 + args.gen_length
    os.makedirs(f"configs/{experiment_name}", exist_ok=True)
    shutil.copy(f"configs/qa_base.json", f"configs/{experiment_name}/base.json")        # 复制一份实验配置模版
    prev_config = f"configs/{experiment_name}/base.json"                                # 实验配置模版的路径(后续代码会修改这个复制文件)
    new_json = make_first_config()

    os.makedirs(f'data/{experiment_name}', exist_ok=True)
    os.makedirs(f'{task}/{experiment_name}', exist_ok=True)
    os.makedirs(f'result_logs/', exist_ok=True)
    with open(f"result_logs/{experiment_name}.txt", "a+") as f:
        print("================================", file=f)                               # 类似 f.write
        print(args, file=f)
    for cur_iter in range(1, args.n_iters):                                             # 论文中的外循环迭代次数,重复多少次 STaR 微调方法
        exp_iteration = f"{experiment_name}_{cur_iter}"
        gen_train() # Generate the training set
        train_set = gen_records() # Create the tfrecords from the data                  # "{experiment_name}/{exp_iteration}.index"
        config_name = gen_config(train_set) # Create the new configuration file         # 核心是修改 total_steps
        train_model() # Train the new model
        eval_model() # Evaluate the new model
        prev_config = config_name  # Prepare for next iteration
        if args.copy_n > 0:
            copy_files()                                                                # [TODO] 复制上次外循环的一些配置文件,暂时不知道有啥用

parse_args() 标准的解析命令行参数,但是这里代码参数非常多。论文中,对一些技术细节写的比较模糊或者看不明白,这里需要结合代码分析。、

启动命令参数 parse_args()

  • 说明:对于 bool 参数,在启动命令中带 --bool_params 或者不带这个参数即可提现,不用具体赋值
参数取值范围默认值说明
--no_promptbooltrueeval时是否移出prompts (不用few-shot prompting,训练默认都是用的,对比实验不用)
--base_epochsfloat1.0第一次 iter 的 epoch
--add_epochsfloat0.2不同 iter 中需要 add 的 epoch
--few_shot_trainboolfalse是否使用 few-shot 训练
--steady_growboolfalse是否使用固定数量的 epoch
--start_stepsfloat40.0第一次外循环的步数(不同外循环步数可能不同)
--exponential_growboolfalse是否使用指数增长
--add_stepsfloat20.0steady_grow 配对参数,每次迭代中增加的步数
--grow_stepsfloat1.2exponential_grow 配对参数,每次迭代中按比例增长
--p_rationalizationfloat1.0使用合理化的错误样本比例
--p_show_hint_savefloat0.0保存合理化提示的比例 [TODO]
--rationalizeboolfalse是否使用合理化
--start_iterint1起始迭代数
--n_itersint64外部循环迭代的最大次数 (论文中的外循环,使用多少次 STaR 微调)
--copy_nint0每次迭代中需要复制的文件数
--n_train_samplesint10000训练样本数
--gradient_accumulation_stepsint8梯度累积的步数 Batch size
--taskstr“commonsenseqa”运行的任务类型 ,论文中有 CommonsenseQA、GSM8K 两个数据集
--directboolfalse是否直接预测(不使用scratchpad)
--gen_lengthint96生成输出的长度
--sequence_countint10每个batch的平均序列数量
--base_model_locationstr“gs://checkpoint-bucket/step_383500/”微调模型的检查点路径
--dry_runboolfalse是否进行快速运行以可视化输出
--skip_evalboolfalse是否跳过评估(例如算术任务)

训练epoch、step是否随着外循环迭代而增长?

epoch 控制参数:
step 控制参数:steady_grow、exponential_grow 或者都不选。三选一。选了 steady_grow、exponential_grow 分别还有一个配对的配置参数:add_steps、grow_steps(比例)。不选的话根据下面计算步数:

# Count data points
        total_count = 0
        for cur_file in sorted(os.listdir(record_folder(cur_iter - 1)), key=lambda x: int(x.split('.')[0].split("_")[-1])):
            with open(f"{record_folder(cur_iter - 1)}/{cur_file}", encoding='utf-8') as train_file:
                train_file_text = train_file.read()
                total_count += len(train_file_text.split("\n\n"))
                print(len(train_file_text.split("\n\n")))
        train_epochs = args.base_epochs + args.add_epochs * (cur_iter - 1)
        cur_steps = int(total_count * train_epochs // (args.gradient_accumulation_steps * args.sequence_count))
        return cur_steps

配置文件

qa_base.json

configs/qa_base.json 是实验的基础配置文件,运行实验会复制这个 template 然后不断修改这里的 value。

{
    "layers": 28,
    "d_model": 4096,
    "n_heads": 16,
    "n_vocab": 50400,
    "norm": "layernorm",
    "pe": "rotary",
    "pe_rotary_dims": 64,
    "seq": 1536, // 模型上下文窗口长度
    "cores_per_replica": 8, // device_inference 中用到,模型并行的参数,模型要分散到多个cores上来进行模型的计算
    "per_replica_batch": 1,	// device_inference 中用到,数据并行的参数,数据并行中每个模块并行的batch大小
    "gradient_accumulation_steps": 8, // 始终是 args.gradient_accumulation_steps
    "warmup_steps": 100,
    "anneal_steps": 300000,
    "lr": 1e-06,
    "end_lr": 1e-06,
    "weight_decay": 0.0,
    "total_steps": 383500,	   // 来自 get_n_steps(),有三种配置模式,见上面
    "tpu_size": 8,
    "p_rationalization": 1.0, // 始终是 args.p_rationalization
    "bucket": "checkpoint-bucket",			// 模型 ckpt 存储桶名
    "model_dir": "full_qa_4",				// 模型存储路径
    "train_set": "qa_train_4.index",
    "val_set": {
      "index": "qa.val.index"
    },
    "eval_harness_tasks": [
      "lambada",
      "piqa",
      "hellaswag",
      "winogrande",
      "mathqa",
      "pubmedqa"
    ],
    "val_batches": 100,
    "val_every": 10000,
    "ckpt_every": 10000,
    "keep_every": 10000,
    "name": "slow_grow_full_epoch_0",			// 这里会不断修改为 "{experiment_name}_0"
    "wandb_project": "full_6",	// wandb是一个日志服务,这里是日志记录的所属项目
    "comment": "",
    "target_save_folder": "commonsenseqa/iterative_full/iterative_full_0", // 文件存储所在文件夹路径
    "target_save": "commonsenseqa/slow_grow_full_epoch/slow_grow_full_epoch_0/slow_grow_full_epoch_0.txt" // 文件存储位置:文件和 name 同名,target_save_folder+name+".txt"
  }

训练核心代码

外层调用:iteration_train.py

调用侧代码(iteration_train.py):

# main:
    for cur_iter in range(1, args.n_iters):                                             # 论文中的外循环迭代次数,重复多少次 STaR 微调方法
        exp_iteration = f"{experiment_name}_{cur_iter}"
        gen_train() # Generate the training set (第一次不执行)
        train_set = gen_records() # Create the tfrecords from the data                  # "{experiment_name}/{exp_iteration}.index"
        config_name = gen_config(train_set) # Create the new configuration file         # 核心是修改 total_steps
        train_model() # Train the new model

在训练前,需要先生成训练数据集(rationale generation)。核心是:gen_train(),然后通过 train_model() 开始微调模型。

def gen_records():
    gen_cmd = f'python3 create_finetune_tfrecords.py {record_folder(cur_iter - 1)} {record_folder(cur_iter - 1)}'
    print(f"Creating records for finetuning {cur_iter}: {gen_cmd}")
    if not args.dry_run and (cur_iter >= args.start_iter):
        os.system(gen_cmd)
    train_set = f"{experiment_name}/{exp_iteration}.index"
    with open(f"data/{train_set}", "w") as new_data_file:
        new_data_file.write(f"{record_folder(cur_iter - 1)}.tfrecords")
    return train_set
def train_model():
    model_cmd = f"python3 device_train.py --config {config_name} --tune-model-path={args.base_model_location}"
    print(f"Train model {cur_iter}: {model_cmd}")
    if not args.dry_run and (cur_iter >= args.start_iter):
        os.system(model_cmd)

rationale generation 代码 gen_train:device_inference.py

device_inference.py

参数取值范围默认值说明
--configstrNone配置文件路径
--directboolfalse是否直接预测(不使用scratchpad)
--rationalizeboolfalse是否使用合理化
--no_promptboolfalseeval时是否移出prompts (不用few-shot prompting,训练默认都是用的,对比实验不用)
--few_shot_trainboolfalse训练时是否移除few-shot-prompts
--show_hint_promptboolfalse是否需要提示提示
--splitstr“dev”split的数据集(train,dev) gen_train里是–split=train,eval_model 里是 dev
--dataset_modestr“cqa”使用的数据集(注意cqa在另一个文件默认值是全写,有代码做了兼容,这里默认值不能改,必须是cqa)
--n_train_samplesint3000训练样本数量
--gen_lengthint96生成长度
--eval_batch_sizeint8评估时的批量大小
--p_show_hint_savefloat0.0保存合理化提示的比例
--ckpt_stepint-1要评估的检查点,-1表示最终检查点
--eval_seqint-1序列长度,-1表示使用参数文件中的配置 (seq是模型上下文tokens最大长度)

此时传入的参数是:

  • prev_config:用的上次迭代的配置,因为这里用上一次学习好的模型来生成数据集;
  • gen_length 输出长度;
    if args.no_prompt:
        eval_seq = 128 + args.gen_length

如果按默认值,这里gen_length是128+96=224

  • p_show_hint_save:合理化相关的参数
  • n_train_samples:训练样本,默认是 10000(论文里始终保持这个数)
def gen_train():
    train_cmd = f"python3 device_inference.py --config={prev_config} --split=train --gen_length={args.gen_length} --p_show_hint_save={args.p_show_hint_save} "
    if task != "commonsenseqa":
        train_cmd += f" --dataset_mode={task} "
    if args.rationalize:
        train_cmd += " --rationalize "
    if args.few_shot_train:
        train_cmd += " --few_shot_train "
    if cur_iter > 1 and args.no_prompt:
        train_cmd += f" --no_prompt --eval_seq {eval_seq} "
    train_cmd += f" --n_train_samples={args.n_train_samples} "
    train_cmd += f" >> result_logs/{experiment_name}.txt"
    print(f"Generating training set {cur_iter} using model {cur_iter - 1}: {train_cmd}")
    if not args.dry_run and (cur_iter >= args.start_iter):
        if (cur_iter == 1) and os.path.exists(record_folder(0) + f"/{experiment_name}_0.txt"):
            print("First file cached") # 第一次不执行
        else:
            os.system(train_cmd)

注意:第一次运行 gen_train 的时候不执行,需要先微调后才执行合理化。

接下来分析 device_inference.py 中的代码:

if __name__ == "__main__":
    # 参数解析
    args = parse_args()
    print(args)
    split = args.split                              # 'dev'
    params = json.load(smart_open(args.config))     # smart_open 是一个用于打开文件的函数,支持多种文件格式和存储后端,本地文件,aws s3,gcs 等等

    # 初始化 wandb
    project = params.get("wandb_project", "mesh-transformer-jax")               # 日志服务所属的项目,随便什么值,这里不重要
    experiment_details = params["name"].split("_")
    wandb_name = "_".join(experiment_details[:-1])
    wandb_iteration = int(experiment_details[-1])
    wandb.init(project=project, name=wandb_name, config=params, resume=True)    # resume=True: 表示如果有相同名称的实验已经存在,则恢复该实验的状态,而不是创建一个新的实验。

    # 根据配置加载不同的 prompt 设置
    prompts_file = "prompts.txt" if not args.direct else "prompts_direct.txt"   # 默认不带 direct,即用带 few-shot 和 rationales 的 prompt
    prompts_file = f"{args.dataset_mode}/{prompts_file}"                        
    if args.no_prompt:
        commonsense_prompts = []
    else:
        with basic_open(prompts_file) as prompts:
            commonsense_prompts = prompts.read().split("\n\n")
    prompts_hint_file = "prompts_answer_key.txt" if not args.direct else "prompts_direct_answer_key.txt"
    prompts_hint_file = f"{args.dataset_mode}/{prompts_hint_file}"
    if args.no_prompt and not args.show_hint_prompt:
        commonsense_prompts_hint = []
    else:
        with basic_open(prompts_hint_file) as prompts:
            commonsense_prompts_hint = prompts.read().split("\n\n")

    # 参数设置
    per_replica_batch = params["per_replica_batch"]                             # 数据并行参数:1
    cores_per_replica = params["cores_per_replica"]                             # 模型并行参数:模型并行中的每个 replica 的核心数,默认是 8
    target_save = params["target_save"] if split != "dev" else f'{args.dataset_mode}/new_dev.txt'
    seq = params["seq"] if args.eval_seq == -1 else args.eval_seq
    hint_seq = seq
    set_opt(params)

    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)   # (replica 数量,每个 replica 的核心数)
    devices = np.array(jax.devices()).reshape(mesh_shape)                       # 为每个 replica 划分 cores,形成一个资源分配矩阵
    ckpt_path = get_ckpt_path(params, args.ckpt_step)                           # 默认用最新的 ckpt
    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):                     # 并行策略的维度:dp,数据并行,mp,模型并行
        network = load_model(params, ckpt_path, devices, mesh_shape)

        dataset = get_dataset(args)
        dataset_keys = set([datakey for datakey, _ in dataset])

        total_batch = per_replica_batch * jax.device_count() // cores_per_replica * args.eval_batch_size    # 数据并行侧,一次性输入的数据 batch 大小
        gen_params = {"top_p": np.ones(total_batch) * 0.9, "temp": np.ones(total_batch) * 0.01}             # top_p: 控制生成文本的多样性的一种采样策略, Nucleus Sampling; temp: 温度参数,用于控制生成文本的随机性。温度越高,生成的文本越随机;温度越低,生成的文本越确定。

        accurate_count = eval_examples(dataset, commonsense_prompts, commonsense_prompts_hint, direct=args.direct)
        for cur_key, cur_counts in accurate_count.items():
            print(f"{split}, {cur_key}, {get_score(cur_counts)}")
            wandb.log({f"{split}_{cur_key}_accuracy": get_score(cur_counts), "iteration": wandb_iteration})

  • 最开始,参数解析,注意一方面参数来自于外层调用传入的(前文分析了),另一部分来自配置文件 json;
  • 初始化 wandb:Weights & Biases(通常简称为 WandB)是一个用于机器学习实验管理和可视化的工具。它提供了一系列功能,帮助研究人员和开发者更好地跟踪、管理和可视化他们的机器学习实验。
  • 然后是根据配置加载不同的 prompt 设置
    • arg.direct:不用带 rationales 的 prompt,默认是用;
    • 加载不带合理化(但有rationales或者无rationales的配置)/ 或者不使用 few-shot;
    • 加载带合理化(hint)的 prompt (且带有 rationales);
  • 然后是从config读一些配置:注意数据集分 train、dev
# seq 是模型上下文窗口长度,input tokens 不能超过这个
seq = params["seq"] if args.eval_seq == -1 else args.eval_seq
hint_seq = seq
    "cores_per_replica": 8, // device_inference 中用到,模型并行的参数,模型要分散到多个cores上来进行模型的计算
    "per_replica_batch": 1,	// device_inference 中用到,数据并行的参数,数据并行中每个模块并行的batch大小
  • replica 指的应该是大模型并行的其中一个部分。per_replica_batch 是数据并行的参数。cores_per_replica 是每个 replia 分配的核心数,是模型并行的参数,模型要分散到多个cores上来进行模型的计算。
    • 数据并行:数据并行是将训练数据分割成多个小批次,并在多个设备上并行处理这些小批次。每个设备都有一个完整的模型副本,计算梯度后再进行参数更新。
    • 模型并行:模型并行是将一个模型的不同部分分布在多个计算设备上。适用于模型非常大,以至于单个设备无法容纳整个模型的情况。
    mesh_shape = (jax.device_count() // cores_per_replica, cores_per_replica)   # (replica 数量,每个 replica 的核心数)
    devices = np.array(jax.devices()).reshape(mesh_shape)                       # 为每个 replica 划分 cores,形成一个资源分配矩阵
    ckpt_path = get_ckpt_path(params, args.ckpt_step)                           # 默认用最新的 ckpt
    with jax.experimental.maps.mesh(devices, ('dp', 'mp')):                     # 并行策略的维度:dp,数据并行,mp,模型并行

注意:eval_batch_size 主要是 cache 样本,样本缓存到这个数,才执行(减少模型io开销)。

eval_examples

def eval_examples(data_examples, few_shot_prompts, few_shot_prompts_hint, direct=False):
    accurate_count = {}
    tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

    main_examples, hint_examples = [], []
    pbar = tqdm(data_examples, smoothing=0)
    for data_example in pbar:   # 逐个遍历:而单个样本的执行和合理化样本的执行都是 cache 到一个 batch 再执行
        main_examples.append(data_example)
        if len(main_examples) == args.eval_batch_size:  # 默认值 8
            successful_examples = eval_batch(           # 评估
                main_examples, few_shot_prompts, seq, tokenizer,
                args.gen_length, gen_params, accurate_count, target_save, direct=direct
            )
            for example_idx, example in enumerate(main_examples):
                if (example_idx not in successful_examples) and (random.random() < params.get('p_rationalization', 1.)): # p_rationalization 默认值是 1
                    hint_examples.append(example)   # 如果回答失败,加入 hint 合理化样本中
            main_examples = [] # 清空队列

        if args.rationalize and len(hint_examples) >= args.eval_batch_size: # 合理化
            cur_hint_examples = hint_examples[:args.eval_batch_size]
            cur_hint_examples = [                                           # hint 样本修改 key
                (hint_example_key + "_r", hint_example) for hint_example_key, hint_example in cur_hint_examples
            ]
            eval_batch(                                                     # 评估
                cur_hint_examples, few_shot_prompts_hint, hint_seq, tokenizer,
                args.gen_length, gen_params, accurate_count, target_save, hint=True, direct=direct  # 开启 hint 合理化
            )
            hint_examples = hint_examples[args.eval_batch_size:]            # 清空当前合理化的样本
        pbar.set_description(f"{split} " + ", ".join([
            f"{cur_key}: {get_score(cur_counts):0.4f}" for cur_key, cur_counts in accurate_count.items()
        ]))
    return accurate_count

eval_batch

def eval_batch(examples, few_shot_prompts, seq, tok, gen_length, gen_params, accuracy, target_save, hint=False, direct=False):
    batch = examples_to_batch(examples, few_shot_prompts, seq, tok, hint=hint, direct=direct, p_show_hint_save=args.p_show_hint_save)   # 把example批处理成合适的prompt
    output = network.generate(batch["padded_batch"], batch["lengths"], gen_length, gen_params)    # 实际上执行输出的代码
    return eval_output(                                                                           # 评估输出结果,记录回答正确的样本
        output, batch["answers"], batch["base_context"], batch["classes"], accuracy, target_save, tok, direct=direct
    )

def examples_to_batch(data_examples, few_shot_prompts, seq, tokenizer, hint=False, direct=False, p_show_hint_save=0.1):
    batch = {
        "base_context": [],
        "initial_batch": [],
        "lengths": [],
        "padded_batch": [],
        "answers": [],
        "classes": []                                   # 分类
    }
    for data_class, data_example in data_examples:
        batch['classes'].append(data_class)
        # Context, without the few-shot prompt
        hintless_base_context = question_to_context(data_example, hint=False, dataset_mode=args.dataset_mode, direct=direct)    # 不带 hint
        base_context = question_to_context(data_example, hint=hint, dataset_mode=args.dataset_mode, direct=direct)
        if args.dataset_mode == "arithmetic":
            few_shot_prompts = base_context.split("\n\n")[:-1]
            base_context = base_context.split("\n\n")[-1]
            hintless_base_context = hintless_base_context.split("\n\n")[-1]

        if random.random() < p_show_hint_save:  # 默认是 0
            hintless_base_context = base_context

        # We always want to act as if no hint was given
        if args.few_shot_train:
            if args.dataset_mode == "arithmetic":
                raise NotImplementedError
            else:
                save_context = "\n\n".join(commonsense_prompts) + "\n\n"
                save_context += hintless_base_context
                batch['base_context'].append(save_context)
        else:
            batch['base_context'].append(hintless_base_context)

        # Input tokens
        if args.no_prompt:
            context = ""
        else:
            context = "\n\n".join(few_shot_prompts) + "\n\n"            # 最终prompt部分 1:默认带 few-shot

        context += base_context                                         # 最终prompt部分 2:当前问题(可能带有合理化)
        tokens = tokenizer.encode(context)                              # tokenizer
        batch['initial_batch'].append(tokens)
        # Input lengths
        batch['lengths'].append(len(tokens))
        # Padded tokens
        provided_ctx = len(tokens)
        pad_amount = max(seq - provided_ctx, 0)                         # seq 是最大窗口长度,如果不够这个长度需要 pad
        if provided_ctx > seq:
            tokens = tokens[-seq:]                                      # 如果超出,需要截断
        batch['padded_batch'].append(np.pad(tokens, ((pad_amount, 0),)).astype(np.uint32))

        # Answer
        if args.dataset_mode == "arithmetic":
            if len(data_example.split("\n")) >= 3:
                target = data_example.split("\n")[-3]
            else:
                target = "invalid"
        elif args.dataset_mode == "cqa":
            target = data_example['answerKey']
        elif args.dataset_mode == "gsm":
            target = data_example['answer'].split("#### ")[-1]
        batch['answers'].append(target)
    batch["lengths"] = np.asarray(batch["lengths"], dtype=np.uint32)
    batch["padded_batch"] = np.array(batch["padded_batch"])
    return batch
def question_to_context(data_example, hint=False, dataset_mode='cqa', direct=False):
    """"
    将问题转为 prompt

    - hint: 是否开启合理化
    """
    if dataset_mode == 'cqa':
        context = f"Q: {data_example['question']['stem']}\nAnswer Choices:\n"
        for choice in data_example['question']['choices']:
            if hint and (choice['label'].lower() == data_example['answerKey'].lower()):
                context += f"({choice['label'].lower()}) {choice['text']} (CORRECT)\n"
            else:
                context += f"({choice['label'].lower()}) {choice['text']}\n"
        context += "A:"
    elif dataset_mode == 'gsm':
        context = f"Q: {data_example['question']}"
        if hint:
            chosen_hint = data_example['answer']                # gsm 竟然直接把答案作为 hint
            context += f" ({chosen_hint})"
        context += "\nA:"
    elif dataset_mode == "arithmetic":
        context = ""
        for example_split, next_example_split in zip(data_example.split('Target:')[:-1], data_example.split('Target:')[1:]):
            if direct and "</scratch>" in example_split:
                context += example_split.split("</scratch>")[-1]
            else:
                context += example_split
            context += "Target:"
            if hint:
                context += " " + next_example_split.split("\n")[-5]
    return context

eval_output

def eval_output(output, answers, context, example_classes, accuracy, target_save, tokenizer, show=False, direct=False, endoftext="<|endoftext|>"):
    """
    评估输出结果,统计准确率,并将成功的示例保存到指定文件中。

    参数:
    - output (list): 模型的输出结果。
    - answers (list): 正确答案列表。
    - context (list): 上下文列表。
    - example_classes (list): 示例类别列表。
    - accuracy (dict): 用于统计准确率的字典。
    - target_save (str): 成功示例保存的文件路径。
    - tokenizer (transformers.PreTrainedTokenizer): 用于处理文本的分词器。
    - show (bool, optional): 是否打印成功示例到控制台。默认为 False。
    - direct (bool, optional): 是否使用直接预测,跳过scratchpad。默认为 False。
    - endoftext (str, optional): 用于标记文本结束的字符串。默认为 "<|endoftext|>"。

    返回:
    - list: 成功示例的索引列表。
    """
    successful_examples = []
    enum_outputs = enumerate(output[1][0][:, :, 0])
    for (idx, o), target, cur_base_context, example_class in zip(enum_outputs, answers, context, example_classes):
        cur_output = tokenizer.decode(o)
        output_numbers = cur_output.split('\n')
        if example_class not in accuracy:
            accuracy[example_class] = {'accurate': 0, 'total': 0}
        accuracy[example_class]['total'] += 1
        if len(output_numbers) == 0:
            continue
        try:
            if args.dataset_mode == "cqa":
                output_numbers = output_numbers[0]
                if "<|endoftext|>" in output_numbers:
                    output_numbers = output_numbers.split("<|endoftext|>")[0]
                output_prediction = output_numbers[-3]                                  # 选项
            elif args.dataset_mode == "gsm":
                output_prediction = ""
                for line_idx, line in enumerate(output_numbers):
                    if "####" in line:
                        output_numbers = "\n".join(output_numbers[:line_idx + 1])
                        if "<|endoftext|>" in output_numbers:
                            output_numbers = output_numbers.split("<|endoftext|>")[0]
                        output_prediction = output_numbers.split("####")[-1].strip()
                        break
            elif args.dataset_mode == "arithmetic":
                if len(output_numbers) == 0:
                    continue
                elif "<|endoftext|>" in output_numbers:
                    prediction_index = output_numbers.index("<|endoftext|>") - 1
                elif "</scratch>" in output_numbers:
                    prediction_index = output_numbers.index("</scratch>") + 1
                    if prediction_index == len(output_numbers):
                        continue
                else:
                    if direct and len(output_numbers) > 1:
                        prediction_index = 1
                    else:
                        prediction_index = 0
                output_prediction = output_numbers[prediction_index]                      # 计算结果

            if "<|endoftext|>" in output_prediction:
                output_prediction = output_prediction.split("<|endoftext|>")[0]

            correct = output_prediction.lower() == target.lower()                         # 判断输出是否和目标一致
            if correct:
                accuracy[example_class]['accurate'] += 1                                  # 回答正确,计数++
                with basic_open(target_save, 'a+') as new_train_f:
                    if args.dataset_mode == "cqa" or args.dataset_mode == "gsm":
                        new_example = cur_base_context + output_numbers + endoftext       # 正确回答的样本作为新的训练样本
                    elif args.dataset_mode == "arithmetic":
                        if args.few_shot_train:
                            raise NotImplementedError
                        joined_output = "\n".join(output_numbers[:prediction_index + 1])
                        if "<|endoftext|>" in joined_output:
                            joined_output = joined_output.split("<|endoftext|>")[0]
                        new_example = cur_base_context + joined_output + endoftext       # 正确回答的样本作为新的训练样本
                    if show:
                        print(new_example)
                    print(new_example, file=new_train_f, end="")                         # 把回答正确的样本写入文件中
                successful_examples.append(idx)
        except IndexError:
            pass
    return successful_examples

合理化部分代码总结

结合代码以及论文解读[NeurlPS 2022] STaR: Self-Taught Reasoner Bootstrapping Reasoning With Reasoning 现在重新来理解论文。


论文基本思路是,先给出few-shot,让模型参考few-shot在回答answer前带上rationales,如果回答不正确,就加上hint回答,最终把回答正确的样本留下进行下一轮微调。

在具体代码实现上,首先在 eval_examples 中,对样本做了个 batch 级别的 cache,每满8个,才执行对应的推理(回答)。这里维护了两个cache 队列,一个是回答正确的队列,一个是直接回答失败的队列(因此,用合理化修改了原始prompt)。两个队列分别满8时分别执行重新的回答操作,具体是在 eval_batch 中实现。先通过 examples_to_batch 对 batch 样本批量处理prompt,比如加上few-shot template 等等(或者加上hint)。然后批量推理。然后通过eval_output评估是否回答正确。如果没有回答正确,那么加入hint的样本中。所有回答正确的样本都会保存作为下一次微调的数据集【注意,对于合理化的样本,保存的问题不带hint】。

所以,根据这个实现,再回答阅读论文中的问题:

注意:这个标里的细节。文字部分说“Note the final STaR model is trained on 78.2% of the training dataset with rationale generation, and an additional 8.5% from rationalization”,而表格里不带合理化的STaR准确率只有68.8%,这里78.2%和68.8%有个差值!这里要怎么理解:因为带有合理化后,fine tune,导致模型处理hard问题的能力提升,所以在之后的实验中,部分问题不需要合理化就可以解出,所以涨了近10个点。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2258552.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

欢迪迈手机商城设计与实现

文末获取源码和万字论文&#xff0c;制作不易&#xff0c;感谢点赞支持。 题目&#xff1a;欢迪迈手机商城设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管…

【鸿蒙实战开发】数据的下拉刷新与上拉加载

本章介绍 本章主要介绍 ArkUI 开发中最常用的场景下拉刷新, 上拉加载&#xff0c;在本章中介绍的内容在实际开发过程当中会高频的使用,所以同学们要牢记本章的内容。下面就让我们开始今天的讲解吧&#xff01; List 组件 在 ArkUI 中List容器组件也可以实现数据滚动的效果&a…

UnityShaderLab 实现程序化形状(一)

1.实现一个长宽可变的矩形&#xff1a; 代码&#xff1a; fixed4 frag (v2f i) : SV_Target{return saturate(length(saturate(abs(i.uv - 0.5)-0.13)))/0.03;} 2.实现一个半径可变的圆形&#xff1a; 代码&#xff1a; fixed4 frag (v2f i) : SV_Target{return (distance(a…

高阶数据结构--B树B+树实现原理B树模拟实现--Java

目录 一、B-树概念 二、B-树插入分析 1.用序列{53, 139, 75, 49, 145, 36, 101}构建B树的过程如下&#xff1a; 2.插入过程总结 三、B树插入实现 四、B树 1.B树概念 2.B树的特性 五、B树应用 1.索引 2.Mysql索引 3.InnoDB 一、B-树概念 1970 年&#xff0c; R.Bayer 和…

网络安全——防火墙

基本概念 防火墙是一个系统&#xff0c;通过过滤传输数据达到防止未经授权的网络传输侵入私有网络&#xff0c;阻止不必要流量的同时允许必要流量进入。防火墙旨在私有和共有网络间建立一道安全屏障&#xff0c;因为网上总有黑客和恶意攻击入侵私有网络来破坏&#xff0c;防火…

基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 多图推理

基于Qwen2-VL模型针对LaTeX OCR任务进行微调训练 - 多图推理 flyfish 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_LoRA配置如何写 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_单图推理 基于Qwen2-VL模型针对LaTeX_OCR任务进行微调训练_-_原模型_单图推理 基于Q…

Ant Design Pro实战--day01

下载nvm https://nvm.uihtm.com/nvm-1.1.12-setup.zip 下载node.js 16.16.0 //非此版本会报错 nvm install 16.16.0 安装Ant Design pro //安装脚手架 npm i ant-design/pro-cli -g //下载项目 pro create myapp //选择版本 simple 安装依赖 npm install 启动umi yarn add u…

一、为什么要学习麒麟?

麒麟认证&#xff1a;开启职业晋升之门 当前&#xff0c;就业难已经成为一个普遍的社会问题。许多大学生毕业后面临着找工作的困境&#xff0c;他们往往发现自己很难找到满意的职位。即使有幸找到了工作&#xff0c;也经常需要应对工作压力大、薪资低等问题。除此之外&#xff…

python如何减小维度

ravel&#xff08;&#xff09;&#xff1a;将多维数组拉平&#xff08;一维&#xff09;。 flatten&#xff08;&#xff09;&#xff1a;将多维数组拉平&#xff0c;并拷贝一份。 squeeze&#xff08;&#xff09;&#xff1a;除去多维数组中&#xff0c;维数为1的维度&…

未来已来:人工智能如何重塑我们的生活与工作

引言 未来的生活和工作场景正从想象走向现实。想象一下&#xff0c;一个清晨&#xff0c;语音助手已经为你安排好一天的任务&#xff0c;自动驾驶汽车准时送你上班&#xff0c;智能冰箱提醒你需要补充的食材。曾经只存在于科幻小说中的场景&#xff0c;如今正在我们的身边实现。…

Adminer源码编译 精简语言中英文和基本使用方法

Adminer是一个小而强悍的基于web的数据库管理工具&#xff0c; 官方默认支持几十种语言&#xff0c;但是对于中国的用户而言只需要有中文和英文就够了&#xff0c;其他语言基本无用。这就需要我们下载Adminer源码自己编译 Adminer.php , 如下图所示 adminer 中英文语言精简版本…

字符编码讲解(C#)

在学习和编码的过程中&#xff0c;极容易遇到如下概念&#xff0c;他们有些是字符编码&#xff0c;有些是涉及的相关概念&#xff0c;接下来我将围绕下面的熟悉又陌生的概念做详细解释&#xff0c;并且梳理其之间的关系 UTF8&#xff0c; Unicode &#xff0c;ASCII&#xff0…

Mac备忘录表格中换行(`Option` + `Return`(回车键))

在Mac的ARM架构设备上&#xff0c;如果你使用的是Apple的原生“备忘录”应用来创建表格&#xff0c;换行操作可以通过以下步骤来实现&#xff1a; 在单元格中换行&#xff1a; 双击你想要编辑的单元格你可以输入文本&#xff0c;按Option&#xff08;⌥&#xff09; Enter来插…

Windows中将springboot项目运行到docker的容器中

0&#xff0c;先打包好项目&#xff0c;再启动docker 1&#xff0c;在Java项目根目录下创建一个名为Dockerfile的文件&#xff08;没有扩展名&#xff09;&#xff0c;并添加以下内容。 # 使用OpenJDK的基础镜像 FROM openjdk:8-jdk-alpine# 设置工作目录 WORKDIR /app# 将项…

HBU深度学习实验14.5-循环神经网络(1.5)

梯度爆炸实验 造成简单循环网络较难建模长程依赖问题的原因有两个&#xff1a;梯度爆炸和梯度消失。一般来讲&#xff0c;循环网络的梯度爆炸问题比较容易解决&#xff0c;一般通过权重衰减或梯度截断可以较好地来避免&#xff1b;对于梯度消失问题&#xff0c;更加有效的方式…

ZED相机应用

下载SDK wget https://stereolabs.sfo2.cdn.digitaloceanspaces.com/zedsdk/3.6/ZED_SDK_Ubuntu18_cuda11.5_v3.6.5.run 安装 ./ZED_SDK_Ubuntu18_cuda11.5_v3.6.5.run skip_python 测试 cd /usr/local/zed/tools ls ZED_Calibration ZED_Depth_Viewer ZED_Diagnostic ZED_E…

伟测科技再融资11.75亿:增收不增利,毛利率近年来持续下滑

《港湾商业观察》施子夫 王璐 12月9日&#xff0c;上海证券交易所上市审核委员会召开2024年第34次上市审核委员会审议会议&#xff0c;审议上海伟测半导体科技股份有限公司(再融资)&#xff08;以下简称&#xff0c;伟测科技&#xff1b;688372.SH&#xff09;事项。 今年8月…

Java爬虫设计:淘宝商品详情接口数据获取

1. 概述 淘宝商品详情接口&#xff08;如Taobao.item_get&#xff09;允许开发者通过编程方式&#xff0c;以JSON格式实时获取淘宝商品的详细信息&#xff0c;包括商品标题、价格、销量等。本文档将介绍如何设计一个Java爬虫来获取这些数据。 2. 准备工作 在开始之前&#x…

如何绕过IP禁令

网站、游戏和应用程序可以屏蔽特定IP地址&#xff0c;从而阻止使用该IP地址的任何人访问其服务。这称为IP禁令。管理员可以出于多种原因&#xff08;例如发出过多请求或可疑活动&#xff09;屏蔽IP地址。但是&#xff0c;这些禁令会使收集数据或访问在线内容变得更加困难。 一…

AI生成不了复杂前端页面?也许有解决方案了

在2024年&#xff0c;编程成为了人工智能领域最热门的赛道。AI编程技术正以惊人的速度进步&#xff0c;但在生成前端页面方面&#xff0c;AI的能力还是饱受质疑。自从ScriptEcho平台上线以来&#xff0c;我们收到了不少用户的反馈&#xff0c;他们表示&#xff1a;“生成的页面…