【深度强化学习】(8) iPPO 模型解析,附Pytorch完整代码

news2024/12/29 8:21:18

大家好,今天和各位分享一下多智能体深度强化学习算法 ippo,并基于 gym 环境完成一个小案例。完整代码可以从我的 GitHub 中获得:https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model


1. 算法原理

多智能体的情形相比于单智能体更加复杂,因为每个智能体在和环境交互的同时也在和其他智能体进行直接或者间接的交互。因此,多智能体强化学习要比单智能体更困难,其难点主要体现在以下几点:

(1)由于多个智能体在环境中进行实时动态交互,并且每个智能体在不断学习并更新自身策略,因此在每个智能体的视角下,环境是非稳态的,即对于一个智能体而言,即使在相同的状态下采取相同的动作,得到的状态转移和奖励信号的分布可能在不断改变;

(2)多个智能体的训练可能是多目标的,不同智能体需要最大化自己的利益;

(3)训练评估的复杂度会增加,可能需要大规模分布式训练来提高效率。

iPPO 算法的模型部分和 PPO 类似,可以看我下面这篇博文:

https://blog.csdn.net/dgvv4/article/details/129496576?spm=1001.2014.3001.5501


IPPO(Independent PPO)是一种完全去中心化的算法,此类算法被称为独立学习。由于对于每个智能体使用单智能体算法 PPO 进行训练,所因此这个算法叫作独立 PPO 算法。

这里使用的 PPO 算法版本为 PPO-截断,其算法流程如下:


2. 代码实现

代码和 ppo 离散模型基本相同

# 和PPO离散模型基本一致
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np

# ----------------------------------------- #
# 策略网络--actor
# ----------------------------------------- #

class PolicyNet(nn.Module):  # 输入当前状态,输出动作的概率分布
    def __init__(self, n_states, n_hiddens, n_actions):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_states, n_hiddens)
        self.fc2 = nn.Linear(n_hiddens, n_hiddens)
        self.fc3 = nn.Linear(n_hiddens, n_actions)
    def forward(self, x):  # [b,n_states]
        x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc2(x)  # [b,n_hiddens]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc3(x)  # [b,n_hiddens]-->[b,n_actions]
        x = F.softmax(x, dim=1)  # 每种动作选择的概率
        return x

# ----------------------------------------- #
# 价值网络--critic
# ----------------------------------------- #

class ValueNet(nn.Module):  # 评价当前状态的价值
    def __init__(self, n_states, n_hiddens):
        super(ValueNet, self).__init__()
        self.fc1 = nn.Linear(n_states, n_hiddens)
        self.fc2 = nn.Linear(n_hiddens, n_hiddens)
        self.fc3 = nn.Linear(n_hiddens, 1)
    def forward(self, x):  # [b,n_states]
        x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc2(x)  # [b,n_hiddens]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc3(x)  # [b,n_hiddens]-->[b,1]
        return x

# ----------------------------------------- #
# 模型构建
# ----------------------------------------- #

class PPO:
    def __init__(self, n_states, n_hiddens, n_actions,
                 actor_lr, critic_lr, 
                 lmbda, eps, gamma, device):
        # 属性分配
        self.n_hiddens = n_hiddens
        self.actor_lr = actor_lr  # 策略网络的学习率
        self.critic_lr = critic_lr  # 价值网络的学习率
        self.lmbda = lmbda  # 优势函数的缩放因子
        self.eps = eps  # ppo截断范围缩放因子
        self.gamma = gamma  # 折扣因子
        self.device = device
        # 网络实例化
        self.actor = PolicyNet(n_states, n_hiddens, n_actions).to(device)  # 策略网络
        self.critic = ValueNet(n_states, n_hiddens).to(device)  # 价值网络
        # 优化器
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
    
    # 动作选择
    def take_action(self, state):  # [n_states]
        state = torch.tensor([state], dtype=torch.float).to(self.device)  # [1,n_states]
        probs = self.actor(state)  # 当前状态的动作概率 [b,n_actions]
        action_dist = torch.distributions.Categorical(probs)  # 构造概率分布
        action = action_dist.sample().item()  # 从概率分布中随机取样 int
        return action
    
    # 训练
    def update(self, transition_dict):
        # 取出数据集
        states = torch.tensor(transition_dict['states'], dtype=torch.float).to(self.device)  # [b,n_states]
        actions = torch.tensor(transition_dict['actions']).view(-1,1).to(self.device)  # [b,1]
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float).to(self.device)  # [b,n_states]
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1,1).to(self.device)  # [b,1]
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1,1).to(self.device)  # [b,1]

        # 价值网络
        next_state_value = self.critic(next_states)  # 下一时刻的state_value  [b,1]
        td_target = rewards + self.gamma * next_state_value * (1-dones)  # 目标--当前时刻的state_value  [b,1]
        td_value = self.critic(states)  # 预测--当前时刻的state_value  [b,1]
        td_delta = td_value - td_target  # 时序差分  # [b,1]

        # 计算GAE优势函数,当前状态下某动作相对于平均的优势
        advantage = 0  # 累计一个序列上的优势函数
        advantage_list = []  # 存放每个时序的优势函数值
        td_delta = td_delta.cpu().detach().numpy()  # gpu-->numpy
        for delta in td_delta[::-1]:  # 逆序取出时序差分值
            advantage = self.gamma * self.lmbda * advantage + delta
            advantage_list.append(advantage)  # 保存每个时刻的优势函数
        advantage_list.reverse()  # 正序
        advantage = torch.tensor(advantage_list, dtype=torch.float).to(self.device)

        # 计算当前策略下状态s的行为概率 / 在之前策略下状态s的行为概率
        old_log_probs = torch.log(self.actor(states).gather(1,actions))  # [b,1]
        log_probs = torch.log(self.actor(states).gather(1,actions))
        ratio = log_probs / old_log_probs

        # clip截断
        surr1 = ratio * advantage
        surr2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantage
        
        # 损失计算
        actor_loss = torch.mean(-torch.min(surr1, surr2))  # clip截断
        critic_loss = torch.mean(F.mse_loss(td_value, td_target))  # 
        # 梯度更新
        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()
        actor_loss.backward()
        critic_loss.backward()
        self.actor_optimizer.step()
        self.critic_optimizer.step()

3. 案例演示

ma-gym 库中的 Combat 环境。Combat 是一个在二维的格子世界上进行的两个队伍的对战模拟游戏,每个智能体的动作集合为:向四周移动格,攻击周围格范围内其他敌对智能体,或者不采取任何行动。起初每个智能体有 3 点生命值,如果智能体在敌人的攻击范围内被攻击到了,则会扣 1 生命值,生命值掉为 0 则死亡,最后存活的队伍获胜。每个智能体的攻击有一轮的冷却时间。

IPPO 代码实践的最主要部分。值得注意的是,在训练时使用了参数共享(parameter sharing)的技巧,即对于所有智能体使用同一套策略参数,这样做的好处是能够使得模型训练数据更多,同时训练更稳定。能够这样做的前提是,两个智能体是同质的(homogeneous),即它们的状态空间和动作空间是完全一致的,并且它们的优化目标也完全一致。感兴趣的读者也可以自行实现非参数共享版本的 IPPO,此时每个智能体就是一个独立的 PPO 的实例。

import numpy as np
import matplotlib.pyplot as plt
import torch
from ma_gym.envs.combat.combat import Combat
from RL_brain import PPO
import time

# ----------------------------------------- #
# 参数设置
# ----------------------------------------- #

n_hiddens = 64  # 隐含层数量
actor_lr = 3e-4
critic_lr = 1e-3
gamma = 0.9
lmbda = 0.97
eps = 0.2
device = torch.device('cuda') if torch.cuda.is_available() \
                            else torch.device('cpu')
num_episodes = 10  # 回合数
team_size = 2  # 智能体数量
grid_size = (15, 15)

# ----------------------------------------- #
# 环境设置--onpolicy
# ----------------------------------------- #

# 创建Combat环境,格子世界的大小为15x15,己方智能体和敌方智能体数量都为2
env = Combat(grid_shape=grid_size, n_agents=team_size, n_opponents=team_size)
n_states = env.observation_space[0].shape[0]  # 状态数
n_actions = env.action_space[0].n  # 动作数

# 两个智能体共享同一个策略
agent = PPO(n_states = n_states,
            n_hiddens = n_hiddens,
            n_actions = n_actions,
            actor_lr = actor_lr,
            critic_lr = critic_lr,
            lmbda = lmbda,
            eps = eps,
            gamma = gamma,
            device = device,
            )

# ----------------------------------------- #
# 模型训练
# ----------------------------------------- #

for i in range(num_episodes):
    # 每回合开始前初始化两支队伍的数据集
    transition_dict_1 = {
        'states': [],
        'actions': [],
        'next_states': [],
        'rewards': [],
        'dones': [],
    }
    transition_dict_2 = {
        'states': [],
        'actions': [],
        'next_states': [],
        'rewards': [],
        'dones': [],
    }

    s = env.reset()  # 状态初始化
    terminal = False  # 结束标记

    while not terminal:

        env.render()

        # 动作选择
        a_1 = agent.take_action(s[0])
        a_2 = agent.take_action(s[1])

        # 环境更新
        next_s, r, done, info = env.step([a_1, a_2])

        # 构造数据集
        transition_dict_1['states'].append(s[0])
        transition_dict_1['actions'].append(a_1)
        transition_dict_1['next_states'].append(next_s[0])
        transition_dict_1['dones'].append(False)
        transition_dict_1['rewards'].append(r[0])

        transition_dict_2['states'].append(s[1])
        transition_dict_2['actions'].append(a_2)
        transition_dict_2['next_states'].append(next_s[1])
        transition_dict_2['dones'].append(False)
        transition_dict_2['rewards'].append(r[1])

        s = next_s  # 状态更新
        terminal = all(done)  # 判断当前回合是否都为True,是返回True,不是返回False

        time.sleep(0.1)
    
    print('epoch:', i)

    # 回合训练
    agent.update(transition_dict_1)
    agent.update(transition_dict_2)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/417893.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

SpringCloud GateWay与Nacos使用

网关就相当于一个内网与外网的出入口,起着 安全、验证的功能,如果没有网关,那么如果需要实现验证的功能,除非 SpringCloud GateWay 作为微服务的网关,起着如下作用 ① 作为所有API接口服务请求的接入点 ② 作为所有后端业务服务…

SpringBoot 整合 RabbitMQ (四十一)

二八佳人体似酥,腰间仗剑斩愚夫。虽然不见人头落,暗里教君骨髓枯。 上一章简单介绍了SpringBoot 实现 Web 版本控制 (四十),如果没有看过,请观看上一章 关于消息中间件 RabbitMQ, 可以看老蝴蝶之前的文章: https://blog.csdn.net/yjltx1234csdn/categor…

还不懂如何与AI高效交流?保姆级且全面的chatGPT提示词工程教程来啦!(一)基础篇

还不懂如何与chatGPT高效交流?保姆级且全面的chatGPT提示词工程教程来啦!(一)基础篇 文章目录还不懂如何与chatGPT高效交流?保姆级且全面的chatGPT提示词工程教程来啦!(一)基础篇一&…

CDH6.3.2大数据集群生产环境安装(七)之PHOENIX组件安装

添加phoenix组件 27.1. 准备安装资源包 27.2. 拷贝资源包到相应位置 拷贝PHOENIX-1.0.jar到/opt/cloudera/csd/ 拷贝PHOENIX-5.0.0-cdh6.2.0.p0.1308267-el7.parcel.sha、PHOENIX-5.0.0-cdh6.2.0.p0.1308267-el7.parcel到/opt/cloudera/parcel-repo 27.3. 进入cm页面进行分发、…

【AIGC】9、BLIP-2 | 使用 Q-Former 连接冻结的图像和语言模型 实现高效图文预训练

文章目录一、背景二、方法2.1 模型结构2.2 从 frozen image encoder 中自主学习 Vision-Language Representation2.3 使用 Frozen LLM 来自主学习 Vision-to-Language 生成2.4 Model pre-training三、效果四、局限性论文:BLIP-2: Bootstrapping Language-Image Pre-…

unity 序列化那些事,支持Dictionary序列化

目录 一、普通类型和UnityEngine空间类型序列化 二、数组、list的序列化 三、自定义类的序列化支持 四、自定义asset 五、在inspector面板中支持Dictionary序列化 1、在MonoBehaviour中实现Dictionary序列化 2、自定义property,让其在inpsector能够显示 3、Mo…

【从零开始学Skynet】实战篇《球球大作战》(七):gateway代码设计(下)

1、确认登录接口 在完成了登录流程后,login会通知gateway(第⑧阶段),让它把客户端连接和新agent(第⑨阶段)关联起来。 sure_agent代码如下所示: s.resp.sure_agent function(source, fd, play…

[Gitops--1]GitOps环境准备

GitOps环境准备 1. 主机规划 序号主机名主机ip主机功能软件1dev192.168.31.1开发者 项目代码 apidemogit,golang,goland2gitlab192.168.31.14代码仓库,CI操作git-lab,git,golang,docker,gitlab-runner3harbor192.168.31.104管理和存储镜像docker,docker-compose,harbor4k8s-m…

基础排序算法【计数排序】非比较排序

基础排序算法【计数排序】非比较排序⏰【计数排序】🕐计数🕦排序🕓测试⏰总结:⏰【计数排序】 计数排序又称为鸽巢原理,是对哈希直接定址法的变形应用 > 基本思路: 1.统计数据出现的次数 2.根据统计的结…

并行分布式计算 并行算法与并行计算模型

文章目录并行分布式计算 并行算法与并行计算模型基础知识定义与描述复杂性度量同步和通讯并行计算模型PRAM 模型异步 PRAM 模型 (APRAM)BSP 模型LogP 模型层次存储模型分层并行计算模型并行分布式计算 并行算法与并行计算模型 基础知识 定义与描述 并…

15个最适合初创公司创始人使用的生产力工具

创业是一段激动人心且收获颇丰的旅程,同时也伴随着一些挑战。创始人往往要面对长时间的工作、紧迫的期限和大量的压力时刻。因此,初创公司创始人必须最大限度地利用他们的时间并利用他们可用的生产力工具——不仅是为了发展他们的业务,而且是…

Cron表达式简单介绍 + Springboot定时任务的应用

前言 表达式是一个字符串,主要分成6或7个域,但至少需要6个域组成,且每个域之间以空格符隔开。 以7个域组成的,从右往左是【年 星期 月份 日期 小时 分钟 秒钟】 秒 分 时 日 月 星期 年 以6个域组成的,从右往左是【星…

【精华】表格识别技术-MI

表格识别是指将图片中的表格结构和文字信息识别成计算机可以理解的数据格式,在办公、商务、教育等场景中有着广泛的实用价值,也一直是文档分析研究中的热点问题。围绕这个问题,我们研发了一套表格识别算法,该算法高效准确地提取图…

RabbitMq 的消息可靠性问题(二)---MQ的消息丢失和consumer消费问题

前言 RabbitMq 消息可靠性问题(一) — publisher发送时丢失 前面我们从publisher的方向出发解决了发送时丢失的问题,那么我们在发送消息到exchange, 再由exchange转存到queue的过程中。如果MQ宕机了,那么我们的消息是如何确保可靠性的呢?当消…

SQL的函数

文章目录一、SQL MIN() Function二、SQL SUM() 函数三、SQL GROUP BY 语句四、SQL HAVING 子句五、SQL EXISTS 运算符六、SQL UCASE() 函数总结一、SQL MIN() Function MIN() 函数返回指定列的最小值。 SQL MIN() 语法 SELECT MIN(column_name) FROM table_name;演示数据库 …

Numba witch makes Python code fast

一. 前言:numba,让python速度提升百倍 python由于它动态解释性语言的特性,跑起代码来相比java、c要慢很多,尤其在做科学计算的时候,十亿百亿级别的运算,让python的这种劣势更加凸显。 办法永远比困难多&a…

ASP.NET Core MVC 从入门到精通之接化发(二)

随着技术的发展,ASP.NET Core MVC也推出了好长时间,经过不断的版本更新迭代,已经越来越完善,本系列文章主要讲解ASP.NET Core MVC开发B/S系统过程中所涉及到的相关内容,适用于初学者,在校毕业生&#xff0c…

4.13实验 加测试题目

今天是个好日子,要搞栈的实验 没啥就是链栈和顺序栈 和出栈入栈,强大都是从最基本开始的 来和我一起写写吧 //顺序栈 typedef struct node{int *base;int *top;int sizer; }shed;//链栈 typedef struct Node{ int data; struct Node* next; }*stact,link; //顺序栈的初始化…

《绝对坦率》速读笔记

文章目录书籍信息概览(第一部分 一种新的管理哲学)建立坦率的关系给予并鼓励指导了解团队中每个人的动机协同创造成果(第二部分 工具和技巧)关系指导团队结果书籍信息 书名:《绝对坦率:一种新的管理哲学》…

北邮22信通:(12)二叉树的遍历书上代码完整版

北邮22信通一枚~ 跟随课程进度每周更新数据结构与算法的代码和文章 持续关注作者 解锁更多邮苑信通专属代码~ 上一篇文章: 下一篇文章: 目录 一.储存最简单数据类型的二叉树 代码部分: 代码效果: 运行结果&#xff1a…