使用DPO微调大模型Qwen2详解

news2024/12/26 9:21:23

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。但传统的RLHF比较复杂,且还需要奖励模型,故DPO方法被提出,其将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。
且huggingface的trl库已经集成了dpo,使用起来非常方便。

本次以QWEN2(蹭热点),为例进行训练,分别介绍单轮对话的DPO多轮对话的DPO,对应的数据集分别如下(均在huggingface):

  • 单轮:lvwerra/stack-exchange-paired
  • 多轮:trl-internal-testing/hh-rlhf-helpful-base-trl-style

通过DPO微调模型大概可以简单的分为两个步骤:
1、将数据处理成所需格式。
2、使用DPOTrainer进行训练

两种形式的dpo代码已集成至github上的大模型训练框架,并做了详细的使用解释及代码位置说明,可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

项目包括一个每个人都可以以此为基础构建自己的开源大模型训练框架流程、支持主流模型使用deepspeed进行Lora、Qlora、DPO等训练、主流模型的chat template模版、以及一些tricks的从零实现模块。欢迎大家star 共同学习!:

单轮对话构建DpoDataset

标准的DpoDataset数据集,最终的数据集对象应包含这3个条目。条目应命名为:

  • prompt
  • chosen
  • rejected

官方示例

单轮官方示例如下:

dpo_dataset_dict = {
    "prompt": [
        "hello",
        "how are you",
        "What is your name?",
        "What is your name?",
        "Which is the best programming language?",
        "Which is the best programming language?",
        "Which is the best programming language?",
    ],
    "chosen": [
        "hi nice to meet you",
        "I am fine",
        "My name is Mary",
        "My name is Mary",
        "Python",
        "Python",
        "Java",
    ],
    "rejected": [
        "leave me alone",
        "I am not fine",
        "Whats it to you?",
        "I dont have a name",
        "Javascript",
        "C++",
        "C++",
    ],
}

多轮示例为上述提到的数据集,大家可以大概看一下是长这个样子:
在这里插入图片描述

从头开始构建

比较简单的方式是套用官方给的示例,如下所示,只需要将数据集映射为上述我们提到的prompt、chosen、rejected格式,此时传递给DPOTrainer的数据是未编码之前的,DPOTrainer中会自动的给我们进行编码。注意下面并没有添加对应模型的chat template,根据不同模型的template可以在return_prompt_and_responses中自行添加即可。

def return_prompt_and_responses(samples) -> Dict[str, str, str]:
    return {
        "prompt": [
            "Question: " + question + "\n\nAnswer: "
            for question in samples["question"]
        ],
        "chosen": samples["response_j"], # rated better than k
        "rejected": samples["response_k"], # rated worse than j
    }

dataset = load_dataset(
    "lvwerra/stack-exchange-paired",
    split="train",
    data_dir="data/rl"
)
original_columns = dataset.column_names

dataset.map(
    return_prompt_and_responses,
    batched=True,
    remove_columns=original_columns
)


dpo_trainer = DPOTrainer(
    model, # 经 SFT 的基础模型
    model_ref, # 一般为经 SFT 的基础模型的一个拷贝
    beta=0.1, # DPO 的温度超参
    train_dataset=dataset, # 上文准备好的数据集
    tokenizer=tokenizer, # 分词器
    args=training_args, # 训练参数,如: batch size, 学习率等
)

为了便于我们理解数据处理细节及进行一些魔改操作,我们可以从头自己构建一个DpoDataset。
首先,深入DPOTrainer源码可以看到其数据处理操作主要是在tokenize_row函数,如下所示,
在这里插入图片描述
最终返回的是一个batch字典字段,代码部分如下所示:
在这里插入图片描述
在这里插入图片描述
最终返回的字段为:

dict(
            prompt_input_ids,
            prompt_attention_mask,
            chosen_input_ids,
            chosen_attention_mask,
            chosen_labels,
            rejected_input_ids,
            rejected_attention_mask,
            rejected_labels,
        )

主要的__getitem__代码如下所示:

    def __getitem__(self, item):
        data = self.data_list[item]
        data = json.loads(data)  # 将json格式转换为python字典
        prompt =  data['prompt']
        chosen = data['chosen']
        rejected = data['rejected']
        # 对prompt进行编码
        prompt = self.user_format.format(content=prompt, stop_token=self.tokenizer.eos_token)
        if self.system_format is not None:
            system = self.system
            if system is not None:
                system_text = self.system_format.format(content=system)
                input_ids = self.tokenizer.encode(system_text, add_special_tokens=False)
                prompt_input_ids = input_ids + self.tokenizer.encode(prompt, add_special_tokens=False)
        else:
            prompt_input_ids = self.tokenizer.encode(prompt, add_special_tokens=False)



        # 进行回答的input id编码
        chosen = self.assistant_format.format(content=chosen, stop_token=self.tokenizer.eos_token)
        rejected = self.assistant_format.format(content=rejected, stop_token=self.tokenizer.eos_token)

        chosen_input_ids = self.tokenizer.encode(chosen, add_special_tokens=False)
        rejected_input_ids = self.tokenizer.encode(rejected, add_special_tokens=False)

        # 对最大长度进行截断
        longer_response_length = max(len(chosen_input_ids), len(rejected_input_ids))
        # keep end 对prompt截断
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            max_prompt_length = max(self.max_prompt_length, self.max_seq_length - longer_response_length)
            prompt_input_ids = prompt_input_ids[-max_prompt_length:]
        # 如果还不符合则回答截断
        if len(prompt_input_ids) + longer_response_length > self.max_seq_length:
            chosen_input_ids = chosen_input_ids[: self.max_seq_length - len(prompt_input_ids)]
            rejected_input_ids = rejected_input_ids[: self.max_seq_length - len(prompt_input_ids)]

        chosen_labels = [-100] * len(prompt_input_ids) + chosen_input_ids
        chosen_input_ids = prompt_input_ids + chosen_input_ids
        rejected_labels = [-100] * len(prompt_input_ids) + rejected_input_ids
        rejected_input_ids = prompt_input_ids + rejected_input_ids
        assert len(chosen_labels) == len(chosen_input_ids)
        assert len(rejected_labels) == len(rejected_input_ids)

        inputs = dict(
            prompt_input_ids=prompt_input_ids,
            prompt_attention_mask=[1] * len(prompt_input_ids),
            chosen_input_ids=chosen_input_ids,
            chosen_attention_mask=[1] * len(chosen_input_ids),
            chosen_labels=chosen_labels,
            rejected_input_ids=rejected_input_ids,
            rejected_attention_mask=[1] * len(rejected_input_ids),
            rejected_labels=rejected_labels,
        )
        return inputs

适配DPOTrainer

构建完dataset后要适配DPOTrainer,可以看到其需要使用dataset进行一个map操作,这也就是DPOTrainer自动给我们处理数据的入口。
在这里插入图片描述
在我们自建的Dataset类中添加一个map函数映射会self即可:

    def map(self, func, **kwargs):
        return self

多轮对话构建DpoDataset

多轮对话构建我们这里就不自己去写了,直接采用DPOTrainer中自带的数据处理即可。
部分代码如下所示:

        if tokenizer.chat_template is None:
            tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
        train_dataset = load_dataset(data_files=args.train_data_path, path='json')

        def process(row):
            row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
            row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
            return row

        train_dataset = train_dataset.map(process)
        train_dataset = train_dataset['train']
        return train_dataset

完整代码集成至github项目中,具体可参见:

开始Qwen2-8B 多轮和单轮DPO训练

使用DPOTrainer即可开始训练

trainer = DPOTrainer(
            model,
            ref_model=None,
            args=train_args,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
            peft_config=peft_config
        )
dpo_trainer.train()
dpo_trainer.save_model()

总结

两种形式的dpo代码已集成至github上的大模型训练框架,并做了详细的使用解释及代码位置说明,可见:https://github.com/mst272/LLM-Dojo/tree/main/train_args/dpo

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1808213.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【教学类-64-02】20240610色块眼力挑战(二)-2-25宫格色差10-100(10倍)(星火讯飞)

背景需求 以下的色块眼里挑战需要人工筛选图片,非常繁琐。 【教学类-64-01】20240607色块眼力挑战(一)-0-255随机底色-CSDN博客文章浏览阅读446次,点赞12次,收藏5次。【教学类-64-01】20240607色块眼力挑战&#xff…

web入门(1)---6.10

总结: 多做一点NSSCTF的新手赛,了解基本题型,然后打牢基础知识 谢队讲解 攻防世界 Web入门题 讲解_哔哩哔哩_bilibili 题目来源:攻防世界新手区 1.view_source 查看源代码 2.get_post 收获: get方法是直接在url…

攻防世界---misc---BotW-

1、下载附件是一张图片 2、查看图片属性,用winhex分析,没有发现奇怪的地方,用binwalk,接着使用foremost 3、得到两张图片,一张是原图,一张是特殊的字符 4、经过查阅资料得知,这是希卡文字&#…

数据中心基础设施智能运维

数据中心基础设施智能运维 随着科技的飞速发展,数据中心作为信息社会的核心基础设施,扮演着越来越重要的角色。然而,传统的运维模式由于对人力资源的高度依赖,已无法满足现代数据中心对高效、安全和可持续运维的要求。华为的《数…

IO流(转换流)

InputStreamReader(字符输入转换流 ) 解决不同编码时,字符流读取文本内容乱码的问题 public static void main(String[] args) {try (//1.得到文件的原始字节流(GBK的字节流形式)FileInputStream is new FileInputStream("src/666.tx…

Objective-C的初始化方法中,应该如何读写属性

除非有明确的原因需要使用setter, getter, 否则总是应该直接访问, 也就是直接使用实例变量(也称为 iVar)来读写数据 理由: 避免子类覆盖setter方法的影响:若在初始化方法中使用setter方法, 使用此方法实例化子类, 可能会调用子类…

23.汽水兑奖

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/106 题目描…

【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道

【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道 大家好 我是寸铁👊 总结了一篇【Golang】Map 稳定有序遍历的实现与探索:保序遍历之道✨ 喜欢的小伙伴可以点点关注 💝 前言🍎 在计算机科学中,数据结…

从零开始搭建Electron项目之运行例程

最好的学习方式就是:给一段能够运行的代码示例。 本文给出了例程资源,以及运行的步骤。 在国内开发electron有一点特别不好,就是如果不爬梯子,下载依赖容易出错。 一、例程资源 到如下路径下载例程到本地。 GitCode - 全球开发者…

新技术前沿-2023-大模型的本质

大模型时代需要什么样的人才? 1 大模型的本质 特斯拉前AI总监Andrej Karpathy的新教程,涵盖模型推理、训练、微调和新兴大模型操作系统以及安全挑战。 1.1 大模型本质就是两个文件 首先,大模型是什么? 大模型本质就是两个文件…

转型AI产品经理(7):“格式塔原则”如何应用在Chatbot产品中

格式塔原则,又称为完形原则,它是一组关于人类如何感知视觉元素的心理学理论,这些原则说明了大脑如何将分散的视觉元素整合为有意义的整体,即使这些元素本身可能是分离的,帮助我们理解人们如何组织和解释复杂的视觉信息…

C++网络编程基础

文章目录 协议局域网通信IP 地址网络通信的本质tcp 和 udp 协议网络字节序网络主机数据转化接口 协议 协议:收到数据后,多出来的那一部分,也叫一种 “约定”,一整套的自硬件到软件,都有协议,需要有人定制&a…

KUKA机器人KRC5控制柜面板LED显示

对于KUKA机器人新系列控制柜KRC5控制柜来说,其控制柜面板LED布局如下图: 其中①②③④分别为: 1、机器人控制柜处于不同状态时,LED显示如下: 2、机器人控制柜正在运行时: 3、机器人控制柜运行时出现的故障…

金融数据中心能力建设指引

金融数据中心能力建设指引 金融数据中心能力建设指引旨在通过高标准的基础设施建设、完善的数据管理、强大的信息安全防护和业务连续性规划,确保数据中心具备高效、安全、可靠的运行能力,支持金融业务的稳定发展。该指引强调技术创新、标准化管理、人才…

迅为RK3562开发板ARM四核A53核心板瑞芯微国产人工智能Linux安卓

iTOP-3562开发板采用瑞芯微RK3562处理器,内部集成了四核A53Mali G52架构,主频2GHZ,内置1TOPSNPU算力,RK809动态调频。支持OpenGLES1.1/2.0/3.2、0penCL2.0、Vulkan 1.1内嵌高性能2D加速硬件。 内置独立NPU, 算力达 1TOPS,可用于轻…

搭建RocketMQ主从异步集群

搭建RocketMQ主从异步集群 1、RocketMQ集群模式 为了追求更好的性能,RocketMQ的最佳实践方式都是在集群模式下完成的。RocketMQ官方提供了三种集群搭建方式: 2主2从异步通信方式:使用异步方式进行主从之间的数据复制。吞吐量大,…

不同数据库背后的数据存储方案

在大数据和AI时代,数据库成为各类应用不可或缺的重要组成部分。而数据库中的数据依赖存储引擎进行管理,包括数据的存储、查询、更新和删除等。因此,在设计系统时,选择正确的数据库存储引擎方案变得尤为重要。这篇文章将以关系型、…

Switch 之 H3C S5500

System # system view <H3C> system‐view [H3C] quit <H3C># display version [H3C]display version H3C Comware Software, Version 7.1.045, Release 3116# configuration save <H3C> save <H3C> display current‐configuration # factory reset …

程序员副业大揭秘:如何独立开发产品,月入过万,实现工作自由

本文首发于公众号“AntDream”&#xff0c;欢迎微信搜索“AntDream”或扫描文章底部二维码关注&#xff0c;和我一起每天进步一点点 咱们程序员这一行&#xff0c;在外人眼里一直都是高收入的代名词&#xff0c;每每提及&#xff0c;都有羡慕的眼光。然后只有我们自己知道&…