使用PyTorch实现简单的AlphaZero的算法(3):神经网络架构和自学习

news2025/1/12 19:39:13

神经网络架构和训练、自学习、棋盘对称性、Playout Cap Randomization,结果可视化

从我们之前的文章中,介绍了蒙特卡洛树搜索 (MCTS) 的工作原理以及如何使用它来获得给定棋盘状态的输出策略。我们也理解神经网络在 MCTS 中的两个主要作用;通过神经网络的策略输出来指导探索,并使用其价值输出代替传统的蒙特卡洛rollout算法。

在这一部分中,我们将从这个神经网络的架构开始,检查它的不同层、输入和输出。然后了解如何使用自我对弈训练网络和研究用于训练神经网络的损失函数。本文还将仔细研究训练的细节,包括特定于 Chain Reaction 游戏的数据增强技术和称为 Playout Cap Randomization 的方法提高训练效率。最后我们将可视化查看我们的工作成果。

神经网络架构

神经网络模型的输入形状为MxNx7,其中M和N分别为Chain Reaction游戏的行数和列数。图形中的数字“7”表示有7个通道,每个通道以二进制数据的形式存储的某些特定信息,如下面所示:

 Description of the encoded state
 
 Size of the state: M*N*7
 
 channel 1 : stores the MxN map where red orbs are 1 in number
 
 channel 2 : stores the MxN map where red orbs are 2 in number
 
 channel 3 : stores the MxN map where red orbs are 3 in number
 
 channel 4 : stores the MxN map where green orbs are 1 in number
 
 channel 5 : stores the MxN map where green orbs are 2 in number
 
 channel 6 : stores the MxN map where green orbs are 3 in number
 
 channel 7 : MxN map of ones if it is red player's turn otherwise 
 a map of zeroes

下面的图片展示了神经网络的架构。

我们的神经网络结构是一个resnet结构-它有conv2d, batchnorm2d和relu层,dropout层和两个任务头。

输出值的头有一个tanh激活函数,产生一个介于-1和+1之间的数字。策略头有一个softmax函数,它帮助我们得到板子上所有动作的概率分布。

基本块(small block)如上图所示与resnet类似,我们会将这些基本块进行组合。

输入通过的第一个块由conv2d、batchnorm2d和relu层组成

由基本块(conv2d和batchnorm2d层)和relu层组成我们上图所示的残差块(resnet)

五个resnet块组成了我们神经网络的中间块

dropout块接收前一个块的输出,其中的linear层起到控制维度数的作用

我们的网络结构中使用了两个dropout块

值头输出的是-1,1之间的动作价值(value)

策略头输出被用作棋盘上所有动作的概率分布(0,1之间)

下图显示了使用PyTorch在Python中实现该体系结构的代码。

完成了我们模型架构,下面就要看下如何进行训练了

自我对局

上图显示了在游戏中如何进行任何单个操作的流程。在自我对局框架中,我们有两个玩家(都是AI),红色和绿色。每个玩家使用上述步骤进行操作。如果红色赢了游戏。对于所有的红色移动,目标值+1,对于所有的绿色移动,目标值是-1。

我们获得策略目标将是使用蒙特卡洛搜索树获得的策略。

损失函数

因为我们有2个任务头,所以损失函数需要包含自价值损失和策略损失

AlphaZero的损失函数如下:

  • 价值损失:在游戏结束时使用价值分配获得的预测值和目标价值之间的均方损失。
  • 策略损失:在预测的策略和从MCTS演习中获得的策略目标之间计算交叉熵损失。

在AlphaZero中训练神经网络的损失就是这两个损失的总和。我们称之为“AlphaLoss”。

数据增强

为了提高训练效率,我们可以这样操作:如果我们知道一个棋盘状态的正确策略,那么我们就知道通过旋转、翻转或转置棋盘矩阵获得的其他七个棋盘状态的正确策略,这就是我们所说的棋盘的对称性。

通过翻转、旋转和换位可以产生7种以上的棋盘状态。对于所有这些状态,我们可以很容易地获正确的政策。

为了在代码中实现这一点,我们需要一个其中存储了棋盘状态和策略目标的缓存区,在游戏结束获得实际奖励值时,目标值分配给临时缓冲区中的每个元素。

下图是构造这个缓存的代码

Playout Cap Randomization

我们还可以引入了Playout Cap Randomization,因为它有助于提高培训效率。

AlphaZero的自我游戏训练过程,它得到的唯一真正奖励是在游戏结束时,所以获得的奖励是非常少的,而价值头专注于预测这个奖励,如果我们想改善价值训练,就需要增加AlphaZero的游戏的次数。

如果我们想提高策略训练,我们则可以关注更多的蒙特卡洛回放。

这里我们可以只增加一些随机选择的动作而不是增加游戏中所有动作的使用次数,只使用一些特定的动作的数据进行训练。在其他动作中,我们可以减少其选择次数。这种技术被称为Playout Cap Randomization。

结果展示

最后让我们看看我们的训练成果

对阵一个随机的代理

随机的代理没有任何策略,只是在棋盘上随机的进行可用的操作。以下是在3x3, 4x4和5x5棋盘上对随机代理的胜率。

可以看到对于一个3 x 3的棋盘,即使没有MCTS,在80个回合后至少可以达到75%的胜率

对于一个4 x 4的棋盘,训练在500个回合后就会饱和,然后就会变成振荡,但在1300回合附近,没有MCTS的代理的胜率超过80%

对于一块5 x 5的棋盘来说,训练在1000个周期左右就饱和了

可视化

每一场比赛都包括棋盘上的一系列动作。对于一块5x5的棋盘,第一步有25种可能。随着训练的进行,神经网络的值头输出不断提高,从而改进了蒙特卡罗搜索。以下是这些动作的可视化。可视化是针对一个5 x 5的棋盘,所以有25种可能性。这25种可能被映射到一个圆(在开始)或一个弧(后面经过训练)。

在1000次蒙特卡洛演练中使用未经训练的值网络所采取的行动。(5 × 5-> 25动作)。25个动作被映射到圆/圆弧中的角度。搜索最多只能到达4步的深度。

25个动作被映射到圆/圆弧中的角度。由于价值网络启发式的存在,搜索甚至深入到20步。

未来的发展方向

Chain Reaction的游戏有一个人类精心设计的启发式策略[2]。训练一个简单AlphaZero代理并试着让它与这样的策略竞争是很有趣的。

有一种称为hidden queen chess”/ “secret queen chess”的国际象棋变体,其中每个玩家在游戏开始时选择他们的一个棋子作为皇后,并且该选择不会向对手透露。但是 AlphaZero 适用于完美信息博弈和实施训练代理在信息不完善的状态下策论的论文会很有趣 [3]。

如果能够超越离散动作空间[4]将是有趣的。连续动作空间将在机器人或自动驾驶汽车应用中更为常见。[4]接受beta分布并学习它的参数。然后使用这个分布的一个缩放版本来近似有界连续空间。

我们有一个在3x3的Chain Reaction棋盘上训练一个效果非常好的代理。如果我们能将这些知识转移到4x4大小或其他大小的棋盘上,那就太好了。这项工作的重点也是一个方向[5]。如果没有这样的传输机制,在更大的棋盘上上进行训练在计算上是非常昂贵的,例如:15x15或20x20棋盘。

本文代码:https://github.com/BentouAI/AlphaZero-Chain-Reaction

引用参考

  1. Wu, D.J. (2020) Accelerating self-play learning in go, arXiv.org. Available at: https://arxiv.org/abs/1902.10565v5.
  2. Chain Reaction (Game). Brilliant.org. Retrieved 17:59, November 27, 2022, from https://brilliant.org/wiki/chain-reaction-game/
  3. https://www.deepmind.com/blog/alphastar-grandmaster-level-in-starcraft-ii-using-multi-agent-reinforcement-learning
  4. Moerland, T.M., Broekens, J., Plaat, A. and Jonker, C.M., 2018. A0c: Alpha zero in continuous action space. arXiv preprint arXiv:1805.09613.
  5. Ben-Assayag, S. and El-Yaniv, R., 2021. Train on small, play the large: Scaling up board games with alphazero and gnn. arXiv preprint arXiv:2107.08387.

https://avoid.overfit.cn/post/d2e6352cf0104473ba896d198f8277bc

作者:Bentou

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/45407.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

程序员真的有必要把GC算法好好过一遍,因为它是进大厂必备的

GC算法概述 最早的GC算法可以追溯到20世纪60年代,但到目前为止,GC的基本算法没有太多的创新,可以分为复制算法(Copying GC)、标记清除(MarkSweep GC)和标记压缩(Mark-Compact GC&am…

pte学习_SQL注入1

一、phpstudy使用及mysql数据库基础 ①进入mysql安装路径的/bin中打开cmd mysql -u root -p //登录MYSQL数据库 show databases; // 查看数据库 drop database mysql; //删除mysql数据库 create database pte; //创建pte数据库 use pte; //进入数据库 show tables; //查…

如果把网络原理倒过来看,从无到有,一切都清晰了(上)

长歌吟松风,曲尽河星稀。 前言 我发现绝大数人和我一样对网络原理充满困惑,不是因为不好理解,而是它往往都是直接告诉你它是什么,但它并不告诉你为什么要这样。 而我让离网络最近的一次距离是在一个偶然停电的深夜,周…

实现响应式布局有几种方法

目录 🔽 什么是响应式布局 响应式与自适应区别 🔽 响应式布局方法总结 响应式布局方法一:CSS3媒体查询 响应式布局方法二:百分比% 响应式布局方法三:vw/vh 响应式布局方法四:rem 响应式布局方法五&…

IPv6进阶:OSPFv3 路由汇总实验配置

实验拓扑 实验需求 R1、R2完成接口IPv6地址的配置;R1、R2按图示运行OSPF。R2的三个Loopback接口并不直接激活OSPFv3,而是以重发布的形式注入;在R1、R2上分别执行OSPF路由汇总,使得双方的路由表中关于对方的Loopback只学习到一条汇…

CANoe-vTESTstudio之State Diagram编辑器(入门介绍)

1. 什么是State Diagram编辑器 Test Diagram编辑器是使用具有各种功能的图形元素对测试用例的测试步骤的测试顺序进行建模。而State Diagram Editor,状态图表编辑器,是针对被测系统基于状态的系统行为,在状态图表编辑器中以图形方式建模,从而可以自动生成要测试的SUT(sys…

代码随想录算法训练营第四十八天| LeetCode198. 打家劫舍、LeetCode213. 打家劫舍 II、LeetCode337. 打家劫舍 III

一、LeetCode198. 打家劫舍 1:题目描述(198. 打家劫舍) 你是一个专业的小偷,计划偷窃沿街的房屋。每间房内都藏有一定的现金,影响你偷窃的唯一制约因素就是相邻的房屋装有相互连通的防盗系统,如果两间相邻的…

移动设备软件开发-广播机制

广播机制 1.广播机制概述 1.1生活中的广播机制 1.显示生活中的广播就比如说村里的大喇叭,车上的收音机接收的广播FM广播,学校里的校园广播都是常见的广播,安卓中的广播和生活中的广播是十分类似的。 1.2广播特点 发送者 多种广播方式实…

群晖外网访问终极解决方法:IPV6+阿里云ddns+ddnsto

写在前面的话 受够了群晖的quickconnet的小水管了,急需一个新的解决方法,这是后发现移动没有公网IP,只有ipv6(公网的),时候有小伙伴要问,要是没有ipv6就没办法访问群晖了吗? 不&…

吉时利KEITHELY2612B源表技术参数

作为2600B系列源表SMU系列产品的一部分,2612B源表SMU是全新改良版双通道SMU,具有紧密集成的4象限设计,能同步源和测量电压/电流以提高研发到自动生产测试等应用的生产率。除保留了2612A的全部产品特点外,2612B还具有6位半分辨率、…

Spring基础篇:高级注解编程

文章内容来自于B站孙哥说Spring第一章:Configuration一:配置Bean替换XML细节二:应用配置Bean工厂对象三:配置Bean细节分析1:整合Logback三:Component第二章:Bean一:Bean的使用1&…

Prometheus+Grafana部署

一 、Prometheus 源码安装和启动配置 普罗米修斯下载网址:https://prometheus.io/download/ 监控集成器下载地址:http://www.coderdocument.com/docs/prometheus/v2.14/instrumenting/exporters_and_integrations.html 1.实验环境 IP角色系统172.16.1…

理解浅拷贝和深拷贝以及实现方法

一、数据类型 数据分为基本数据类型(String, Number, Boolean, Null, Undefined,Symbol)和引用数据类型Object,包含(function,Array,Date)。 1、基本数据类型的特点:直接存储在栈内存中的数据 …

品牌投资与形象全面升级 | 快来认识全新的 Go 旅城通票

近日,Go 旅城通票(Go City)品牌全面升级,旨在提高旅游爱好者对品牌的认知。从新冠疫情大流行中阴霾中走出来的 Go 旅城通票复苏势头强劲,专注于技术提升,使命是协助旅游爱好者无论到世界各地的哪一个城市畅…

在线分析网站日志软件-免费分析网站蜘蛛的软件

搜索引擎蜘蛛的作用是什么?我们网站上的内容如果要想被搜索引擎收录并且给予排名,就必须要经过搜索引擎蜘蛛的爬取并且建立索引。所以让搜索引擎蜘蛛更好的了解我们的网站是很重要的一步!搜索引擎蜘蛛在爬取某个网站,是通过网站的…

浅谈虚拟地址转换成物理地址(值得收藏)

这里,我们讲解一下Linux是如何将虚拟地址转换成物理地址的 一、地址转换 在进程中,我们不直接对物理地址进行操作,CPU在运行时,指定的地址要经过MMU转换后才能访问到真正的物理内存。 地址转换的过程分为两部分,分段…

Linux systemctl 详解自定义 systemd unit

Linux systemctl 详解&自定义 systemd unit systemctl 序 大家都知道,我们安装了很多服务之后,使用 systemctl 来管理这些服务,比如开启、重启、关闭等等,所以 systemctl 是一个 systemd 系统。centos 使用 systemctl 来代…

9.8 段错误,虚拟内存,内存映射 CSAPP

相信写代码的或多或少都会遇到段错误,segmentation fault. 今天终于看到这里面的底层原理 参考: https://greenhathg.github.io/2022/05/18/CMU213-CSAPP-Virtual-Memory-Systems/18-Virtual-Memory-SystemsSimple memory system exampleAddress Trans…

(转)CSS结合伪类实现icon

老规矩,还是先说说业务场景:有一个图片列表,可以添加、删除和更改,其中呢删除时设计给的设计稿时悬浮(hover)在图片上时显示删除的图标,所以就有了这个用before实现icon的场景 进入正文&#xf…

嵌入式系统开发笔记108:IO的使用方法与面向对象程序设计

文章目录前言一、IO引脚的基本概念二、映射层的设置1、映射层是原理图的直译层2、IO引脚的设置在hal.h 和 hal.cpp文件中完成(1)在hal.h中进行类定义(2)在hal.cpp中完成引脚映射三、面向对象程序设计思想1、程序设计分类2、举例3、…