Deep Deterministic Policy Gradient (DDPG)算法

news2025/2/8 15:18:05

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
import random
from collections import deque

# 定义 Actor 网络
class Actor(nn.Module):
    def __init__(self, state_dim, action_dim, max_action):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(state_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, action_dim)
        self.max_action = max_action

    def forward(self, state):
        x = torch.relu(self.fc1(state))
        x = torch.relu(self.fc2(x))
        action = self.max_action * torch.tanh(self.fc3(x))  # 输出在 [-max_action, max_action] 范围内
        return action

# 定义 Critic 网络
class Critic(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 1)

    def forward(self, state, action):
        x = torch.cat([state, action], 1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        q_value = self.fc3(x)
        return q_value

# DDPG 算法
class DDPG:
    def __init__(self, state_dim, action_dim, max_action):
        self.actor = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target = Actor(state_dim, action_dim, max_action).to(device)
        self.actor_target.load_state_dict(self.actor.state_dict())
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)

        self.critic = Critic(state_dim, action_dim).to(device)
        self.critic_target = Critic(state_dim, action_dim).to(device)
        self.critic_target.load_state_dict(self.critic.state_dict())
        self.critic_optimizer = optim.Adam(self.critic.parameters(), lr=1e-3)

        self.max_action = max_action
        self.replay_buffer = deque(maxlen=1000000)
        self.batch_size = 256
        self.gamma = 0.99
        self.tau = 0.005

    def select_action(self, state):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.actor(state).cpu().data.numpy().flatten()

    def add_to_buffer(self, state, action, reward, next_state, done):
        self.replay_buffer.append((state, action, reward, next_state, done))

    def train(self):
        if len(self.replay_buffer) < self.batch_size:
            return

        # 从 Replay Buffer 中采样
        batch = random.sample(self.replay_buffer, self.batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))

        state = torch.FloatTensor(state).to(device)
        action = torch.FloatTensor(action).to(device)
        reward = torch.FloatTensor(reward).unsqueeze(1).to(device)
        next_state = torch.FloatTensor(next_state).to(device)
        done = torch.FloatTensor(done).unsqueeze(1).to(device)

        # 更新 Critic 网络
        with torch.no_grad():
            next_action = self.actor_target(next_state)
            target_q = self.critic_target(next_state, next_action)
            target_q = reward + (1 - done) * self.gamma * target_q

        current_q = self.critic(state, action)
        critic_loss = nn.MSELoss()(current_q, target_q)

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        # 更新 Actor 网络
        actor_loss = -self.critic(state, self.actor(state)).mean()

        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        # 软更新目标网络
        for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

        for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
            target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data)

# 训练 DDPG 算法
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('MountainCarContinuous-v0')  # 使用 MountainCarContinuous-v0 环境
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

ddpg = DDPG(state_dim, action_dim, max_action)
max_episodes = 1000
max_steps = 200

for episode in range(max_episodes):
    state = env.reset()
    episode_reward = 0

    for step in range(max_steps):
        env.render()  # 渲染环境
        action = ddpg.select_action(state)
        next_state, reward, done, _ = env.step(action)
        ddpg.add_to_buffer(state, action, reward, next_state, done)
        state = next_state
        episode_reward += reward

        ddpg.train()

        if done:
            break

    print(f"Episode {episode + 1}, Reward: {episode_reward}")

env.close()

简介

在强化学习领域中,很多实际问题涉及到连续的动作空间,比如机器人的关节角度控制、自动驾驶中车辆的速度和转向控制等。传统的基于策略梯度的算法(如 A2C、A3C 等)以及基于值函数的算法(如 DQN 及其变体)在处理连续动作空间时往往面临诸多困难,DDPG 旨在有效地解决在连续动作空间下进行策略学习的问题。

图片

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2267013.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

C++ OCR 文字识别

一.引言 文字识别&#xff0c;也称为光学字符识别&#xff08;Optical Character Recognition, OCR&#xff09;&#xff0c;是一种将不同形式的文档&#xff08;如扫描的纸质文档、PDF文件或数字相机拍摄的图片&#xff09;中的文字转换成可编辑和可搜索的数据的技术。随着技…

【解决报错】AttributeError: ‘NoneType‘ object has no attribute ‘group‘

学习爬虫时&#xff0c;遇到如下报错&#xff1a; 报错原因&#xff1a; 正则表达式的 search 或 finditer 方法没有找到任何匹配项&#xff0c;可能是换行符处理不当等。 解决方法如下&#xff1a; 在正则表达式末尾加上re.S即可&#xff0c;re.S是一个编译标志&#xff0c…

JVM实战—3.JVM垃圾回收的算法和全流程

大纲 1.JVM内存中的对象何时会被垃圾回收 2.JVM中的垃圾回收算法及各算法的优劣 3.新生代和老年代的垃圾回收算法 4.避免本应进入S区的对象直接升入老年代 5.Stop the World问题分析 6.JVM垃圾回收的原理核心流程 7.问题汇总 1.JVM内存中的对象何时会被垃圾回收 (1)什么…

基于SpringBoot在线音乐系统平台功能实现十八

一、前言介绍&#xff1a; 1.1 项目摘要 随着互联网技术的迅猛发展和普及&#xff0c;人们对音乐的获取和欣赏方式发生了巨大改变。传统的音乐播放方式&#xff0c;如CD、磁带或本地下载的音乐文件&#xff0c;已经不能满足用户日益增长的需求。用户更希望通过网络直接获取各…

RouYi-Vue框架,环境搭建以及使用

使用若以框架需要配置node.js&#xff0c;如果不了解可以去看node.js安装&#xff0c;uni-app的配置使用_uniapp使用nodejs类库-CSDN博客 安装若依 首先是去若以官网下载自己所需要的框架类型 RuoYi-Vue: &#x1f389; 基于SpringBoot&#xff0c;Spring Security&#xff…

XL系列433芯片、2.4G收发芯片 通讯对码说明

XL系列433芯片对码说明&#xff1a; 发射芯片 XL4456 通过数据脚接收高低电平然后经过调制将波形发出&#xff0c;而接收芯片 XL520 通过接收波形后进行解调&#xff0c;数据脚输出高低电平。至于具体的通信协议&#xff0c;需要用户自定义&#xff0c;一般而言&#xff0c;使…

蓝牙BLE开发——解决iOS设备获取MAC方式

解决iOS设备获取MAC方式 uniapp 解决 iOS 获取 MAC地址&#xff0c;在Android、iOS不同端中互通&#xff0c;根据MAC 地址处理相关的业务场景&#xff1b; 文章目录 解决iOS设备获取MAC方式监听寻找到新设备的事件BLE工具效果图APP监听设备返回数据解决方式ArrayBuffer转16进制…

期权懂|如何计算期权卖方平仓后的盈利?

锦鲤三三每日分享期权知识&#xff0c;帮助期权新手及时有效地掌握即市趋势与新资讯&#xff01; 如何计算期权卖方平仓后的盈利&#xff1f; 期权卖方平仓后的盈利计算涉及多个因素&#xff0c;包括期权的交易价格、平仓价格以及权利金的变动等。 交易价格&#xff1a;期权卖…

QT:一个TCP客户端自动连接的测试模型

版本 1:没有取消按钮 测试效果&#xff1a; 缺陷&#xff1a; 无法手动停止 测试代码 CMakeLists.txt cmake_minimum_required(VERSION 3.19) project(AutoConnect LANGUAGES CXX)find_package(Qt6 6.5 REQUIRED COMPONENTS Core Widgets Network)qt_standard_project_setup(…

uniapp中wx.getFuzzyLocation报错如何解决

一、用wx.getLocation接口审核不通过 用uniapp开发小程序时难免需要获取当前地理位置。 代码如下&#xff1a; uni.getLocation({type: wgs84,success: function (res) {console.log(当前位置的经度&#xff1a; res.longitude);console.log(当前位置的纬度&#xff1a; r…

解决Ubuntu下无法装载 Windows D盘的问题

电脑安装了 Windows 和 Ubuntu 24.04 后&#xff0c;在Ubuntu系统上装载 D盘&#xff0c;发现无法装载错误如下&#xff1a; Error mounting /dev/nvme0n1p4 at /media/jackeysong/Data: wrong fs type, bad option, bad superblock on /dev/nvme0n1p4, missing codepage or h…

硬件设计-高速电路的过孔

目录 摘要 &#xff1a; 过孔的机械特性&#xff1a; 过孔直径&#xff1a; 过孔焊盘尺寸 摘要 &#xff1a; 过孔这个词指得是印刷电路板&#xff08; PCB &#xff09;上的孔。过孔可以用做焊接插装器件的焊&#xff08; Through hole) &#xff0c;也可用做连接层间走…

mysql索引的理解

1、索引是什么&#xff1f; 索引&#xff1a;简单理解就是我们字典的目录&#xff0c;一个索引可以找得到多个记录。 作用加快我们数据库的查询速度。索引本身较大&#xff0c;往往存储在磁盘的文件里。可能存储在单独的索引文件中&#xff0c;也可能和数据一起存储在数据文件…

【WRF模拟】如何得到更佳的WRF模拟效果?

【WRF模拟】如何得到更佳的WRF模拟效果&#xff1f; 模型配置&#xff08;The Model Configuration&#xff09;1.1 模拟区域domain设置1.2 分辨率Resolution (horizontal and vertical)案例&#xff1a;The Derecho of 29-30 June 2012 1.3 初始化和spin-up预热过程案例1-有无…

IOS safari 播放 mp4 遇到的坎儿

起因 事情的起因是调试 IOS 手机下播放服务器接口返回的 mp4 文件流失败。对于没调试过移动端和 Safari 的我来说着实费了些功夫&#xff0c;网上和AI也没有讲明白。好在最终大概理清楚了&#xff0c;在这里整理出来供有缘人参考。 问题 因为直接用 IOS 手机的浏览器打开页面…

单片机+人体红外感应的防盗系统设计(仿真+源码+PCB文件+报告)

资料下载地址&#xff1a;单片机人体红外感应的防盗系统设计(仿真源码PCB文件报告) 1、功能介绍 (1)该设计包括硬件和软件设计两个部分。 (2)本红外线防盗报警系统由热释电红外传感器、报警器、单片机控制电路、LED控制电路及相关的控制管理软件组成。用户终端完成信息采集、处…

网络攻防实践

1. 学习总结 黛蛇蠕虫案例&#xff1a; 原理&#xff1a;利用系统漏洞&#xff0c;并集成攻击代码。其中&#xff0c;通过蜜罐技术并进行数据分析所获取的攻击场景如下&#xff1a; 外部感染源攻陷蜜罐主机 执行Shellcode后获取主机权限后连接控制命令服务器&#xff0c;获取F…

寒假准备找实习复习java基础-day1

CMD常用命令&#xff1a; java跨平台原理&#xff1a; JRE和JVM Java基本数据类型

MacOS安装Xcode(非App Store)

文章目录 访问官网资源页面 访问官网资源页面 直接访问官网的历史版本下载资源页面地址&#xff1a;https://developer.apple.com/download/more/完成APP ID的登陆&#xff0c;直接找到需要的软件下载即可 解压后&#xff0c;安装将xcode.app移动到应用程序文件夹。

Docker 安装mysql ,redis,nacos

一、Mysql 一、Docker安装Mysql 1、启动Docker 启动&#xff1a;sudo systemctl start dockerservice docker start 停止&#xff1a;systemctl stop docker 重启&#xff1a;systemctl restart docker 2、查询mysql docker search mysql 3、安装mysql 3.1.默认拉取最新版…