pytorch强化学习(1)——DQNSARSA

news2024/12/26 0:13:09

实验环境

python=3.10
torch=2.1.1
gym=0.26.2
gym[classic_control]
matplotlib=3.8.0
numpy=1.26.2

DQN代码

首先是module.py代码,在这里定义了网络模型和DQN模型

import torch
import torch.nn as nn
import numpy as np

class Net(nn.Module):
    # 构造只有一个隐含层的网络
    def __init__(self, n_states, n_hidden, n_actions):
        super(Net, self).__init__()
        # [b,n_states]-->[b,n_hidden]
        self.network = nn.Sequential(
            torch.nn.Linear(n_states, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_actions)
        )

    # 前传
    def forward(self, x):  # [b,n_states]
        return self.network(x)


class DQN:
    def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):
        # 属性分配
        self.n_states = n_states  # 状态的特征数
        self.n_hidden = n_hidden  # 隐含层个数
        self.n_actions = n_actions  # 动作数
        self.lr = lr  # 训练时的学习率
        self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放
        self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索
        # 计数器,记录迭代次数
        self.count = 0

        # 实例化训练网络
        self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)

        # 优化器,更新训练网络的参数
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.criterion = torch.nn.MSELoss()  # 损失函数

    def choose_action(self, gym_state):
        state = torch.Tensor(gym_state)
        if np.random.random() < self.epsilon:
            action_values = self.q_net(state)  # q_net(state)采取动作后的预测
            action = action_values.argmax().item()
        else:
            # 随机选择一个动作
            action = np.random.randint(self.n_actions)
        return action

    def update(self, gym_state, action, reward, next_gym_state, done):
        state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)
        q_value = self.q_net(state)[action]
        # 前千万不能缺少done,如果下一步游戏结束的花,那下一步的q值应该为0
        q_target = reward + self.gamma * self.q_net(next_state).max() * (1 - float(done))

        self.optimizer.zero_grad()
        dqn_loss = self.criterion(q_value, q_target)
        dqn_loss.backward()
        self.optimizer.step()

然后是train.py代码,在这里调用DQN模型和gym环境,来进行训练:

import gym
import torch
from module import DQN
import matplotlib.pyplot as plt

lr = 1e-3  # 学习率
gamma = 0.95  # 折扣因子
epsilon = 0.8  # 贪心系数
n_hidden = 200  # 隐含层神经元个数

env = gym.make("CartPole-v1")
n_states = env.observation_space.shape[0]  # 4
n_actions = env.action_space.n  # 2 动作的个数

dqn = DQN(n_states, n_hidden, n_actions, lr, gamma, epsilon)

if __name__ == '__main__':
    reward_list = []
    for i in range(500):
        state = env.reset()[0]  # len=4
        total_reward = 0
        done = False
        while True:

            # 获取当前状态下需要采取的动作
            action = dqn.choose_action(state)
            # 更新环境
            next_state, reward, done, _, _ = env.step(action)
            dqn.update(state, action, reward, next_state, done)
            state = next_state

            total_reward += reward

            if done:
                break
        print("第%d回合,total_reward=%f" % (i, total_reward))
        reward_list.append(total_reward)

    # 绘图
    episodes_list = list(range(len(reward_list)))
    plt.plot(episodes_list, reward_list)
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title('DQN Returns')
    plt.show()

SARSA代码

首先是module.py代码,在这里定义了网络模型和SARSA模型。
SARSA和DQN基本相同,只有在更新Q网络的时候略有不同,已在代码相应位置做出注释。

import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

class Net(nn.Module):
    # 构造只有一个隐含层的网络
    def __init__(self, n_states, n_hidden, n_actions):
        super(Net, self).__init__()
        # [b,n_states]-->[b,n_hidden]
        self.network = nn.Sequential(
            torch.nn.Linear(n_states, n_hidden),
            torch.nn.ReLU(),
            torch.nn.Linear(n_hidden, n_actions)
        )

    # 前传
    def forward(self, x):  # [b,n_states]
        return self.network(x)


class SARSA:
    def __init__(self, n_states, n_hidden, n_actions, lr, gamma, epsilon):
        # 属性分配
        self.n_states = n_states  # 状态的特征数
        self.n_hidden = n_hidden  # 隐含层个数
        self.n_actions = n_actions  # 动作数
        self.lr = lr  # 训练时的学习率
        self.gamma = gamma  # 折扣因子,对下一状态的回报的缩放
        self.epsilon = epsilon  # 贪婪策略,有1-epsilon的概率探索
        # 计数器,记录迭代次数
        self.count = 0

        # 实例化训练网络
        self.q_net = Net(self.n_states, self.n_hidden, self.n_actions)

        # 优化器,更新训练网络的参数
        self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=lr)

        self.criterion = torch.nn.MSELoss()  # 损失函数

    def choose_action(self, gym_state):
        state = torch.Tensor(gym_state)
        # 基于贪婪系数,有一定概率采取随机策略
        if np.random.random() < self.epsilon:
            action_values = self.q_net(state)  # q_net(state)是在当前状态采取各个动作后的预测
            action = action_values.argmax().item()
        else:
            # 随机选择一个动作
            action = np.random.randint(self.n_actions)
        return action

    def update(self, gym_state, action, reward, next_gym_state, done):
        state, next_state = torch.tensor(gym_state), torch.tensor(next_gym_state)
        q_value = self.q_net(state)[action]

        '''
        sarsa在更新网络时选择的是q_net(next_state)[next_action] 
        这是sarsa算法和dqn的唯一不同
        dqn是选择max(q_net(next))
        '''
        next_action = self.choose_action(next_state)
        # 千万不能缺少done,如果下一步游戏结束的话,那下一步的q值应该为0,而不是q网络输出的值
        q_target = reward + self.gamma * self.q_net(next_state)[next_action] * (1 - float(done))

        self.optimizer.zero_grad()
        dqn_loss = self.criterion(q_value, q_target)
        dqn_loss.backward()
        self.optimizer.step()

SARSA也有tarin.py文件,功能和上面DQN的一样,内容也几乎完全一样,只是把DQN的名字改成SARSA而已,所以在这里不再赘述。

运行结果

DQN的运行结果如下:
在这里插入图片描述

SARSA运行结果如下:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1306977.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Python】计算最少排班人数(1)

需求背景&#xff1a; 某工作24小时需要人员&#xff0c;单个班次为8或9小时&#xff0c;需要根据每小时需要的人数、不同的班次计算最低需要的人数。 每小时需要的人数举例&#xff1a; 班次类型&#xff1a; 问题分析&#xff1a; 1、需要各个时段的人员数大于等于需要的人…

CV中的Attention机制:SENet

paper: Squeeze-and-Excitation Networks paper link:https://arxiv.org/pdf/1709.01507.pdf repo link:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks 摘要&#xff1a; 卷积神经网络&#xff08;CNNs&#xff09;的核心构建块是卷积算子&#xff0c;它使…

微信小程序map视野发生改变时切换定位点

<!--地图--> <view><map id"myMap" style"width: 100%; height: 300px;" latitude"{{latitude}}" longitude"{{longitude}}"scale"{{scale}}" markers"{{markers}}" controls"{{controls}}&q…

OpenCV-Python:图像平滑操作

目录 图像平滑基础 本文目标 2D卷积 图像模糊&#xff08;图像平滑&#xff09; 平均模糊 高斯模糊 中值模糊 双边滤波 图像平滑基础 在尽量保留图像原有信息的情况下&#xff0c;过滤掉图像内部的噪声&#xff0c;这一过程称为对图像的平滑处理&#xff0c;所得的图像…

排序算法之三:希尔排序

希尔排序基本思想 希尔排序法又称缩小增量法 希尔排序法的基本思想是&#xff1a;先选定一个整数&#xff0c;把待排序文件中所有记录分成个组&#xff0c;所有距离为的记录分在同一组内&#xff0c;并对每一组内的记录进行排序。然后&#xff0c;取&#xff0c;重复上述分组…

ThingsBoard 前端项目轮播图部件开发

前言 ThingsBoard 是目前 Github 上最流行的开源物联网平台&#xff08;14.6k Star&#xff09;&#xff0c;可以实现物联网项目的快速开发、管理和扩展, 是中小微企业物联网平台的不二之选。 本文介绍如何在 ThingsBoard 前端项目中开发轮播图部件。 产品需求 最近接到产品…

C++并查集

1.并查集概念 1.1.并查集定义 在一些应用问题中&#xff0c;需要&#xff1a; 将 n 个不同的元素划分成一些不相交的集合开始时&#xff0c;每个元素自成成为一个集合然后按一定的规律&#xff0c;将归于同一组元素的集合合并期间需要反复用到查询某个元素归属于那个集合的算…

python利用requests库进行接口测试的方法详解

前言 之前介绍了接口测试中需要关注得测试点&#xff0c;现在我们来看看如何进行接口测试&#xff0c;现在接口测试工具有很多种&#xff0c;例如&#xff1a;postman,soapui,jemter等等&#xff0c;对于简单接口而言&#xff0c;或者我们只想调试一下&#xff0c;使用工具是非…

迈入数据结构殿堂——时间复杂度和空间复杂度

目录 一&#xff0c;算法效率 1.如何衡量一个算法的好坏&#xff1f; 2.算法效率 二&#xff0c;时间复杂度 1.时间复杂度的概念 2.大O的渐进表示法 3.推导大O的渐进表示法 4.常见时间复杂度举例 三&#xff0c;空间复杂度 一&#xff0c;算法效率 数据结构和算法是密…

【产品】Axure的基本使用(二)

文章目录 一、元件基本介绍1.1 概述1.2 元件操作1.3 热区的使用 二、表单型元件的使用2.1 文本框2.2 文本域2.3 下拉列表2.4 列表框2.5 单选按钮2.6 复选框2.7 菜单与表格元件的使用 三、实例3.1 登录2.2 个人简历 一、元件基本介绍 1.1 概述 在Axure RP中&#xff0c;元件是…

如何使用透明显示屏

透明显示屏的使用主要取决于具体的应用场景和需求。以下是一些常见的使用透明显示屏的方法&#xff1a; 商业展示&#xff1a;透明显示屏可以作为商品展示柜&#xff0c;通过高透明度、高分辨率的屏幕展示商品细节&#xff0c;吸引顾客的注意力。同时&#xff0c;透明显示屏还可…

人工智能基本常识:让深度学习技术更加人性化

近年来&#xff0c;人工智能技术日臻成熟。现在&#xff0c;许多产品和服务都依靠人工智能技术实现自动化和智能化&#xff0c;因此它与我们的日常生活息息相关。无论是为我们带来各种便利的家用设备&#xff0c;还是我们一直在使用的产品制造方式&#xff0c;人工智能的影响无…

【算法题】智能成绩表(js)

总分相同按名字字典顺序。 解法&#xff1a; function solution(lines) {const [personNum, subjectNum] lines[0].split(" ").map((item) > parseInt(item));const subjects lines[1].split(" ");const classMates [];let results [];for (let i…

C++笔记:动态内存管理

文章目录 语言层面的内存划分C语言动态内存管理的缺陷new 和 delete 的使用了解语法new 和 delete 操作内置类型new 和 delete 操作自定义类型 new 和 delete 的细节探究new 和 delete 的底层探究operator new 和 operator new[]operator delete 和 operator delete[] 显式调用…

2023快速上手新红利项目:短剧分销推广CPS

短剧分销推广CPS是一个新红利项目&#xff0c;对于新手小白来说也可以快速上手。 以下是一些建议&#xff0c;帮助新手小白更好地进行短剧分销推广CPS&#xff1a; 学习基础知识&#xff1a;了解短剧的基本概念、制作流程和推广方式。了解短剧的市场需求和受众群体&#xff0c…

wpf devexpress如何使用AccordionControl

添加一个数据模型 AccordionControl可以被束缚到任何实现IEnumerable接口的对象或者它的派生类&#xff08;例如IList,ICollection&#xff09; 如下代码例子示范了一个简单的数据模型使用&#xff1a; using System.Collections.Generic;namespace DxAccordionGettingStart…

zabbix精简模板

一、监控项目介绍 linux自带得监控项目比较多&#xff0c;也不计较杂&#xff0c;很多监控项目用不到。所以这里要做一个比较精简得监控模版 二、监控模板克隆 1.搜索原模板 2.克隆模板 全克隆模板&#xff0c;这样就和原来原模板没有联系了&#xff0c;操作也不会影响原模…

软件测试基础知识+面试总结(超详细整理)

一、什么是软件&#xff1f; 软件是计算机系统中的程序和相关文件或文档的总称。 二、什么是软件测试&#xff1f; 说法一&#xff1a;使用人工或自动的手段来运行或测量软件系统的过程&#xff0c;以检验软件系统是否满足规定的要求&#xff0c;并找出与预期结果之间的差异…

UI设计中的肌理插画是什么样的?

肌理插画本质也和扁平插画差不多&#xff0c;相较扁平插画&#xff0c;肌理插画的层次感、细节更多&#xff0c;也会更立体生动。 肌理插画风格没有描边线&#xff0c;画面轻快&#xff0c;通过色块的明暗来区分每个元素&#xff0c;有点像色彩版的素描&#xff0c;但更简单&a…