使用 DDPO 在 TRL 中微调 Stable Diffusion 模型

news2024/9/27 23:29:42

引言

扩散模型 (如 DALL-E 2、Stable Diffusion) 是一类文生图模型,在生成图像 (尤其是有照片级真实感的图像) 方面取得了广泛成功。然而,这些模型生成的图像可能并不总是符合人类偏好或人类意图。因此出现了对齐问题,即如何确保模型的输出与人类偏好 (如“质感”) 一致,或者与那种难以通过提示来表达的意图一致?这里就有强化学习的用武之地了。

在大语言模型 (LLM) 领域,强化学习 (RL) 已被证明是能让目标模型符合人类偏好的非常有效的工具。这是 ChatGPT 等系统卓越性能背后的主要秘诀之一。更准确地说,强化学习是人类反馈强化学习 (RLHF) 的关键要素,它使 ChatGPT 能像人类一样聊天。

在 Training Diffusion Models with Reinforcement Learning 一文中,Black 等人展示了如何利用 RL 来对扩散模型进行强化,他们通过名为去噪扩散策略优化 (Denoising Diffusion Policy Optimization,DDPO) 的方法针对模型的目标函数实施微调。

在本文中,我们讨论了 DDPO 的诞生、简要描述了其工作原理,并介绍了如何将 DDPO 加入 RLHF 工作流中以实现更符合人类审美的模型输出。然后,我们切换到实战,讨论如何使用 trl 库中新集成的 DDPOTrainer 将 DDPO 应用到模型中,并讨论我们在 Stable Diffusion 上运行 DDPO 的发现。

DDPO 的优势

DDPO 并非解决 如何使用 RL 微调扩散模型 这一问题的唯一有效答案。

在进一步深入讨论之前,我们强调一下在对 RL 解决方案进行横评时需要掌握的两个关键点:

  1. 计算效率是关键。数据分布越复杂,计算成本就越高。

  2. 近似法很好,但由于近似值不是真实值,因此相关的错误会累积。

在 DDPO 之前,奖励加权回归 (Reward-Weighted Regression,RWR) 是使用强化学习微调扩散模型的主要方法。RWR 重用了扩散模型的去噪损失函数、从模型本身采样得的训练数据以及取决于最终生成样本的奖励的逐样本损失权重。该算法忽略中间的去噪步骤/样本。虽然有效,但应该注意两件事:

  1. 通过对逐样本损失进行加权来进行优化,这是一个最大似然目标,因此这是一种近似优化。

  2. 加权后的损失甚至不是精确的最大似然目标,而是从重新加权的变分界中得出的近似值。

所以,根本上来讲,这是一个两阶近似法,其对性能和处理复杂目标的能力都有比较大的影响。

DDPO 始于此方法,但 DDPO 没有将去噪过程视为仅关注最终样本的单个步骤,而是将整个去噪过程构建为多步马尔可夫决策过程 (MDP),只是在最后收到奖励而已。这样做的好处除了可以使用固定的采样器之外,还为让代理策略成为各向同性高斯分布 (而不是任意复杂的分布) 铺平了道路。因此,该方法不使用最终样本的近似似然 (即 RWR 的做法),而是使用易于计算的每个去噪步骤的确切似然 ( )。

如果你有兴趣了解有关 DDPO 的更多详细信息,我们鼓励你阅读 原论文 及其 附带的博文。

DDPO 算法简述

考虑到我们用 MDP 对去噪过程进行建模以及其他因素,求解该优化问题的首选工具是策略梯度方法。特别是近端策略优化 (PPO)。整个 DDPO 算法与近端策略优化 (PPO) 几乎相同,仅对 PPO 的轨迹收集部分进行了比较大的修改。

下图总结了整个算法流程:

ed44347f9a844f55abb1d2351e8a231e.png
dppo rl 流图

DDPO 和 RLHF: 合力增强美观性

RLHF 的一般训练步骤如下:

  1. 有监督微调“基础”模型,以学习新数据的分布。

  2. 收集偏好数据并用它训练奖励模型。

  3. 使用奖励模型作为信号,通过强化学习对模型进行微调。

需要指出的是,在 RLHF 中偏好数据是获取人类反馈的主要来源。

DDPO 加进来后,整个工作流就变成了:

  1. 从预训练的扩散模型开始。

  2. 收集偏好数据并用它训练奖励模型。

  3. 使用奖励模型作为信号,通过 DDPO 微调模型

请注意,DDPO 工作流把原始 RLHF 工作流中的第 3 步省略了,这是因为经验表明 (后面你也会亲眼见证) 这是不需要的。

下面我们实战一下,训练一个扩散模型来输出更符合人类审美的图像,我们分以下几步来走:

  1. 从预训练的 Stable Diffusion (SD) 模型开始。

  2. 在 美学视觉分析 (Aesthetic Visual Analysis,AVA)  数据集上训练一个带有可训回归头的冻结 CLIP 模型,用于预测人们对输入图像的平均喜爱程度。

  3. 使用美学预测模型作为奖励信号,通过 DDPO 微调 SD 模型。

记住这些步骤,下面开始干活:

使用 DDPO 训练 Stable Diffusion

环境设置

首先,要成功使用 DDPO 训练模型,你至少需要一个英伟达 A100 GPU,低于此规格的 GPU 很容易遇到内存不足问题。

使用 pip 安装 trl

pip install trl[diffusers]

主库安装好后,再安装所需的训练过程跟踪和图像处理相关的依赖库。注意,安装完 wandb 后,请务必登录以将结果保存到个人帐户。

pip install wandb torchvision

注意: 如果不想用 wandb ,你也可以用 pip 安装 tensorboard

演练一遍

trl 库中负责 DDPO 训练的主要是 DDPOTrainerDDPOConfig 这两个类。有关 DDPOTrainerDDPOConfig 的更多信息,请参阅 相应文档。trl 代码库中有一个 示例训练脚本。它默认使用这两个类,并有一套默认的输入和参数用于微调 RunwayML 中的预训练 Stable Diffusion 模型。

此示例脚本使用 wandb 记录训练日志,并使用美学奖励模型,其权重是从公开的 Hugging Face 存储库读取的 (因此数据收集和美学奖励模型训练均已经帮你做完了)。默认提示数据是一系列动物名。

用户只需要一个命令行参数即可启动脚本。此外,用户需要有一个 Hugging Face 用户访问令牌,用于将微调后的模型上传到 Hugging Face Hub。

运行以下 bash 命令启动程序:

python stable_diffusion_tuning.py --hf_user_access_token <token>

下表列出了影响微调结果的关键超参数:

参数描述单 GPU 训练推荐值(迄今为止)
num_epochs训练 epoch200
train_batch_size训练 batch size3
sample_batch_size采样 batch size6
gradient_accumulation_steps梯度累积步数1
sample_num_steps采样步数50
sample_num_batches_per_epoch每个 epoch 的采样 batch 数4
per_prompt_stat_tracking是否跟踪每个提示的统计信息。如果为 False,将使用整个 batch 的平均值和标准差来计算优势,而不是对每个提示进行跟踪True
per_prompt_stat_tracking_buffer_size用于跟踪每个提示的统计数据的缓冲区大小32
mixed_precision混合精度训练True
train_learning_rate学习率3e-4

这个脚本仅仅是一个起点。你可以随意调整超参数,甚至彻底修改脚本以适应不同的目标函数。例如,可以集成一个测量 JPEG 压缩度的函数或 使用多模态模型评估视觉文本对齐度的函数 等。

经验与教训

  1. 尽管训练提示很少,但其结果似乎已经足够泛化。对于美学奖励函数而言,该方法已经得到了彻底的验证。

  2. 尝试通过增加训练提示数以及改变提示来进一步泛化美学奖励函数,似乎反而会减慢收敛速度,但对模型的泛化能力收效甚微。

  3. 虽然推荐使用久经考验 LoRA,但非 LoRA 也值得考虑,一个经验证据就是,非 LoRA 似乎确实比 LoRA 能产生相对更复杂的图像。但同时,非 LoRA 训练的收敛稳定性不太好,对超参选择的要求也高很多。

  4. 对于非 LoRA 的超参建议是: 将学习率设低点,经验值是大约 1e-5 ,同时将 mixed_ precision 设置为 None

结果

以下是提示 bearheavendune 微调前 (左) 、后 (右) 的输出 (每行都是一个提示的输出):

微调前微调后
0e41e5c7d35e0aeccf24b8bdba1a90bc.png225abcd69e5fc63a7df77de55efae224.png
e5e83a905422030a3de28613306d018d.pngaee771e788197666035e349c2492999d.png
414abb39837581960a34e68a2a3af9a9.png249a214e3f40b2ddfa5e39f4c4913fcb.png

限制

  1. 目前 trlDDPOTrainer 仅限于微调原始 SD 模型;

  2. 在我们的实验中,主要关注的是效果较好的 LoRA。我们也做了一些全模型训练的实验,其生成的质量会更好,但超参寻优更具挑战性。

总结

像 Stable Diffusion 这样的扩散模型,当使用 DDPO 进行微调时,可以显著提高图像的主观质感或其对应的指标,只要其可以表示成一个目标函数的形式。

DDPO 的计算效率及其不依赖近似优化的能力,在扩散模型微调方面远超之前的方法,因而成为微调扩散模型 (如 Stable Diffusion) 的有力候选。

trl 库的 DDPOTrainer 实现了 DDPO 以微调 SD 模型。

我们的实验表明 DDPO 对很多提示具有相当好的泛化能力,尽管进一步增加提示数以增强泛化似乎效果不大。为非 LoRA 微调找到正确超参的难度比较大,这也是我们得到的重要经验之一。

DDPO 是一种很有前途的技术,可以将扩散模型与任何奖励函数结合起来,我们希望通过其在 TRL 中的发布,社区可以更容易地使用它!

致谢

感谢 Chunte Lee 提供本博文的缩略图。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/trl-ddpo

原文作者: Luke Meyers,Sayak Paul,Kashif Rasul,Leandro von Werra

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1130490.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

重装win11,个人记录详细步骤-干货

重装win11&#xff0c;个人记录详细步骤-干货 下载镜像-windows官网 https://www.microsoft.com/zh-cn/software-download/windows11%20 安装的选这个就行 虽然他这里写的是家庭版&#xff0c;进去里面就可以选择其他版本 重装win11有个前提 系统最低要求 本文列出了 Windo…

Databend 开源周报第 116 期

Databend 是一款现代云数仓。专为弹性和高效设计&#xff0c;为您的大规模分析需求保驾护航。自由且开源。即刻体验云服务&#xff1a;https://app.databend.cn 。 Whats On In Databend 探索 Databend 本周新进展&#xff0c;遇到更贴近你心意的 Databend 。 特性预览&#…

AR道具贴纸SDK,创新技术解决方案

增强现实&#xff08;AR&#xff09;技术已经成为企业提升用户体验&#xff0c;增强品牌影响力的重要工具。美摄AR道具贴纸SDK&#xff0c;作为一款领先的AR技术产品&#xff0c;致力于为企业提供一站式的技术解决方案&#xff0c;帮助企业轻松实现AR应用的快速开发和部署。 一…

使用Vscode创建一个C_Hello程序

Vscode用来学习C语言语法确实很方便。问题是安装好了&#xff0c;不会用&#xff0c;或编译失败&#xff0c;也是常有的事情&#xff0c;其中一个原因就是不会创建工作区。下面介绍使用Vscode创建一个C语言工作区。有时候看着很简单&#xff0c;时间久了&#xff0c;我竟然忘记…

Element UI + Vue 新增和编辑共用表单校验无法清除问题(已解决)

问题描述 在新增和编辑过程中大部分情况下 两个表单是一致的&#xff0c;而且编辑也有回显需要&#xff0c;所有绝大多数情况下 都是一个表单两个用处&#xff0c;但是随之而来出现了一个无法清除校验的问题&#xff0c;在先点击编辑后再点击新增会出现校验红字&#xff1a; …

centos搭建elastic集群

1、环境可以在同一台集群上搭建elastic&#xff0c;也可以在三台机器上搭建&#xff0c;这次演示的是在同一台机器搭建机器。 2、下载elastic &#xff1a;https://www.elastic.co/cn/downloads/past-releases#elasticsearch 2、​​​​​​ tar -zxvf elasticsearch-xxx-版…

ubuntu18.04双系统安装(2023最新最详细)以及解决重启后发现进不了Ubuntu问题

目录 一.简介 二.安装教程 1.首先确定了电脑的引导格式是UEFIGPT还是BIOSMBR 2. 使用Windows磁盘管理划分足够的磁盘空间 3. 开始安装 三.重启后发现自动进入WIN10系统了&#xff0c;进不了Ubuntu&#xff1f; 一.简介 Linux是一种自由和开放源代码的操作系统内核&#x…

【嵌入式项目应用】__JSON数据格式(无脑操作移植使用)

目录 一、什么是JSON 二、JSON是什么样子的呢&#xff1f; 三、常用类型 四、和XML的比较 五、CJSON库&#xff1f; 六、CJSON库组包与解析示例 1. 确定协议数据 2. 组JSON数据包实例 操作实例&#xff1a; 完整代码&#xff1a; 3. 解析JSON数据包示例 操作实例 完…

【QT】其他常用控件1

新建项目 scrollArea 滚动 toolBox 插入 tabWidget stackedWidget 切换 索引是0 运行后&#xff0c;没有切换按钮&#xff0c;结合pushbutton&#xff0c;加两个Button 代码 #include "widget.h" #include "ui_widget.h"Widget::Widget(QWidget *parent)…

Deno 的配置文件、框架,标准库

目录 1、配置文件 imports 和scopes tasks lint fmt lock nodeModulesDir npmRegistry compilerOptions 一个全的示例 2、Web框架 2.1 Deno 原生框架 Fresh Aleph Ultra Lume Oak 3、标准库 3.1 版本和稳定性 1、配置文件 Deno支持一个配置文件&#xff0c…

Unity 中使用波浪动画创建 UI 图像

如何使用 只需将此组件添加到画布中的空对象即可。强烈建议您将此对象放入其自己的画布/嵌套画布中&#xff0c;因为它会弄脏每一帧的画布并导致重新生成整个网格。 注意&#xff1a;不支持切片图像。 using System.Collections.Generic; using UnityEngine; using UnityEng…

算法通关村第十一关白银挑战——位运算符的高频算法题

大家好&#xff0c;我是怒码少年小码。 今天讲讲几个位运算的经典算法。 位移的妙用 1. 位1的个数 LeetCode 191&#xff1a;编写一个函数&#xff0c;输入是一个无符号整数&#xff08;以二进制串的形式&#xff09;&#xff0c;返回其二进制表达式中数字位数为 ‘1’ 的个…

html 常见兼容性问题

目录 前言: 用法: 代码: 1. 盒模型差异: 2. 表格布局问题: 3. 浏览器前缀问题: 4. 字体渲染问题: 理解: 讨论: 前言: 在Web开发中&#xff0c;兼容性问题是常见的挑战之一。不同的浏览器和设备可能以不同的方式解释和呈现HTML&#xff0c;导致网页在某些环境下出现问题…

fastjson对象序列化的问题

今天偶然遇到一个fastjson将字符串反序列化为一个对象的时候的问题&#xff0c;就是简单的通过com.alibaba.fastjson.JSON将对象转为字符串&#xff0c;然后再从字符串转换为原类型的对象。 涉及的代码也非常简单 package cn.edu.sgu.www.mhxysy.service.role.impl;import cn…

单片机设计基于STM32的空气净化器设计

**单片机设计介绍&#xff0c;1615[毕设课设]基于STM32的空气净化器设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图pcb设计图 五、 程序六、 文章目录 一 概要 此设计资料主要包含原理图、PCB、源程序、元器件清等资料&#xff0c; 二、功能设计 设计思路 …

uniapp--点击上传图片到oss再保存数据给后端接口

项目采用uniapp与uview2.0组件库 --1.0的也可以参考一下&#xff0c;大差不差 一、项目要求与样式图 点击上传n张图片到oss&#xff0c;然后点击提交给后端 二、思路 1、打开上传按钮&#xff0c;弹出框内出现上传图片和提交按钮 2、点击上传图片区域&#xff0c;打开本地图…

redis持久化之AOF(Append Only File)

1 : AOF 是什么 以日志的形式来记录每个写操作&#xff08;增量保存&#xff09;&#xff0c;将redis执行过的所有写指令记录下来&#xff08;读操作不记 录&#xff09;&#xff0c;只允追加文件但不可改写文件&#xff0c;redis启动之初会读取该文件重新构造数据&#xff0c;…

【Linux】IP协议

文章目录 &#x1f4d6; 前言1. 网络层2. IP协议格式3. IP报文分片和组装3.1 如何分片和组装&#xff1a;3.2 组装的衍生问题&#xff1a; 4. 网段划分&#xff08;重点&#xff09;4.1 子网掩码&#xff1a;4.2 IP地址的数量限制&#xff1a;4.3 私有IP地址和公网IP地址&#…

读高性能MySQL(第4版)笔记19_云端和合规性

1. 如何构建数据库环境 1.1. 托管MySQL 1.2. VM上构建 1.3. 天下没有免费的午餐&#xff0c;每一个选择都伴随着一系列的权衡 2. 托管MySQL 2.1. 服务商提供了一个可访问的数据库设置程序&#xff0c;而不需要用户深入了解MySQL的具体细节 2.2. 使用托管MySQL将缺乏很多的…

vueDay03——计算属性

一、一般场景 当我们需要对某个数据进行简单判断渲染的时候&#xff0c;我们通常会使用如下方法 <div>nginx当前状态&#xff1a;{{ openNginx true ? true : false}} </div> 但是这样就很影响观感&#xff0c;因为渲染出来的只有openNginx的值&#xff0c;而…