强化学习:实现了基于蒙特卡洛树和策略价值网络的深度强化学习五子棋(含码源)

news2024/11/18 22:50:01

在这里插入图片描述
【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现

在这里插入图片描述
专栏详细介绍:【强化学习原理+项目专栏】必看系列:单智能体、多智能体算法原理+项目实战、相关技巧(调参、画图等、趣味项目实现、学术应用项目实现

对于深度强化学习这块规划为:

  • 基础单智能算法教学(gym环境为主)
  • 主流多智能算法教学(gym环境为主)
    • 主流算法:DDPG、DQN、TD3、SAC、PPO、RainbowDQN、QLearning、A2C等算法项目实战
  • 一些趣味项目(超级玛丽、下五子棋、斗地主、各种游戏上应用)
  • 单智能多智能题实战(论文复现偏业务如:无人机优化调度、电力资源调度等项目应用)

本专栏主要方便入门同学快速掌握强化学习单智能体|多智能体算法原理+项目实战。后续会持续把深度学习涉及知识原理分析给大家,让大家在项目实操的同时也能知识储备,知其然、知其所以然、知何由以知其所以然。

声明:部分项目为网络经典项目方便大家快速学习,后续会不断增添实战环节(比赛、论文、现实应用等)

  • 专栏订阅(个性化选择):

    • 强化学习原理+项目专栏大合集-《推荐订阅☆☆☆☆☆》

    • 强化学习单智能体算法原理+项目实战《推荐订阅☆☆☆☆》

    • 强化学习多智能体原理+项目实战《推荐订阅☆☆☆☆☆》

    • 强化学习相关技巧(调参、画图等《推荐订阅☆☆☆》)

    • tensorflow_gym-强化学习:免费《推荐订阅☆☆☆☆》

    • 强化学习从基础到进阶-案例与实践:免费《推荐订阅☆☆☆☆☆》

实现了基于蒙特卡洛树和策略价值网络的深度强化学习五子棋(含码源)

  • 特点

    • 自我对弈
    • 详细注释
    • 流程简单
  • 代码结构

    • net:策略价值网络实现
    • mcts:蒙特卡洛树实现
    • server:前端界面代码
    • legacy:废弃代码
    • docs:其他文件
    • utils:工具代码
    • network.py:移植过来的网络结构代码
    • model_5400.pkl:移植过来的网络训练权重
    • train_agent.py:训练脚本
    • web_server.py:对弈服务脚本
    • web_server_demo.py:对弈服务脚本(移植网络)

1.1 流程

1.2策略价值网络

采用了类似ResNet的结构,加入了SPP模块。

(目前,由于训练太耗时间了,连续跑了三个多星期,才跑了2000多个自我对弈的棋谱,经过实验,这个策略网络的表现,目前还是不行,可能育有还没有训练充分)

同时移植了另一个开源的策略网络以及其训练权重(network.py、model_5400.pkl),用于进行仿真演示效果。

1.3 训练

根据注释调整train_agent.py文件,并运行该脚本

部分代码展示:


if __name__ == '__main__':

    conf = LinXiaoNetConfig()
    conf.set_cuda(True)
    conf.set_input_shape(8, 8)
    conf.set_train_info(5, 16, 1e-2)
    conf.set_checkpoint_config(5, 'checkpoints/v2train')
    conf.set_num_worker(0)
    conf.set_log('log/v2train.log')
    # conf.set_pretrained_path('checkpoints/v2m4000/epoch_15')

    init_logger(conf.log_file)
    logger()(conf)

    device = 'cuda' if conf.use_cuda else 'cpu'

    # 创建策略网络
    model = LinXiaoNet(3)
    model.to(device)

    loss_func = AlphaLoss()
    loss_func.to(device)

    optimizer = torch.optim.SGD(model.parameters(), conf.init_lr, 0.9, weight_decay=5e-4)
    lr_schedule = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.95)

    # initial config tree
    tree = MonteTree(model, device, chess_size=conf.input_shape[0], simulate_count=500)
    data_cache = TrainDataCache(num_worker=conf.num_worker)

    ep_num = 0
    chess_num = 0
    # config train interval
    train_every_chess = 18

    # 加载检查点
    if conf.pretrain_path is not None:
        model_data, optimizer_data, lr_schedule_data, data_cache, ep_num, chess_num = load_checkpoint(conf.pretrain_path)
        model.load_state_dict(model_data)
        optimizer.load_state_dict(optimizer_data)
        lr_schedule.load_state_dict(lr_schedule_data)
        logger()('successfully load pretrained : {}'.format(conf.pretrain_path))

    while True:
        logger()(f'self chess game no.{chess_num+1} start.')
        # 进行一次自我对弈,获取对弈记录
        chess_record = tree.self_game()
        logger()(f'self chess game no.{chess_num+1} end.')
        # 根据对弈记录生成训练数据
        train_data = generate_train_data(tree.chess_size, chess_record)
        # 将训练数据存入缓存
        for i in range(len(train_data)):
            data_cache.push(train_data[i])
        if chess_num % train_every_chess == 0:
            logger()(f'train start.')
            loader = data_cache.get_loader(conf.batch_size)
            model.train()
            for _ in range(conf.epoch_num):
                loss_record = []
                for bat_state, bat_dist, bat_winner in loader:
                    bat_state, bat_dist, bat_winner = bat_state.to(device), bat_dist.to(device), bat_winner.to(device)
                    optimizer.zero_grad()
                    prob, value = model(bat_state)
                    loss = loss_func(prob, value, bat_dist, bat_winner)
                    loss.backward()
                    optimizer.step()
                    loss_record.append(loss.item())
                logger()(f'train epoch {ep_num} loss: {sum(loss_record) / float(len(loss_record))}')
                ep_num += 1
                if ep_num % conf.checkpoint_save_every_num == 0:
                    save_checkpoint(
                        os.path.join(conf.checkpoint_save_dir, f'epoch_{ep_num}'),
                        ep_num, chess_num, model.state_dict(), optimizer.state_dict(), lr_schedule.state_dict(), data_cache
                    )
            lr_schedule.step()
            logger()(f'train end.')
        chess_num += 1
        save_chess_record(
            os.path.join(conf.checkpoint_save_dir, f'chess_record_{chess_num}.pkl'),
            chess_record
        )
        # break

    pass

1.4 仿真实验

根据注释调整web_server.py文件,加载所用的预训练权重,并运行该脚本

浏览器打开网址:http://127.0.0.1:8080/ 进行对弈

部分代码展示

# 用户查询机器落子状态
@app.route('/state/get/<state_id>', methods=['GET'])
def get_state(state_id):
    global state_result
    state_id = int(state_id)
    state = 0
    chess_state = None
    if state_id in state_result.keys() and state_result[state_id] is not None:
        state = 1
        chess_state = state_result[state_id]
        state_result[state_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'state': state,
            'chess_state': chess_state
        }
    }
    return jsonify(ret)


# 游戏开始,为这场游戏创建蒙特卡洛树
@app.route('/game/start', methods=['POST'])
def game_start():
    global trees
    global model, device, chess_size, simulate_count
    tree_id = random.randint(1000, 100000)
    trees[tree_id] = MonteTree(model, device, chess_size=chess_size, simulate_count=simulate_count)
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {
            'tree_id': tree_id
        }
    }
    return jsonify(ret)


# 游戏结束,销毁蒙特卡洛树
@app.route('/game/end/<tree_id>', methods=['POST'])
def game_end(tree_id):
    global trees
    tree_id = int(tree_id)
    trees[tree_id] = None
    ret = {
        'code': 0,
        'msg': 'OK',
        'data': {}
    }
    return ret


if __name__ == '__main__':
    app.run(
        '0.0.0.0',
        8080
    )

1.5 仿真实验(移植网络)

运行脚本:python web_server_demo.py

浏览器打开网址:http://127.0.0.1:8080/ 进行对弈

  • 参考文档
    • https://www.doc88.com/p-57516985322644.html
    • https://zhuanlan.zhihu.com/p/139539307
    • https://zhuanlan.zhihu.com/p/53948964

码源链接见文章顶部或者文末

https://download.csdn.net/download/sinat_39620217/88045879

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/756378.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

好用的门店信息管理系统推荐?门店信息系统系统应该注重什么?

传统门店的信息管理模式总是会存在人工成本高&#xff0c;效率低&#xff0c;流程麻烦、数据复盘繁琐等问题。围绕门店信息管理过程中面临的各类痛点&#xff0c;蚓链数字化门店信息管理系统可以帮助门店更好的管理门店经营&#xff0c;货品盘点&#xff0c;库存管理&#xff0…

「XKOI」Round 3 赛后题解

比赛链接&#xff1a;「XKOI」Round 3 本题解同步发表于 洛谷&#xff1a;传送门 CSDN&#xff1a;传送门 文章目录 比赛链接&#xff1a;[「XKOI」Round 3](https://www.luogu.com.cn/contest/117863)A [T343985 CRH的工作](https://www.luogu.com.cn/problem/T343985)1.1 …

CRC算法并行运算Verilog实现

因为CRC循环冗余校验码的算法和硬件电路结构比较简单&#xff0c;所以CRC是一种在工程中常用的数据校验方法。尽管CRC简单&#xff0c;但在工程应用中还是有些问题会对工程师产生困惑。这篇文章将介绍一下CRC&#xff0c;希望对大家有所帮助。 一、CRC算法介绍 CRC校验原理看起…

PPO(Proximal Policy Optimization Algorithms)论文解读及实现

论文标题&#xff1a;Proximal Policy Optimization Algorithms 核心思路&#xff1a;使用off policy 代替on policy,用一个策略网络来产生数据&#xff0c;用一个策略网络来更新参数&#xff0c;分别为policy_old和policy 0 摘要 Whereas standard policy gradient methods …

Python自动化办公:pptx篇

文章目录 简介能做什么PPT要素介绍官方demo高阶引申参考文献 202201笔记迁移 简介 python-pptx包是用来自动化处理ppt的。 使用的第一步是安装 pip install python-pptx相比python-docx&#xff0c;python-pptx的使用更为麻烦一些&#xff0c;原因有很多&#xff0c;比如说&…

波奇学Linux:make和Makefile

make和Makefile自动化构建并能决定源文件调用顺序&#xff0c;同时不必再写gcc命令 第一行依赖关系&#xff0c;第二行是tab键开头&#xff0c;是依赖方法 依赖关系&#xff1a;目标文件&#xff1a;依赖文件。 依赖方法&#xff1a;目标文件和依赖文件间的关系。 如果只有一条…

es下载历史的tar文件

第一步进入官网找到历史版本 第二步复制历史版本名称组合成下面的链接 直接get访问下载。如下链接所示只需要修改7.3.0这个版本号 https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.3.0-linux-x86_64.tar.gz

ChatGLM使用记录

ChatGLM ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型&#xff0c;基于 General Language Model (GLM) 架构&#xff0c;具有 62 亿参数。结合模型量化技术&#xff0c;用户可以在消费级的显卡上进行本地部署&#xff08;INT4 量化级别下最低只需 6GB 显存&#xff0…

opencv实战--角度测量和二维码条形码识别

文章目录 前言一、鼠标点击的角度测量二、二维码条形码识别 前言 一、鼠标点击的角度测量 首先导入一个带有角度的照片 然后下面的代码注册了一个鼠标按下的回调函数&#xff0c; 还有一个点的数列&#xff0c;鼠标事件为按下的时候就记录点&#xff0c;并画出点&#xff0c;…

uniapp微信小程序上传体积压缩包过大分包操作和上传时遇到代码质量未通过问题

1&#xff1a;首先我们得从项目最初阶段就得考虑项目是否要进行分包操作&#xff0c;如果得分包&#xff0c;我们应该创建一个与pages同级的文件夹&#xff0c;命名可以随意 2:第二部我们将需要分包的文件和页面放到分包文件夹里面subpage&#xff0c;这里我们得注意&#xff…

Python基础语法第三章之顺序循环条件

目录 一、顺序语句 二、条件语句 2.1什么是条件语句 2.2语法格式 2.2.1 if 2.2.2if - else 2.2.3if - elif - else 2.3缩进和代码块 2.4闰年的判断练习 2.5空语句 pass 三、循环语句 3.1while 循环 3.1.1代码示例练习 3.2 for 循环 ​3.3 continue 3.4 break 一…

给LLM装上知识:从LLM+LangChain的本地知识库问答到LLM与知识图谱的结合

前言 过去半年&#xff0c;随着ChatGPT的火爆&#xff0c;直接带火了整个LLM这个方向&#xff0c;然LLM毕竟更多是基于过去的经验数据预训练而来&#xff0c;没法获取最新的知识&#xff0c;以及各企业私有的知识 为了获取最新的知识&#xff0c;ChatGPT plus版集成了bing搜…

linux常用工具介绍

文章目录 前言目录文件查看ls1、查看详细信息&#xff08;文件大小用K、M等显示&#xff09;2、按照文件创建时间排序&#xff08;常在查看日志时使用&#xff09; sort1、排序数字 df 、du1、查看目录的大小2、查看目录 从大到小排序 显示前n个3、查看磁盘使用情况 tailf一些目…

银河麒麟高级服务器操作系统V10安装mysql数据库

一、安装前 1.检查是否已经安装mysql rpm -qa | grep mysql2.将查询出的包卸载掉 rpm -e --nodeps 文件名3.将/usr/lib64/libLLVM-7.so删除 rm -rf /usr/lib64/libLLVM-7.so4.检查删除结果 rpm -qa | grep mysql5.搜索残余文件 whereis mysql6.删除残余文件 rm -rf /usr/b…

怎么用二维码做企业介绍?企业宣传二维码2种制作方法

怎么做一个企业推广二维码呢&#xff1f;现在制作二维码来做宣传推广是常用的一种方式&#xff0c;一般需要包含企业介绍、工作环境、产品简介、宣传视频、公司地址等等方面内容&#xff0c;那么企业介绍二维码该如何制作&#xff1f;下面给大家分享一下使用二维码编辑器&#…

EventBus详解

目录 1 EventBus 简介简介角色关系图四种线程模型 2.EventBus使用步骤添加依赖注册解注册创建消息类发送消息接收消息粘性事件发送消息 使用postStick()接受消息 3 EventBus做通信优点4 源码getDefault()register()findSubscriberMethods方法findUsingReflection方法findUsingR…

前端部署项目,经常会出现下载完 node 或者 npm 运行时候发现,提示找不到

1. 首先要在下载时候选择要下载的路径&#xff0c;不能下载完后&#xff0c;再拖拽到其他文件夹&#xff0c;那样就会因为下载路径和当前路径不一致&#xff0c;导致找不到相关变量。 2. 所以一开始就要在下载时候确定要存放的路径&#xff0c;然后如果运行报错&#xff0c;就…

【Java基础教程】(十三)面向对象篇 · 第七讲:继承性详解——继承概念及其限制,方法覆写和属性覆盖,关键字super的魔力~

Java基础教程之面向对象 第七讲 本节学习目标1️⃣ 继承性1.1 继承的限制 2️⃣ 覆写2.1 方法的覆写2.2 属性的覆盖2.3 关键字 this与 super的区别 3️⃣ 继承案例3.1 开发数组的父类3.2 开发排序类3.3 开发反转类 &#x1f33e; 总结 本节学习目标 掌握继承性的主要作用、实…

git指令记录

参考博客&#xff08;侵权删&#xff09;&#xff1a;关于Git这一篇就够了_17岁boy想当攻城狮的博客-CSDN博客 Git工作区介绍_git 工作区_xyzso1z的博客-CSDN博客 git commit 命令详解_gitcommit_辰风沐阳的博客-CSDN博客 本博客只作为自己的学习记录&#xff0c;无商业用途&…

计算机存储设备

缓存为啥比内存快 内存使用 DRAM 来存储数据的、也就是动态随机存储器。内部使用 MOS 和一个电容来存储。 需要不停地给它刷新、保持它的状态、要是不刷新、数据就丢掉了、所以叫动态 、DRAM 缓存使用 SRAM 来存储数据、使用多个晶体管(比如6个)就是为了存储1比特 内存编码…