使用KTO进行更好、更便宜、更快速的LLM对齐

news2024/11/14 21:22:05

KTO全称为Kahneman-Tversky Optimisation,这种对齐方法使在我们的数据上对大型语言模型(LLM)进行对齐变得前所未有地容易和便宜,而且不会损害性能。大型语言模型的成功在很大程度上得益于与人类反馈的对齐。如果ChatGPT曾经拒绝回答您的问题,很可能是因为它被训练为避免说出有争议的内容。然而,对于公司来说,对他们自己的LLM进行对齐一直是困难的。下面我们简单介绍下KTO方法,这种方法可以提高LLM的整体性能和质量,同时节省成本。

大规模对齐LLM

LLM对齐对于优化性能至关重要,但一直以来都很困难,因为:

标准的对齐方法,即带有人类反馈的强化学习(RLHF),有许多复杂的部分,很多开源项目已经努力使其工作。
对齐方法期望以偏好的形式获得反馈(例如,对于输入X,输出A比B更好)。利用人类注释工作的这种反馈很快就会变得非常昂贵,并且也可能导致数据冲突。人类自己的评分主观性强,因此需要大量努力来定义输出A如何定量优于输出B。

这两个因素意味着,对于大多数组织来说,自己的LLM大规模对齐历史上是不可能的。但这一差距正在缩小。斯坦福研究人员最近用一种称为直接偏好优化(DPO)的技术解决了第一个问题,这在数学上等同于RLHF,同时更加简单,使得对齐对于开源努力变得可行。

剩下的瓶颈是数据。只有少数几个包含文本上人类偏好的公共数据集,而且它们是通用的。例如,如果你想要人类对两种LLM输出更准确地判断意大利经济状况的反馈,你需要咨询专业人士。但获取此类数据很昂贵,无论您是直接付费还是要求员工花费宝贵的时间提供反馈。

克服数据瓶颈

在Contextual AI项目中,作者已经找到了克服这一数据瓶颈的方法。通过研究经济学家Kahneman和Tversky关于人类决策的工作,设计了一种不需要像“输入X的输出A胜过输出B”这样的偏好的对齐方法。相反,对于输入X,我们只需要知道输出Y是可取的还是不可取的。这种单一反馈是丰富的:每个公司都有可以标记为可取(例如,销售成功)或不可取(例如,没有销售)的客户互动数据。

通过在三个公共数据集(Anthropic HH、Stanford Human Preferences 和 Open Assistant)的组合上对齐从 1B 到 30B 的模型,将 KTO 与现有方法进行比较。然后,遵循现在的标准做法,使用 GPT-4 将对齐模型的各代与数据集中提供的人类首选基线进行比较。


与其他对齐模型相比,Kahneman-Tversky 优化在性能上大幅提升,无论是标准微调还是 DPO

更多原理和细节请参考:

https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf

代码

https://github.com/ContextualAI/HALOs

class SimpleKTOTrainer(UnpairedPreferenceTrainer):
   """A simple version of KTO meant to introduce you to the HALOs repo."""
   def loss(self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
   """Compute the Kahneman-Tversky loss for a batch of policy and reference model log probabilities. 
   For each batch of n/2 chosen examples and n/2 rejected examples (belonging to n different inputs), calculate the loss as follows.

   If generation y ~ p_chosen, where x' ~ are the examples with rejected generations, we have the 'chosen' loss:
       L(x, y) := 1 - sigmoid(beta * (log p_policy(y|x) - log p_reference(y|x) - KL(p_policy(y_rejected|x') || p_reference(y_rejected|x')))
   If generation y ~ p_rejected, , where x' ~ are the examples with chosen generations, we have the 'rejected' loss:
       L(x, y) := 1 - sigmoid(beta * KL(p_policy(y_chosen|x') || p_reference(y_chosen|x')) - [log p_policy(y|x) - log p_reference(y|x)])
   """
   chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
   rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)

   chosen_logratios = (policy_chosen_logps - reference_chosen_logps)
   rejected_logratios = (policy_rejected_logps - reference_rejected_logps)

   losses = torch.cat((1 - F.sigmoid(self.config.loss.beta * (chosen_logratios - rejected_KL)), 1 - F.sigmoid(self.config.loss.beta * (chosen_KL - rejected_logratios))), 0)

   chosen_rewards = self.config.loss.beta * (policy_chosen_logps - reference_chosen_logps).detach()
   rejected_rewards = self.config.loss.beta * (policy_rejected_logps - reference_rejected_logps).detach()

   return losses, chosen_rewards, rejected_rewards

另外在最近有人对IPO/DPO/KTO性能做了对比:

对于 Zephyr 模型,我们观察到最好的性能是以最低的成本实现的。这在所有三种测试的算法中都是一致的,社区的一个有趣的后续实验是在 0.0-0.2 范围内进行细粒度扫描。虽然 DPO 可以获得最高的 MT Bench 分数,但我们发现 KTO(配对)在除一种设置之外的所有设置中都取得了更好的结果。IPO虽然有更强的理论保证,但除了一种情况外,在所有情况下似乎都比基本模型更糟糕。

https://huggingface.co/blog/pref-tuning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1397607.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

禅道使用教程

禅道的使用 一.禅道的使用1.添加部门和批量添加用户2.以产品经理的身份登录进行使用和操作2.1创建产品2.2创建模块2.3添加产品计划2.4添加产品需求2.5创建项目2.6设置团队 3.项目经理使用禅道3.1关联需求3.2批量分解,给人员分配任务3.3假设项目完成开发,项目经理创建版本 4.测试…

《数字图像处理-OpenCV/Python》连载:傅里叶变换与频域滤波

《数字图像处理-OpenCV/Python》连载:空间滤波之高斯滤波器 本书京东 优惠购书链接 https://item.jd.com/14098452.html 本书CSDN 独家连载专栏 https://blog.csdn.net/youcans/category_12418787.html 第 11 章 傅里叶变换与频域滤波 空间图像滤波是图像与滤波器核…

根据基因名批量查找它的Uniprot编号

背景: 前几天老师交给我一个任务,给我一个基因列表,让我查找它们所编码的蛋白质的蛋白质序列。我上了一下uniprot数据库,发现这个任务可以分成两步: 找到这个基因在Uniprot数据库中所对应的蛋白质编码根据蛋白质编码…

街机模拟游戏逆向工程(HACKROM)教程:[12]68K汇编-程序流控制

在之前的文章中,我们测试过一些简短的一小段程序,这些程序都有一个共同的程序运行流程,就是一句一句地向下执行,比如: movea.l #$325, a0 * ↓move.b #$01, (a0) * ↓move.b #$02, $01(a…

【软件测试常见Bug清单】

软件测试中,bug的类型有很多种,比如:代码错误、界面优化、设计缺陷、需求补充和用户体验等; 一般情况下,需求补充和设计缺陷比较好区分,但是代码错误、界面优化和用户体验区分不是很明显; 下面…

主动轮廓——计算机视觉中的图像分割方法

​ 一、说明 简单来说,计算机视觉就是为计算机提供类似人类的视觉。作为人类,我们很容易识别任何物体。我们可以很容易地识别山丘、树木、土地、动物等,但计算机没有眼睛,也没有大脑,因此它很难识别任何图像。计算机只…

PostgreSQL 的对象层次

所有的数据库离开数据量来谈性能都是耍流氓。 就你那几万条的数据库,用啥都行,典型的就是怎么方便怎么来。 不过 PostgreSQL 上手确实比 MySQL 概念更多。 PostgreSQL 比 MySQL 多了一层。 PostgreSQL 是从PostgreSQL 是从 Database,到 S…

RK3568平台 LT9211转接芯片调试笔记

一.简介 龙讯LT9211是一个高性能转换器,支持MIPI LVDS TTL两两之间转换。 使用此款芯片大部分为MIPI与LVDS进行互相转换。 下图为LT9211的典型应用图: 二.LT9211原理图 三.车载显示器和摄像头系统 四.调试LT9211输出 MIPI数据 (1&#xf…

【Linux install】Ubuntu和win双系统安装及可能遇到的所有问题

文章目录 1.前期准备1.1关闭快速启动和安全启动1.1.1 shell命令行进入BIOS1.1.2 windows设置中高级启动1.1.3 在开机时狂按某个键进入BIOS1.1.4 关闭Fast boot和Secure boot 1.2 制作启动盘1.3 划分磁盘空间1.3.1 查看目前的虚拟内存大小 2.开始安装2.1 使用启动盘启动2.1.1 法…

洛谷NOIP2002 普及组 选数 +NOIP1999普及组 回文数

两道日常的练习题&#xff0c;废话不多说&#xff0c;直接上题上代码&#xff1a; 这道题目的难点在于怎样去根据一个不同的k值&#xff0c;通过代码来实现将所有符合题目要求的数字相加并且不重复的功能。下面请看代码&#xff0c;会有详细的讲解&#xff1a; #include<io…

又聊代码重构

今天有幸和一位朋友聊了一下代码的重构。回来之后感觉不够尽兴&#xff0c;所以决定再来输出一篇。 代码来至于今天下午的提交。 重构是对代码的觉知和业务的逻辑的进一步归纳总结 只有开发者对代码的不断觉察和理解&#xff0c;才会产生重构代码的念头。因此&#xff0c;驱动…

GO 中如何防止 goroutine 泄露

文章目录 概述如何监控泄露一个简单的例子泄露情况分类chanel 引起的泄露发送不接收接收不发送nil channel真实的场景 传统同步机制MutexWaitGroup 总结参考资料 今天来简单谈谈&#xff0c;Go 如何防止 goroutine 泄露。 概述 Go 的并发模型与其他语言不同&#xff0c;虽说它…

小白水平理解面试经典题目LeetCode 121 Best Time to Buy and Sell Stock

121 Best Time to Buy and Sell Stock (买卖股票的最佳时机) 你好&#xff0c;2024年的第一个月&#xff0c;又是秋风萧瑟天气凉&#xff0c;草木摇落露为霜。.。。在这个特殊的时代&#xff0c;作为我们普通的一个打工人&#xff0c;我们用这道题&#xff0c;开启对这个不符合…

菜鸟关于做前、后端的整理(html、js),以及疑问

涉及到后端的接口py&#xff0c;前端html和js 这三部分就按照如下格式放到server项目主路径下&#xff0c;这样后端机可以作为一个前端server main.pystaticmain.jsmain.htmlhtml 首先是html要设定网页的显示 <!DOCTYPE html> <html> <head><title>…

小米,我请你不要将卖手机那套话术带进汽车圈

文 | AUTO芯球 ​作者 | 雷歌 当你们用卖手机时那一套营销话术玩汽车&#xff0c;整个汽车圈都被你们逗乐了。 这不&#xff0c;在被用户问到“贵公司汽车有哪些驾驶模式”时&#xff0c;你们声称自己有16.8亿种驾驶模式。 你小米说这话的逻辑&#xff0c;不就是将加速、转…

网络安全最大的威胁:洞察数字时代的风险之巅

在数字化时代&#xff0c;网络安全问题越发突显&#xff0c;企业和个人都面临着来自多方面的威胁。究竟网络安全领域的最大威胁是什么&#xff1f;本文将深入探讨这一问题&#xff0c;揭示数字空间中最为严重的威胁。 1. 恶意软件的肆虐&#xff1a; 恶意软件一直是网络安全的…

29、WEB攻防——通用漏洞SQL注入增删改查盲注延迟布尔报错

文章目录 盲注增删改查 盲注 概念&#xff1a;在注入过程中&#xff0c;获取的数据不能回显至前端页面&#xff0c;此时我们需要利用一些方法进行判断或尝试&#xff0c;这个过程被称为盲注。 解决&#xff1a;常规的联合查询注入不行的情况。 分类&#xff1a; 基于布尔的SQ…

Leetcode2957. 消除相邻近似相等字符

Every day a Leetcode 题目来源&#xff1a;2957. 消除相邻近似相等字符 解法1&#xff1a;遍历 分类讨论 遍历字符串 word&#xff0c;比较相邻的 3 个元素 word[i - 1]、word[i] 和 word[i 1]&#xff0c;记 left_distance abs(mid - left)&#xff0c;right_distance…

739.每日温度 496.下一个更大元素 I

739.每日温度 496.下一个更大元素 I 739.每日温度 力扣题目链接(opens new window) 请根据每日 气温 列表&#xff0c;重新生成一个列表。对应位置的输出为&#xff1a;要想观测到更高的气温&#xff0c;至少需要等待的天数。如果气温在这之后都不会升高&#xff0c;请在该位…

(初研) Sentence-embedding fine-tune notebook

由于工作需要&#xff0c;需要对embedding模型进行微调&#xff0c;我调用了几种方案&#xff0c;都比较繁琐。先记录一个相对简单的方案。以下内容并不一定正确&#xff0c;请刷到的大佬给予指正&#xff0c;不胜感激&#xff01;&#xff01;&#xff01; 一.对BGE模型&…