强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN

news2025/1/12 0:58:54

强化学习代码实践1.DDQN:在CartPole游戏中实现 Double DQN

      • 1. 导入依赖
      • 2. 定义 Q 网络
      • 3. 创建 Agent
      • 4. 训练过程
      • 5. 解释
      • 6. 调整超参数

CartPole 游戏中实现 Double DQN(DDQN)训练网络时,我们需要构建一个使用两个 Q 网络(一个用于选择动作,另一个用于更新目标)的方法。Double DQN 通过引入目标网络来减少 Q-learning 中过度估计的偏差。

下面是一个基于 PyTorch 的 Double DQN 实现:

1. 导入依赖

import random
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import gym
from collections import deque

2. 定义 Q 网络

我们需要定义一个 Q 网络,用于计算 Q 值。这里使用简单的全连接网络。

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, action_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

3. 创建 Agent

class DoubleDQNAgent:
    def __init__(self, state_dim, action_dim, gamma=0.99, epsilon=0.1, epsilon_decay=0.995, epsilon_min=0.01, lr=0.0005):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.lr = lr

        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=self.lr)

        self.memory = deque(maxlen=10000)
        self.batch_size = 64

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.choice(range(self.action_dim))  # Explore
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state)
            return torch.argmax(q_values).item()  # Exploit

    def store_experience(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample_batch(self):
        return random.sample(self.memory, self.batch_size)

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def train(self):
        if len(self.memory) < self.batch_size:
            return

        batch = self.sample_batch()
        states, actions, rewards, next_states, dones = zip(*batch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # Q values for current states
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # Next Q values using target network
        next_q_values = self.target_network(next_states)
        next_actions = self.q_network(next_states).argmax(1)
        next_q_values = next_q_values.gather(1, next_actions.unsqueeze(1)).squeeze(1)

        # Double DQN update
        target = rewards + (1 - dones) * self.gamma * next_q_values

        # Compute loss
        loss = nn.MSELoss()(q_values, target)

        # Optimize the Q-network
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

4. 训练过程

def train_cartpole():
    env = gym.make('CartPole-v1')
    agent = DoubleDQNAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.n)

    episodes = 1000
    for episode in range(episodes):
        state, info = env.reset()
        done = False
        total_reward = 0

        while not done:
            action = agent.select_action(state)
            next_state, reward, done, truncated, info = env.step(action)
            agent.store_experience(state, action, reward, next_state, done)
            state = next_state

            agent.train()
            total_reward += reward

        agent.update_target_network()

        print(f"Episode {episode}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.4f}")

    env.close()

if __name__ == '__main__':
    train_cartpole()

5. 解释

  • QNetwork: 使用一个简单的 3 层全连接神经网络来近似 Q 函数。
  • DoubleDQNAgent:
    • select_action: 根据 ε-greedy 策略选择动作。
    • store_experience: 存储经验回放。
    • sample_batch: 从记忆中随机采样批次。
    • train: 更新 Q 网络的权重,使用 Double DQN 的目标计算方法。
    • update_target_network: 每一定步数更新目标网络。
  • 训练过程: 在每一回合中,代理与环境互动并更新 Q 网络,通过经验回放机制逐步学习。

6. 调整超参数

  • gamma: 折扣因子,控制未来奖励的影响。
  • epsilon: 初始的探索率,随着训练的进行逐渐减小。
  • lr: 学习率,控制权重更新的步伐。
  • batch_size: 每次更新时,从记忆库中采样的批量大小。

这个代码可以直接用于训练一个 CartPole 的 Double DQN 代理,逐步优化 Q 网络来完成游戏任务。如果你有更复杂的需求,像更深的网络结构或其他改进,可以在此基础上进一步扩展。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2275211.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Go】:图片上添加水印的全面指南——从基础到高级特性

前言 在数字内容日益重要的今天&#xff0c;保护版权和标识来源变得关键。为图片添加水印有助于声明所有权、提升品牌认知度&#xff0c;并防止未经授权的使用。本文将介绍如何用Go语言实现图片水印&#xff0c;包括静态图片和带旋转、倾斜效果的文字水印&#xff0c;帮助您有…

PyQt5 UI混合开发,控件的提升

PromoteLabelTest.py 提升的类 import sys from PyQt5.QtWidgets import QApplication, QWidget,QVBoxLayout,QTextEdit,QPushButton,QHBoxLayout,QFileDialog,QLabelclass PromoteLabel(QLabel):def __init__(self,parent None):super().__init__(parent)self.setText("…

CI/CD 流水线

CI/CD 流水线 CI 与 CD 的边界CI 持续集成CD&#xff08;持续交付/持续部署&#xff09;自动化流程示例&#xff1a; Jenkins 引入到 CI/CD 流程在本地或服务器上安装 Jenkins。配置 Jenkins 环境流程设计CI 阶段&#xff1a;Jenkins 流水线实现CD 阶段&#xff1a;Jenkins 流水…

ROS核心概念解析:从Node到Master,再到roslaunch的全面指南

Node 在ROS中&#xff0c;最小的进程单元就是节点&#xff08;node&#xff09;。一个软件包里可以有多个可执行文件&#xff0c;可执行文件在运行之后就成了一个进程(process)&#xff0c;这个进程在ROS中就叫做节点。 从程序角度来说&#xff0c;node就是一个可执行文件&…

深入Android架构(从线程到AIDL)_22 IPC的Proxy-Stub设计模式04

目录 5、 谁来写Proxy及Stub类呢? 如何考虑人的分工 IA接口知识取得的难题 在编程上&#xff0c;有什么技术可以实现这个方法&#xff1f; 范例 5、 谁来写Proxy及Stub类呢? -- 强龙提供AIDL工具&#xff0c;给地头蛇产出Proxy和Stub类 如何考虑人的分工 由框架开发者…

风水算命系统架构与功能分析

系统架构 服务端&#xff1a;Java&#xff08;最低JDK1.8&#xff0c;支持JDK11以及JDK17&#xff09;数据库&#xff1a;MySQL数据库&#xff08;标配5.7版本&#xff0c;支持MySQL8&#xff09;ORM框架&#xff1a;Mybatis&#xff08;集成通用tk-mapper&#xff0c;支持myb…

551 灌溉

常规解法&#xff1a; #include<bits/stdc.h> using namespace std; int n,m,k,t; const int N105; bool a[N][N],b[N][N]; int cnt; //设置滚动数组来存贮当前和下一状态的条件 //处理传播扩散问题非常有效int main() {cin>>n>>m>>t;for(int i1;i&l…

HDFS编程 - 使用HDFS Java API进行文件操作

文章目录 前言一、创建hdfs-demo项目1. 在idea上创建maven项目2. 导入hadoop相关依赖 二、常用 HDFS Java API1. 简介2. 获取文件系统实例3. 创建目录4. 创建文件4.1 创建文件并写入数据4.2 创建新空白文件 5. 查看文件内容6. 查看目录下的文件或目录信息6.1 查看指定目录下的文…

Java面试题~~

深拷贝和浅拷贝区别了解吗?什么是引用拷贝? 关于深拷贝和浅拷贝区别&#xff0c;我这里先给结论&#xff1a; 浅拷贝&#xff1a;浅拷贝会在堆上创建一个新的对象&#xff08;区别于引用拷贝的一点&#xff09;&#xff0c;不过&#xff0c;如果原对象内部的属性是引用类型的…

el-table 自定义表头颜色

第一种方法&#xff1a;计算属性 <template><div><el-table:data"formData.detail"border stripehighlight-current-row:cell-style"{ text-align: center }":header-cell-style"headerCellStyle"><el-table-column fixed…

MySQL笔记大总结20250108

Day2 1.where (1)关系运算符 select * from info where id>1; select * from info where id1; select * from info where id>1; select * from info where id!1;(2)逻辑运算符 select * from info where name"吴佩奇" and age19; select * from info wh…

精选2款.NET开源的博客系统

前言 博客系统是一个便于用户创建、管理和分享博客内容的在线平台&#xff0c;今天大姚给大家分享2款.NET开源的博客系统。 StarBlog StarBlog是一个支持Markdown导入的开源博客系统&#xff0c;后端基于最新的.Net6和Asp.Net Core框架&#xff0c;遵循RESTFul接口规范&…

SEO内容优化:如何通过用户需求赢得搜索引擎青睐?

在谷歌SEO优化中&#xff0c;内容一直是最重要的因素之一。但要想让内容真正发挥作用&#xff0c;关键在于满足用户需求&#xff0c;而不是简单地堆砌关键词。谷歌的算法越来越智能化&#xff0c;更注重用户体验和内容的实用性。 了解目标用户的需求。通过工具如Google Trends…

Spring——自动装配

假设一个场景&#xff1a; 一个人&#xff08;Person&#xff09;有一条狗&#xff08;Dog&#xff09;和一只猫(Cat)&#xff0c;狗和猫都会叫&#xff0c;狗叫是“汪汪”&#xff0c;猫叫是“喵喵”&#xff0c;同时人还有一个自己的名字。 将上述场景 抽象出三个实体类&…

计算机网络(三)——局域网和广域网

一、局域网 特点&#xff1a;覆盖较小的地理范围&#xff1b;具有较低的时延和误码率&#xff1b;使用双绞线、同轴电缆、光纤传输&#xff0c;传输效率高&#xff1b;局域网内各节点之间采用以帧为单位的数据传输&#xff1b;支持单播、广播和多播&#xff08;单播指点对点通信…

错误的类文件: *** 类文件具有错误的版本 61.0, 应为 52.0 请删除该文件或确保该文件位于正确的类路径子目录中

一、问题 用maven对一个开源项目打包时&#xff0c;遇到了“错误的类文件: *** 类文件具有错误的版本 61.0, 应为 52.0 请删除该文件或确保该文件位于正确的类路径子目录中。”&#xff1a; 二、原因 原因是当前java环境是Java 8&#xff08;版本52.0&#xff09;&#xff0c;但…

【大模型入门指南 07】量化技术浅析

【大模型入门指南】系列文章&#xff1a; 【大模型入门指南 01】深度学习入门【大模型入门指南 02】LLM大模型基础知识【大模型入门指南 03】提示词工程【大模型入门指南 04】Transformer结构【大模型入门指南 05】LLM技术选型【大模型入门指南 06】LLM数据预处理【大模型入门…

在线工具箱源码优化版

在线工具箱 前言效果图部分源码源码下载部署教程下期更新 前言 来自缤纷彩虹天地优化后的我爱工具网源码&#xff0c;百度基本全站收录&#xff0c;更能基本都比较全&#xff0c;个人使用或是建站都不错&#xff0c;挑过很多工具箱&#xff0c;这个比较简洁&#xff0c;非常实…

@LocalBuilder装饰器: 维持组件父子关系

一、前言 当开发者使用Builder做引用数据传递时&#xff0c;会考虑组件的父子关系&#xff0c;使用了bind(this)之后&#xff0c;组件的父子关系和状态管理的父子关系并不一致。为了解决组件的父子关系和状态管理的父子关系保持一致的问题&#xff0c;引入LocalBuilder装饰器。…

C 语言内存探秘:数据存储的字节密码

文章目录 一、数据在内存中的存储1、基本数据类型存储2、数组存储3、结构体存储1、基本存储规则2、举例说明3、查看结构体大小和成员偏移量的方法 二、大小端字节序三、字节序的判断 一、数据在内存中的存储 1、基本数据类型存储 整型&#xff1a;如int类型&#xff0c;通常在…