强化学习--多维动作状态空间的设计

news2025/1/17 5:59:33

目录

  • 一、离散动作
  • 二、连续动作
    • 1、例子1
    • 2、知乎给出的示例
    • 2、github里面的代码

免责声明:以下代码部分来自网络,部分来自ChatGPT,部分来自个人的理解。如有其他观点,欢迎讨论!

一、离散动作

注意:本文均以PPO算法为例。

# time: 2023/11/22 21:04
# author: YanJP


import torch
import torch
import torch.nn as nn
from torch.distributions import Categorical

class MultiDimensionalActor(nn.Module):
    def __init__(self, input_dim, output_dims):
        super(MultiDimensionalActor, self).__init__()

        # Define a shared feature extraction network
        self.feature_extractor = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        # Define individual output layers for each action dimension
        self.output_layers = nn.ModuleList([
            nn.Linear(64, num_actions) for num_actions in output_dims
        ])

    def forward(self, state):
        # Feature extraction
        features = self.feature_extractor(state)

        # Generate Categorical objects for each action dimension
        categorical_objects = [Categorical(logits=output_layer(features)) for output_layer in self.output_layers]

        return categorical_objects

# 定义主函数
def main():
    # 定义输入状态维度和每个动作维度的动作数
    input_dim = 10
    output_dims = [5, 8]  # 两个动作维度,分别有 3 和 4 个可能的动作

    # 创建 MultiDimensionalActor 实例
    actor_network = MultiDimensionalActor(input_dim, output_dims)

    # 生成输入状态(这里使用随机数据作为示例)
    state = torch.randn(1, input_dim)

    # 调用 actor 网络
    categorical_objects = actor_network(state)

    # 输出每个动作维度的采样动作和对应的对数概率
    for i, categorical in enumerate(categorical_objects):
        sampled_action = categorical.sample()
        log_prob = categorical.log_prob(sampled_action)
        print(f"Sampled action for dimension {i+1}: {sampled_action.item()}, Log probability: {log_prob.item()}")

if __name__ == "__main__":
    main()

#Sampled action for dimension 1: 1, Log probability: -1.4930928945541382
#Sampled action for dimension 2: 3, Log probability: -2.1875085830688477

注意代码中categorical函数的两个不同传入参数的区别:参考链接
简单来说,logits是计算softmax的,probs直接就是已知概率的时候传进去就行。

二、连续动作

参考链接:github、知乎
为什么取对数概率?参考回答
在这里插入图片描述

1、例子1

先看如下的代码:

# time: 2023/11/21 21:33
# author: YanJP
#这是对应多维连续变量的例子:
# 参考链接:https://github.com/XinJingHao/PPO-Continuous-Pytorch/blob/main/utils.py
# https://www.zhihu.com/question/417161289
import torch.nn as nn
import torch
class Policy(nn.Module):
    def __init__(self, in_dim, n_hidden_1, n_hidden_2, num_outputs):
        super(Policy, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(in_dim, n_hidden_1),
            nn.ReLU(True),
            nn.Linear(n_hidden_1, n_hidden_2),
            nn.ReLU(True),
            nn.Linear(n_hidden_2, num_outputs)
        )

class Normal(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()
        self.stds = nn.Parameter(torch.zeros(num_outputs))  #创建一个可学习的参数 
    def forward(self, x):
        dist = torch.distributions.Normal(loc=x, scale=self.stds.exp())
        action = dist.sample((every_dimention_output,))  #这里我觉得是最重要的,不填sample的参数的话,默认每个分布只采样一个值!!!!!!!!
        return action

if __name__ == '__main__':
    policy = Policy(4,20,20,5)
    normal = Normal(5) #设置5个维度
    every_dimention_output=10  #每个维度10个输出
    observation = torch.Tensor(4)
    action = normal.forward(policy.layer( observation))
    print("action: ",action)
  • self.stds.exp(),表示求指数,因为正态分布的标准差都是正数。
  • action = dist.sample((every_dimention_output,))这里最重要!!!

2、知乎给出的示例


class Agent(nn.Module):
    def __init__(self, envs):
        super(Agent, self).__init__()
        self.actor_mean = nn.Sequential(
            layer_init(nn.Linear(np.array(envs.single_observation_space.shape).prod(), 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, np.prod(envs.single_action_space.shape)), std=0.01),
        )
        self.actor_logstd = nn.Parameter(torch.zeros(1, np.prod(envs.single_action_space.shape)))

    def get_action_and_value(self, x, action=None):
        action_mean = self.actor_mean(x)
        action_logstd = self.actor_logstd.expand_as(action_mean)
        action_std = torch.exp(action_logstd)
        probs = Normal(action_mean, action_std)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action).sum(1), probs.entropy().sum(1), self.critic(x)

这里的np.prod(envs.single_action_space.shape),表示每个维度的动作数相乘,然后初始化这么多个actor网络的标准差和均值,最后action里面的sample就是采样这么多个数据。(感觉还是拉成了一维计算)

2、github里面的代码

github

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Beta,Normal

class GaussianActor_musigma(nn.Module):
	def __init__(self, state_dim, action_dim, net_width):
		super(GaussianActor_musigma, self).__init__()

		self.l1 = nn.Linear(state_dim, net_width)
		self.l2 = nn.Linear(net_width, net_width)
		self.mu_head = nn.Linear(net_width, action_dim)
		self.sigma_head = nn.Linear(net_width, action_dim)

	def forward(self, state):
		a = torch.tanh(self.l1(state))
		a = torch.tanh(self.l2(a))
		mu = torch.sigmoid(self.mu_head(a))
		sigma = F.softplus( self.sigma_head(a) )
		return mu,sigma

	def get_dist(self, state):
		mu,sigma = self.forward(state)
		dist = Normal(mu,sigma)
		return dist

	def deterministic_act(self, state):
		mu, sigma = self.forward(state)
		return mu

上述代码主要是通过设置mu_head 和sigma_head的个数,来实现多维动作。

class GaussianActor_mu(nn.Module):
	def __init__(self, state_dim, action_dim, net_width, log_std=0):
		super(GaussianActor_mu, self).__init__()

		self.l1 = nn.Linear(state_dim, net_width)
		self.l2 = nn.Linear(net_width, net_width)
		self.mu_head = nn.Linear(net_width, action_dim)
		self.mu_head.weight.data.mul_(0.1)
		self.mu_head.bias.data.mul_(0.0)

		self.action_log_std = nn.Parameter(torch.ones(1, action_dim) * log_std)

	def forward(self, state):
		a = torch.relu(self.l1(state))
		a = torch.relu(self.l2(a))
		mu = torch.sigmoid(self.mu_head(a))
		return mu

	def get_dist(self,state):
		mu = self.forward(state)
		action_log_std = self.action_log_std.expand_as(mu)
		action_std = torch.exp(action_log_std)

		dist = Normal(mu, action_std)
		return dist

	def deterministic_act(self, state):
		return self.forward(state)
class Critic(nn.Module):
	def __init__(self, state_dim,net_width):
		super(Critic, self).__init__()

		self.C1 = nn.Linear(state_dim, net_width)
		self.C2 = nn.Linear(net_width, net_width)
		self.C3 = nn.Linear(net_width, 1)

	def forward(self, state):
		v = torch.tanh(self.C1(state))
		v = torch.tanh(self.C2(v))
		v = self.C3(v)
		return v

上述代码只定义了mu的个数与维度数一样,std作为可学习的参数之一。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1239513.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

93.STL-系统内置仿函数

目录 算术仿函数 关系仿函数 逻辑仿函数 C 标准库中提供了一些内置的函数对象&#xff0c;也称为仿函数&#xff0c;它们通常位于 <functional> 头文件中。以下是一些常见的系统内置仿函数&#xff1a; 算术仿函数 功能描述&#xff1a; 实现四则运算其中negate是一元…

基于向量加权平均算法优化概率神经网络PNN的分类预测 - 附代码

基于向量加权平均算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于向量加权平均算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于向量加权平均优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xf…

Unity开发之C#基础-File文件读取

前言 今天我们将要讲解到c#中 对于文件的读写是怎样的 那么没接触过特别系统编程小伙伴们应该会有一个疑问 这跟文件有什么关系呢&#xff1f; 我们这样来理解 首先 大家对电脑或多或少都应该有不少的了解吧 那么我们这些软件 都是通过变成一个一个文件保存在电脑中 我们才可以…

vue根据接口数据配置动态路由(动态配置后台管理系统路由权限)

文章目录 前言一、什么是动态路由二、以 后台管理系统路由权限配置为例静态路由配置动态路由配置 总结如有启发&#xff0c;可点赞收藏哟~ 前言 其几天记录了根据目录接口动态配置vue的静态路由 本文结合addRoute记录下配置动态路由 一、什么是动态路由 动态路由是根据实际配…

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码

基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码 文章目录 基于材料生成算法优化概率神经网络PNN的分类预测 - 附代码1.PNN网络概述2.变压器故障诊街系统相关背景2.1 模型建立 3.基于材料生成优化的PNN网络5.测试结果6.参考文献7.Matlab代码 摘要&#xff1a;针对PNN神…

Gradle常用命令与参数依赖管理和版本决议

一、Gradle 常用命令与参数 本课程全程基于 Gradle8.0 环境 1、Gradle 命令 介绍 gradle 命令之前我们先来了解下 gradle 命令怎么在项目中执行。 1.1、gradlew gradlew 即 Gradle Wrapper&#xff0c;在学习小组的第一课时已经介绍过了这里就不多赘述。提一下执行命令&am…

msvcp140.dll是什么?msvcp140.dll丢失的有哪些解决方法

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中之一就是“msvcp140.dll丢失”。这个错误通常会导致某些应用程序无法正常运行。为了解决这个问题&#xff0c;我们需要采取一些措施来修复丢失的msvcp140.dll文件。本文将详细介绍5个解决msvcp140.dl…

所有产品都值得用AI再做一遍,让AGI与品牌营销双向奔赴

微软 CEO Satya Nadella 曾经说过&#xff1a;“所有的产品都值得用 AI 重做一遍。” AI 大模型的出现&#xff0c;开启了一个全新的智能化时代&#xff0c;重新定义了人机交互。这让生成式 AI 技术变得「触手可得」&#xff0c;也让各行业看到 AGI 驱动商业增长的更大可能性。…

【开源】基于Vue和SpringBoot的高校宿舍调配管理系统

项目编号&#xff1a; S 051 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S051&#xff0c;文末获取源码。} 项目编号&#xff1a;S051&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能需求2.1 学生端2.2 宿管2.3 老师端 三、系统…

Axios 请求响应结果的结构

发送请求 this.$axios.get(https://apis.jxcxin.cn/api/title?urlhttps://apis.jxcxin.cn/,{params: {id: 10}}).then(res > {console.log(res)})输出返回结果 confing 请求时的配置对象&#xff0c;如果请求的url&#xff0c;请求的方法&#xff0c;请求的参数&#xff0c…

如何解决msvcp110.dll丢失问题,分享5个有效的解决方法

最近&#xff0c;我在使用电脑时遇到了一个令人头疼的问题——msvcp110.dll丢失。这个错误通常会导致某些应用程序无法正常运行。为了解决这个问题&#xff0c;我们需要采取一些有效的方法来修复丢失的msvcp110.dll文件。那么&#xff0c;msvcp110.dll到底是什么呢&#xff1f;…

【python基础(三)】操作列表:for循环、正确缩进、切片的使用、元组

文章目录 一. 遍历整个列表1. 在for循环中执行更多操作2. 在for循环结束后执行一些操作 二. 避免缩进错误三. 创建数值列表1. 使用函数range()2. 使用range()创建数字列表3. 指定步长。4. 对数字列表执行简单的统计计算5. 列表解析 五. 使用列表的一部分-切片1. 切片2. 遍历切片…

一文搞懂什么是 GNU/Linux 操作系统

Author&#xff1a;rab 目录 前言一、UNIX二、Linux三、GNU 前言 你是否经常看见或听说过这么一句话&#xff1a;这是一个类 Unix 的 GNU/Linux 操作系统&#xff0c;你是怎么理解这句话的呢&#xff1f;想要搞懂这句话的含义&#xff0c;你需要了解以下三点基本常识。 一、U…

RedisTemplate使用详解

RedisTemplate介绍StringRedisTemplate介绍RedisConnectionFactory介绍RedisConnectionFactory源码解析 RedisOperations介绍RedisOperations源码解析 RedisTemplate使用连接池配置RedisTemplate连接池连接池配置 RedisTemplate应用场景RedisTemplate主要特点RedisTemplate使用…

Nuxt.js Next.js Nest.js

Nuxt.js和Next.js都是服务端渲染框架(SSR)&#xff0c;属于前端框架,Nest.js则是node框架,属于后端框架。 其中Nuxt.js是vue的ssr框架&#xff0c;Next.js是react的ssr框架。 都是比vue和react更上层的前端框架。 文章目录 1.SSR2.Nuxt2.1 Nuxt的下载2.2 Nuxt的集成2.3 Nuxt…

【tomcat】java.lang.Exception: Socket bind failed: [730048

项目中一些旧工程运行情况处理 问题 1、启动端口占用 2、打印编码乱码 ʮһ&#xfffd;&#xfffd; 13, 2023 9:33:26 &#xfffd;&#xfffd;&#xfffd;&#xfffd; org.apache.coyote.AbstractProtocol init &#xfffd;&#xfffd;&#xfffd;&#xfffd;: Fa…

【DevOps】Git 图文详解(八):后悔药 - 撤销变更

Git 图文详解&#xff08;八&#xff09;&#xff1a;后悔药 - 撤销变更 1.后悔指令 &#x1f525;2.回退版本 reset3.撤销提交 revert4.checkout / reset / revert 总结 发现写错了要回退怎么办&#xff1f;看看下面几种后悔指令吧&#xff01; ❓ 还没提交的怎么撤销&#x…

人工智能基础_机器学习047_用逻辑回归实现二分类以上的多分类_手写代码实现逻辑回归OVR概率计算---人工智能工作笔记0087

然后我们再来看一下如何我们自己使用代码实现逻辑回归的,对二分类以上,比如三分类的概率计算 我们还是使用莺尾花数据 首先我们把公式写出来 def sigmoid(z): 定义出来这个函数 可以看看到这需要我们理解OVR是如何进行多分类的,我们先来看这个 OVR分类器 思想 OVR(One-vs-…

如何用cmd命令快速搭建FTP服务

环境&#xff1a; Win10专业版 问题描述&#xff1a; 如何用cmd命令快速搭建FTP服务 解决方案&#xff1a; 1.输入以下命令来安装IIS&#xff08;Internet Information Services&#xff09;&#xff1a; dism /online /enable-feature /featurename:IIS-FTPServer /all …

好用的博客评论系统 Valine 使用及避坑指南

评论系统&#xff0c;即网站的一个小功能&#xff0c;展示评论内容和用户输入框。开源免费的评论系统可不多&#xff0c;原来很火的"多说"评论系统都关闭了&#xff0c;而Disqus又是国外的访问受限。无意间发现了Valine&#xff0c;挺不错的&#xff0c;分享给大家。…