【强化学习】——Q-learning算法为例入门Pytorch强化学习

news2024/9/23 7:33:02

🤵‍♂️ 个人主页:@Lingxw_w的个人主页

✍🏻作者简介:计算机研究生在读,研究方向复杂网络和数据挖掘,阿里云专家博主,华为云云享专家,CSDN专家博主、人工智能领域优质创作者,安徽省优秀毕业生
🐋 希望大家多多支持,我们一起进步!😄
如果文章对你有帮助的话,
欢迎评论 💬点赞👍🏻 收藏 📂加关注+ 

目录

1、强化学习是什么

1.1 定义

1.2 基本组成

1.3 马尔可夫决策过程

2、强化学习的应用

3、常见的强化学习算法

3.1 Q-learning算法

3.2 Q-learning的算法步骤

3.3 Pytorch代码实现


1、强化学习是什么

1.1 定义

强化学习(Reinforcement Learning,RL)是一种机器学习方法,其目标是通过智能体(Agent)与环境的交互学习最优行为策略,以使得智能体能够在给定环境中获得最大的累积奖励。

强化学习在许多领域都有应用,例如机器人控制、游戏智能、自动驾驶、资源管理等。通过与环境的交互和试错学习,强化学习使得智能体能够在复杂、不确定的环境中做出优化的决策,并逐步提升性能。

1.2 基本组成

强化学习的基本组成部分包括:

  1. 智能体(Agent):在强化学习中,智能体是学习和决策的主体,它通过与环境的交互来获取知识和经验,并根据获得的奖励信号进行学习和优化。

  2. 环境(Environment):环境是智能体所处的外部世界,它可以是真实的物理环境,也可以是虚拟的模拟环境。智能体通过观察环境的状态,执行动作,并接收来自环境的奖励或惩罚信号。

  3. 状态(State):状态表示环境的某个特定时刻的观察或描述,它包含了智能体需要的所有信息来做出决策。

  4. 动作(Action):动作是智能体在某个状态下采取的行为,它会对环境产生影响并导致状态的转换。

  5. 奖励(Reward):奖励是环境根据智能体的行为给予的反馈信号,用于指导智能体学习合适的策略。奖励可以是正数(奖励)也可以是负数(惩罚),智能体的目标是最大化累积奖励。

1.3 马尔可夫决策过程

(Markov Decision Process,MDP)强化学习中常用的建模框架,用于描述具有马尔可夫性质的序贯决策问题。它是基于马尔可夫链(Markov Chain)和决策理论的组合。

在马尔可夫决策过程中,智能体与环境交互,通过采取一系列动作来影响环境的状态和获得奖励。MDP的关键特点是马尔可夫性质,即当前状态的信息足以决定未来状态的转移概率。这意味着在MDP中,未来的状态和奖励仅取决于当前状态和采取的动作,而与过去的状态和动作无关。 

2、强化学习的应用

强化学习旨在解决以下类型的问题:

  1. 决策问题:强化学习可以用于解决需要做出一系列决策的问题。例如,自动驾驶车辆需要在不同交通情况下选择合适的行驶策略,智能机器人需要学习在复杂环境中执行任务的最佳策略。

  2. 控制问题:强化学习可用于控制系统的优化。例如,通过学习最优策略来调整电力网格的能源分配,或者在金融投资中确定最佳的投资组合。

  3. 资源管理:强化学习可以应用于资源管理问题,如动态网络管理、数据中心的负载平衡、无线通信中的频谱分配等。智能体可以通过与环境的交互来学习如何最优地利用和分配有限的资源。

  4. 序列决策问题:强化学习适用于需要在连续时间步骤中做出决策的问题。例如,在自然语言处理中,可以使用强化学习来训练智能体生成合适的文本回复,或者在推荐系统中根据用户行为动态调整推荐策略。

  5. 探索与开发:强化学习可以用于探索未知环境和发现新知识。通过与环境的交互,智能体可以通过试错学习来积累经验并发现最优策略。

3、常见的强化学习算法

  • Q-learning:一种基于值函数(Q函数)的强化学习算法,通过迭代更新Q值来学习最优策略。
  • SARSA:另一种基于值函数的强化学习算法,与Q-learning类似,但在更新Q值时采用了一种“状态-动作-奖励-下一状态-下一动作(State-Action-Reward-State-Action)”的更新策略。
  • 策略梯度(Policy Gradient):一类直接学习策略函数的方法,通过优化策略函数的参数来提高智能体的性能。
  • 深度强化学习(Deep Reinforcement Learning):将深度学习方法与强化学习相结合,利用神经网络来表示值函数或策略函数,以解决具有高维状态空间的复杂任务。

3.1 Q-learning算法

Q-learning是一种经典的强化学习算法,用于解决马尔可夫决策过程(Markov Decision Process,MDP)的问题。它是基于值函数的强化学习算法,通过迭代地更新Q值来学习最优策略。

在Q-learning中,智能体与环境的交互过程由状态、动作、奖励下一个状态组成。智能体根据当前状态选择一个动作,与环境进行交互,接收到下一个状态和相应的奖励。Q-learning的目标是学习一个Q值函数,它估计在给定状态下采取特定动作所获得的长期累积奖励。

Q值函数表示为Q(s, a),其中s是状态,a是动作。初始时,Q值可以初始化为任意值。Q-learning使用贝尔曼方程(Bellman equation)来更新Q值,以逐步逼近最优的Q值函数:

Q(s, a) = Q(s, a) + α * (r + γ * max[Q(s', a')] - Q(s, a))

在上述方程中,α是学习率(learning rate),决定了每次更新的幅度;r是当前状态下执行动作a所获得的奖励;γ是折扣因子(discount factor),用于权衡即时奖励和未来奖励的重要性;s'是下一个状态,a'是在下一个状态下的最优动作。

3.2 Q-learning的算法步骤

  1. 初始化Q值函数。
  2. 在每个时间步骤中,根据当前状态选择一个动作。
  3. 执行动作,观察奖励和下一个状态。
  4. 根据贝尔曼方程更新Q值函数。
  5. 重复2-4步骤,直到达到预定的停止条件或收敛。

通过多次迭代更新Q值函数,Q-learning最终能够收敛到最优的Q值函数。智能体可以根据Q值函数选择具有最高Q值的动作作为策略,以实现最优的行为决策。

Q-learning是一种基于模型的强化学习方法,不需要对环境的模型进行显式建模,适用于离散状态空间和动作空间的问题。对于连续状态和动作空间的问题,可以通过函数逼近方法(如深度Q网络)来扩展Q-learning算法。

3.3 Pytorch代码实现

基于PyTorch的Q-learning算法来解决OpenAI Gym中的CartPole环境。

首先,导入所需的库,包括gym用于创建环境,random用于随机选择动作,以及torchtorch.nn用于构建和训练神经网络。

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

定义了一个Q网络(QNetwork)作为强化学习算法的近似函数。该网络具有三个全连接层,其中前两个层使用ReLU激活函数,最后一层输出动作值。forward方法用于定义网络的前向传播。

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

接下来,定义了一个QLearningAgent类。在初始化中,指定了状态维度、动作维度、折扣因子和探索率等超参数。同时创建了两个Q网络q_networktarget_networktarget_network用于计算目标Q值。还定义了优化器和损失函数。 

# Q-learning算法
class QLearningAgent:
    def __init__(self, state_dim, action_dim, gamma, epsilon):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率

        # 初始化Q网络和目标网络
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()

        self.optimizer = optim.Adam(self.q_network.parameters())
        self.loss_fn = nn.MSELoss()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.FloatTensor(state)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()

    def train(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return

        # 从回放缓存中采样一个小批量样本
        samples = random.sample(replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # 计算当前状态的Q值
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # 计算下一个状态的Q值
        next_q_values = self.target_network(next_states).max(1)[0]
        expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # 计算损失并更新Q网络
        loss = self.loss_fn(q_values, expected_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

select_action方法用于根据当前状态选择动作。以epsilon的概率选择随机动作,以探索环境;以1-epsilon的概率选择基于当前Q值的最优动作。

# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

train方法用于训练Q网络。它从回放缓存中采样一个小批量样本,并计算当前状态和下一个状态的Q值。然后计算损失并进行优化。

接下来,创建CartPole环境并获取状态和动作的维度。

然后,实例化一个QLearningAgent对象,并设置相关的超参数。

接下来,进行训练循环。在每个回合中,重置环境,然后在每个时间步中执行以下步骤:

  1. 根据当前状态选择一个动作。
  2. 执行动作,观察下一个状态、奖励和终止信号。
  3. 将状态、动作、奖励、下一个状态和终止信号存储在回放缓存中。
  4. 调用agent的train方法进行网络训练。

每隔一定的回合数,通过update_target_network方法更新目标网络的权重。

# 创建Q-learning智能体
agent = QLearningAgent(state_dim, action_dim, gamma=0.99, epsilon=0.2)

# 训练
replay_buffer = []
episodes = 1000
batch_size = 32

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        agent.train(replay_buffer, batch_size)

    if episode % 10 == 0:
        agent.update_target_network()
        print(f"Episode: {episode}, Total Reward: {total_reward}")

最后,使用训练好的智能体进行测试。在测试过程中,根据当前状态选择动作,并执行动作,直到终止信号出现。同时可通过env.render()方法显示环境的图形界面。

# 使用训练好的智能体进行测试
state = env.reset()
done = False
total_reward = 0

while not done:
    env.render()
    action = agent.select_action(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward

print(f"Test Total Reward: {total_reward}")

env.close()

代码执行完毕后,关闭环境并显示测试的总奖励。

总体而言,这段代码实现了基于PyTorch的Q-learning算法,并将其应用于CartPole环境。通过训练,智能体可以学习到一个最优策略,使得杆子保持平衡的时间尽可能长。

汇总的代码:

import gym
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# 定义Q网络
class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Q-learning算法
class QLearningAgent:
    def __init__(self, state_dim, action_dim, gamma, epsilon):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.gamma = gamma  # 折扣因子
        self.epsilon = epsilon  # 探索率

        # 初始化Q网络和目标网络
        self.q_network = QNetwork(state_dim, action_dim)
        self.target_network = QNetwork(state_dim, action_dim)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.target_network.eval()

        self.optimizer = optim.Adam(self.q_network.parameters())
        self.loss_fn = nn.MSELoss()

    def update_target_network(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randint(0, self.action_dim - 1)
        else:
            state = torch.FloatTensor(state)
            q_values = self.q_network(state)
            return torch.argmax(q_values).item()

    def train(self, replay_buffer, batch_size):
        if len(replay_buffer) < batch_size:
            return

        # 从回放缓存中采样一个小批量样本
        samples = random.sample(replay_buffer, batch_size)
        states, actions, rewards, next_states, dones = zip(*samples)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions)
        rewards = torch.FloatTensor(rewards)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones)

        # 计算当前状态的Q值
        q_values = self.q_network(states)
        q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)

        # 计算下一个状态的Q值
        next_q_values = self.target_network(next_states).max(1)[0]
        expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        # 计算损失并更新Q网络
        loss = self.loss_fn(q_values, expected_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

# 创建环境
env = gym.make('CartPole-v1')
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n

# 创建Q-learning智能体
agent = QLearningAgent(state_dim, action_dim, gamma=0.99, epsilon=0.2)

# 训练
replay_buffer = []
episodes = 1000
batch_size = 32

for episode in range(episodes):
    state = env.reset()
    done = False
    total_reward = 0

    while not done:
        action = agent.select_action(state)
        next_state, reward, done, _ = env.step(action)
        replay_buffer.append((state, action, reward, next_state, done))

        state = next_state
        total_reward += reward

        agent.train(replay_buffer, batch_size)

    if episode % 10 == 0:
        agent.update_target_network()
        print(f"Episode: {episode}, Total Reward: {total_reward}")

# 使用训练好的智能体进行测试
state = env.reset()
done = False
total_reward = 0

while not done:
    env.render()
    action = agent.select_action(state)
    state, reward, done, _ = env.step(action)
    total_reward += reward

print(f"Test Total Reward: {total_reward}")

env.close()

 相关博客专栏订阅链接

【机器学习】——房屋销售的探索性数据分析

【机器学习】——数据清理、数据变换、特征工程

【机器学习】——决策树、线性模型、随机梯度下降

【机器学习】——多层感知机、卷积神经网络、循环神经网络

【机器学习】——模型评估、过拟合和欠拟合、模型验证

【机器学习】——模型调参、超参数优化、网络架构搜索

【机器学习】——方差和偏差、Bagging、Boosting、Stacking

【机器学习】——模型调参、超参数优化、网络架构搜索

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/673276.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

神经网络:参数更新

在计算机视觉中&#xff0c;参数更新是指通过使用梯度信息来调整神经网络模型中的参数&#xff0c;从而逐步优化模型的性能。参数更新的作用、原理和意义如下&#xff1a; 1. 作用&#xff1a; 改进模型性能&#xff1a;参数更新可以使模型更好地适应训练数据&#xff0c;提高…

python学习——pandas统计分析基础

目录 pandas统计分析基础1. Series数据2.文件读取csv文件Excel文件 3.DataFrame连接数据库读取数据库存入数据库DataFrame的属性访问DataFrame中的数据【实例1】info详细信息和describe描述统计分析【实例2】 排序【实例3】 布尔索引&#xff0c;条件索引【案例】修改数据 3.描…

LIBSVM与LIBLINEAR支持向量机库对模式识别与回归的可视化代码实践

支持向量机(SVM)是一种流行的分类技术。虽然提出时间到现在有70来年了&#xff0c;但在90年代获得了很好的发展和扩展&#xff0c;在人像识别、文本分类、手写字符识别、生物信息学等模式识别问题中有得到应用。然而&#xff0c;对于不熟悉SVM的初学者来说&#xff0c;往往会因…

ThreadPoolExecutor解读

目录 线程池状态 构造方法 newFixedThreadPool newCachedThreadPool newSingleThreadExecutor 提交任务 关闭线程池 其它方法 线程池状态 ThreadPoolExecutor 使用 int 的高 3 位来表示线程池状态&#xff0c;低 29 位表示线程数量 状态名 高 3 位 接收新任务 处理…

JavaScript ES10新特性

文章目录 导文Array.prototype.flat()和Array.prototype.flatMap()Object.fromEntries()String.prototype.trimStart()和String.prototype.trimEnd()格式化数字动态导入可选的catch绑定BigIntglobalThis 导文 JavaScript ES10&#xff0c;也被称为ES2019&#xff0c;引入了一些…

javascript被禁用怎么办?怎么启用?||如何解决javascript:void(0)的问题?

javascript被禁用怎么办&#xff1f;怎么启用&#xff1f; 有些小伙伴可能因为浏览器弹窗的凌乱而感到烦恼&#xff0c;想要通过浏览器禁用JavaScript的方式来避免这些广告。有些小伙伴则是因为设置了不知名的设置导致JavaScript被禁用&#xff0c;影响日常的使用。接下来的这…

Vue3的计算属性和监听属性

目录 computed 语法介绍 简写版 完整版 watch 介绍 监听ref式数据代码示例 监听reactive式数据 watchEffect函数 computed 语法介绍 与Vue2.x中computed配置功能一致 import {computed} from vuesetup(){...//计算属性——简写let fullName computed(()>{return per…

【kubernetes】部署kubelet与kube-proxy

前言:二进制部署kubernetes集群在企业应用中扮演着非常重要的角色。无论是集群升级,还是证书设置有效期都非常方便,也是从事云原生相关工作从入门到精通不得不迈过的坎。通过本系列文章,你将从虚拟机准备开始,到使用二进制方式从零到一搭建起安全稳定的高可用kubernetes集…

ffmpeg调整音频音量踩坑

前一阵用Flutter结合ffmpeg做了一个音视频合并功能&#xff0c;记录一下遇到的问题。 合并方法 首先是音视频合并命令&#xff1a; ffmpeg -i input.mp4 -i input.mp3 -filter_complex "[1:a]adelay0s:all1[a1];[a1]amixinputs1[amixout]" -map 0:v:0 -map "…

Ts系列之条件类型

Ts系列之条件类型 Ts系列之条件类型前言一、初遇二、条件判断三、待补充 Ts系列之条件类型 前言 本片文章主要对ts条件类型的用法做一个讲解。 一、初遇 1、首先我们来看看一个小例子&#xff1a; interface Person {name: string;age: number; } interface Son extends P…

神经网络:梯度计算

在计算机视觉中&#xff0c;梯度计算是一项关键任务&#xff0c;它在优化算法中扮演着重要的角色。梯度表示函数在某一点上的变化率&#xff0c;可以指导模型参数的更新&#xff0c;使得模型逐步接近最优解。下面我将详细解释梯度计算的作用、原理和意义。 作用&#xff1a; 梯…

C++ 新的类型转换

文章目录 前言一、静态转换&#xff08;static_cast&#xff09;二、动态转换&#xff08;dynamic_cast&#xff09;&#xff1a;三、常量转换&#xff08;const_cast&#xff09;&#xff1a;四、重新解释转换&#xff08;reinterpret_cast&#xff09;&#xff1a;总结 前言 …

基于Java+Swing实现记事本-完美版

基于JavaSwing实现记事本-完美版 一、系统介绍二、功能展示1.主页2.文件功能3.编辑功能4.格式--功能5.查看功能 三、系统实现1. Fontv.java 四、其它1.其他系统实现2.获取源码 一、系统介绍 1.主页 2.文件功能 3.编辑功能 4.格式–功能 5.查看功能 二、功能展示 1.主页 2.文…

一文带你弄清Map集合及其实现类(适合小白秋招篇)

前言&#xff1a; 本篇文章主要讲解Java中的Map集合接口以及相关实现类的知识。该专栏比较适合刚入坑Java的小白以及准备秋招的大佬阅读。 如果文章有什么需要改进的地方欢迎大佬提出&#xff0c;对大佬有帮助希望可以支持下哦~ 小威在此先感谢各位小伙伴儿了&#x1f601; 以…

【Linux】冯诺依曼体系结构 操作系统 进程概念

目录 一、冯诺依曼体系结构 二、操作系统 1、概念 2、设计OS的目的 三、进程 1、基本概念 2、描述进程-PCB 3、组织进程 4、查看进程和终止 5、通过系统调用获取进程标识符 6、通过系统调用创建进程-fork 7、进程状态 8、特殊进程 8.1 僵尸进程 8.2 孤儿进程 一、冯诺依曼体…

【从零开始学习JAVA | 第八篇】String类

目录 前言&#xff1a; String类&#xff1a; 常见的认识误区&#xff1a; 创建String类&#xff1a; 注意点&#xff1a; 总结&#xff1a; 前言&#xff1a; String类是Java中最常见的一个类&#xff0c;本篇将对Stirng类的各种功能进行详细的介绍&#xff0c;各位小伙伴…

js:使用vue-codemirror实现一个语法高亮的网页代码编辑器

codemirror code editor component for vuejs 译文&#xff1a;vuejs的codemirror代码编辑器组件 文档 https://github.com/surmon-china/vue-codemirror 安装 # 依赖 pnpm install codemirror vue-codemirror --save# 语言 pnpm install codemirror/lang-json --save pnpm …

【VulnHub系列】MyFileServer

因为是从PDF转换过来偶尔可能会出现内容缺少&#xff0c;可以看原版PDF&#xff1a;有道云笔记 实验环境 Kali&#xff1a;192.168.10.102 MyFileServer&#xff1a;192.168.10.106 实验过程 通过arp-scan来发现靶机的IP地址 sudo arp-scan --interface eth0 192.168.10.1…

mediapipe 谷歌高效ML框架-图像识别、人脸检测、关键点检测

参考&#xff1a; https://github.com/google/mediapipe https://developers.google.com/mediapipe/solutions/guide 框架也支持cv、nlp、audio等项目&#xff0c;速度很快&#xff1a; 1、图形识别 参考&#xff1a;https://developers.google.com/mediapipe/solutions/vi…

【从零开始学习JAVA | 第九篇】字符串综合练习

前言&#xff1a; 在前一篇我们学习了String类以及两个接口函数&#xff0c;今天我们将利用昨天的知识以及讲解新的方法进行几个实战操作&#xff0c;以此来巩固我们的所学内容。 1.实现用户登录&#xff0c;对用户输入的密码进行验证 需求&#xff1a;已知正确的用户名和密码…