基于DQN的强化学习 快速浏览(基础知识+示例代码)

news2024/9/27 21:25:00

一、强化学习的基础概念

强化学习中有2个主要的实体,一个是智能体(agent),另一个是环境(environment)。在强化学习过程中,智能体能够得到的是环境当前的状态(State),即环境智能体所处环境当前的情况。另一个是上一步获得的环境的奖励(Reward),即环境给予智能体动作的一个反馈。智能体根据这两个信息,决定在环境中采取的动作(Action),以及环境接收智能体的动作,返回下一步的状态和对智能体的奖励。整个过程可以归纳为:在t时刻,给定该时刻的状态s_t和获得的奖励r_t,根据这些值来决定当前步骤的动作a_t,将动作转递给环境,得到下一个时刻的状态s_{t+1}和获得的奖励r_{t+1}

在强化学习中需要解决的问题是如何训练一个智能体,使得智能体能够在合适的状态下产生合适的动作,使后续的奖励总和最大。我们称智能体根据环境状态产生动作的方法为策略(policy),在这种情况下,强化学习可以归结为寻找一个最优策略,使得未来能够获得的奖励最大。

强化学习有很多分类,其中根据深度学习模型描述的是策略本身还是在当前状态下未来能够获得的奖励,可分为两种,前者称为策略优化(policy optimization),后者称为质量函数学习(Q-learning,这里Q即为Quality,质量)。假设强化学习的过程是一个马尔可夫过程,即未来获得的奖励和过去的历史无关,则质量函数是当前状态的一个函数,可根据Bellman方程对其进行描述。假设奖励的折扣率为\gamma (0<\gamma<1),设置这个参数的目的是因为我们需要对未来的奖励进行求和,在时间跨度上无限大,如果没有折扣率,未来总的奖励可能是无穷大。在这种情况下,我们可以计算总的折扣后的奖励,如式(1)所示。

(1)

由于式(1)中的R_t,即未来奖励加权求和的结果和当前的状态及采取的动作有关,当引入t时刻的状态s_t和采取的动作a_t后,可以根据Bellman方程求得对应的折扣奖励关于状态和动作的函数,称之为质量函数Q,如式(2)。

(2)

二、强化学习的软件环境Gym安装

式(2)是进行强化学习的基础,用深度学习来学习质量函数的算法被称为DQN,这里以一个简单的强化学习例子来阐述如何使用DQN进行强化学习的任务。首先使用OpenAI的Gym工具集来构造强化学习环境。Gym工具集有很多场景,包括经典控制环境、雅达利游戏、二维和三维机器人环境等。

具体安装方式如下:

# 方式1:
pip install gym
# 方式2:
git clone https://github.com/openai/gym
pip install -e .

三、车杆环境介绍

本博文以gym的车杆环境为例,进行DQN的模型搭建和训练。

车杆(Cartpole)环境介绍如下:

车杆环境由一个可以自由转动的杆子连接一个可以水平运动的小车构成。通过向环境发送“左移”或者“右移”控制小车移动。每次发生移动指令之后,环境会返回一个数组来表示小车当前的运动状态,这个数组包括:小车当前的位置(-4.8~4.8)、小车当前的速度(负无穷到正无穷)、杆子当前的角度(-24~24)、杆子顶端的速度(负无穷到正无穷)来表示。环境还会返回一个奖励值,该值在杆子的角度在-15~15之间且小车位置在-2.4~2.4之间时为1(最多持续200步),其他状态下为0,且强化学习段落(episode)会在下一步终止。

我们需要训练的模型就是根据的状态数组做动态调整,让小车上的杆子能够保持在一定的角度范围内,且小车的位置也能保持在一定距离范围内,从而最终达到奖励值最大的目的。

 【注】:这里提到的强化学习段落episode指的是:All states that come in between an initial-state and a terminal-state; for example: one game of Chess. The agents goal it to maximize the total reward it reward it receives during an episode. In situations where there is no terminal-state, we consider an infinite episode. It is important to remember that different episodes are completely independent of one another. 大概意思是:一个episode即为一轮博弈,智能体从最开始的状态到某一个终止状态为一个episode,这个过程是有状态集、行为集、奖励等组成的一个完成序列。且各个episode是完全独立的。

四、车杆(Cartpole)环境使用

首先介绍如何使用Cartpole环境。代码及对应的解释如下:

import gym
env = gym.make('CartPole-v0') # 通过gym.make创建了一个CartPole环境env
for i_episode in range(20): # 共运行了20个强化学习段落
    observation = env.reset()
    for t in range(100): # 每个段落最大100步
        env.render()
        print(observation)
        action = env.action_space.sample() # 对动作进行随机采样(在CartPole环境下只有0和1两种动作)
        state, reward, terminated, truncated, info = env.step(action) # 执行动作,获取环境的反馈。state指执行了该动作后环境的状态,reward指当前动作获取的奖励,terminated表示当前段落是否结束;truncated通常指是否超出时间限制,info指环境的其他信息
        if terminated:
            print("Episode finished after {} timesteps".format(t+1))
            break
env.close()

这里再强调一遍gym中常用的代码接口及其含义:

代码接口含义
env = gym.make("XXX").env进入指定的实验环境
env.render()渲染环境,即可视化看看环境的样子
env.reset重置环境,返回一个随机的初始状态
env.step(action)将选择的action输入给env,env 按照这个动作走一步进入下一个状态
env.step(action)的返回值

state:执行action后的状态

reward:执行action的奖励

terminated:whether a terminal state (as defined under the MDP of the task) is reached. In this case further step() calls could return undefined results.

truncated:whether a truncation condition outside the scope of the MDP is satisfied. Typically a timelimit, but could also be used to indicate agent physically going out of bounds. Can be used to end the episode prematurely before a terminal state is reached.

info:其他信息

env.render()渲染出当前的智能体以及环境的状态,用于可视化

以上代码会让杆子很快偏离平衡位置,导致强化学习段落结束,为了能让杆子稳定,需要DQN对每一个env.step选择具体的动作。

五、QDN模型的搭建

先构建质量函数的深度学习模型,具体代码及解释如下:

import torch
import torch.nn as nn
class DQN(nn.Module):
    def __init__(self, naction, nstate, nhidden):
        super(DQN, self).__init__()
        self.naction = naction # 候选的动作总数。在CartPole中,该值为2
        self.nstate = nstate # 状态的维度数。在CartPole中,该值为4
        self.linear1 = nn.Linear(naction + nstate, nhidden)
        self.linear2 = nn.Linear(nhidden, nhidden)
        self.linear3 = nn.Linear(nhidden, 1)

    def forward(self, state, action):
        action_enc = torch.zeros(action.size(0), self.naction)
        action_enc.scatter_(1, action.unsqueeze(-1),1)
        output = torch.cat((state, action_enc), dim=-1)
        output = torch.relu(self.linear1(output))
        output = torch.relu(self.linear2(output))
        output = self.linear3(output)
        return output.squeeze(-1)

为了加强模型训练的收敛,在DQN算法的训练中需要用到重放(replay)技巧,通过反复播放强化学习的历史记录来加强模型的训练。这里用一个记忆类类记录训练历史,代码及解释如下:

import random
class Memory(object):
    def __init__(self,capacity=1000):
        self.capacity = capacity # 记忆的长短
        self.size = 0
        self.data = []

    def __len__(self):
        return self.size

    def push(self, state, action, state_next, reward, is_ended): # 向记忆类中放入单步训练的记录
        # state当前状态,action采取的动作,state_next下一步的状态,reward状态的奖励,is_ended下一步状态是否为最终状态
        if len(self) > self.capacity:
            k = random.randint(self.capacity)
            self.data.pop(k)
            self.size -= 1
        self.data.append((state, action, state_next, reward, is_ended))

    def sample(self, bs): # 获取一定迷你批次bs的历史数据进行重放,通过重放数据进行学习。
        data = random.choices(self.data, k = bs)
        states, actions, states_next, rewards, is_ended = zip(*data)

        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions)
        states_next = torch.tensor(states_next, dtype=torch.float32)
        rewards =  torch.tensor(rewards, dtype=torch.float32)
        is_ended =  torch.tensor(is_ended, dtype=torch.float32)

        return states, actions, states_next, rewards, is_ended

六、DQN模型的训练

有了基础模型和重放类后,接下来对模型进行训练,代码及解释如下:

import copy
import torch.nn
# 定义2个网络,用于加速模型收敛
dqn = DQN(2,4,8) # 主要优化它
dqn_t = DQN(2,4,8) # 用来辅助dqn的模型优化,增强dqn的数值稳定性,加速模型收敛
dqn_t.load_state_dict(copy.deepcopy(dqn.state_dict()))
eps = 0.1 # 定义强化学习模型的探索比例。探索:对动作空间的随机采样达到遍历动作空间的目的,保证探索数量;利用:使用模型并选择模型的最优动作和环境进行交互,防止由重复探索出现。
# 折扣系数
gamma = 0.999
optim = torch.optim.Adam(dqn.parameters(), lr=1e-3)
criterion = torch.nn.HuberLoss() # 需要torch1.9.0及以上的版本
step_cnt= 0
mem = Memory()

for episode in range(300):
    state = env.reset()
    while True:
        action_t = torch.tensor([0,1])
        state_t = torch.tensor([state, state], dtype=torch.float32)

        # 计算最优策略
        torch.set_grad_enabled(False)
        q_t = dqn(state_t, action_t)
        max_t = q_t.argmax()
        torch.set_grad_enabled(True)

        # 探索和利用的平衡
        if random.random() < eps:
            max_t = random.choice([0,1])
        else:
            max_t = max_t.iem()

        state_next, reward, done, truncated, info = env.step(max_t)

        mem.push(state, max_t, state_next, reward, done)
        state = state_next

        if done:
            break

        # 重放训练
        for _ in range(10):
            state_t, action_t, state_next_t, reward_t, is_ended_t = mem.sample(32)
            q1 = dqn(state_t, action_t)
            torch.set_grad_enabled(False)
            q2_0 = dqn_t(state_next_t, torch.zeros(state_t.size(0), dtype=torch.long))
            q2_1 = dqn_t(state_next_t, torch.ones(state_t.size(0), dtype=torch.long))

            # 利用Bellman方程进行迭代
            q2_max = reward_t + gamma*(1-is_ended_t)*(torch.stack((q2_0, q2_1), dim=1).max(1)[0])
            torch.set_grad_enabled(True)

            # 优化损失函数
            delta = q2_max - q1
            loss = criterion(delta)
            optim.zero_grad()
            loss.backward()
            for p in dqn.parameters():
                p.grad.data.clamp_(-1,1)
            optim.step()
            step_cnt += 1

            # 同步2个网络的参数
            if step_cnt % 1000 == 0:
                dqn_t.load_state_dict(copy.deepcopy((dqn.state_dict())))

    env.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/31433.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

NePTuNe 论文笔记

NePTuNe:Neural Powered Tucker Network for Knowledge Graph Completion- Introduction- Background- Algorithm- Experiment- Conclusion- CodeShashank Sonkar, Arzoo Katiyar, Richard G.Baraniuk - Introduction 目前的链接预测方法&#xff1a; 张量因式分解&#xff1…

【服务器数据恢复】raidz多块硬盘离线的数据恢复案例

服务器数据恢复环境&#xff1a; 一台采用zfs文件系统的服务器&#xff0c;配备32块硬盘。 服务器故障&#xff1a; 服务器在运行过程中崩溃&#xff0c;经过初步检测没有发现服务器有物理故障&#xff0c;重启服务器后故障依旧&#xff0c;用户联系我们中心要求恢复服务器数据…

SpringBoot: Controller层的优雅实现

目录1. 实现目标2. 统一状态码3. 统一响应体4. 统一异常5. 统一入参校验6. 统一返回结果7. 统一异常处理8. 验证1. 实现目标 优雅校验接口入参响应体格式统一处理异常统一处理 2. 统一状态码 创建状态码接口&#xff0c;所有状态码必须实现这个接口&#xff0c;统一标准 pa…

Eolink 征文活动- -专为开发者设计的一款国产免费 API 协作平台

&#x1f497;wei_shuo的个人主页 &#x1f4ab;wei_shuo的学习社区 &#x1f310;Hello World &#xff01; ▌背景 后端开发的程序员都需要有一个用得顺手的接口测试工具&#xff1b;以前&#xff0c;大家都喜欢用Google开发的一款接口测试工具postman来进行测试&#xff0c;…

Java面向对象三大基本特征之多态

多态性是面向对象编程的又一个重要特征&#xff0c;那么多态是什么呢&#xff1f; 一、多态的概念 1.概念&#xff1a;多态是指在父类中定义的属性和方法被子类继承之后&#xff0c;可以具有不同的数据类型或表现出不同的行为&#xff0c;这使得同一个属性或方法在父类及其各…

Linux 文件操作(一) —— 遍历指定目录下的所有文件

目录 一、访问目录相关函数 1、打开/访问目录 (opendir / fdopendir) 2、读取目录内容 (readdir) 3、关闭目录 (closedir) 二、遍历指定目录下的所有文件 一、访问目录相关函数 1、打开/访问目录 (opendir / fdopendir) opendir / fdopendir 函数的作用是访问指定路径的…

工程基建--前端基建

序&#xff1a; 工程基建 &#xff1a; 编码规范、api规范、前后端协作、环境部署、微服务、微前端、性能、安全防御、统计监控、可视化 等等的建设&#xff1b; 后端基建&#xff1a; 后端规范文档、后端模板、安全、日志、微服务、RESTful API、中间件、数据库、分布式、权…

新手怎么做微信商城小程序_微信商城小程序模版哪里找

微信小程序已经在我们的生活中随处可见&#xff0c;甚至是抖音头条等其它的平台也开始做起了小程序&#xff0c;在这种情况下&#xff0c;微信小程序势必会成为未来商城的主战场之一。闻风而来想做小程序的人不少&#xff0c;而其中新手零基础也能做的小程序商城模板类工具&…

C++入门教程2||C++ 数据类型

C 数据类型 使用编程语言进行编程时&#xff0c;需要用到各种变量来存储各种信息。变量保留的是它所存储的值的内存位置。这意味着&#xff0c;当您创建一个变量时&#xff0c;就会在内存中保留一些空间。 您可能需要存储各种数据类型&#xff08;比如字符型、宽字符型、整型…

【Leetcode】15. 三数之和

一、题目 难度不小 注意是不能重复 Python提交格式&#xff0c;返回一个list 二、暴力解法 排序 三重循环 有没有像我这样的傻子&#xff0c;三重循环&#xff0c;还没去重 后来发现要去重&#xff0c;必须要先排序&#xff0c;然后判断一下当前的数是否跟前面那个数相同&am…

SpringBoot项目如何优雅的实现操作日志记录

前言 大家好&#xff0c;我是希留。 在实际开发当中&#xff0c;对于某些关键业务&#xff0c;我们通常需要记录该操作的内容&#xff0c;一个操作调一次记录方法&#xff0c;每次还得去收集参数等等&#xff0c;会造成大量代码重复。 我们希望代码中只有业务相关的操作&…

html5期末大作业 基于HTML+CSS制作dr钻戒官网5个页面 企业网站制作

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

Java中的多重继承问题

继承是面向对象编程 &#xff08;OOP&#xff09; 语言&#xff08;如Java&#xff09;的主要功能之一。它是一种以增强软件设计中类重用能力的方式组织类的基本技术。多重继承是众多继承类型之一&#xff0c;是继承机制的重要原则。但是&#xff0c;它因在类之间建立模棱两可的…

使用HTML实现一个静态页面(含源码)

&#x1f389;精彩专栏推荐 &#x1f4ad;文末获取联系 ✍️ 作者简介: 一个热爱把逻辑思维转变为代码的技术博主 &#x1f482; 作者主页: 【主页——&#x1f680;获取更多优质源码】 &#x1f393; web前端期末大作业&#xff1a; 【&#x1f4da;毕设项目精品实战案例 (10…

什么密码,永远无法被黑客攻破?

在开始本文前&#xff0c;先给大家出个解谜题&#xff0c;密码是一句英文&#xff0c;开动你的脑筋吧&#xff0c;我们在本文结尾会揭晓答案&#xff1a; 密文&#xff1a;Cigumpz yin hvq se 提示&#xff1a;和身份有关的一切 说起破译密码&#xff0c;就不得不提一个人&a…

Vue3中vite.config.js文件相关配置和mock数据配置

文章目录1. vite.config.js文件相关配置2. 路径别名3. mock数据配置1. vite.config.js文件相关配置 import { defineConfig } from vite import vue from vitejs/plugin-vue import vueJsx from vitejs/plugin-vue-jsx import path from path// https://vitejs.dev/config/ ex…

简单的股票行情演示(二) - AKShare

一、概述二、环境搭建三、使用总结 1、API文档2、数据字典3、效果截图4、后台服务四、相关文章原文链接&#xff1a;简单的股票行情演示&#xff08;二&#xff09; - akshare 一、概述 上一篇文章简单的股票行情演示&#xff08;一&#xff09; - 实时标的数据中讲述了从新浪…

web前端期末大作业 HTML+CSS+JavaScript仿安踏

⛵ 源码获取 文末联系 ✈ Web前端开发技术 描述 网页设计题材&#xff0c;DIVCSS 布局制作,HTMLCSS网页设计期末课程大作业 | 在线商城购物 | 水果商城 | 商城系统建设 | 多平台移动商城 | H5微商城购物商城项目 | HTML期末大学生网页设计作业&#xff0c;Web大学生网页 HTML&a…

连续仨月霸占牛客榜首,京东T8呕心巨作:700页JVM虚拟机实战手册

什么是Java虚拟机 虚拟机是一种抽象化的计算机&#xff0c;通过在实际的计算机上仿真模拟各种计算机功能来实现的。Java虚拟机有自己完善的硬体架构&#xff0c;如处理器、堆栈、寄存器等&#xff0c;还具有相应的指令系统。JVM屏蔽了与具体操作系统平台相关的信息&#xff0c…

Linux下 生成coredump文件

一. coredump文件路径 网上很多博文说到 coredump 文件默认会在默认的目录下生成。 按照网上很多的说法&#xff0c;再运行程序就会生成core文件&#xff0c;一般路径和可执行程序一个路径。 但是&#xff0c;我尝试在 ubuntu20.04系统下&#xff0c;怎么也找不到去哪里了&a…