基于“动手学强化学习”的知识点(二):第 15 章 模仿学习(gym版本 >= 0.26)

news2025/3/18 17:43:01

第 15 章 模仿学习(gym版本 >= 0.26)

  • 摘要

摘要

本系列知识点讲解基于动手学强化学习中的内容进行详细的疑难点分析!具体内容请阅读动手学强化学习!


对应动手学强化学习——模仿学习


# -*- coding: utf-8 -*-


import gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utils


class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1)


class ValueNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim):
        super(ValueNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return self.fc2(x)


class PPO:
    ''' PPO算法,采用截断方式 '''
    def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,
                 lmbda, epochs, eps, gamma, device):
        self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.critic = ValueNet(state_dim, hidden_dim).to(device)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
        self.gamma = gamma
        self.lmbda = lmbda
        self.epochs = epochs  # 一条序列的数据用于训练轮数
        self.eps = eps  # PPO中截断范围的参数
        self.device = device

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state)
        '''根据概率分布创建一个离散分类分布对象,用于采样离散动作。
        离散的概率模型。
        '''
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()

    def update(self, transition_dict):    
        
        processed_state = []
        for s in transition_dict['states']:
            if isinstance(s, tuple):
                # 如果元素是元组,则取元组的第一个元素
                processed_state.append(s[0])
            else:
                processed_state.append(s)
        # states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)
        states = torch.tensor(processed_state, dtype=torch.float).to(self.device)
        actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)
        rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)
        next_states = torch.tensor(transition_dict['next_states'],dtype=torch.float).to(self.device)
        dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)
        '''
        计算 TD 目标(即回归目标):
                            td_target=r+γ×V(s′)×(1−done)
        '''
        td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)
        '''计算 TD 残差(或优势估计的基础):当前状态的 TD 目标减去当前 critic 估计的状态价值。'''
        td_delta = td_target - self.critic(states)
        '''调用辅助函数(在 rl_utils 模块中定义)计算优势函数,通常使用广义优势估计(GAE)。'''
        advantage = rl_utils.compute_advantage(self.gamma, self.lmbda, td_delta.cpu()).to(self.device)
        '''
        先将状态输入 actor 网络得到动作概率分布(例如 shape 为 (batch_size, action_dim))。
        使用 .gather(1, actions) 选出每个样本所执行动作对应的概率(注意 actions 的形状必须匹配)。
        取对数得到旧的对数概率,再 detach() 阻断梯度流,保存旧策略下的概率值。
        '''
        old_log_probs = torch.log(self.actor(states).gather(1, actions)).detach()

        for _ in range(self.epochs):
            '''在当前策略下重新计算所有样本的对数概率,与旧对数概率进行比较。'''
            log_probs = torch.log(self.actor(states).gather(1, actions))
            '''计算概率比率,即新旧策略的概率之比,用于 PPO 的 clip 损失计算。'''
            ratio = torch.exp(log_probs - old_log_probs)
            '''计算无截断的策略目标,乘上优势值。'''
            surr1 = ratio * advantage
            '''对 ratio 进行截断,确保其在 [1−ϵ,1+ϵ] 范围内(例如 [0.8, 1.2]),然后乘以优势。'''
            surr2 = torch.clamp(ratio, 1 - self.eps, 1 + self.eps) * advantage  # 截断
            '''PPO 算法的目标是最大化最小值,因此这里取两者中的较小值再取负号作为损失。对整个 batch 求均值。'''
            actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数
            '''计算 critic 的均方误差(MSE)损失:当前 critic 估计与 TD 目标之间的误差,对整个 batch 取平均。'''
            critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
            self.actor_optimizer.zero_grad()
            self.critic_optimizer.zero_grad()
            actor_loss.backward()
            critic_loss.backward()
            self.actor_optimizer.step()
            self.critic_optimizer.step()


actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

env_name = 'CartPole-v0'
env = gym.make(env_name)
if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)

return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)



def sample_expert_data(n_episode):
    states = []
    actions = []
    for episode in range(n_episode):
        state = env.reset()
        done = False
        while not done:
            action = ppo_agent.take_action(state)
            states.append(state)
            actions.append(action)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state = next_state
    processed_states = []
    for s in states:
        if isinstance(s, tuple):
            # 如果元素是元组,则取元组的第一个元素
            processed_states.append(s[0])
        else:
            processed_states.append(s)
    return np.array(processed_states), np.array(actions)


if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)

n_samples = 30  # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]


class BehaviorClone:
    def __init__(self, state_dim, hidden_dim, action_dim, lr):
        self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)

    def learn(self, states, actions):
        """解释:定义一个学习函数,接收一批专家数据中的状态和动作,用于更新策略网络。"""
        states = torch.tensor(states, dtype=torch.float).to(device)
        actions = torch.tensor(actions).view(-1, 1).to(device)
        '''
        - 将 states 输入 policy 网络,得到每个状态下所有动作的概率分布,
          假设输出形状为 (batch_size, action_dim);
        - 使用 .gather(1, actions.long()) 从概率分布中取出对应专家动作的概率(注意动作需要转换为长整型索引);
        - 对这些概率取对数,得到对数概率(log likelihood)。
        '''
        log_probs = torch.log(self.policy(states).gather(1, actions.long()))
        # log_probs = torch.log(self.policy(states).gather(1, actions))
        '''计算行为克隆的损失,即负对数似然损失。对所有样本的负对数概率取均值。'''
        bc_loss = torch.mean(-log_probs)  # 最大似然估计

        self.optimizer.zero_grad()
        bc_loss.backward()
        self.optimizer.step()

    def take_action(self, state):
        if isinstance(state, tuple):
            state = state[0]
        state = torch.tensor([state], dtype=torch.float).to(device)
        probs = self.policy(state)
        action_dist = torch.distributions.Categorical(probs)
        action = action_dist.sample()
        return action.item()


def test_agent(agent, env, n_episode):
    return_list = []
    for episode in range(n_episode):
        episode_return = 0
        state = env.reset()
        done = False
        while not done:
            action = agent.take_action(state)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
    return np.mean(return_list)


if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
np.random.seed(0)

lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []

with tqdm(total=n_iterations, desc="进度条") as pbar:
    for i in range(n_iterations):
        sample_indices = np.random.randint(low=0, high=expert_s.shape[0], size=batch_size)
        bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])
        current_return = test_agent(bc_agent, env, 5)
        test_returns.append(current_return)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})
        pbar.update(1)

iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()






class Discriminator(nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(Discriminator, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, 1)

    def forward(self, x, a):
        cat = torch.cat([x, a], dim=1)
        x = F.relu(self.fc1(cat))
        return torch.sigmoid(self.fc2(x))
    
class GAIL:
    def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):
        print(state_dim, action_dim, hidden_dim)
        self.discriminator = Discriminator(state_dim, hidden_dim, action_dim).to(device)
        self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)
        self.agent = agent

    def learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):

        expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)
        expert_actions = torch.tensor(expert_a).to(device)

        processed_state = []
        for s in agent_s:
            if isinstance(s, tuple):
                # 如果元素是元组,则取元组的第一个元素
                processed_state.append(s[0])
            else:
                processed_state.append(s)
        agent_states = torch.tensor(processed_state, dtype=torch.float).to(device)
        agent_actions = torch.tensor(agent_a).to(device)
        '''作用:将专家动作转换为 one-hot 编码形式,转换为浮点数。'''
        expert_actions = F.one_hot(expert_actions.long(), num_classes=2).float()
        agent_actions = F.one_hot(agent_actions.long(), num_classes=2).float()

        expert_prob = self.discriminator(expert_states, expert_actions)
        agent_prob = self.discriminator(agent_states, agent_actions)
        '''
        作用:计算二元交叉熵损失(BCE):
        - 对 agent 数据,目标标签设为 1(即希望判别器认为 agent 数据为“真”),
          损失为 BCE(agent_prob, 1);
        - 对专家数据,目标标签设为 0(希望判别器认为专家数据为“假”),
          损失为 BCE(expert_prob, 0)。
        - 然后将两部分损失相加。
        '''
        discriminator_loss = nn.BCELoss()(
            agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(
                expert_prob, torch.zeros_like(expert_prob))
        self.discriminator_optimizer.zero_grad()
        discriminator_loss.backward()
        self.discriminator_optimizer.step()

        '''
        作用:利用判别器对 agent 数据输出计算奖励:
        - 计算 –log(agent_prob) 作为奖励信号(当 agent_prob 较小时,奖励较高,鼓励 agent 模仿专家);
        - detach() 阻断梯度,转移到 CPU 并转换为 numpy 数组,方便后续传入 agent.update。
        '''
        rewards = -torch.log(agent_prob).detach().cpu().numpy()
        transition_dict = {
            'states': agent_s,
            'actions': agent_a,
            'rewards': rewards,
            'next_states': next_s,
            'dones': dones
        }
        self.agent.update(transition_dict)



if not hasattr(env, 'seed'):
    def seed_fn(self, seed=None):
        env.reset(seed=seed)
        return [seed]
    env.seed = seed_fn.__get__(env, type(env))
# env.seed(0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda, epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []

with tqdm(total=n_episode, desc="进度条") as pbar:
    for i in range(n_episode):
        episode_return = 0
        state = env.reset()
        done = False
        state_list = []
        action_list = []
        next_state_list = []
        done_list = []
        while not done:
            action = agent.take_action(state)
            result = env.step(action)
            if len(result) == 5:
                next_state, reward, done, truncated, info = result
                done = done or truncated  # 可合并 terminated 和 truncated 标志
            else:
                next_state, reward, done, info = result
            # next_state, reward, done, _ = env.step(action)
            state_list.append(state)
            action_list.append(action)
            next_state_list.append(next_state)
            done_list.append(done)
            state = next_state
            episode_return += reward
        return_list.append(episode_return)
        gail.learn(expert_s, expert_a, state_list, action_list, next_state_list, done_list)
        if (i + 1) % 10 == 0:
            pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})
        pbar.update(1)    
    
    
iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2317344.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2025-03-17 Unity 网络基础1——网络基本概念

文章目录 1 网络1.1 局域网1.2 以太网1.3 城域网1.4 广域网1.5 互联网(因特网)1.6 万维网1.7 小结 2 IP 地址2.1 IP 地址2.2 端口号2.3 Mac 地址2.4 小结 3 客户端与服务端3.1 客户端3.2 服务端3.3 网络游戏中的客户端与服务端 1 网络 ​ 在没有网络之前…

springboot441-基于SpringBoot的校园自助交易系统(源码+数据库+纯前后端分离+部署讲解等)

💕💕作者: 爱笑学姐 💕💕个人简介:十年Java,Python美女程序员一枚,精通计算机专业前后端各类框架。 💕💕各类成品Java毕设 。javaweb,ssm&#xf…

浅谈数据分析及数据思维

目录 一、数据分析及数据分析思维?1.1 数据分析的本质1.2 数据分析思维的本质1.2.1 拥有数据思维的具体表现1.2.2 如何培养自己的数据思维1.2.2.1 书籍1.2.2.2 借助工具1.2.2.3 刻意练习 二、数据分析的价值及必备能力?2.1 数据分析的价值2.1.1 现状分析…

自定义uniapp组件,以picker组件为例

编写目的 本文说明基于vue3定义uniapp组件的关键点: 1、一般定义在components文件夹创建组件,组件与页面已经没有明确的语法格式区别,所以可以与页面的语法保持一致 ; 2、组件定义后使用该组件的页面不需要引用组件即可使用&am…

【操作系统安全】任务4:Windows 系统网络安全实践里常用 DOS 命令

目录 一、引言 二、网络信息收集类命令 2.1 ipconfig 命令 2.1.1 功能概述 2.1.2 实例与代码 2.2 ping 命令 2.2.1 功能概述 2.2.2 实例与代码 2.3 tracert 命令 2.3.1 功能概述 2.3.2 实例与代码 三、网络连接与端口管理类命令 3.1 netstat 命令 3.1.1 功能概述…

【从零开始学习计算机科学】信息安全(二)物理安全

【从零开始学习计算机科学】信息安全(二)物理安全 物理安全物理安全的涵义物理安全威胁常见物理安全问题物理安全需求规划物理安全需求设备安全防盗和防毁机房门禁系统机房入侵检测和报警系统防电磁泄漏防窃听设备管理设备维护设备的处置和重复利用设备的转移电源安全电源调整…

LeetCode hot 100—验证二叉搜索树

题目 给你一个二叉树的根节点 root ,判断其是否是一个有效的二叉搜索树。 有效 二叉搜索树定义如下: 节点的左子树只包含 小于 当前节点的数。节点的右子树只包含 大于 当前节点的数。所有左子树和右子树自身必须也是二叉搜索树。 示例 示例 1&#…

MongoDB 可观测性最佳实践

MongoDB 介绍 MongoDB 是一个高性能、开源的 NoSQL 数据库,它采用灵活的文档数据模型,非常适合处理大规模的分布式数据。MongoDB 的文档存储方式使得数据结构可以随需求变化而变化,提供了极高的灵活性。它支持丰富的查询语言,允许…

论文阅读笔记——LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS

LoRA 论文 传统全面微调&#xff0c;对每个任务学习的参数与原始模型相同&#xff1a; m a x Φ ∑ ( x , y ) ∈ Z ∑ t 1 ∣ y ∣ l o g ( P Φ ( y t ∣ x , y < t ) ) 式(1) max_{\Phi}\sum_{(x,y)\in Z}\sum^{|y|}_{t1}log(P_{\Phi}(y_t|x,y<t)) \qquad \text{式(…

UE5中 Character、PlayerController、PlayerState、GameMode和GameState核心类之间的联动和分工·

1. GameMode 与 GameState 关系描述 GameMode&#xff1a;定义游戏规则和逻辑&#xff0c;控制游戏的开始、进行和结束。GameState&#xff1a;存储和同步全局游戏状态&#xff0c;如得分、时间、胜利条件等。 联动方式 GameMode初始化GameState&#xff1a;GameMode在游戏…

Ubuntu24.04 启动后突然进入tty,无法进入图形界面

问题描述 昨晚在编译 Android AOSP 14 后&#xff0c;进入了登录页面&#xff0c;但出现了无法输入密码的情况&#xff0c;且无法正常关机&#xff0c;只能强制重启。重启后&#xff0c;系统只能进入 TTY 页面&#xff0c;无法进入图形界面。 问题排查 经过初步排查&#x…

搭建主从服务器

任务需求 客户端通过访问 www.nihao.com 后&#xff0c;能够通过 dns 域名解析&#xff0c;访问到 nginx 服务中由 nfs 共享的首页文件&#xff0c;内容为&#xff1a;Very good, you have successfully set up the system. 各个主机能够实现时间同步&#xff0c;并且都开启防…

jenkins 配置邮件问题整理

版本&#xff1a;Jenkins 2.492.1 插件&#xff1a; A.jenkins自带的&#xff0c; B.安装功能强大的插件 配置流程&#xff1a; 1. jenkins->系统配置->Jenkins Location 此处的”系统管理员邮件地址“&#xff0c;是配置之后发件人的email。 2.配置系统自带的邮件A…

JVM中常量池和运行时常量池、字符串常量池三者之间的关系

文章目录 前言常量池&#xff08;Constant Pool&#xff09;运行时常量池&#xff08;Runtime Constant Pool&#xff09;字符串常量池&#xff08;String Literal Pool&#xff09;运行时常量池和字符串常量池位置变化方法区与永久代和元空间的关系三者之间的关系常量池与运行…

Mysql篇——SQL优化

本篇将带领各位了解一些常见的sql优化方法&#xff0c;学到就是赚到&#xff0c;一起跟着练习吧~ SQL优化 准备工作 准备的话我们肯定是需要一张表的&#xff0c;什么表都可以&#xff0c;这里先给出我的表结构&#xff08;表名&#xff1a;userinfo&#xff09; 通过sql查看…

FPGA|Verilog-SPI驱动

最近准备蓝桥杯FPGA的竞赛&#xff0c;因为感觉官方出的IIC的驱动代码思路非常好&#xff0c;写的内容非常有逻辑并且规范。也想学习一下SPI的协议&#xff0c;所以准备自己照着写一下。直到我打开他们给出的SPI底层驱动&#xff0c;我整个人傻眼了&#xff0c;我只能说&#x…

Windows11 新机开荒(二)电脑优化设置

目录 前言&#xff1a; 一、注册微软账号绑定权益 二、此电脑 桌面图标 三、系统分盘及默认存储位置更改 3.1 系统分盘 3.2 默认存储位置更改 四、精简任务栏 总结&#xff1a; 前言&#xff1a; 本文承接上一篇 新机开荒&#xff08;一&#xff09; 上一篇文章地址&…

关于deepseek R1模型分布式推理效率分析

1、引言 DeepSeek R1 采用了混合专家&#xff08;Mixture of Experts&#xff0c;MoE&#xff09;架构&#xff0c;包含多个专家子网络&#xff0c;并通过一个门控机制动态地激活最相关的专家来处理特定的任务 。DeepSeek R1 总共有 6710 亿个参数&#xff0c;但在每个前向传播…

揭秘大数据 | 9、大数据从何而来?

在科技发展史上&#xff0c;恐怕没有任何一种新生事物深入人心的速度堪比大数据。 如果把2012年作为数据量爆发性增长的第一年&#xff0c;那么短短数年&#xff0c;大数据就红遍街头巷尾——从工业界到商业界、学术界&#xff0c;所有的行业都经受了大数据的洗礼。从技术的迭…

使用Dependency Walker和Beyond Compare快速排查dll动态库损坏或被篡改的问题

目录 1、问题描述 2、用Dependency Walker工具打开qr.dll库&#xff0c;查看库与库的依赖关系以及接口调用情况&#xff0c;定位问题 3、使用Beyond Compare工具比较一下正常的msvcr100d.dll和问题msvcr100d.dll的差异 4、最后 C软件异常排查从入门到精通系列教程&#xff…