【强化学习】常用算法之一 “SAC”

news2024/12/28 19:32:13

 

作者主页:爱笑的男孩。的博客_CSDN博客-深度学习,活动,python领域博主爱笑的男孩。擅长深度学习,活动,python,等方面的知识,爱笑的男孩。关注算法,python,计算机视觉,图像处理,深度学习,pytorch,神经网络,opencv领域.https://blog.csdn.net/Code_and516?type=blog个人简介:打工人。

持续分享:机器学习、深度学习、python相关内容、日常BUG解决方法及Windows&Linux实践小技巧。

如发现文章有误,麻烦请指出,我会及时去纠正。有其他需要可以私信我或者发我邮箱:zhilong666@foxmail.com 

        强化学习(Reinforcement Learning)是一种机器学习的方法,通过让智能体在与环境的交互中学习来制定策略,以最大化预期的累积奖励。SAC(Soft Actor-Critic)算法是一种强化学习算法,它结合了策略优化和价值函数学习,实现对连续动作空间的鲁棒性采样优化。

本文将详细讲解强化学习常用算法之一“SAC”


目录

一、简介

二、发展史

三、算法公式

策略更新公式:

Q值函数更新公式:

值函数更新公式:

四、算法原理

        1. 策略优化

        2. 价值函数学习

        3. 熵优化

        4. 自适应温度参数

五、算法功能

六、示例代码

七、总结


一、简介

        强化学习(Reinforcement Learning,RL)是一种机器学习的分支,其目标是让智能体(agent)通过与环境的交互学习到最优的行为策略。SAC(Soft Actor-Critic)算法是近年来在强化学习领域取得重要突破的算法之一,它是一种基于策略优化和价值函数学习的算法。相对于传统的强化学习算法,SAC算法在优化过程中引入了熵正则化和软化策略更新的概念,使得智能体能够更好地探索未知的状态,提高学习效率。

二、发展史

        SAC算法的发展离不开前人的工作。在介绍SAC算法之前,我们先了解一些相关的算法。

        1. DQN(Deep Q-Networks) DQN是由DeepMind提出的一种强化学习算法,它首次将深度神经网络与Q-Learning相结合。

        通过使用经验回放和目标网络来提升学习的稳定性,DQN算法在很多基准测试中都取得了优异的结果。

        2. DDPG(Deep Deterministic Policy Gradient) DDPG是一种用于连续动作空间的深度强化学习算法,它结合了深度神经网络和确定性策略梯度。

        DDPG算法在连续控制问题上取得了很好的表现,并被广泛用于实际应用中。

       3. SAC算法的前身 SAC算法的前身包括TD3(Twin Delayed DDPG)和DDPG算法。

        TD3算法在DDPG算法的基础上引入了双网络和延迟更新,进一步提升了算法的性能。SAC算法在TD3算法的基础上进一步拓展,引入了熵优化和自适应温度参数等技术,以适应更复杂的任务。

        SAC算法最早由Haarnoja等人于2018年提出,并发表在期刊《Journal of Machine Learning Research》中。该算法结合了Actor-Critic方法和强化学习中的熵概念,为强化学习中的连续控制任务提供了一种更高效、更稳定的解决方案。

三、算法公式

        SAC算法主要由以下几个核心公式组成:

  • 策略更新公式:

        其中,∇θpolicy​​J(θpolicy​)表示策略的梯度,π(a∣s)表示策略在状态s下采取动作a的概率,Qπ(s,a)表示状态动作对(s,a)的值函数,α表示熵调节系数,V~(s)表示柔化的值函数。 

  • Q值函数更新公式:

        其中,Q(s,a)表示状态动作对(s,a)的值函数,r(s,a)表示在状态s采取动作a时获得的即时奖励,γ表示折扣因子,V(s′)表示状态s's′的值函数,p(s′∣s,a)表示在状态s采取动作a后转移到状态s′的转移概率。 

  • 值函数更新公式:

        其中,V(s)表示状态ss的值函数,a ∼π表示从策略π中采样得到动作a。

四、算法原理

        SAC算法采用了一系列技术来实现在连续动作空间的鲁棒性采样优化。下面介绍SAC算法的主要原理:

        1. 策略优化

        SAC算法使用策略梯度方法来进行优化。通过最大化软Q值的目标函数,SAC算法能够有效地在连续动作空间进行采样,以提高采样效率和优化性能。

        2. 价值函数学习

        SAC算法引入了值函数的学习,通过学习值函数,可以更准确地估计状态-动作对的价值。值函数的学习可以通过最小化Bellman误差来实现,进一步提高算法的性能。

        3. 熵优化

        SAC算法通过最小化策略的熵来优化策略。熵是一个度量策略的不确定性的指标,通过最小化策略的熵,可以使策略更加均衡和多样化。这有助于提高算法对于不同环境和任务的适应性。

        4. 自适应温度参数

        SAC算法引入了自适应温度参数α,通过优化温度参数的选择,可以在最大化预期累积奖励和最小化策略熵之间取得平衡。自适应温度参数能够更好地适应不同任务和环境,提高算法的性能。

五、算法功能

        SAC算法在强化学习任务中具有以下主要功能:

  1. 支持连续动作空间:SAC算法适用于处理连续动作空间的任务,如机器人控制、无人驾驶等。
  2. 高效稳定的策略更新:通过引入熵调节项和柔化的值函数,SAC算法能够在不降低效率的情况下提高策略的探索性和稳定性。
  3. 较好的学习性能:相对于传统的强化学习算法,SAC算法在连续控制任务中通常能够获得更好的学习性能。
  4. 灵活的参数设置:SAC算法的熵调节系数和柔化值函数等参数可以根据任务的需求进行灵活调整,以得到最佳的性能。

六、示例代码

        以下是一个使用强化学习SAC算法训练和测试倒立摆环境的示例代码。在OpenAI Gym中安装倒立摆环境(Pendulum-v0)或其他适合SAC算法的环境后,可以执行该代码。

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
from torch.utils.data import DataLoader, Dataset


class ReplayBuffer(Dataset):
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity

    def __len__(self):
        return len(self.buffer)

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) >= self.capacity:
            self.buffer.pop(0)
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = zip(*random.sample(self.buffer, batch_size))
        return [torch.tensor(i) for i in batch]


class ValueNetwork(nn.Module):
    def __init__(self, state_dim):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


class PolicyNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.mean = nn.Linear(256, action_dim)
        self.log_std = nn.Linear(256, action_dim)
        self.max_action = max_action

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        mean = self.mean(x)
        log_std = self.log_std(x).clamp(-20, 2)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        action = normal.rsample()
        return action.clamp(-self.max_action, self.max_action), normal.log_prob(action).sum(1)


class SACAgent:
    def __init__(self, state_dim, action_dim, max_action):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.value_net = ValueNetwork(state_dim).to(self.device)
        self.target_value_net = ValueNetwork(state_dim).to(self.device)
        self.target_value_net.load_state_dict(self.value_net.state_dict())
        self.policy_net = PolicyNetwork(state_dim, action_dim, max_action).to(self.device)
        self.replay_buffer = ReplayBuffer(capacity=1000000)
        self.value_optimizer = optim.Adam(self.value_net.parameters(), lr=3e-4)
        self.policy_optimizer = optim.Adam(self.policy_net.parameters(), lr=3e-4)
        self.value_criterion = nn.MSELoss()

    def update_value_network(self, states, actions, rewards, next_states, masks):
        next_actions, next_log_probs = self.policy_net.sample(next_states)
        next_values = self.target_value_net(next_states)
        q_targets = rewards + masks * (next_values - next_log_probs.exp())
        values = self.value_net(states)
        loss = self.value_criterion(values, q_targets.detach())
        self.value_optimizer.zero_grad()
        loss.backward()
        self.value_optimizer.step()

    def update_policy_network(self, states):
        actions, log_probs = self.policy_net.sample(states)
        values = self.value_net(states)
        q_values = values - log_probs.exp()
        policy_loss = (log_probs.exp() * (log_probs - q_values).detach()).mean()
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

    def update_target_network(self):
        self.target_value_net.load_state_dict(self.value_net.state_dict())

    def train(self, env, num_episodes, batch_size, update_interval):
        state = env.reset()
        episode_rewards = []
        for episode in range(num_episodes):
            episode_reward = 0
            done = False
            while not done:
                action, _ = self.policy_net.sample(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
                next_state, reward, done, _ = env.step(action.cpu().detach().numpy()[0])
                self.replay_buffer.push(state, action.cpu().detach().numpy()[0], reward, next_state, float(done))
                state = next_state
                episode_reward += reward
                if len(self.replay_buffer) > batch_size:
                    states, actions, rewards, next_states, masks = self.replay_buffer.sample(batch_size)
                    self.update_value_network(states.float().to(self.device),
                                              actions.float().to(self.device),
                                              rewards.float().unsqueeze(1).to(self.device),
                                              next_states.float().to(self.device),
                                              masks.float().unsqueeze(1).to(self.device))
                    if episode % update_interval == 0:
                        self.update_policy_network(states.float().to(self.device))
                        self.update_target_network()
            episode_rewards.append(episode_reward)
        return episode_rewards

    def test(self, env):
        state = env.reset()
        done = False
        episode_reward = 0
        while not done:
            action, _ = self.policy_net.sample(torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(self.device))
            state, reward, done, _ = env.step(action.cpu().detach().numpy()[0])
            episode_reward += reward
        return episode_reward


if __name__ == "__main__":
    env_name = "Pendulum-v0"
    env = gym.make(env_name)
    env.seed(0)
    torch.manual_seed(0)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])

    agent = SACAgent(state_dim, action_dim, max_action)

    num_episodes = 100
    batch_size = 128
    update_interval = 10

    episode_rewards = agent.train(env, num_episodes, batch_size, update_interval)
    test_reward = agent.test(env)

    print("Training rewards:", episode_rewards)
    print("Test reward:", test_reward)

        这段代码首先定义了一个重放缓冲区(Replay Buffer)类,用于存储和采样经验。之后定义了值函数网络(Value Network)和策略网络(Policy Network)的类,分别用于估计值函数和策略。接下来是一个SACAgent类,其中包含了更新值函数网络、更新策略网络和更新目标值函数网络等方法。然后定义了训练和测试方法,其中训练方法会在环境中进行多个episode,并返回每个episode的累计奖励;测试方法用于评估训练好的策略在环境中的表现。

        在运行示例代码时,需要首先安装OpenAI Gym和PyTorch库。根据实际环境,可以修改env_name、num_episodes、batch_size和update_interval等参数。运行结果中,"Training rewards"是每个训练episode的累计奖励,"Test reward"是对训练好的策略在测试环境中运行的累计奖励。

运行结果

Training rewards: [-1653.2674326245028, -4.761833854079121, -5.92794045978663, -7.101895383837817, -8.203949829019429, -9.320596188422504, -10.398472530688595, -11.046385714744188, -10.069612464051666, -9.028488437838597, -7.656846467478978, -6.751302759291316, -5.892224950031628, -4.932040818022195, -4.404335946243107, -3.9543475318455914, -3.8004235924909593, -3.8954312087615484, -4.121609662371389, -4.645552707416158, -5.194625548020546, -6.270942803647476, -7.722571387132912, -9.49117141815922, -10.748767915311705, -11.837523420567333, -10.51287854951289, -8.911409206767225, -7.3159242910765805, -6.26554445728115, -5.318318410816599, -4.47352859150234, -3.705487907578077, -3.155346120863036, -2.6655070443703384, -2.5458468110930834, -2.7734881221702694, -3.1021955848735714, -3.9183340756372385, -4.677010046229791, -5.57281093988401, -6.5885638098856845, -7.691982718183524, -9.510764926014309, -10.809366474064687, -11.987368416541688, -10.57040679863866, -9.250008035195474, -7.908586504443504, -6.220578988348704, -4.8460643338024765, -4.060980241950622, -3.405435529895923, -2.767329044940599, -2.511189533487366, -2.4275672225189084, -2.454642944293755, -2.5937254351057217, -3.1160835151897968, -4.058114436352538, -5.445887904623622, -6.620130141605474, -7.949470581770992, -9.310201166829376, -11.434984365444118, -12.219258381790816, -10.891645129483637, -9.486480025372442, -8.059946018495705, -6.6809631024851495, -4.991482855801217, -3.7126215715421353, -3.031910380007442, -2.374267357519335, -2.0286805007142283, -2.0474943313467784, -2.4227627809752352, -3.191653624713721, -4.2051864164440875, -5.190187599031304, -6.332166895519481, -7.600904756549318, -8.942357396564006, -10.40428240474939, -11.714269430490143, -10.826518362820941, -9.66884676107395, -8.464936630889763, -6.899476506182678, -5.903640338789183, -4.751731696723347, -4.017007527711459, -3.5796759436048413, -3.328303909157216, -3.4151482609755326, -3.8343294110510615, -4.676829653734708, -5.442567944257961, -6.859903604078736, -8.648312542545764]
Test reward: -1207.2040024164842
 

七、总结

        本篇文章详细介绍了强化学习中的SAC(Soft Actor-Critic)算法,包括其发展史、算法公式、原理、功能以及示例代码。SAC算法是一种基于策略优化和价值函数学习的强化学习算法,通过引入熵调节和柔化值函数的概念,使得智能体能够更好地探索未知状态和优化策略。示例代码展示了使用SAC算法解决CartPole问题的过程,通过训练智能体与环境的交互,逐步提高智能体的控制性能。SAC算法在连续控制任务中具有较好的性能和稳定性,在未来有望应用于更多复杂的强化学习任务中。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/701938.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Ameya360:广和通发布新一代5G FWA解决方案

为满足日益增长的5G宽带连接需求,提升FWA部署的经济效益和技术可行性,广和通在MWCS 2023期间发布了基于新一代5G模组FG190&FG180的5G FWA整体解决方案,为FWA等移动终端提供了灵活、便捷、高效、可靠的联网方案,促进FWA快速迭代…

SC2161旋变数字转换器可pin对pin兼容AD2S1210

SC2161 是一款 10 位至 16 位分辨率旋变数字转换器,集成片上可编程正弦波振荡器,为旋变器提供正弦波激励。可pin对pin兼容AD2S1210。 转换器的正弦和余弦输入端允许输入 3.15 Vp−p 27%、频率为 2 kHz 至 20 kHz 范围内的信号。Type II 伺服环路用于跟踪…

基于51单片机的智能照明系统

目录 基于51单片机的智能照明系统一、原理图二、部分代码三、视频演示 基于51单片机的智能照明系统 功能: 1.通过LCD屏幕显示实时时间、光强和物体等 2.通过DS1302获取实时时间 3.通过按键调整灯的开关时间和手自动设置手动模式下手动开灯 4.蜂鸣器报警功能 5.上位…

“因构建 而可见”,亚马逊云科技中国峰会助力企业数字化转型升级

过去十年,数字化转型的浪潮携带着机遇和挑战席卷而来,几乎每个企业都在做数字化转型,开始向大数据、人工智能等新技术寻求生产力的突破。但随着数字化转型深入,很多企业开始感受到数字化投入的成本压力,加之新技术正带…

ML算法——Support Vector Machine随笔【机器学习】

文章目录 4、Support Vector Machine (SVM)4.1、理论部分4.1.1、更优的决策边界4.1.2、解决低维不可分问题 4.2、sklearn 实现4.2.1、SVM 分类(SVC)4.2.2、SVM回归(SVR)4.2.3、网格调参 4.3、案例 4、Support Vector Machine (SVM…

用Java编写Groovy脚本,然后用命令行执行该脚本

1、Groovy 语言简介 Groovy 是 Apache 旗下的一门基于 JVM 平台的动态/敏捷编程语言Groovy 可以与 Java 语言无缝对接,在写 Groovy 的时候如果忘记了语法可以直接按Java的语法继续写,也可以在 Java 中调用 Groovy 脚本,都可以很好的工作Groo…

Pytorch常用的函数(三)深度学习中常见的卷积操作详细总结

Pytorch常用的函数(三)深度学习中常见的卷积操作 1、标准卷积(Standard Convolution) 1.1 标准卷积的理解 我们直接来看二维卷积,这在实际应用中是最常见的。 上图中Conv 2D其实就是卷积核,也叫做滤波器。滤波器的值决定了输出的情况,模型…

【Java】Java核心 86:Git 教程(9)GIT远程仓库操作

文章目录 14.GIT远程仓库操作-关联、拉取、推送、克隆目标内容小结 Git提供了一系列命令来进行远程仓库的操作。 下面是一些常用的Git远程仓库操作&#xff1a; 克隆远程仓库到本地&#xff1a; git clone <远程仓库URL>查看远程仓库信息&#xff1a; git remote -v添…

功能键F4在Microsoft Excel中有什么用

的确,许多 Excel 用户发现使用键盘快捷键对他们来说更有效。事实上,键盘快捷键可能是使用鼠标的最佳选择,因为使用 Excel 时使用触摸屏可能不是视力障碍者的最佳选择。 使用功能键,如 Excel 中的 F4 以及 F2 可能是非常必要的。在这篇文章中,我们将研究功能键 F4 及其在 …

【Java】Java核心 85:Git 教程(8)GIT远程仓库介绍与码云仓库注册创建

文章目录 13.GIT远程仓库介绍与码云仓库注册创建目标小结 Git是一个分布式版本控制系统&#xff0c;它允许多个开发者协同工作并管理代码的版本。远程仓库是存放在网络上的Git仓库&#xff0c;可以用于团队成员之间的代码共享和协作。 常见的远程仓库托管服务提供商有GitHub、…

XShell、XFtp、Linux上MySQL的远程连接及使用

下载资源包&#xff0c;请于文章顶部下载即可 XShell的使用 1. 打开安装好的XShell 2. 点击左上角新建连接 3. 填写相应连接服务器信息 4. 输入需要连接到Linux操作系统哪个用户的用户名 5. 输入连接到用户的密码 6. 远程登录Linux成功 7. 此时可以正常使用Linux指令操作Linu…

chatgpt赋能python:隐藏鼠标:Python实现隐藏鼠标的应用

隐藏鼠标&#xff1a;Python实现隐藏鼠标的应用 作为一名有10年Python编程经验的工程师&#xff0c;我深知Python在图形用户界面(GUI)开发上的优势&#xff0c;其中一个有趣而且有用的应用就是隐藏鼠标。 在某些情况下&#xff0c;用户可能希望隐藏鼠标&#xff0c;这可以用于…

ARM-异常与中断(四)

文章目录 中断中断请求、中断源中断服务程序保存现场、恢复现场中断仲裁、中断优先级中断嵌套 异常广义上的异常同步异常异步异常精确异步异常&#xff08;Precise Asynchronous Exception&#xff09;非精确异步异常&#xff08;Imprecise Asynchronous Exception&#xff09;…

【DBA专属】MHA高可用数据库集群-----------一主一备两从一管理,一个VIP客户端

MHA高可用数据库集群 目录 环境配置&#xff1a; 所有操作系统均为centos 7.x 64bit 1、关闭防火墙&#xff1a;&#xff08;所有服务器&#xff09; 2.配置所有主机名映射&#xff08;所有服务器&#xff09; 3、同步时区 4.安装MHA node及相关perl依赖包&#xff08;所有…

AutoSAR系列讲解(入门篇)4.1-BSW概述

BSW概述 一、什么是BSW 二、BSW的结构 1、微控制器硬件抽象层&#xff08;MCAL&#xff09; 2、ECU抽象层 3、服务层 4、复杂驱动 三、再将结构细分 一、什么是BSW 中文翻译就是基础软件层&#xff08;Basic Software&#xff09;。这个基础软件层实质上就是将整个ECU分…

【斯坦福】FrugalGPT: 如何使用大型语言模型,同时降低成本并提高性能

FrugalGPT: 如何使用大型语言模型&#xff0c;同时降低成本并提高性能 作者&#xff1a;Lingjiao Chen, Matei Zaharia, James Zou 引言 本文介绍了一种新颖的方法&#xff0c;旨在解决使用大型语言模型&#xff08;LLM&#xff09;时面临的成本和性能挑战。随着GPT-4和Chat…

链路聚合综合实战

拓扑 需求 -PC1和PC3属于vlan 10、PC2和PC4属于vlan 20 -设备之间配置lacp模式的链路聚合&#xff0c;并确保同vlan之间的主机可以互通 配置步骤 1&#xff09;PC配置IP地址 2&#xff09;所有交换机创建vlan10 和vlan20 3&#xff09;交换机和PC互联的接口设置为access &am…

python数据分析之连接MySQL数据库并进行数据可视化

大家好&#xff0c;我是带我去滑雪&#xff01; 本期将熟悉MySQL数据库以及管理和操作MySQL数据库的数据库管理工具Navicat Premium&#xff0c;然后在python中调用MySQL数据库进行数据分析和数据可视化。 目录 1、MySQL数据库与数据库管理工具Navicat Premium 2、调用MySQL…

EasyCVR如何实现国标级联无人机推送的RTMP推流通道?

EasyCVR视频融合平台基于云边端一体化架构&#xff0c;可支持多协议、多类型设备接入&#xff0c;包括&#xff1a;NVR、IPC、视频编码器、无人机、车载设备、智能手持终端、移动执法仪等。平台具有强大的数据接入、处理及分发能力&#xff0c;可在复杂的网络环境中&#xff0c…

el-date-picker禁用指定日期之前或之后的日期

一、elementUI中el-date-picker禁用指定日期之前或之后的日期 通过配置picker-options配置指定禁用日期&#xff08;pickerOptions写到data里面&#xff09; <el-date-pickerv-model"date"type"date"size"small"value-format"yyyy-MM-d…