Imitation Learning学习记录(理论例程)

news2024/11/15 12:07:03

前言

最近还是衔接着之前的学习记录,这次打算开始学习模仿学习的相关原理,参考的开源资料为

TeaPearce/Counter-Strike_Behavioural_Cloning: IEEE CoG & NeurIPS workshop paper ‘Counter-Strike Deathmatch with Large-Scale Behavioural Cloning’ (github.com)
[2104.04258] Counter-Strike Deathmatch with Large-Scale Behavioural Cloning (arxiv.org)

简单来说,行为克隆就是利用已有的人为示范数据作为输入来训练出一个策略,策略就会输出指定的动作,然而,行为克隆只能学习到专家的行为,而无法进行探索和自主学习。这意味着行为克隆的性能受限于专家的行为水平,并且可能无法适应新的、未在专家演示中出现过的情况。通过引入奖励函数,可以在行为克隆中加入一定的探索和自主学习能力。奖励函数可以根据当前状态和采取的动作来评估行为的好坏,并为模型提供反馈信号。通过优化奖励函数,可以使模型学习到更好的策略,并且能够适应新的情况和环境。奖励函数在行为克隆中起到了指导和调整模型学习的作用。其中,文章用到的例程的网络结构如下

请添加图片描述

本文打算从数据获取、模型训练、效果展示三个部分展开介绍

数据获取

在这个过程中作者使用了Game State Integration(GSI)技术来获取在线数据。通过GSI,作者可以从游戏中获取实时的游戏状态信息,包括玩家、队伍、武器、位置等各种数据。具体来说,作者可能使用了Valve提供的GSI接口来获取游戏状态信息。这些信息可以用于后续的行为克隆和分析工作。

这个过程中的核心代码如下

# now find the requried process and where two modules (dll files) are in RAM  
hwin_csgo = win32gui.FindWindow(0, ('counter-Strike: Global Offensive'))  
if(hwin_csgo):  
    pid=win32process.GetWindowThreadProcessId(hwin_csgo)  
    handle = pymem.Pymem()  
    handle.open_process_from_id(pid[1])  
    csgo_entry = handle.process_base  
else:  
    print('CSGO wasnt found')  
    os.system('pause')  
    sys.exit()  
  
# now find two dll files needed  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  
    tmp=next(list_of_modules)  
    # used to be client_panorama.dll, moved to client.dll during 2020  
    if(tmp.name=="client.dll"):  
        print('found client.dll')  
        off_clientdll=tmp.lpBaseOfDll  
        break  
list_of_modules=handle.list_modules()  
while(list_of_modules!=None):  
    tmp=next(list_of_modules)  
    if(tmp.name=="engine.dll"):  
        print('found engine.dll')  
        off_enginedll=tmp.lpBaseOfDll  
        break

大致逻辑为:

  1. 查找CSGO进程:

    • 使用win32gui.FindWindow查找名为’counter-Strike: Global Offensive’的窗口句柄。
    • 如果找到窗口句柄,则通过win32process.GetWindowThreadProcessId获取与该窗口关联的进程ID。
    • 使用pymem.Pymem()创建一个进程内存访问对象,并通过open_process_from_id方法打开该进程。
    • 如果CSGO进程未找到,则打印消息并退出程序。
  2. 查找client.dll和engine.dll:

    • 使用handle.list_modules()获取进程中的所有模块列表。
    • 遍历模块列表,查找名为"client.dll"的模块,并获取动态链接库的基地址(lpBaseOfDll)。
    • 注意:这里使用了两次handle.list_modules()来分别查找两个DLL文件,但实际上你可以只调用一次并将结果存储在列表中,然后遍历这个列表来查找两个DLL。
    • 类似地,代码还查找名为"engine.dll"的模块,并获取其基地址。

找到窗口和动态链接库以后就可以开始录像并通过GSI或者RAM来访问键位等游戏信息,得到的数据类型大致为

  • frame_i_x: 图像信息
  • frame_i_xaux: 包含在前一个时间步骤中应用的动作,以及血量、弹药和团队。用于更好地帮助智能体寻找敌人以及适应当前情况
  • frame_i_y: 对应键盘以及鼠标的动作
  • frame_i_helperarr: 在格式kill_flag, death_flag中,每个变量都是二元变量,例如[[1,0]],意味着玩家击杀一次,但在该时间步内没有死亡

其中,具体的键位信息如下:

# how many slots were used for each action type?  
n_keys = 11 # number of keyboard outputs, w,s,a,d,space,ctrl,shift,1,2,3,r  
n_clicks = 2 # number of mouse buttons, left, right  
n_mouse_x = len(mouse_x_possibles) # number of outputs on mouse x axis  
n_mouse_y = len(mouse_y_possibles) # number of outputs on mouse y axis  
n_extras = 3 # number of extra aux inputs, eg health, ammo, team. others could be weapon, kills, deaths  
aux_input_length = n_keys+n_clicks+1+1+n_extras # aux uses continuous input for mouse this is multiplied by ACTIONS_PREV elsewhere

一个帧所包含的具体信息值如下:

请添加图片描述
请添加图片描述

模型训练

网络结构

输入先进入一个预训练好的EfficientNetB0模型,该模型在ImageNet数据集上进行了训练。并加上了时间序列信息,接下来将提取好的特征输入进一个带有时序信息的ConvLSTM网络

base_model = EfficientNetB0(weights='imagenet',input_shape=(input_shape[1:]),include_top=False,drop_connect_rate=0.2)
if 'drop' in model_name:  
    if 'big' in model_name:  
        x = ConvLSTM2D(filters=512,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)  
    else:  
        x = ConvLSTM2D(filters=256,kernel_size=(3,3),stateful=False,return_sequences=True,dropout=0.5, recurrent_dropout=0.5)(x)

输出的信息为

# set up outputs, sepearate outputs will allow seperate losses to be applied  
output_1 = TimeDistributed(Dense(n_keys, activation='sigmoid'))(dense_5)  
output_2 = TimeDistributed(Dense(n_clicks, activation='sigmoid'))(dense_5)  
output_3 = TimeDistributed(Dense(n_mouse_x, activation='softmax'))(dense_5) # softmax since mouse is mutually exclusive  
output_4 = TimeDistributed(Dense(n_mouse_y, activation='softmax'))(dense_5)   
output_5 = TimeDistributed(Dense(1, activation='linear'))(dense_5)   
# output_all = concatenate([output_1,output_2,output_3,output_4], axis=-1)  
output_all = concatenate([output_1,output_2,output_3,output_4,output_5], axis=-1)

损失函数

  1. 键盘按键损失(loss1a, loss1b, loss1c, loss1d
    • loss1a:计算 WASD 键(通常用于游戏中的移动)的二进制交叉熵损失。
    • loss1b:计算空格键的二进制交叉熵损失。
    • loss1c:计算重新加载键(如游戏中的“R”键)的二进制交叉熵损失。
    • loss1d(注释掉的部分):原本可能用于计算其他键盘按键的损失,但在提供的代码中,它被重新定义为武器切换键(1, 2, 3)的损失。
  2. 鼠标点击损失(loss2a, loss2b
    • loss2a:计算鼠标左键点击的二进制交叉熵损失。
    • loss2b:计算鼠标右键点击的二进制交叉熵损失(如果n_clicks大于1的话)。
  3. 鼠标移动损失(loss3, loss4
    • loss3:计算鼠标在 X 轴上的移动损失。由于鼠标移动是互斥的(即鼠标不能同时处于多个位置),因此使用了分类交叉熵损失(categorical_crossentropy)。
    • loss4:计算鼠标在 Y 轴上的移动损失,同样使用了分类交叉熵损失。

除此之外,还有一个loss_crit损失函数,

loss_crit = 10*losses.MSE(y_true[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  
                   + GAMMA*y_pred[:,1:,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1]  
                   ,y_pred[:,:-1,n_keys+n_clicks+n_mouse_x+n_mouse_y:n_keys+n_clicks+n_mouse_x+n_mouse_y+1])

这是一个基于时序差分(Temporal Difference, TD)的均方误差(Mean Squared Error, MSE)损失函数,用于强化学习中的值函数逼近。它计算了当前时间步的奖励(或值)与下一个时间步的预测奖励(或值)之和(经过折扣因子 GAMMA 调整后)与当前时间步的预测奖励(或值)之间的均方误差。这种损失函数允许神经网络学习如何根据当前状态和环境信息来预测未来的奖励或值,从而优化策略或值函数。在这个特定的实现中,损失还乘以了一个系数(如10),可能是为了调整该损失在总损失中的相对权重。

奖励函数如下,奖励为 R(杀敌数,死亡数,子弹数)

reward_i = kill_i - 0.5*dead_i - 0.01*shoot_i # this is reward function  
y[i,j,-2:] = (reward_i,0.) # 0. is a placeholder for original advantage

效果展示

通过e2e.yml文件配置虚拟环境,更改了游戏内的窗口分辨率,设置了一些其他的参数,运行dm_run_agent.py以后在自己的电脑上成功复现

请添加图片描述

总结

本次基于Counter-Strike Deathmatch with Large-Scale Behavioural Cloning这个开源项目系统地学习了一下行为克隆的基本流程,从数据采集、模型训练以及损失函数的定义到最终复现,拓宽了我对RL的认知,在日后也能够更好地迁移到Robotic,逻辑如下:

  1. 数据收集:首先,需要收集人类专家在特定任务中的行为数据。这些数据通常包括机器人所处的状态(如位置、姿态、环境信息等)以及对应的人类专家在该状态下所采取的动作(如移动方向、操作指令等)。这些数据构成了行为克隆算法的训练集。
  2. 模型训练:使用收集到的数据训练一个模型,如神经网络模型。这个模型将学习从状态到动作的映射关系,即根据机器人当前的状态预测应该执行的动作。在训练过程中,模型会不断优化其参数,以最小化预测动作与真实动作之间的差异。
  3. 模型部署:训练好的模型可以部署到机器人上,用于指导机器人的行为。当机器人遇到新的状态时,它会将当前状态输入到模型中,模型会输出一个预测的动作。机器人将根据这个预测的动作来执行相应的操作。
  4. 反馈与调整:在机器人执行动作的过程中,可以通过收集反馈信息来进一步调整模型。例如,可以观察机器人执行动作后的效果,如果效果不理想,则可以收集新的数据并重新训练模型,以提高其性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1666352.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【源头开发】运营级竞拍商城源码/抢拍转拍/竞拍源码/转卖寄售/拆分/溢价商城转拍溢价php源码uniapp源码

大家好啊,欢迎来到web测评,我是年哥,我们有个小伙伴又开发了一款竞拍商城的源码,是此系统的源头开发者,本系统是前后端分离的架构,前端php,后端uniapp,系统现在是持续的在更新中&…

libcity笔记: HSTLSTMEncoder

1 __init__ 2 encode 得到的内容如下: data_feature的内容: 一共有多少个location1【包括pad的一个】最长的时间间隔(秒)最长的距离间隔(千米)多少个useer idpadding 的locationidpad_item的内容 location…

[单机]成吉思汗3_GM工具_VM虚拟机

稀有端游成吉思汗1,2,3单机版虚拟机一键端完整版 本教程仅限学习使用,禁止商用,一切后果与本人无关,此声明具有法律效应!!!! 教程是本人亲自搭建成功的,绝对是完整可运行的&#x…

React 第三十一章 前端框架的分类

现代前端框架,有一个非常重要的特点,那就是基于状态的声明式渲染。如果要概括的话,可以使用一个公式: UI f(state) state:当前视图的一个状态f:框架内部的一个运行机制UI&#xff1…

计算机视觉——基于改进UNet图像增强算法实现

1. 引言 在低光照条件下进行成像非常具有挑战性,因为光子计数低且存在噪声。高ISO可以用来增加亮度,但它也会放大噪声。后处理,如缩放或直方图拉伸可以应用,但这并不能解决由于光子计数低导致的低信噪比(SNR&#xff…

从头理解transformer,注意力机制(下)

交叉注意力 交叉注意力里面q和KV生成的数据不一样 自注意力机制就是闷头自学 解码器里面的每一层都会拿着编码器结果进行参考,然后比较相互之间的差异。每做一次注意力计算都需要校准一次 编码器和解码器是可以并行进行训练的 训练过程 好久不见输入到编码器&…

【图论 回溯 广度优先搜索】126. 单词接龙 II

本文涉及知识点 图论 回溯 深度优先搜索 广度优先搜索 图论知识汇总 LeetCode 126. 单词接龙 II 按字典 wordList 完成从单词 beginWord 到单词 endWord 转化,一个表示此过程的 转换序列 是形式上像 beginWord -> s1 -> s2 -> … -> sk 这样的单词序…

机器学习入门到放弃2:朴素贝叶斯

1. 算法介绍 1.1 算法定义 朴素贝叶斯分类(NBC)是以贝叶斯定理为基础并且假设特征条件之间相互独立的方法,先通过已给定的训练集,以特征词之间独立作为前提假设,学习从输入到输出的联合概率分布,再基于学习…

oracle 数据库与服务、实例与SID、表空间、用户与表模式

一、数据库与数据库服务: 概念:就是一个数据库的标识,在安装时就要想好,以后一般不修改,修改起来也麻烦,因为数据库一旦安装,数据库名就写进了控制文件,数据库表,很多地方都会用到这个数据库名。是数据库系统的入口,它会内置一些高级权限的用户如SYS,SYSTEM等。我们…

Xilinx 千兆以太网TEMAC IP核 MDIO 配置及物理接口

基于AXI4-Lite接口可以访问MDIO(Management Data Input/Output)接口,而MDIO接口连接MAC外部的PHY芯片,用户可通过AXI4-Lite接口实现对PHY芯片的配置。 1 MDIO接口简介 开放系统互连模型OSI的最低两层分别是数据链路层和物理层,数据链路层的…

探讨欧盟就人工智能监管达成协议

《人工智能法案》是一项具有里程碑意义的立法,它可以创造一个有利的环境,在这种环境中,人工智能的使用将成为一种更优秀的安全和信任的工具,确保整个欧盟的公共和私人机构利益相关者的参与。 历时3天的“马拉松式”谈判圆满结束&…

数据可视化训练第四天(模拟投掷筛子并且统计频次)

投掷一个筛子 import matplotlib.pyplot as plt from random import randint import numpy as npclass Die:"""模拟投掷筛子"""def __init__(self,num_sides6):self.num_sidesnum_sidesdef roll(self):return randint(1,self.num_sides)num1000…

vi\vim编辑器

root用户(超级管理员) 无论是Windows、MacOS、Linux均采用多用户的管理模式进行权限管理。 在Linux系统中,拥有最大权限的账户名为:root(超级管理员) root用户拥有最大的系统操作权限,而普通…

论文盲审吐槽多,谁给盲审不负责的老师买单?如何看待浙江大学「一刀切」的研究生学位论文双盲评审制度?

::: block-1 “时问桫椤”是一个致力于为本科生到研究生教育阶段提供帮助的不太正式的公众号。我们旨在在大家感到困惑、痛苦或面临困难时伸出援手。通过总结广大研究生的经验,帮助大家尽早适应研究生生活,尽快了解科研的本质。祝一切顺利!—…

二维数组 和 变长数组

在上一期的内容中,为诸君讲解到了一维数组,在一维数组的基础上,C语言中还有着多维数组,其中,比较典型且运用较为广泛的就是我们今天的主角——二维数组 一 . 二维数组的概念 我们把单个或者多个元素组成的数组定义为一…

DI-engine强化学习入门(七)如何自定义神经网络模型

在强化学习中,需要根据决策问题和策略选择合适的神经网络。DI-engine中,神经网络模型可以通过两种方式指定: 使用配置文件中的cfg.policy.model自动生成默认模型。这种方式下,可以在配置文件中指定神经网络的类型(MLP、CNN等)以及超参数(隐层大小、激活函数等),DI-engine会根据…

【漏洞复现】泛微OA E-Cology XmlRpcServlet文件读取漏洞

漏洞描述: 泛微OA E-Cology是一款面向中大型组织的数字化办公产品,它基于全新的设计理念和管理思想,旨在为中大型组织创建一个全新的高效协同办公环境。泛微OA E-Cology XmlRpcServlet存在任意文件读取漏洞,允许未经授权的用户读…

三星硬盘格式化后怎么恢复数据

在数字化时代,硬盘作为数据存储的核心部件,承载着我们的重要文件、照片、视频等资料。然而,不慎的格式化操作可能使我们失去宝贵的数据。面对这样的困境,许多用户可能会感到无助和焦虑。本文旨在为三星硬盘用户提供格式化后的数据…

计算机网络实验2:路由器常用协议配置

实验目的和要求 掌握路由器基本配置原理理解路由器路由算法原理理解路由器路由配置方法实验项目内容 路由器的基本配置 路由器单臂路由配置 路由器静态路由配置 路由器RIP动态路由配置 路由器OSPF动态路由配置实验环境 1. 硬件:PC机; 2. 软…

金三银四面试题(二十六):责任链模式知多少?

什么是责任链模式 责任链模式(Chain of Responsibility Pattern)是一种行为型设计模式,旨在通过将请求的处理分布在一系列对象上,从而使得多个对象可以尝试处理同一个请求。这些对象被链接成一条链,每个对象都可以对请…