【SSL-RL】自监督强化学习:自预测表征 (SPR)算法

news2024/11/24 6:20:47

        📢本篇文章是博主强化学习(RL)领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在👉强化学习专栏:

       【强化学习】(44)---《自监督强化学习:自预测表征 (SPR)算法》

自监督强化学习:自预测表征 (SPR)算法

目录

1. 引言

2. SPR算法的核心思想

2.1 潜在状态表示学习

2.2 潜在状态的多步预测

2.3 一致性损失

2.4 总损失函数

3. SPR算法的工作流程

3.1 数据编码

3.2 潜在状态预测

3.3 一致性损失优化

3.4 策略学习

[Python] SPR算法的实现示例

[Experiment] SPR算法的应用示例

[Notice]  代码解析

4. SPR的优势与挑战

5. 结论


1. 引言

        自预测表征,Self-Predictive Representations (SPR)算法 是一种用于自监督强化学习的算法,旨在通过学习预测未来的潜在状态来帮助智能体构建有用的状态表示。SPR在强化学习任务中无需依赖稀疏或外部奖励,通过自监督学习的方法获得环境的潜在结构和动态信息。这种方法特别适合高维观测环境(如图像)或部分可观测的任务。

        SPR的关键目标是通过让智能体在潜在空间中预测未来的状态,从而形成对环境的理解,使得智能体可以高效地进行策略学习和探索。


2. SPR算法的核心思想

        SPR的核心思想是训练一个模型,使其能够在潜在空间中预测未来的状态表示。这种潜在表示应当具备描述环境动态和指导智能体决策的能力。SPR包含以下主要要素:

  • 潜在状态的预测(Latent State Prediction):SPR训练模型在潜在空间中预测未来的潜在状态,而不是直接在观测空间中进行预测,从而减少状态空间的复杂性。
  • 多步预测(Multi-step Prediction):SPR不仅预测下一步的潜在状态,还进行多步预测,从而捕捉环境的长时间依赖关系。
  • 一致性损失(Consistency Loss):通过一个自监督一致性损失,确保潜在空间的预测能够准确反映未来的真实状态。

2.1 潜在状态表示学习

        在SPR中,环境的高维观测( o_t ) 首先通过编码器 ( f_\theta )映射到低维潜在空间中的状态表示( z_t )。公式上,潜在状态表示为:

[ z_t = f_\theta(o_t) ]

其中,( \theta )是编码器的参数。该潜在表示( z_t )应该包含与任务相关的关键信息,以便用于预测未来的潜在状态。

2.2 潜在状态的多步预测

        SPR使用一个预测网络( g_\phi )来预测未来的潜在状态。预测网络的输入是当前潜在状态( z_t ) 和当前的动作序列,输出是未来的潜在状态预测 ( \hat{z}_{t+k} ),其中( k )是预测的步数。公式表示如下:

[ \hat{z}{t+k} = g\phi(z_t, a_t, \dots, a_{t+k-1}) ]

        这种多步预测的设计能够让SPR捕捉到长时间依赖关系,使得潜在表示更加稳定和有效。

2.3 一致性损失

        为了确保模型的预测能力,SPR设计了一个一致性损失,用于约束预测的潜在状态与真实的潜在状态保持一致。一致性损失通过最小化预测的潜在状态( \hat{z}{t+k} )和真实潜在状态( z{t+k} )之间的差异来实现。公式如下:

[ L_{\text{consistency}} = \sum_{k=1}^K | \hat{z}{t+k} - z{t+k} |^2 ]

其中,( K )是预测的步数。一致性损失确保了模型在潜在空间中的预测能够准确反映未来的实际状态,从而形成稳定的状态表示。

2.4 总损失函数

        SPR的训练损失函数综合了多步预测的一致性损失,最终的损失函数为:

[ L_{\text{SPR}} = L_{\text{consistency}} ]

        通过优化一致性损失,SPR可以学习到对环境动态有用的潜在表示,从而帮助智能体更好地理解和探索环境。


3. SPR算法的工作流程

3.1 数据编码

        在每个时间步 ( t ),环境的高维观测( o_t )被编码器 ( f_\theta )映射到低维的潜在表示( z_t )。该表示保留了当前观测中的关键信息,同时降低了数据维度。

3.2 潜在状态预测

        通过预测网络( g_\phi ),SPR在潜在空间中预测未来的潜在状态( \hat{z}_{t+k} )。这使得模型能够在低维空间中进行未来状态的预测,而不需要直接预测高维观测。

3.3 一致性损失优化

        通过最小化一致性损失,SPR模型在潜在空间中优化预测,使得潜在表示能够准确地反映环境的动态变化。

3.4 策略学习

        一旦学习到稳定的潜在状态表示,SPR可以与常规的强化学习算法(如DQN、PPO等)结合,将潜在状态作为输入,优化策略。此时,强化学习算法在低维潜在空间中工作,从而显著提高了学习效率。


[Python] SPR算法的实现示例

        以下是一个简化的SPR实现示例,展示如何通过编码器、预测网络和一致性损失来实现潜在表示的自监督学习。

        🔥若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱📌,以便于及时分享给您(私信难以及时回复)。

"""《SPR算法的实现示例》
    时间:2024.11
    作者:不去幼儿园
"""
import torch
import torch.nn as nn
import torch.optim as optim

# 定义SPR模型类
class SPR(nn.Module):
    def __init__(self, obs_dim, act_dim, latent_dim):
        super(SPR, self).__init__()
        self.encoder = Encoder(obs_dim, latent_dim)
        self.predictor = Predictor(latent_dim, act_dim, latent_dim)

    def forward(self, obs, actions):
        latent_state = self.encoder(obs)
        predicted_latent = self.predictor(latent_state, actions)
        return latent_state, predicted_latent

# 定义编码器和预测网络
class Encoder(nn.Module):
    def __init__(self, obs_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(obs_dim, 64)
        self.fc2 = nn.Linear(64, latent_dim)
        self.relu = nn.ReLU()

    def forward(self, obs):
        x = self.relu(self.fc1(obs))
        latent_state = self.fc2(x)
        return latent_state

class Predictor(nn.Module):
    def __init__(self, latent_dim, act_dim, latent_output_dim):
        super(Predictor, self).__init__()
        self.fc1 = nn.Linear(latent_dim + act_dim, 64)
        self.fc2 = nn.Linear(64, latent_output_dim)
        self.relu = nn.ReLU()

    def forward(self, latent_state, actions):
        x = torch.cat([latent_state, actions], dim=1)
        x = self.relu(self.fc1(x))
        predicted_latent = self.fc2(x)
        return predicted_latent

# 训练SPR模型
def train_spr_model(spr_model, obs_batch, actions_batch, next_obs_batch, optimizer):
    latent_state, predicted_latent = spr_model(obs_batch, actions_batch)
    next_latent_state = spr_model.encoder(next_obs_batch)

    # 计算一致性损失
    consistency_loss = torch.mean((predicted_latent - next_latent_state) ** 2)

    # 更新模型参数
    optimizer.zero_grad()
    consistency_loss.backward()
    optimizer.step()

# 示例用法
obs_dim = 64
act_dim = 32
latent_dim = 16
spr_model = SPR(obs_dim, act_dim, latent_dim)
optimizer = optim.Adam(spr_model.parameters(), lr=1e-3)

# 假设有批量数据
obs_batch = torch.randn(64, obs_dim)
actions_batch = torch.randn(64, act_dim)
next_obs_batch = torch.randn(64, obs_dim)

# 训练模型
train_spr_model(spr_model, obs_batch, actions_batch, next_obs_batch, optimizer)

[Experiment] SPR算法的应用示例

        在强化学习任务中,SPR可以帮助智能体在没有奖励信号的情况下学习环境的动态结构,并建立有效的潜在状态表示。此潜在状态表示能够用于增强常规强化学习算法的性能,特别是在稀疏奖励或复杂观测场景中。以下是SPR与常规强化学习算法(如DQN或PPO)结合使用的应用示例。

应用流程

  1. 环境初始化:创建强化学习环境,定义观测和动作空间的维度。
  2. SPR模型初始化:创建SPR模型,包括编码器和预测器网络。
  3. 强化学习算法初始化:例如使用DQN智能体,将SPR提取的潜在表示作为状态输入。
  4. 训练循环
    • 潜在状态编码:通过SPR模型的编码器,将环境观测映射到潜在状态。
    • 策略选择:在潜在空间中使用DQN选择最优动作。
    • 环境交互与反馈:执行动作,环境返回奖励和下一个观测。
    • 潜在状态的多步预测:使用SPR的预测器网络对未来的潜在状态进行预测,并计算一致性损失。
    • 更新模型和策略:根据一致性损失优化SPR模型,并根据奖励优化DQN策略。
# 定义DQN智能体
class DQNAgent:
    def __init__(self, state_dim, action_dim, lr=1e-3):
        self.q_network = nn.Sequential(
            nn.Linear(state_dim, 64),
            nn.ReLU(),
            nn.Linear(64, action_dim)
        )
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

    def select_action(self, state):
        with torch.no_grad():
            q_values = self.q_network(state)
            action = q_values.argmax().item()
        return action

    def update(self, states, actions, rewards, next_states, dones):
        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()
        with torch.no_grad():
            max_next_q_values = self.q_network(next_states).max(1)[0]
            target_q_values = rewards + (0.99 * max_next_q_values * (1 - dones))
        loss = torch.mean((q_values - target_q_values) ** 2)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

实例训练:

# 训练循环
spr_model = SPR(obs_dim, act_dim, latent_dim)
dqn_agent = DQNAgent(state_dim=latent_dim, action_dim=env.action_space.n)
spr_optimizer = optim.Adam(spr_model.parameters(), lr=1e-3)

for episode in range(num_episodes):
    obs = env.reset()
    done = False
    episode_reward = 0
    
    while not done:
        # 编码当前观测到潜在状态
        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        latent_state = spr_model.encoder(obs_tensor)

        # 选择动作
        action = dqn_agent.select_action(latent_state)
        next_obs, reward, done, _ = env.step(action)
        
        # 更新SPR模型
        next_obs_tensor = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0)
        spr_model.update(obs_tensor, torch.tensor([action]), next_obs_tensor, spr_optimizer)
        
        # 更新DQN智能体
        dqn_agent.update(latent_state, torch.tensor([action]), torch.tensor([reward]), spr_model.encoder(next_obs_tensor), torch.tensor([done]))

        obs = next_obs
        episode_reward += reward

    print(f"Episode {episode + 1}: Total Reward = {episode_reward}")

[Notice]  代码解析

  • 潜在状态表示学习:SPR模型将高维观测编码为潜在状态,简化了状态表示的维度。
  • 一致性损失优化:SPR模型在潜在空间中通过预测未来的潜在状态进行优化,从而帮助智能体理解环境的动态结构。
  • 策略优化:DQN智能体在潜在空间中选择最优动作,并通过环境反馈的奖励更新策略。

        由于博文主要为了介绍相关算法的原理应用的方法,缺乏对于实际效果的关注,算法可能在上述环境中的效果不佳或者无法运行,一是算法不适配上述环境,二是算法未调参和优化,三是没有呈现完整的代码,四是等等。上述代码用于了解和学习算法足够了,但若是想直接将上面代码应用于实际项目中,还需要进行修改。


4. SPR的优势与挑战

优势

  1. 减少维度和复杂性:通过在低维潜在空间中进行预测和策略学习,SPR减少了高维观测带来的计算复杂性。
  2. 捕捉环境动态:SPR通过多步预测和一致性损失,使得模型能够捕捉环境的长期依赖关系。
  3. 无奖励学习:SPR可以在没有奖励信号的情况下构建有用的状态表示,特别适合稀疏奖励或无奖励的环境。

挑战

  1. 预测误差积累:在多步预测中,预测误差可能会积累,从而影响潜在表示的稳定性。
  2. 超参数敏感性:多步预测的步数 ( K ) 和一致性损失的权重可能需要在不同任务中进行调优。
  3. 潜在空间的解释性:SPR学习的潜在表示可能缺乏解释性,特别是在复杂的观测中。

5. 结论

        Self-Predictive Representations (SPR)是一种有前景的自监督强化学习方法,通过在潜在空间中预测未来的状态来构建有用的状态表示。SPR不仅可以减少环境观测的复杂性,还能够捕捉环境的长期动态关系,对于部分可观测的任务尤其有效。未来,SPR在处理复杂环境、稀疏奖励和多智能体系统中的应用具有广阔的研究和应用前景。

参考文献:

  • Pathak, D., et al. (2017). "Curiosity-driven Exploration by Self-supervised Prediction." ICML.
  • Hafner, D., et al. (2019). "Learning Latent Dynamics for Planning from Pixels." ICML.
  • Dosovitskiy, A., et al. (2021). "Image Transformer." NeurIPS.

 更多自监督强化学习文章,请前往:【自监督强化学习】专栏 


     文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者添加VX:Rainbook_2,联系作者。✨

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2239365.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Docker部署kafka集群

1,编写Docker Compose文件 编写一个docker-compose.yml文件来定义服务,以下定义了一个Zookeeper服务和三个Kafka Broker服务: 注意:把10.0.8.4替换成宿主IP version: 3.8services:zookeeper:image: bitnami/zookeeper:latestconta…

Qt滑动条美化自定义

效果展示 主要代码 头文件 下面是hi控件的头文件,我们继承一个Qt原生的滑动条类QSlider,然后在基类的基础上进行自定义,我会对重要的变量进行解析: class XSlider : public QSlider {Q_OBJECT public:explicit XSlider(QWidget…

wordpress实用功能A5资源网同款 隐藏下载框 支付框 需要登录才能查看隐藏的内容

实用功能 隐藏下载框 支付框 需要登录才能查看隐藏的内容, 个人网站防天朝申查实测有效 。 登录前,未登录: 登录后,已登录: 功能说明 该代码段的主要功能是隐藏支付框并为未 登录用户显示一条提示信息,告知他们需要…

Vue 简单入手

前端工程化(Front-end Engineering)指的是在前端开发中,通过一系列工具、流程和规范的整合,以提高开发效率、代码质量和可维护性的一种技术和实践方法。其核心目的是使得前端开发变得更高效、可扩展和可维护。 文章目录 一、Vue 项…

Spingboot 定时任务与拦截器(详细解释)

在 boot 环境中,一般来说,要实现定时任务,我们有两中方案,一种是使用 Spring 自带的定时 任务处理器 Scheduled 注解,另一种就是使用第三方框架 Quartz , Spring Boot 源自 SpringSpringMVC &#…

无人机+无人车+无人狗+无人船:互通互联技术探索详解

关于“无人机无人车机器狗(注:原文中的“无人狗”可能是一个笔误,因为在实际技术领域中,常用的是“机器狗”这一术语)无人船”的互通互联技术,以下是对其的详细探索与解析: 一、系统架构与关键…

ima.copilot-腾讯智能工作台

一、产品描述 ima.copilot是腾讯推出的基于腾讯混元大模型技术的智能工作台,通过先进的人工智能技术,为用户提供了一个全新的搜读写体验,让知识管理变得更加智能和高效。它不仅是一个工具,更是一个智能的伙伴,能够帮助…

集合卡尔曼滤波(EnsembleKalmanFilter)的MATLAB例程(三维、二维)

本 M A T L A B MATLAB MATLAB代码实现了一个三维动态系统的集合卡尔曼滤波(Ensemble Kalman Filter, EnKF)示例。代码的主要目的是通过模拟真实状态和测量值,使用 EnKF 方法对动态系统状态进行估计。 文章目录 参数设置初始化真实状态定义状…

【动手学电机驱动】STM32-FOC(5)基于 IHM03 的无感 FOC 控制

STM32-FOC(1)STM32 电机控制的软件开发环境 STM32-FOC(2)STM32 导入和创建项目 STM32-FOC(3)STM32 三路互补 PWM 输出 STM32-FOC(4)IHM03 电机控制套件介绍 STM32-FOC(5&…

光老化测试的三种试验:紫外老化、氙灯老化、碳弧灯老化

光老化是指材料在阳光照射下,由于紫外线、热和氧气的共同作用而发生的物理和化学变化。这种现象对纺织材料、塑料材料、涂料材料和橡胶材料的应用有显著影响。这些材料户外家具、汽车内饰和户外供水排水管道、建筑外墙涂料、汽车漆面、船舶涂料、汽车轮胎、密封件、…

VMWare安装包及安装过程

虚拟机基本使用 检查自己是否开启虚拟化 如果虚拟化没有开启,需要自行开启:百度加上自己电脑的品牌型号,进入BIOS界面开启 什么是虚拟机 所谓的虚拟机,就是在当前计算机系统中,又开启了一个虚拟系统 这个虚拟系统&…

消费的外部性

大学宿舍,遇到在你睡觉的时候开外放不戴耳机的室友,但中午12点,室友却在那拉上窗帘睡觉。能带饭吗?能代签到吗?能倒个垃圾吗?能带个外卖吗?自己永远麻烦别人,你要让他帮个忙又这推那…

易趋亮相2024 PMI项目管理大会

11月9日-10日,2024 PMI项目管理大会在广州圆满举办,项目管理行业优秀代表企业——易趋(隶属深圳市蓝云软件有限公司),携最新产品和解决方案亮相本次展会。 (主论坛现场) 本届大会以“‘项’有所成 行以致远…

边缘计算与推理算力:智能时代的加速引擎

在数据量爆炸性增长的今天,边缘计算与推理算力正成为推动智能应用的关键力量。智能家居、自动驾驶、工业4.0等领域正在逐步从传统的云端计算转向边缘计算,而推理算力的加入,为边缘计算提供了更强的数据处理能力和实时决策能力。本文将探讨边缘…

基于matlab的CNN食物识别分类系统,matlab深度学习分类,训练+数据集+界面

文章目录 前言🎓一、数据集准备🎓二、模型训练🍀🍀1.初始化🍀🍀2.加载数据集🍀🍀3.划分数据集,并保存到新的文件夹🍀🍀4.可视化数据集&#x1f34…

马斯克万卡集群AI数据中心引发的科技涟漪:智算数据中心挑战与机遇的全景洞察

一、AI 爆发重塑数据中心格局 随着AI 技术的迅猛发展,尤其是大模型的崛起,其对数据中心产生了极为深远的影响。大模型以其数以亿计甚至更多的参数和对海量数据的处理需求,成为了 AI 发展的核心驱动力之一,同时也为数据中心带来了…

移远通信亮相骁龙AI PC生态科技日,以领先的5G及Wi-Fi产品革新PC用户体验

PC作为人们学习、办公、娱乐的重要工具,已经深度融入我们的工作和生活。随着物联网技术的快速发展,以及人们对PC性能要求的逐步提高,AI PC成为了行业发展的重要趋势。 11月7-8日,骁龙AI PC生态科技日在深圳举办。作为高通骁龙的重…

Unity资源打包Addressable资源保存在项目中

怎么打包先看“Unity资源打包Addressable AA包” 其中遗留一个问题,下载下来的资源被保存在C盘中了,可不可以保存在项目中呢?可以。 新建了一个项目,路径与“Unity资源打包Addressable AA包”都不相同了 1.创建资源缓存路径 在…

postman变量和脚本功能介绍

1、基本概念——global、collection、environment 在postman中,为了更好的管理各类变量、测试环境以及脚本等,创建了一些概念,包括:globals、collection、environment。其实在postman中,最上层还有一个Workspaces的概…

为什么汽车电源正在用 48V 取代 12V

欧姆定律也有利于 48 伏电源 假设您需要为汽车的起动电机供电。可能存在以下静态和动态特征: 电源电压:12V 额定电流:40A 额定功率:480W 标称平均阻抗:0.3Ω 浪涌电流:150A 浪涌功率:1,8…