【大模型学习】多模态大模型进行偏好优化

news2024/9/21 0:49:17

一、简介

训练模型以理解并预测人类偏好是一项复杂的任务。传统方法如SFT(监督微调)通常需要较高的成本,因为这些算法需要对数据进行特定标签的标注。偏好优化(Preference Optimization)作为一种替代方案,可以简化这一过程并提供更准确的结果。通过对候选回答的对比和排序,而不是赋予固定的标签,偏好优化能够更高效地捕捉人类偏好的细微差别。

虽然偏好优化已经在大语言模型中广泛使用,但现在它也可以应用于视觉语言模型(VLM)。得益于TRL(Transformer Reinforcement Learning)的开发,现在我们可以使用TRL对VLM进行直接偏好优化(Direct Preference Optimization)。本文将介绍使用TRL和DPO对视觉语言模型进行训练的全过程。

二、偏好数据集

进行偏好优化,首先需要有一个能体现用户偏好的数据集。在双项选择的设定下,相应的数据一般包含一个提示词(Prompt)和两个候选回答,其中一个被标记为选中(chosen),另一个被标记为淘汰(rejected)。模型需要学习选择正确的回答,而不是被淘汰的回答。下图展示了一个例子:

❔ 问题: 有多少个家庭?

  • ❌ 被淘汰的回答: 图片没有提供关于家庭的信息。
  • ✅ 选中的回答: 图片显示了一个工会组织的表格,包含18000个家庭。

尽管选中的回答也不是完全正确(应该是18000000个家庭),但比被淘汰的回答更好。

本文将使用openbmb/RLAIF-V-Dataset作为示例数据集,该数据集包含超过83000条标注数据。可以通过以下代码查看数据集:

from datasets import load_dataset

dataset = load_dataset("openbmb/RLAIF-V-Dataset", split="train[:1%]")
sample = dataset[1]
sample["image"].show()
sample["question"]
'how many families?'
sample["rejected"]
'The image does not provide any information about families.'
sample["chosen"]
'The image shows a Union Organization table setup with 18,000 families.'

我们将要训练的 VLM 模型需要文本和图像同时作为输入,所以这里的第一步还是要对数据集格式进行改造。一条数据应该被结构化成能模拟人机对话的形式。用户提供一个提示语,其中包含一张图片和一个问题,然后模型需要能够给出一个回答。我们用以下代码实现格式转换:

from datasets import features
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b", do_image_splitting=False)

def format(example):
    # Prepare the input for the chat template
    prompt = [
        {
            "role": "user",
            "content": [{"type": "image"}, {"type": "text", "text": example["question"]}],
        },
    ]
    chosen = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example["chosen"]}],
        },
    ]
    rejected = [
        {
            "role": "assistant",
            "content": [{"type": "text", "text": example["rejected"]}],
        },
    ]
    # Apply the chat template
    prompt = processor.apply_chat_template(prompt, tokenize=False)
    chosen = processor.apply_chat_template(chosen, tokenize=False)
    rejected = processor.apply_chat_template(rejected, tokenize=False)
    # Resize the image to ensure it fits within the maximum allowable
    # size of the processor to prevent OOM errors.
    max_size = processor.image_processor.size["longest_edge"]
    example["image"].thumbnail((max_size, max_size))
    return {"images": [example["image"]], "prompt": prompt, "chosen": chosen, "rejected": rejected}

# Apply the formatting function to the dataset,
# remove columns to end up with only "images", "prompt", "chosen", "rejected" columns
dataset = dataset.map(format, remove_columns=dataset.column_names)

# Make sure that the images are decoded, it prevents from storing bytes.
# More info here https://github.com/huggingface/blog/pull/2148#discussion_r1667400478
f = dataset.features
f["images"] = features.Sequence(features.Image(decode=True)) # to avoid bytes
dataset = dataset.cast(f)

完成了格式转换,我们来看看第一条数据:

>>> dataset[1]
{'images': [<PIL.JpegImagePlugin.JpegImageFile image mode=L size=980x812 at 0x154505570>],
 'prompt': 'User:<image>how many families?<end_of_utterance>\n',
 'rejected': 'Assistant: The image does not provide any information about families.<end_of_utterance>\n',
 'chosen': 'Assistant: The image shows a Union Organization table setup with 18,000 families.<end_of_utterance>\n'}

三、训练

3.1 训练需要多大的 GPU 显存?

以微调1B的模型为例子,假设模型的的每个参数用32bit存储,32bit=4byte。

每个参数通常以浮点数形式存储。FP32(32位浮点数)每个参数占用4字节的存储空间,而BF16(16位浮点数)每个参数占用2字节的存储空间。

需要用到GPU的部分:模型权重(需要加载进去)、梯度(更新参数)、优化器(状态量,SGD和Adam占用的显存空间不一样)、激活值等等

1 Byte = 1 \times 10^{-9} GB

  • 模型权重1B = 1b x 4 byte = 4GB;
  • 梯度的显存需求与模型权重相同 4GB;
  • 以Adam优化器(LLM用的多)为例,Adam需要维护模型的参数、每个参数的动量和平方梯度信息,因此占用的显存大约是模型权重的3倍 [一阶动量估计(类似动量)、二阶动量估计(平方梯度)];

注意,优化器都是用FP32进行存储的,因为大量的小值累加(sum、mean)操作,如果用FP16进行会损失精度,太小的值用FP16会表示为0。

  • 激活值(中间结果),反向传播和前向传播会用到,这边只是简单起见,bs=1,和模型参数一样是4GB,实际上这个计算推导很复杂,后面有机会再写~,同时Transformer中激活值和序列长度以平方次数增长;
  • 输入数据:跟Batch size、样本I大小有关系,就是B x I x 4 字节,这边暂时忽略;
参数来源计算公式显存需求
要训练的模型8 \times 10^9 \times 432 GB
参考模型(这个任务额外要的,防止模型发生偏移,和要训练的模型一样大)8 \times 10^9 \times 432 GB
梯度8 \times 10^9 \times 432 GB
优化器状态量3 \times 8 \times 10^9 \times 472 GB
合计168 GB

可以使用量化、LoRA 等技术来大幅度地减少显存需求,让训练可以进行。

3.2 使用 bfloat16 和 LoRA 后的显存需求

参数来源计算公式显存需求
要训练的模型8 \mathrm{G} \times 216 GB
参考模型8 \mathrm{G} \times 216 GB
梯度55 \mathrm{M} \times 20.1 GB
优化器状态量3 \times 55 \mathrm{M} \times 20.3 GB
合计32.4 GB

四、微调Llava 1.5和PaliGemma等模型

TRL的DPO实现已支持Idefics2、Llava 1.5和PaliGemma,同时TRL也在努力支持更多的模型。最简单的调用方法是使用TRL提供的示例脚本。例如,如果你想微调PaliGemma,可以使用以下命令:

accelerate launch examples/scripts/dpo_visual.py \
    --dataset_name HuggingFaceH4/rlaif-v_formatted \
    --model_name_or_path google/paligemma-3b-pt-224 \
    --per_device_train_batch_size 2 \
    --gradient_accumulation_steps 32 \
    --dataset_num_proc 32 \
    --output_dir dpo_paligemma_rlaif-v \
    --bf16 \
    --torch_dtype bfloat16 \
    --gradient_checkpointing \
    --use_peft \
    --lora_target_modules=all-linear

五、可视化结果

下表展示了一些可视化的结果:

ImageQuestionIdefics2Idefics2+DPO
Are there two ships in this image?YesNo
Is the ground uneven in this image?NoYes
Is there one shovel in this image?YesNo

六、参考链接

[1] https://huggingface.co/docs/peft/en/index

[2] https://cloud.google.com/vertex-ai/generative-ai/docs/model-garden/lora-qlora?hl=zh-cn

[3] https://huggingface.co/blog/zh/dpo_vlm

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1994646.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

云计算任务调度优化matlab仿真,对比蚁群优化和蛙跳优化

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 4.1 ACO蚁群优化 4.2 蛙跳优化 5.完整程序 1.程序功能描述 云计算任务调度优化,优化目标位任务消耗时间&#xff0c;调度后的经济效益以及设备功耗&#xff0c;对比蚁群优化算法和蛙跳优化…

【IEEE独立出版 | EI稳定检索】第三届人工智能、物联网和云计算技术国际会议(AIoTC 2024)

【IEEE独立出版 | EI稳定检索】 第三届人工智能、物联网和云计算技术国际会议&#xff08;AIoTC 2024&#xff09; 2024 3rd International Conference on Artificial Intelligence, Internet of Things and Cloud Computing Technology 2024年9月13-15日 | 中国武汉 AIoTC …

mysql 日志爆满,删除日志文件,定时清理日志

今天发现网站不能正常访问&#xff0c;于是登陆服务器查找问题。 机智的我随手用命令&#xff1a;df -l 发现 硬盘爆满了&#xff0c;于是就知道问题所在了。 Filesystem 1K-blocks Used Available Use% Mounted on/dev/xvda1 20641404 16963004 16929876 10…

使用 Elastic 和 Mistral 构建多语言 RAG(二)

这篇文章是之前的文章 “使用 Elastic 和 Mistral 构建多语言 RAG&#xff08;一&#xff09;” 的续篇。在这篇文章中&#xff0c;我将展示如何在本地部署中完成在那篇文章中的实现。 注意&#xff1a;由于 semantic text 从 8.15 版本开始提供&#xff0c;你需要至少 8.15 及…

Go框架选战:Gin、Echo、Fiber的终极较量

Gin 优点: 高性能: 优化以处理高并发和低延迟请求。易于上手: 对于熟悉 Go 的开发者来说&#xff0c;API 设计直观&#xff0c;学习曲线低。社区支持强: 广泛使用&#xff0c;有大量第三方中间件和教程。 缺点: 相比于其他框架如 Echo&#xff0c;Gin缺乏内置的验证支持Gin…

万字长文揭秘高性能架构

从零开始学架构系列文章&#xff1a; 从零开始学架构——概念和基础 从零开始学架构——万字长文揭秘高性能架构 从零开始学架构——高可用架构 从零开始学架构——可扩展架构 高性能存储 关系数据库 互联网业务兴起之后&#xff0c;海量用户加上海量数据的特点&#xff0…

无人机之民用无人机用途分类篇

一、航拍无人机 用于航拍摄影和电影制作&#xff0c;提供空中视角的拍摄服务。可用于电影制作、广告拍摄、房地产销售等。 二、物流无人机 用于快递和货物运输&#xff0c;提高物流效率&#xff0c;可以到达传统配送方式难以覆盖的地区&#xff0c;在突发事件如自然灾害、疫…

keepalived工作原理和使用方式

keepalived是什么 keepalived是集群管理中保证集群高可用的一个服务软件&#xff0c;用来防止单点故障。 keepalived主要有三个模块 分别是core、check和vrrp。core模块为keepalived的核心&#xff0c;负责主进程的启动、维护以及全局配置文件的加载和解析。check负责健康检…

怎么根据企业特点提供个性化的六西格玛培训?

近年来&#xff0c;六西格玛作为一种强大的质量管理方法&#xff0c;以其数据驱动、流程优化和减少缺陷为核心&#xff0c;被众多企业视为提升竞争力的关键工具。然而&#xff0c;并非所有企业都能直接套用标准的六西格玛培训体系&#xff0c;因为每个企业的文化、行业特性、发…

顺序队列和链式队列的基本操作

顺序队列 函数说明&#x1f603;&#xff1a; InitStack( &s)&#xff1a;初始化栈 StackEmpty(s)&#xff1a;判断一个栈是否为空 Push(& s, x)&#xff1a;进栈 Pop(&s, &x)&#xff1a;出栈 GetTop(s,&x)&#xff1a;读栈顶元素 show(s)&#xff1a;读出…

Vue3项目框架搭建

前言 大多时候是在别人搭建好的项目上开发需求&#xff0c;突然要自己从新项目搭建开始&#xff0c;纯纯赶鸭子上架&#xff0c;参考一些项目&#xff0c;试着搭建的&#xff0c;记录一下历程&#xff0c;主要怕忘了。有些地方本该贴上代码截图更好&#xff0c;但是我此刻手头…

Vue引入使用iconfont字体图标

由于element-ui或element-plus提供的图标有时候并不能满足日常需求,所以这篇介绍一下前端引入阿里巴巴矢量图标库使用,不止是vue使用,不限于vue2、vue3,html或是其他框架也是同样的道理,只要引入都是同样可以使用的。 1. 首先进入阿里巴巴矢量图标库官网 官网:https://…

弱智吧:大模型变聪明,有我一份贡献【大模型VS弱智吧,谁聪明?谁弱智?】

「被门夹过的核桃&#xff0c;还能补脑吗&#xff1f;」 在中文网络上流传着这样一段话&#xff1a;弱智吧里没有弱智。 百度「弱智吧」是个神奇的地方&#xff0c;在这里人人都说自己是弱智&#xff0c;但大多聪明得有点过了头。最近几年&#xff0c;弱智吧的年度总结文章都可…

算法——决策树

简介&#xff1a;个人学习分享&#xff0c;如有错误&#xff0c;欢迎批评指正。 一、什么是决策树&#xff1f; 决策树&#xff08;decision tree&#xff09;&#xff1a;决策树是一种树形结构的监督学习算法&#xff0c;广泛应用于分类任务和回归任务中。它通过递归地将数据…

豆瓣的ip地址怎样修改:探索显示机制与实用操作

在数字化时代&#xff0c;网络空间成为了我们日常生活不可或缺的一部分。豆瓣&#xff0c;作为一个集书籍、电影、音乐评论及社交功能于一体的综合性平台&#xff0c;其用户遍布全球。然而&#xff0c;有时我们可能因为隐私保护、网络限制或特定需求而希望修改在豆瓣上显示的IP…

【STM32 FreeRTOS】任务

使用 RTOS 的实时应用程序可以被构建为一组独立的任务。每个任务在自己的上下文中执行&#xff0c;不依赖于系统内的其他任务或 RTOS 调度器本身。在任何时间点&#xff0c;应用程序中只能执行一个任务&#xff0c;实时 RTOS 调度器负责决定所要执行的任务。因此&#xff0c; R…

Figure 02 机器人发布:未来AI的巅峰还是泡沫中的救命稻草?

引言 近日&#xff0c;Figure AI 公司发布了其最新的机器人产品 Figure 02&#xff0c;引发了广泛关注。作为 Figure AI 的第二代人形机器人&#xff0c;Figure 02 的推出引发了关于它是否是“地表最强”机器人的讨论。同时&#xff0c;由于 OpenAI 的技术支持&#xff0c;这款…

Java Web —— 第三天(Ajax+组件)

Ajax 概念: Asynchronous JavaScript And XML&#xff0c;异步的JavaScript和XML。 作用: 数据交换:通过Aiax可以给服务器发送请求&#xff0c;并获服务器响应的数据 异步交互:可以在不重新加载整个页面的情况下&#xff0c;服务器交换数据并更新部分网页的技术&#xff0c…

Java开发笔记--通用基础数据校验的设计

最近在开发一个功能&#xff0c;对排水管网的基础数据(包括管井、管道、泵站&#xff0c;雨水口&#xff0c;雨水口线&#xff0c;泵站&#xff0c;污水处理厂&#xff0c;排口等)的导入进行校验。 以字段为纬度&#xff0c;考虑二个方面的校验&#xff1a;数据库唯一&#xf…

RHCA III之路---EX436-9

RHCA III之路---EX436-9 1. 题目2. 解题2.1 安装apache2.2 配置页面2.3 配置selinux和防火墙2.4 创建资源 3. 确认 1. 题目 2. 解题 考试时会给你个url,从url下载index.html并放入默认目录 2.1 安装apache 3个节点分别安装 yum install -y httpd2.2 配置页面 nodea上执行 …