对抗式生成模仿学习(GAIL)

news2024/10/7 17:34:35

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy

函数 GAIL_Training():
    初始化策略 π 的参数
    初始化判别器 D 的参数

    循环 直到收敛 或 达到最大迭代次数:
        # 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中
        生成 trajectories 使用 π 并存储在 D_policy 中

        # 判别器训练
        循环 discriminator_updates 次数:
            # 从策略缓冲区 D_policy 中采样数据
            采样 (s_policy, a_policy) 从 D_policy 中
            # 从专家示范数据 D_expert 中采样数据
            采样 (s_expert, a_expert) 从 D_expert 中

            # 更新判别器 D
            计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]
            使用梯度下降法更新判别器参数以最小化 L_D

        # 策略更新
        采样 (s, a, ...) 从 D_policy 中
        计算伪奖励 r = -log(1 - D(s, a))

        # 使用伪奖励 r 更新策略 π
        计算 L_π 使用 PPO 或 其他强化学习方法
        使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam

from .ppo import PPO
from gail_airl_ppo.network import GAILDiscrim


class GAIL(PPO):

    def __init__(self, buffer_exp, state_shape, action_shape, device, seed,
                 gamma=0.995, rollout_length=50000, mix_buffer=1,
                 batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,
                 units_actor=(64, 64), units_critic=(64, 64),
                 units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,
                 clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):
        super().__init__(
            state_shape, action_shape, device, seed, gamma, rollout_length,
            mix_buffer, lr_actor, lr_critic, units_actor, units_critic,
            epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm
        )

        # Expert's buffer.
        self.buffer_exp = buffer_exp

        # Discriminator.
        self.disc = GAILDiscrim(
            state_shape=state_shape,
            action_shape=action_shape,
            hidden_units=units_disc,
            hidden_activation=nn.Tanh()
        ).to(device)

        self.learning_steps_disc = 0
        self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)
        self.batch_size = batch_size
        self.epoch_disc = epoch_disc

    def update(self, writer):
        self.learning_steps += 1

        for _ in range(self.epoch_disc):
            self.learning_steps_disc += 1

            # Samples from current policy's trajectories.
            states, actions = self.buffer.sample(self.batch_size)[:2]
            # Samples from expert's demonstrations.
            states_exp, actions_exp = \
                self.buffer_exp.sample(self.batch_size)[:2]
            # Update discriminator.
            self.update_disc(states, actions, states_exp, actions_exp, writer)

        # We don't use reward signals here,
        states, actions, _, dones, log_pis, next_states = self.buffer.get()

        # Calculate rewards.
        rewards = self.disc.calculate_reward(states, actions)

        # Update PPO using estimated rewards.
        self.update_ppo(
            states, actions, rewards, dones, log_pis, next_states, writer)

    def update_disc(self, states, actions, states_exp, actions_exp, writer):
        # Output of discriminator is (-inf, inf), not [0, 1].
        logits_pi = self.disc(states, actions)
        logits_exp = self.disc(states_exp, actions_exp)

        # Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].
        loss_pi = -F.logsigmoid(-logits_pi).mean()
        loss_exp = -F.logsigmoid(logits_exp).mean()
        loss_disc = loss_pi + loss_exp

        self.optim_disc.zero_grad()
        loss_disc.backward()
        self.optim_disc.step()

        if self.learning_steps_disc % self.epoch_disc == 0:
            writer.add_scalar(
                'loss/disc', loss_disc.item(), self.learning_steps)

            # Discriminator's accuracies.
            with torch.no_grad():
                acc_pi = (logits_pi < 0).float().mean().item()
                acc_exp = (logits_exp > 0).float().mean().item()
            writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)
            writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1826445.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java数据库编程

引言 在现代应用开发中&#xff0c;与数据库交互是不可或缺的一部分。Java提供了JDBC&#xff08;Java Database Connectivity&#xff09; API&#xff0c;允许开发者方便地连接到数据库并执行SQL操作。本文将详细介绍Java数据库编程的基础知识&#xff0c;包括JDBC的基本概念…

为什么 JavaScript 在国外逐渐用于前端+后端开发

这个问题其实没人能给出可证伪的结论&#xff0c;那不如干脆给一个感性的答案: 因为阿里“不争气”。 确切的说&#xff0c;因为阿里的nodejs团队没卷赢&#xff0c;至少暂时还没卷赢&#xff0c;没拿到真正有价值的业务场景&#xff0c;做出真正有说服力的案例项目。刚好我有…

如何进行LLM大模型推理优化

解密LLM大模型推理优化本质 一、LLM推理的本质以及考量点 LLM推理聚焦Transformer架构的Decoder以生成文本。过程分两步&#xff1a;首先&#xff0c;模型初始化并加载输入文本&#xff1b;接着&#xff0c;进入解码阶段&#xff0c;模型自回归地生成文本&#xff0c;直至满足…

邮件钓鱼--有无SPF演示--Swaks

目录 临时邮箱网址: Swaks 简单使用说明&#xff1a;(kali自带) 操作流程: 无SPF:(直接伪造发信人) 演示1 演示2 演示3 ​编辑 有SPF:--演示 临时邮箱网址: http://24mail.chacuo.net/ https://www.linshi-email.com/ Swaks 简单使用说明&#xff1a;(kali自带) -t –t…

专题六——模拟

目录 一替换所有的问号 二提莫攻击 三N字形变换 四外观数列 五数青蛙 一替换所有的问号 oj链接&#xff1a;替换所有的问号 思路&#xff1a;简单模拟&#xff1b;注意i0和in是处理越界问题就行&#xff01;&#xff01; class Solution { public:string modifyString…

基于scikit-learn的机器学习分类任务实践——集成学习

一、传统机器学习分类流程与经典思想算法简述 传统机器学习是指&#xff0c;利用线性代数、数理统计与优化算法等数学方式从设计获取的数据集中构建预测学习器&#xff0c;进而对未知数据分类或回归。其主要流程大致可分为七个部分&#xff0c;依次为设计获取数据特征集&#x…

Reactor 网络模型、Java代码实例

文章目录 1. 概述2. Reactor 单线程模型2.1 ByteBufferUtil2.2 服务端代码2.3 客户端2.4 运行截图 3. Reactor多线程模型3.1 服务端代码3.2 运行截图 4. 主从 Reactor多线程模型4.1 服务端代码4.2 运行截图 参考文献 1. 概述 在 I/O 多路复用的场景下&#xff0c;当有数据处于…

ChatTTS-WebUI测试页面项目

概述 分享可以一个专门为对话场景设计的文本转语音模型ChatTTS&#xff0c;例如LLM助手对话任务。它支持英文和中文两种语言。最大的模型使用了10万小时以上的中英文数据进行训练。在HuggingFace中开源的版本为4万小时训练且未SFT的版本. 该模型能够预测和控制细粒度的韵律特…

跪求大数据把我推给做投资交易的红薯!

在qq群里认识了君诺金融Juno Markets外汇交易平台的业务经理&#xff0c;平台上大剌剌的打出20%交易返现活动&#xff0c;一时听信了他们的话在该平台有开户入金做交易&#xff0c;做了这家平台的代理&#xff0c;然而君诺金融Juno Markets平台却不给佣金&#xff0c;我都是属于…

浏览器必备插件:最新Allow copy万能网页复制下载,解锁网页限制!

今天阿星给大家安利一个超级实用的小工具&#xff0c;专治那些“禁止复制”的网页文字。学生党、资料搜集狂人&#xff0c;你们有福了&#xff01; 想象一下&#xff0c;你在网上冲浪&#xff0c;突然遇到一篇干货满满的文章&#xff0c;正想复制下来慢慢品味&#xff0c;结果…

值传递和址传递

值传递 上面的代码是想要交换x&#xff0c;y的值&#xff0c;把x&#xff0c;y传递给swap函数之后&#xff0c;执行下面的操作&#xff1a; 在swap中a和b交换了&#xff0c;但是和x&#xff0c;y没有关系&#xff0c;所以x&#xff0c;y在main中不会变。 址传递 下面再看把x…

springcloud gateway转发websocket请求的404问题定位

一、问题 前端小程序通过springcloud gateway接入并访问后端的诸多微服务&#xff0c;几十个微服务相关功能均正常&#xff0c;只有小程序到后端推送服务的websocket连接建立不起来&#xff0c;使用whireshark抓包&#xff0c;发现在小程序通过 GET ws://192.168.6.100:8888/w…

Apple Intelligence 横空出世!它的独家秘诀在哪里?

在 WWDC 2024 大会上&#xff0c;苹果公司揭晓了自家的生成式 AI 项目——Apple Intelligence&#xff0c;其策略核心在于采用 ⌈ 更为聚焦的小型模型 ⌋ &#xff0c;而非盲目追求大模型的普遍趋势。横空出世的它究竟有什么过人之处&#xff1f;一文带你探究竟&#xff01;生成…

[DDR4] DDR1 ~ DDR4 发展史导论

依公知及经验整理&#xff0c;原创保护&#xff0c;禁止转载。 专栏 《深入理解DDR4》 内存和硬盘是电脑的左膀右臂&#xff0c; 挑起存储的大梁。因为内存的存取速度超凡地快&#xff0c; 但内存上的数据掉电又会丢失&#xff0c;一直其中缓存的作用&#xff0c;就像是我们的工…

2786. 访问数组中的位置使分数最大

这并不是一个难题,但是我确实在做题中得到了一些启发,所以记录一下 先讲一讲这个题目的做法: 首先不难想到这是一个dp问题,(由 i 可以跳到 j ) 而且应该不难, 要不然就不是medium了,doge 那么,暴力的dp就是: dp[j] max (dp[i] nums OR dp[j] dp[i] nums - x) , i<j, 前…

mongodb 集群安装

1. 配置域名 Server1&#xff1a; OS version: CentOS Linux release 8.5.2111 hostnamectl --static set-hostname mongo01 vi /etc/sysconfig/network # Created by anaconda hostnamemong01 echo "192.168.88.20 mong1 mongo01.com mongo02.com" >> /…

【笔记】【矩阵的二分】668. 乘法表中第k小的数

力扣链接&#xff1a;题目 参考地址&#xff1a;参考 思路&#xff1a;二分查找 把矩阵想象成一维的已排好序的数组&#xff0c;用二分法找第k小的数字。 假设m行n列&#xff0c;则对应一维下标范围是从1到mn&#xff0c;初始&#xff1a; l1; rmn; mid(lr)/2 设mid在第i行&a…

【C++11】第一部分(一万六千多字)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 前言 C11简介 统一的列表初始化 &#xff5b;&#xff5d;初始化 std::initializer_list 声明 auto decltype 右值引用和移动语义 左值引用和右值引用 左值引…

C#实现WMI获取硬盘参数

文章目录 背景涉及框架及库WMI查询小工具参数解释U盘移动硬盘本机设备 总结 背景 因为需求需要涉及获取硬盘的SN参数&#xff0c;但是又不想要获取到U盘或移动硬盘设备的SN&#xff0c;所以就浅浅的研究了一下。 以下就是我目前发现的一些参数的作用&#xff0c;够我用了。。。…

QT QFileDialog文件选择对话框

QT QFileDialog文件选择对话框 选择txt或者cpp文件&#xff0c;读取内容并显示 参考&#xff1a; QT写入文件与读取文件内容_qt往一个文件写东西-CSDN博客 #include "QtFilePreview.h" #include "qfiledialog.h" #include "qfile.h" #includ…