强化学习-python案例

news2024/9/30 22:37:14

强化学习是一种机器学习方法,旨在通过与环境的交互来学习最优策略。它的核心概念是智能体(agent)在环境中采取动作,从而获得奖励或惩罚。智能体的目标是最大化长期奖励,通过试错的方式不断改进其决策策略。

在强化学习中,智能体观察当前状态,选择动作,并根据环境反馈(奖励和下一个状态)调整其策略。常见的强化学习算法包括Q-learning、策略梯度方法和深度强化学习等。强化学习广泛应用于游戏、机器人控制、推荐系统等领域。

  1. 奖励(Reward)
    r t = R ( s t , a t ) r_t = R(s_t, a_t) rt=R(st,at)
    其中 r t r_t rt 是在时间步 t t t 时,智能体在状态 s t s_t st 下采取动作 a t a_t at 所获得的奖励。

  2. 状态价值函数(State Value Function)
    V ( s ) = E [ ∑ t = 0 ∞ γ t r t ∣ s 0 = s ] V(s) = \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s \right] V(s)=E[t=0γtrts0=s]
    其中 V ( s ) V(s) V(s) 是状态 s s s 的价值, γ \gamma γ 是折扣因子 ( 0 ≤ γ < 1 ( 0 \leq \gamma < 1 (0γ<1),表示未来奖励的重要性。

  3. 动作价值函数(Action Value Function)
    Q ( s , a ) = E [ ∑ t = 0 ∞ γ t r t ∣ s 0 = s , a 0 = a ] Q(s, a) = \mathbb{E} \left[ \sum_{t=0}^{\infty} \gamma^t r_t \mid s_0 = s, a_0 = a \right] Q(s,a)=E[t=0γtrts0=s,a0=a]
    其中 Q ( s , a ) Q(s, a) Q(s,a) 是在状态 s s s 下采取动作 a a a 的价值。

  4. 贝尔曼方程(Bellman Equation)

    • 状态价值函数的贝尔曼方程:
      V ( s ) = ∑ a π ( a ∣ s ) ∑ s ′ , r P ( s ′ , r ∣ s , a ) [ r + γ V ( s ′ ) ] V(s) = \sum_{a} \pi(a \mid s) \sum_{s', r} P(s', r \mid s, a) \left[ r + \gamma V(s') \right] V(s)=aπ(as)s,rP(s,rs,a)[r+γV(s)]
    • 动作价值函数的贝尔曼方程:
      Q ( s , a ) = ∑ s ′ , r P ( s ′ , r ∣ s , a ) [ r + γ max ⁡ a ′ Q ( s ′ , a ′ ) ] Q(s, a) = \sum_{s', r} P(s', r \mid s, a) \left[ r + \gamma \max_{a'} Q(s', a') \right] Q(s,a)=s,rP(s,rs,a)[r+γamaxQ(s,a)]
  5. 策略(Policy)
    π ( a ∣ s ) = P ( a ∣ s ) \pi(a \mid s) = P(a \mid s) π(as)=P(as)
    其中 π ( a ∣ s ) \pi(a \mid s) π(as) 是在状态 s s s 下选择动作 a a a 的概率。

目标函数

  1. 策略梯度目标函数
    J ( θ ) = E τ ∼ π θ [ ∑ t = 0 T r t ] J(\theta) = \mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} r_t \right] J(θ)=Eτπθ[t=0Trt]
    • 说明 J ( θ ) J(\theta) J(θ) 是关于策略参数 θ \theta θ 的目标函数,表示在策略 π θ \pi_\theta πθ 下,执行轨迹 τ \tau τ 的预期总奖励。目标是最大化该期望值,通常通过梯度上升方法进行优化。

损失函数

  1. 策略损失函数(使用REINFORCE算法):
    L ( θ ) = − E τ ∼ π θ [ ∑ t = 0 T r t log ⁡ π θ ( a t ∣ s t ) ] L(\theta) = -\mathbb{E}_{\tau \sim \pi_\theta} \left[ \sum_{t=0}^{T} r_t \log \pi_\theta(a_t \mid s_t) \right] L(θ)=Eτπθ[t=0Trtlogπθ(atst)]

    • 说明:这个损失函数的目的是最小化负的期望总奖励。通过优化该损失函数,可以最大化目标函数 J ( θ ) J(\theta) J(θ)。这里的 log ⁡ π θ ( a t ∣ s t ) \log \pi_\theta(a_t \mid s_t) logπθ(atst) 是对策略的对数概率,表示在状态 s t s_t st 下采取动作 a t a_t at 的可能性。
  2. 价值函数损失(对于Q-learning):
    L ( θ ) = E [ ( r t + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ ) − Q ( s , a ; θ ) ) 2 ] L(\theta) = \mathbb{E} \left[ \left( r_t + \gamma \max_{a'} Q(s', a'; \theta) - Q(s, a; \theta) \right)^2 \right] L(θ)=E[(rt+γamaxQ(s,a;θ)Q(s,a;θ))2]

    • 说明:该损失函数用于最小化当前动作价值函数 Q ( s , a ; θ ) Q(s, a; \theta) Q(s,a;θ) 和目标价值 r t + γ max ⁡ a ′ Q ( s ′ , a ′ ; θ ) r_t + \gamma \max_{a'} Q(s', a'; \theta) rt+γmaxaQ(s,a;θ) 之间的均方误差。通过最小化该损失,更新网络参数 θ \theta θ 以更准确地预测价值。

细节总结

  • 目标函数:用于衡量当前策略的性能,指导优化过程。强化学习的目标是通过更新策略来最大化期望奖励。
  • 损失函数:是优化过程中实际最小化的函数,直接反映模型的学习效果。损失函数的设计直接影响学习的效率和效果。

这些公式是强化学习中策略优化和价值评估的核心,理解它们有助于深入掌握强化学习的理论基础和应用。

代码

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# 环境假设
class SimpleEnv:
    def reset(self):
        return np.random.rand(4)  # 随机状态

    def step(self, action):
        next_state = np.random.rand(4)
        reward = np.random.rand()  # 随机奖励
        done = np.random.rand() > 0.9  # 随机结束
        return next_state, reward, done

# 策略网络
class PolicyNetwork(nn.Module):
    def __init__(self):
        super(PolicyNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 128),
            nn.ReLU(),
            nn.Linear(128, 2),  # 假设有两个动作
        )

    def forward(self, x):
        return torch.softmax(self.fc(x), dim=-1)

# 计算折扣奖励
def compute_discounted_rewards(rewards, discount_factor=0.99):
    discounted_rewards = []
    cumulative_reward = 0
    for r in reversed(rewards):
        cumulative_reward = r + cumulative_reward * discount_factor
        discounted_rewards.insert(0, cumulative_reward)
    return discounted_rewards

# 训练函数
def train(env, policy_net, optimizer, episodes=1000):
    for episode in range(episodes):
        state = env.reset()
        rewards = []
        log_probs = []
        
        while True:
            state_tensor = torch.FloatTensor(state)
            probs = policy_net(state_tensor)
            action = np.random.choice(len(probs), p=probs.detach().numpy())
            log_prob = torch.log(probs[action])
            
            next_state, reward, done = env.step(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
            
            if done:
                break
        
        # 计算折扣奖励
        discounted_rewards = compute_discounted_rewards(rewards)
        
        # 更新策略
        optimizer.zero_grad()
        loss = -sum(log_prob * reward for log_prob, reward in zip(log_probs, discounted_rewards))
        loss.backward()
        optimizer.step()

        # 输出每个回合的总奖励
        total_reward = sum(rewards)
        print(f"Episode {episode + 1}, Total Reward: {total_reward:.2f}")

# 测试函数
def test(env, policy_net, episodes=10):
    for episode in range(episodes):
        state = env.reset()
        total_reward = 0
        
        while True:
            state_tensor = torch.FloatTensor(state)
            with torch.no_grad():
                probs = policy_net(state_tensor)
            action = torch.argmax(probs).item()
            next_state, reward, done = env.step(action)
            total_reward += reward
            state = next_state
            
            if done:
                break
        
        print(f"Test Episode {episode + 1}, Total Reward: {total_reward:.2f}")

# 主程序
env = SimpleEnv()
policy_net = PolicyNetwork()
optimizer = optim.Adam(policy_net.parameters(), lr=0.01)

train(env, policy_net, optimizer)
test(env, policy_net)

在这里插入图片描述

训练奖励图:显示每个训练回合的总奖励变化,帮助评估模型在训练过程中的学习效果。
测试奖励图:展示在测试回合中模型的总奖励,反映训练后的表现。

代码结构

  1. 环境(Environment)

    • SimpleEnv 类:模拟一个简单的环境,包含 resetstep 方法。
      • reset():初始化并返回一个随机状态。
      • step(action):根据所采取的动作返回下一个状态、奖励和是否结束标志。
      • 奖励和结束状态是随机生成的,模拟了一个非常简化的环境。
  2. 策略网络(Policy Network)

    • PolicyNetwork 类:定义一个神经网络,用于近似策略。
      • 使用全连接层,输入状态维度为 4(环境状态的维度),输出动作概率的维度为 2(假设有两个可能的动作)。
      • forward 方法通过 softmax 函数输出每个动作的概率。
  3. 折扣奖励计算

    • compute_discounted_rewards(rewards, discount_factor=0.99):计算每个时间步的折扣奖励。
      • 从后往前遍历奖励列表,使用折扣因子更新累计奖励,生成折扣奖励列表。
  4. 训练函数(Training Function)

    • train(env, policy_net, optimizer, episodes=1000):进行训练的主函数。
      • 循环执行指定的回合数:
        • 重置环境,初始化奖励和日志概率列表。
        • 在回合中循环,使用当前状态选择动作并记录日志概率和奖励。
        • 计算并更新策略网络的损失,使用反向传播更新参数。
        • 每个回合结束后打印总奖励,帮助监控训练进度。
  5. 测试函数(Testing Function)

    • test(env, policy_net, episodes=10):用于评估训练后模型表现的函数。
      • 重置环境并执行多个测试回合,选择最大概率的动作。
      • 累计并打印每个测试回合的总奖励,评估训练的效果。
  6. 主程序

    • 创建环境和策略网络实例,定义优化器(Adam)。
    • 调用训练函数进行训练,然后调用测试函数进行评估。

整体逻辑

  1. 环境设置:定义了一个非常简单的环境,主要用于演示如何应用策略梯度方法。实际应用中,可以替换为更复杂的环境,比如OpenAI的Gym库中的环境。

  2. 策略学习:使用神经网络近似策略,通过与环境的交互收集状态、动作、奖励,并更新网络参数,以优化策略。

  3. 输出和评估:通过在训练过程中的总奖励输出和测试过程中的评估,可以观察到模型的学习进展。

小结

这段代码是一个简单的强化学习示例,展示了如何使用策略梯度方法和PyTorch进行训练和测试。虽然环境和任务是简化的,但它提供了一个良好的基础,便于理解强化学习的核心概念和实现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2180903.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

传统操作系统和分布式操作系统的区别

分布式操作系统和传统操作系统之间的区别&#xff0c;根植于它们各自的设计哲学和目标。要理解这些差异&#xff0c;需要从操作系统的基本定义、结构、功能以及它们在不同计算环境中的表现进行分析。每种系统都试图解决特定的计算挑战&#xff0c;因此在不同的使用场景下具有各…

互斥量mutex、锁、条件变量和信号量相关原语(函数)----很全

线程相关知识可以看这里: 线程控制原语(函数)的介绍-CSDN博客 进程组、会话、守护进程和线程的概念-CSDN博客 1.同步概念 所谓同步&#xff0c;即同时起步&#xff0c;协调一致。不同的对象&#xff0c;对“同步”的理解方式略有不同。如&#xff0c;设备同步&#xff0c;是…

前端——js补充

一、数学对象 1.随机数 // 0-1 console.log(Math.random()); // 0-9 console.log(Math.random() * 9); // 6-13 console.log(Math.random() * (13 - 6) 6); //n-m Math.random() * (m - n) n 2.取整 // 向下取整 console.log(Math.floor(1.9));//1 // 向上取整 console.log(…

解决端口被占用

当你被你的编译器提醒&#xff0c; 当前端口被占用&#xff0c; 但明明你的服务什么的都没有启用&#xff0c;这时有三种解决办法&#xff1a; 1 。 重启 。 重启解决80%的问题 2 。 修改你的端口号 。 3 。 去windows命令行中查看&#xff0c;端口占用情况 第一步 &#xf…

OpenStack Yoga版安装笔记(十五)Horizon安装

1、官方文档 OpenStack Installation Guidehttps://docs.openstack.org/install-guide/ 本次安装是在Ubuntu 22.04上进行&#xff0c;基本按照OpenStack Installation Guide顺序执行&#xff0c;主要内容包括&#xff1a; 环境安装 &#xff08;已完成&#xff09;OpenStack…

[卸载] 软件彻底卸载工具的下载及详细安装使用过程(附有下载文件)

一般软件安装的有问题&#xff0c;或者想重新安装其他版本就需要将原来的版本删除干净&#xff0c;但常常删不干净&#xff0c;本文分享一个软件彻底卸载工具&#xff0c;完成彻底卸载软件的工作 下载链接在文末 下载压缩包后解压 &#xff01;&#xff01;安装路径不要有中文…

激光切割机适用材质有哪些

激光切割机是一种利用激光束对各种材料进行高精度、高速度切割的机器设备。其适用材质广泛&#xff0c;包括但不限于以下两大类&#xff1a; 一、金属材料 不锈钢&#xff1a;激光切割机较容易切割不锈钢薄板&#xff0c;使用高功率YAG激光切割系统&#xff0c;切割不锈钢板的…

AMD Instinct™ MI200 GPU内存空间概述

AMD Instinct™ MI200 GPU memory space overview — ROCm Blogs 注意: 本博客之前是 AMD实验室笔记博客系列的一部分。 HIP API 支持在加速系统上为主机和设备内存提供多种分配方式。在本文中&#xff0c;我们将&#xff1a; 1. 介绍一组常用的内存空间 2. 识别每种内存空间的…

Kubernetes强制删除terminating状态的namespace

Kubernetes中的Namespace处于Terminating状态并且常规删除不起作用。 1.Namespace长时间处于Terminating状态往往是因为某些finalizers阻止了它的删除。 kubectl get namespace <namespace-name> -o json > namespace.json 2.编辑生成的 namespace.json文件&#xff…

今日指数项目A股大盘数据采集

1、A股大盘数据采集 1.1 A股大盘数据采集准备 1.1.1 配置ID生成器bean A股大盘数据采集入库时&#xff0c;主键ID保证唯一&#xff0c;所以在stock_job工程配置ID生成器&#xff1a; Configuration public class CommonConfig {/*** 配置基于雪花算法生成全局唯一id* 参与…

打点 - 泛微 E-Cology WorkflowServiceXml

请求路径 /services%20/WorkflowServiceXml显示如下&#xff0c;漏洞可能存在 利用&#xff1a; 根据提示在 CMD 处输入 Memshell 注入内存马&#xff0c;并点击执行&#xff0c;成功注入 冰蝎配置&#xff0c;输入内存马地址 成功连接 命令执行

2024/9/30 英语每日一段

The British Academy has created three high-profile awards to sit alongside the trophies it hands out to adult television shows--going some way, it is hoped, to replace Bafta’s abandoned children’s TV awards event. “Children’s programme-making has been …

【RockyLinux 9.4】安装 NVIDIA 驱动,改变分辨率,避坑版本。(CentOS 系列也能用)

总览 1.下载和解决依赖问题 2.修改相关参数 3.安装过程 一、下载和解决依赖问题 1.下载 去这里看看&#xff0c;填写相关的显卡参数&#xff0c;选择 linux 版本&#xff0c;然后开始下载。 https://www.nvidia.cn/drivers/lookup/ 进入这个选择界面&#xff1a; 开始下载&…

04-指向指针的指针

int main(int argc, const char* argv[]) {int x 5;int* p &x;*p 6;int** q &p;int*** r &q;printf("%d\n", *p);//指向p中的值 6printf("%d\n", *q);// 指向指针p的地址 printf("%d\n", *(*q));//指向p中地址中的值 6print…

3.1K Star,这款开源在线视频下载神器绝了,速度达 30M/S

Hi&#xff0c;骚年&#xff0c;我是大 G&#xff0c;公众号「GitHub 指北」会推荐 GitHub 上有趣有用的项目&#xff0c;一分钟 get 一个优秀的开源项目&#xff0c;挖掘开源的价值&#xff0c;欢迎关注。 在如今的数字时代&#xff0c;无论是个人用户还是企业&#xff0c;都…

学科竞赛管理平台:SpringBoot框架深度开发

摘 要 随着国家教育体制的改革&#xff0c;全国各地举办的竞赛活动数目也是逐年增加&#xff0c;面对如此大的数目的竞赛信息&#xff0c;传统竞赛管理方式已经无法满足需求&#xff0c;为了提高效率&#xff0c;竞赛管理系统应运而生。 本学科竞赛管理系统以实际运用为开发背景…

odoo中查找模型以及继承模型的全部字段

快捷键alt k呼出命令菜单&#xff0c;或者直接按alt h呼出界面如下&#xff1a; 输入模型 按模型的名称搜索 视图、字段在里面都能找到了 或者点击这里

TongESB7, TongGW, admin账号密码重置方式

停止控制台 修改系统库 identities 表 configuration字段中的password 重启manage

常用的英文文献数据库和资源平台

在学术研究中&#xff0c;获取和引用权威的英文文献资源是非常重要的。以下列举了几大最常用的英文文献数据库和资源平台&#xff0c;这些平台广泛收录了各类学术论文、期刊、会议论文、书籍等文献资料&#xff0c;是研究人员和学生常用的工具&#xff1a; 1. Google Scholar …

C嘎嘎入门篇:类和对象(2)

前言&#xff1a; 上一篇小编讲了类和对象&#xff08;1&#xff09;&#xff0c;当然&#xff0c;在看这篇文章之前&#xff0c;读者朋友们一定要掌握好前面的基础内容&#xff0c;因为这篇和前面息息相关&#xff0c;废话不多说&#xff0c;下面小编就加快步伐&#xff0c;开…