智能交通(6)——DQN代码复现

news2025/1/24 17:43:19

伪代码

如算法描述,dqn即深度q网络和记忆池

初始化记忆池和可以容纳的数量N

动作价值函数Q使用随机权重进行初始化。

目标动作价值函数Q′也使用相同的权重进行初始化,即Q′=Q。

循环训练M局

初始化和预处理观察到的状态

每局循环训练T步

采用e的概率随机选取动作

其他情况选择q值最大的动作

执行动作并且得到下一步的状态和奖励

更新下一步的状态,并进行预处理

将其s,r,a,s_存入记忆池

从记忆池随机选取minbatch的数据

对于每个转换,计算目标值yj​。如果第j+1步是终止状态,则yj​=rj​;否则,使用贝尔曼方程计算yj​=rj​+γmaxa​Q(sj+1​,a;θ′),其中γ是折扣因子。

计算损失函数loss=loss=(yj​−Q(sj​,aj​;θ))2,并通过反向传播更新网络参数θ。

每进行c步更新策略网络

游戏环境介绍

环境描述:CartPole-v1是一个物理模拟环境,其中有一个水平杆(pole)固定在一个移动的推车(cart)上。目标是保持杆子竖直,防止它倒向任何一侧或推车移动到轨道的尽头。

  1. 终止条件:环境会在以下情况下终止:

    • 杆子的角度超过15度(±15°)。
    • 推车的位置超出轨道的中心线一定距离(通常是2.4个单位)
  2. 动作空间(Action Space):在这个环境中,智能体可以选择两个离散动作之一:

    • 向左推车(0)
    • 向右推车(1)

代码复现

# describe : dqn算法训练流程

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from agent import QNetwork
from agent import ReplayBuffer

# 创建CartPole环境
env = gym.make('CartPole-v1')

# 超参数
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
learning_rate = 0.001 # 神经网络优化算法中使用的学习能力,用于调整网络权重的步长
gamma = 0.99 # reward discount
epsilon = 1.0  # 探索率
epsilon_decay = 0.995 # epsilon衰减率
epsilon_min = 0.01 # 最小探索率
episodes = 300
steps = 300
batch_size = 64
memory_size = 2000

# 初始化记忆池
replay_buffer = ReplayBuffer(memory_size)
# 初始化Q网络和优化器
q_network = QNetwork(state_size, action_size)
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# 初始化奖励数组
total_rewards = []
for episode in range(episodes):
    state, info = env.reset()
    state = torch.FloatTensor(state).unsqueeze(0) # 预处理,将 state 转换为 PyTorch 的 FloatTensor,这是神经网络处理所需的数据类型
    total_reward = 0
    for step in range(steps):
        # epsilon-greedy策略选择动作
        if np.random.rand() <= epsilon:
            action = np.random.randint(action_size)
        else:
            with torch.no_grad():
                q_values = q_network(state)
                action = q_values.max(1)[1].item()
        # 执行动作
        next_state, reward, done, _, _ = env.step(action)
        next_state = torch.FloatTensor(next_state).unsqueeze(0)
        # 存储经验
        replay_buffer.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        if done:
            total_rewards.append(total_reward)
            print(f"Episode: {episode}, Total Reward: {total_reward}")
            break
        # 经验回放
        if len(replay_buffer) >= batch_size:
            states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)
            # 将采样的状态序列 states 合并成一个张量,以便于批量处理。
            states = torch.cat(states)
            actions = torch.LongTensor(actions).unsqueeze(1)
            rewards = torch.FloatTensor(rewards).unsqueeze(1)
            next_states = torch.cat(next_states)
            dones = torch.FloatTensor(dones).unsqueeze(1)
            # 批量N,计算当前状态下不同动作的Q值,取选择的动作对应的Q值
            current_q_values = q_network(states).gather(1, actions)
            # 批量N,计算下一个状态取得的最大的Q值
            next_q_values = q_network(next_states).max(1)[0].unsqueeze(1)
            target_q_values = rewards + (gamma * next_q_values * (1 - dones))
            # 使用损失函数计算当前 Q 值和目标 Q 值之间的差异。
            loss = criterion(current_q_values, target_q_values)
            # 清除之前的梯度,为新的梯度更新做准备。
            optimizer.zero_grad()
            # 计算损失函数关于网络参数的梯度。
            loss.backward()
            # 根据计算出的梯度更新网络的权重。
            optimizer.step()
    # epsilon衰减
    if epsilon > epsilon_min:
        epsilon *= epsilon_decay

# 绘制奖励图
env.close()
plt.plot(total_rewards)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Rewards per Episode in CartPole-v1')
plt.show()
# describe : 定义经验回放缓冲区和Q网络

import random
import torch
import torch.nn as nn

# 经验回放缓冲区
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = []
        self.capacity = capacity
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    # 用于从缓冲区中随机采样一批经验。
    def sample(self, batch_size):
        return zip(*random.sample(self.buffer, batch_size))

    def __len__(self):
        return len(self.buffer)

# Q网络
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size):
        super(QNetwork, self).__init__()
        # 定义了第一个全连接层,将输入状态的特征从 state_size 映射到 24 个特征。
        self.fc1 = nn.Linear(state_size, 24)
        # 定义了第二个全连接层,将 24 个特征再次映射到 24 个特征。
        self.fc2 = nn.Linear(24, 24)
        # 定义了第三个全连接层,将 24 个特征映射到输出层,其大小等于可能的动作数量。
        self.fc3 = nn.Linear(24, action_size)

    def forward(self, x):
        # 应用第一个全连接层并使用 ReLU 激活函数。
        x = torch.relu(self.fc1(x))
        # 应用第二个全连接层并使用 ReLU 激活函数。
        x = torch.relu(self.fc2(x))
        # 应用第三个全连接层,不使用激活函数,因为 Q 网络的输出通常不需要激活函数。
        x = self.fc3(x)
        # 返回q值
        return x

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1977832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Scene Transformer】scene transformer论文阅读笔记

文章目录 序言(Abstract)(Introduction)(Related Work)(Methods)(Scene-centric Representation for Agents and Road Graphs)(Encoding Transformer)(Predicting Probabilities for Each Futures)(Joint and Marginal Loss Formulation) (Results)(Discussion)(Questions) sce…

Linux|最佳命令行下载加速器

引言 无论是远程工作还是本地工作&#xff0c;我们经常需要从外部获取信息。在没有其他选择的情况下&#xff0c;使用命令行工具来获取这些信息是一个不错的选择。 本文[1]将介绍一些通过命令行下载内容时最常使用的工具。 Wget 我们首先介绍一个广受欢迎的工具 wget。它是一个…

使用Qt编译modbus

一.编译库文件 1. 创建library项目 2. 选择要配置的编译器 3. 把自动生成的源码都移除&#xff1a;&#xff08;右键单击&#xff0c;选择 remove&#xff09; 4 4. 导入库源码 把源码拷贝到项目目录下&#xff08;.pro 文件所在的目录&#xff09; 5. 修改 configure.js 文…

(计算机网络)物理层

目录 一.基本概念 二.基本术语 三.码元 四.多路复用技术 一.基本概念 1. 2. 3. 4. 5. 6. 7. 8. 9. 二.基本术语 1. 2. 3.早期--公用的电话网传输数据&#xff0c;网络上传的是模拟信号&#xff0c;调制解调器--将数字信号转化成模拟信号&#xff0c;最后&#xff0c;调制解…

Java: 线程安全问题的解决方案(synchronized)

发生原因 要想解决线程安全问题,那么我们首先得知道线程安全问题为什么会发生. 发生原因: 线程在操作系统中是"随机调度,抢占式执行的"[根本原因].多个线程,同时修改同一个变量修改操作不是"原子"的内存可见性问题指令重排序 解决方案 原因1和2,我们很…

htsjdk库FeatureCodec和Feature接口介绍

在 HTSJDK 库中,FeatureCodec 接口和 Feature 接口分别扮演不同的角色,用于处理基因组数据的不同方面。下面是这两个接口的区别和各自的功能: FeatureCodec 接口 主要功能 编码和解码:FeatureCodec 接口的主要职责是定义如何将数据从文件格式解码为 Java 对象(即 Featur…

【C语言】分支与循环(循环篇)——结尾猜数字游戏实现

前言 C语言是一种结构化的计算机语言&#xff0c;这里指的通常是顺序结构、选择结构、循环结构&#xff0c;掌握这三种结构之后我们就可以解决大多数问题。 分支结构可以使用if、switch来实现&#xff0c;而循环可以使用for、while、do while来实现。 1. while循环 C语言中…

[CP_AUTOSAR]_系统服务_DEM模块(三)功能规范之诊断事件定义

目录 1、诊断事件定义1.1、Event priority&#xff08;事件优先级&#xff09;1.2、Event occurrence&#xff08;事件发生计数器&#xff09;1.3、Event kind&#xff08;事件类别&#xff09;1.4、Event destination&#xff08;故障内存&#xff09;1.5、Diagnostic monitor…

2.MonggoDB是什么?

1. 不是什么&#xff1f; 要想知道MongoDB是什么&#xff0c;我们得先搞清楚它不是什么&#xff0c;首先它不是关系数据&#xff0c;不是像下面这样这种格式存储数据。 这个图展示了关系型数据库的常用存储方式&#xff0c;一个表格&#xff0c;里面存储了多行记录&#xff0…

Linux系统中的两个核心进程:`init`和`kthreadd`

文章目录 1 init 进程1.1 基本信息1.2 主要功能1.3 示例 2 kthreadd 进程2.1 基本信息2.2 主要功能2.3 示例 3 对比总结4 用户空间进程与内核线程4.1 用户空间进程特点 4.2 内核线程特点 5 对比总结6 结论参考链接封面 本文详细对比了Linux系统中的两个核心进程&#xff1a; i…

nvm 对node版本的控制

使用nvm切换Node.js版本的步骤如下 nvm list available // 显示可以安装的所有node.js的版本 如果出现空白 问题解决 经过查找nvm的文档&#xff0c;发现&#xff0c;对于中国用户而言&#xff0c;可以切换nodejs或npm的镜像地址来访问&#xff1a; nvm node_mirror https:…

软raid - - udev规则

一、什么是udev FROM AI: udev是Linux 2.6内核及以后版本中引入的一个设备管理框架&#xff0c;它取代了之前的devfs系统。udev以守护进程的形式运行&#xff0c;并且工作在用户空间而不是内核空间。它的主要功能是动态地创建和管理/dev目录下的设备节点&#xff0c;并且能够根…

SAP支出管理,企业成本控制的智能钥匙

在企业运营中&#xff0c;有效的支出管理是确保财务健康和提升竞争力的关键。SAP支出管理系统作为企业资源规划的核心组成部分&#xff0c;提供了一套全面的解决方案&#xff0c;帮助企业实现成本控制、风险管理和合规性监督。实现支出管理流程自动化&#xff0c;并主动管理更多…

蚂蚁笔试0511-编程题

解题思路&#xff1a; 记录0、正数、负数的个数&#xff0c;分类讨论。 解题思路&#xff1a; 有n个位置&#xff0c;每个位置有m个数&#xff0c;所以一共有m^n种情况&#xff0c;每种情况至少包含权值1&#xff0c;也就是全相等是一段&#xff0c;或者说是一个数形成的 从第二…

黑马Java零基础视频教程精华部分_11_面向对象进阶(3)

系列文章目录 文章目录 系列文章目录一、抽象类1、为什么要有抽象类&#xff1f;2、抽象方法3、抽象类4、抽象类和抽象方法定义格式5、注意事项 二、接口1、为什么会有接口&#xff1f;2、接口和抽象类的异同3、接口的定义和使用4、接口中成员的特点5、接口和类之间的关系6、实…

机器学习第五十周周报 CGNN

文章目录 week50 CGNN摘要Abstract0. 概述1. 题目2. Abstract3. 网络结构3.1 状态更新网络3.2 method 4. 文献解读4.1 Introduction4.2 创新点4.3 实验过程 5. 结论6.相关代码CompositeLGNNCompositeGNNcomposite-graph-class小结参考文献 week50 CGNN 摘要 本周阅读了题为Co…

【题解】189. 轮转数组(数组、数学、双指针)

https://leetcode.cn/problems/rotate-array/description/?envTypestudy-plan-v2&envIdtop-interview-150 class Solution { public:void rotate(vector<int>& nums, int k) {k k % nums.size(); // 注意k要取一下模reverse(nums.begin(), nums.end()-k);rev…

常见cms漏洞之ASPCMS

项目地址&#xff0c;自行百度 漏洞复现&#xff0c;使用后台配置文件拿shell 访问后台 http://192.168.177.153/admin_aspcms/login.asp 账号密码已设置成&#xff0c;登录即可 admin 123456

关于Win11无法自动进入休眠问题

最近从win10升到11了 感觉还挺好用的 就是右键有点逆天 主要是为了3k屏连接1080显示器更新的 但是发现彻底无法休眠了&#xff0c;这个还要说的是以前win10睡眠一次就算一次硬盘通电&#xff0c;现在更新后不会了 下面说回休眠 b站浏览一番后发现是22H2隐藏了休眠时间设置&a…

哈希 || unordered系列的关联式容器底层 | 哈希模拟实现 | HashTable代码实现

底层结构 unordered系统的关联式容器之所以效率比较高&#xff0c;是因为其底层使用了哈希结构。 哈希概念 顺序结构以及平衡树中&#xff0c;元素关键码与其存储位置之间没有对应的关系&#xff0c;因此在查找一个元素时&#xff0c;必须经过关键码的多次比较。 顺序查找的…