强化学习笔记之【SAC算法】

news2024/11/24 8:35:33

强化学习笔记之【SAC算法】


前言:

本文为强化学习笔记第三篇,第一篇讲的是Q-learning和DQN,第二篇DDPG,第三篇TD3

TD3比DDPG少了一个target_actor网络,其它地方有点小改动

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle

博客园本文链接:https://www.cnblogs.com/hassle/p/18459320


文章目录

  • 强化学习笔记之【SAC算法】
      • 前言:
      • 一、SAC算法
      • 二、SAC算法Latex解释
      • 三、SAC五大网络和模块
        • 3.1 Actor 网络
        • 3.2 Critic1 和 Critic2 网络
        • 3.3 Target Critic1 和 Target Critic2 网络
        • 3.4 软更新模块
        • 3.5 总结

STAND ALONE COMPLEX = S . A . C

首先,我们需要明确,Q-learning算法发展成DQN算法,DQN算法发展成为DDPG算法,而DDPG算法发展成TD3算法,TD3算法发展成SAC算法

Soft Actor-Critic (SAC) 是一种基于策略梯度的深度强化学习算法,它具有最大化奖励与最大化熵(探索性)的双重目标。SAC 通过引入熵正则项,使策略在决策时具有更大的随机性,从而提高探索能力。

一、SAC算法

OK,先用伪代码让你们感受一下SAC算法

# 定义 SAC 超参数
alpha = 0.2               # 熵正则项系数
gamma = 0.99              # 折扣因子
tau = 0.005               # 目标网络软更新参数
lr = 3e-4                 # 学习率

# 初始化 Actor、Critic、Target Critic 网络和优化器
actor = ActorNetwork()                      # 策略网络 π(s)
critic1 = CriticNetwork()                   # 第一个 Q 网络 Q1(s, a)
critic2 = CriticNetwork()                   # 第二个 Q 网络 Q2(s, a)
target_critic1 = CriticNetwork()            # 目标 Q 网络 1
target_critic2 = CriticNetwork()            # 目标 Q 网络 2

# 将目标 Q 网络的参数设置为与 Critic 网络相同
target_critic1.load_state_dict(critic1.state_dict())
target_critic2.load_state_dict(critic2.state_dict())

# 初始化优化器
actor_optimizer = torch.optim.Adam(actor.parameters(), lr=lr)
critic1_optimizer = torch.optim.Adam(critic1.parameters(), lr=lr)
critic2_optimizer = torch.optim.Adam(critic2.parameters(), lr=lr)

# 经验回放池(Replay Buffer)
replay_buffer = ReplayBuffer()

# SAC 训练循环
for each iteration:
    # Step 1: 从 Replay Buffer 中采样一个批次 (state, action, reward, next_state)
    batch = replay_buffer.sample()
    state, action, reward, next_state, done = batch

    # Step 2: 计算目标 Q 值 (y)
    with torch.no_grad():
        # 从 Actor 网络中获取 next_state 的下一个动作
        next_action, next_log_prob = actor.sample(next_state)
        
        # 目标 Q 值的计算:使用目标 Q 网络的最小值 + 熵项
        target_q1_value = target_critic1(next_state, next_action)
        target_q2_value = target_critic2(next_state, next_action)
        min_target_q_value = torch.min(target_q1_value, target_q2_value)

        # 目标 Q 值 y = r + γ * (最小目标 Q 值 - α * next_log_prob)
        target_q_value = reward + gamma * (1 - done) * (min_target_q_value - alpha * next_log_prob)

    # Step 3: 更新 Critic 网络
    # Critic 1 损失
    current_q1_value = critic1(state, action)
    critic1_loss = F.mse_loss(current_q1_value, target_q_value)

    # Critic 2 损失
    current_q2_value = critic2(state, action)
    critic2_loss = F.mse_loss(current_q2_value, target_q_value)

    # 反向传播并更新 Critic 网络参数
    critic1_optimizer.zero_grad()
    critic1_loss.backward()
    critic1_optimizer.step()

    critic2_optimizer.zero_grad()
    critic2_loss.backward()
    critic2_optimizer.step()

    # Step 4: 更新 Actor 网络
    # 通过 Actor 网络生成新的动作及其 log 概率
    new_action, log_prob = actor.sample(state)

    # 计算 Actor 的目标损失:L = α * log_prob - Q1(s, π(s))
    q1_value = critic1(state, new_action)
    actor_loss = (alpha * log_prob - q1_value).mean()

    # 反向传播并更新 Actor 网络参数
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

    # Step 5: 软更新目标 Q 网络参数
    with torch.no_grad():
        for param, target_param in zip(critic1.parameters(), target_critic1.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

        for param, target_param in zip(critic2.parameters(), target_critic2.parameters()):
            target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

二、SAC算法Latex解释

1、初始化 Actor、Critic1、Critic2、TargetCritic1 、TargetCritic2 网络
2、Buffer中采样 (state, action, reward, next_state)

3、Actor 输入 next_state 对应输出 next_action 和 next_log_prob
4、Actor 输入 state 对应输出 new_action 和 log_prob
5、Critic1 和 Critic2 分别输入next_state 和 next_action 取其中较小输出经熵正则计算得 target_q_value

6、使用 MSE_loss(Critic1(state, action), target_q_value) 更新 Critic1
7、使用 MSE_loss(Critic2(state, action), target_q_value) 更新 Critic2
8、使用 (alpha * log_prob - critic1(state, new_action)).mean() 更新 Actor


三、SAC五大网络和模块

SAC 算法 中,Actor、Critic1、Critic2、Target Critic1 和 Target Critic2 网络是核心模块,它们分别用于输出动作、评估状态-动作对的价值,并通过目标网络进行稳定的更新。

3.1 Actor 网络

Actor 网络用于在给定状态下输出一个高斯分布的均值和标准差(即策略)。它是通过神经网络近似的随机策略。用于选择动作。

import torch
import torch.nn as nn

class ActorNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(ActorNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mean_layer = nn.Linear(256, action_dim)  # 输出动作的均值
        self.log_std_layer = nn.Linear(256, action_dim)  # 输出动作的log标准差

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        mean = self.mean_layer(x)  # 输出动作均值
        log_std = self.log_std_layer(x)  # 输出 log 标准差
        log_std = torch.clamp(log_std, min=-20, max=2)  # 限制标准差范围
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = torch.exp(log_std)  # 将 log 标准差转为标准差
        normal = torch.distributions.Normal(mean, std)
        action = normal.rsample()  # 通过重参数化技巧进行采样
        log_prob = normal.log_prob(action).sum(-1)  # 计算 log 概率
        return action, log_prob


3.2 Critic1 和 Critic2 网络

Critic 网络用于计算状态-动作对的 Q 值,SAC 使用两个 Critic 网络(Critic1 和 Critic2)来缓解 Q 值的过估计问题。

class CriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(CriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.q_value_layer(x)  # 输出 Q 值
        return q_value


3.3 Target Critic1 和 Target Critic2 网络

Target Critic 网络的结构与 Critic 网络相同,用于稳定 Q 值更新。它们通过软更新(即在每次训练后慢慢接近 Critic 网络的参数)来保持训练的稳定性。

class TargetCriticNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(TargetCriticNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.q_value_layer = nn.Linear(256, 1)  # 输出 Q 值

    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)  # 将 state 和 action 作为输入
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.q_value_layer(x)  # 输出 Q 值
        return q_value

3.4 软更新模块

在 SAC 中,目标网络会通过软更新逐渐逼近 Critic 网络的参数。每次更新后,目标网络参数会按照 ττ 的比例向 Critic 网络的参数靠拢。

def soft_update(critic, target_critic, tau=0.005):
    for param, target_param in zip(critic.parameters(), target_critic.parameters()):
        target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)

3.5 总结
  1. 初始化网络和参数:
    • Actor 网络:用于选择动作。
    • Critic 1 和 Critic 2 网络:用于估计 Q 值。
    • Target Critic 1 和 Target Critic 2:与 Critic 网络架构相同,用于生成更稳定的目标 Q 值。
  2. 目标 Q 值计算:
    • 使用目标网络计算下一状态下的 Q 值。
    • 取两个 Q 网络输出的最小值,防止 Q 值的过估计。
    • 引入熵正则项,计算公式: y = r + γ ⋅ min ⁡ ( Q 1 , Q 2 ) − α ⋅ log ⁡ π ( a ∣ s ) y=r+\gamma\cdot\min(Q_1,Q_2)-\alpha\cdot\log\pi(a|s) y=r+γmin(Q1,Q2)αlogπ(as)
  3. 更新 Critic 网络:
    • 最小化目标 Q 值与当前 Q 值的均方误差 (MSE)。
  4. 更新 Actor 网络:
    • 最大化目标损失: L = α ⋅ log ⁡ π ( a ∣ s ) − Q 1 ( s , π ( s ) ) L=\alpha\cdot\log\pi(a|s)-Q_1(s,\pi(s)) L=αlogπ(as)Q1(s,π(s)),即在保证探索的情况下选择高价值动作。
  5. 软更新目标网络:
    • 软更新目标 Q 网络参数,使得目标网络参数缓慢向当前网络靠近,避免振荡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2208136.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Android Studio 和 MATLAB 中 gradle无法下载或下载过慢问题的解决 2024-10-08

系统环境: win10 64bit , MATLAB 2022b 1.从第三方镜像下载gradle包 如 腾讯镜像站 : 腾讯软件源gradle 选择需要的版本进行下载: 这里我选择首图中需要的 gradle-7.0.2-all.zip 2.完成 将下载好的文件放置下列路径 C:\Users\Administrator(这里替换成你所使用的用户名)\…

vscode显示.vscode文件

对于我这样的vscode新手来说,刚开始,都不知道如何生成.vscode文件,敢肯定的是,有很多同学和我一样,也不知道如何生成.vscode文件。 这个的话,我选择了第一个 然后提示报错也没事,因为已经生成…

月之暗面推出 Kimi 探索版:搜索量暴增 10 倍,精读 500 页信息,开启 AI 搜索新纪元

月之暗面推出 Kimi 探索版:搜索量暴增 10 倍,精读 500 页信息,开启 AI 搜索新纪元 作者: 猫头虎 最近,国产 AI 独角兽公司月之暗面再度刷新了我们的认知,其推出的Kimi 探索版凭借自主 AI 搜索能力&#…

【寄存器开发速成】半小时入门寄存器开发(基于STM32的寄存器开发简明教程)

一.认识寄存器 寄存器(register)是CPU(中央处理器)的组成部分,是一种直接整合到cpu中的有限的高速访问速度的存储器,它是有一些与非门组合组成的,分为通用寄存器和特殊寄存器。 寄存器是CPU的最…

产品经理,真有35岁这道坎吗?

前言 在职场生涯的某个阶段,产品经理们往往会面临一个普遍的疑问:是否存在一个35岁的门槛,一旦跨过,职业发展就会遭遇瓶颈?尤其是在技术迭代迅速的互联网行业,这样的担忧尤为明显。然而,对于有…

教你如何2小时从零开始搭建一套完整的性能测试环境

文章目录 一、前言1.1 准备工作1.2 最终目标 二、安装步骤2.1 购买云服务器和NAS存储2.2 kubekey方式安装k8s集群2.2.1 环境检查及安装基础包2.2.2 kubekey安装k8s集群2.2.3 压测机环境准备2.2.4 中间件部署2.2.4.1 部署NFS远程共享存储2.2.4.2 部署MySQL2.2.4.3 部署Redis2.2.…

spring 启动失败 active: @env@

参考:SpringBoot启动失败报错,spring.profiles.active:env中环境变量无法识别报错_active: env_profileactive启动报错 ine 3, column 13:-CSDN博客

通义灵码-----阿里巴巴推出的 AI 编程助手,一站式安装使用教程。 我自己就是在用,感觉写代码会高效很多

"通义灵码"(Tongyi Lingma),这是阿里巴巴推出的 AI 编程助手。通义灵码是基于阿里云的通义大模型,为开发者提供代码补全、代码生成等智能辅助功能。 启用和使用通义灵码 以下是如何在 IntelliJ IDEA 中安装和使用通义灵…

VSCode的常用插件(持续更新)

点击左边工具栏的“扩展”,在搜索栏中查找对应插件,点击“安装”,安装完成后右边界面的插件会显示“卸载”按钮。 1、中文(简体)语言包 2、Auto Rename Tag 修改开始标签,结束标签也会随之自动变化。 3、O…

ClickHouse的原理及使用,

1、前言 一款MPP查询分析型数据库——ClickHouse。它是一个开源的,面向列的分析数据库,由Yandex为OLAP和大数据用例创建。ClickHouse对实时查询处理的支持使其适用于需要亚秒级分析结果的应用程序。ClickHouse的查询语言是SQL的一种方言,它支…

Python 情感分析与词向量

Python 情感分析与词向量 在现代数据驱动的世界中,情感分析成为了一种重要的文本分析技术,它帮助我们理解和挖掘用户对产品、服务或事件的情感倾向。Python 作为一种强大的编程语言,提供了丰富的工具和库来支持情感分析的实现,其…

【Kubernets】配置类型资源 Etcd, Secret, ConfigMap

文章目录 所有资源概览Etcd详细说明一、基本概念二、主要功能三、架构与组件四、数据模型与操作五、安全与认证六、集群部署与管理 Secret详细说明一、Secret 的类型二、Secret 的创建三、Secret 的使用四、Secret 的更新与删除五、Secret 的安全性 ConfigMap详细说明一、Confi…

Web服务器小项目(Linux / C / epoll)

注意:前置知识: HTTP: https://xingzhu.top/archives/web-fu-wu-qi Linux 多线程: https://xingzhu.top/archives/duo-xian-cheng 源码放github上了,欢迎star: https://github.com/xingzhuz/webServer 思路 实现代码 server.h #pragma once #include &…

毕设成品 基于深度学习二维码检测识别系统

文章目录 0 简介1 二维码基础概念1.1 二维码介绍1.2 QRCode1.3 QRCode 特点 2 机器视觉二维码识别技术2.1 二维码的识别流程2.2 二维码定位2.3 常用的扫描方法 4 深度学习二维码识别4.1 部分关键代码 最后 0 简介 今天学长向大家分享一个毕业设计项目 **毕业设计 基于深度学习…

【最新华为OD机试E卷-支持在线评测】第K个排列(100分)多语言题解-(Python/C/JavaScript/Java/Cpp)

🍭 大家好这里是春秋招笔试突围 ,一枚热爱算法的程序员 💻 ACM金牌🏅️团队 | 大厂实习经历 | 多年算法竞赛经历 ✨ 本系列打算持续跟新华为OD-E/D卷的多语言AC题解 🧩 大部分包含 Python / C / Javascript / Java / Cpp 多语言代码 👏 感谢大家的订阅➕ 和 喜欢�…

【MySQL 保姆级教学】在Linux(CentoS 7)中安装MySQL(1)

目录 1. 卸载linux(Centos7) 中不要的环境2. 获取MySQL官方yum源2.1 获取yum源前先查看自己 linux(Centos)的版本2.2 获取官方yum源 3. 安装xftp和连接4. 开放连接端口5. 上传文件到Centos76. 安装MySQL6.1 顺利安装6.2 查询是否安…

Terminus ssh key 登陆

生成key 一、添加 KEY 配置 电脑: Terminus > Preferences,或 ⌘,。选择左侧 Keychain 标签。 手机: Terminus > Keychain 电脑: 点击右侧上方的 NEW KEY 按钮, 手机: 点加号 电脑: 在最右侧弹出的页面中填写 Label 和 Private key,Private ke…

电脑怎么录屏?探索屏幕捕捉的奥秘,新手也能成为录屏高手!

在数字时代,无论是制作教学视频、分享游戏精彩瞬间还是展示软件操作流程,屏幕录制都成了一项必不可少的技能。然而,对于许多初次接触录屏的新手来说,如何开始这一过程似乎充满了挑战。本文将为你揭开录屏的神秘面纱,带…

golang-基础知识(流程控制)

1 条件判断if和switch 所有的编程语言都有这个if,表示如果满足条件就做某事,不满足就做另一件事,go中的if判断和其它语言的区别主要有以下两点 1. go里面if条件判断不需要括号 2. go的条件判断语句中允许声明一个变量,这个变量…

MySQL8.0环境部署+Navicat17激活教程

安装MySQL 下载MySQL MySQL官网下载当前最新版本,当前是8.0.39。 选择No thanks, just start my download等待下载即可。 安装MySQL 下载完成后,双击安装进入安装引导页面。选择Custom自定义安装。 选择MySQL Server 8.0.39 - X64安装。 点击Execute执…