24/8/17算法笔记 DDPG算法

news2024/9/23 13:28:00

深度确定性策略梯度(DDPG)算法是一种用于解决连续动作空间强化学习问题的算法。它结合了确定性策略梯度(DPG)和深度学习技术的优点,通过Actor-Critic框架进行策略和价值函数的近似表示。DDPG算法的关键组成部分包括经验回放缓冲区、Actor-Critic神经网络、探索噪声、目标网络以及软目标更新。

DDPG算法使用两个神经网络,分别作为Actor和Critic。Actor网络负责生成策略,即在给定状态下选择最佳动作,而Critic网络评估当前策略的表现,通过Q值来衡量。经验回放缓冲区存储了与环境交互过程中产生的转换数据,这些数据用于训练网络,打破样本之间的时间相关性,提高学习效率。

DDPG算法的一个关键特性是目标网络的使用,它通过缓慢更新目标网络的参数来增加学习过程的稳定性。软更新是通过将目标网络参数设置为目标网络参数加上一小部分主网络参数的变化来实现的。

探索噪声是DDPG算法中用于平衡探索与利用的另一个重要组成部分。通过在Actor网络输出的动作上添加噪声,鼓励智能体探索环境,这有助于发现更好的策略。

在实现DDPG算法时,需要定义Actor和Critic网络,初始化目标网络,并设置优化器。训练过程中,通过从经验回放缓冲区中采样数据来更新网络参数。更新过程包括计算目标Q值、当前Q值,并分别更新Critic和Actor网络。

import gym
from matplotlib import pyplot as plt
%matplotlib inline
#创建环境
env = gym.make('Pendulum-v1')
env.reset()

#打印游戏
def show():
    plt.imshow(env.render(mode='rgb_array'))
    plt.show()

action网络模型

import torch

class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.sequential = torch.nn.Sequential(
            torch.nn.Linear(3,64),#第一层全连接层,将输入特征从3维映射到64维。
            torch.nn.ReLU(),     #ReLU激活函数,用于引入非线性。
            torch.nn.Linear(64,64), #第二层全连接层,将64维特征再次映射到64维。
            torch.nn.ReLU(),    #ReLU激活函数。
            torch.nn.Linear(64,1), #第三层全连接层,将64维特征映射到1维,即输出一个值。
            torch.nn.Tanh(), #Tanh激活函数,将输出值映射到-1到1之间。
        )
    def forward(self,state):  #是前向传播函数,它定义了如何计算网络的输出。在这里,它将输入 state 通过 self.sequential 进行处理,然后乘以2.0。
        return self.sequential(state)*2.0

model_action = Model()
model_action_next = Model()

model_action_next.load_state_dict(model_action.state_dict())
model_action(torch.randn(1,3))

value网络模型

model_value = torch.nn.Sequential(
    torch.nn.Linear(4,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,1),
)
model_value_next = torch.nn.Sequential(
    torch.nn.Linear(4,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,64),
    torch.nn.ReLU(),
    torch.nn.Linear(64,1),
)
model_value_next.load_state_dict(model_value.state_dict())
model_value(torch.randn(1,4))

动作函数

import random
import numpy as np
def get_action(state):
    state = torch.FloatTensor(state).reshape(1,3)
    action = model_action(state).item() #.item(): 这个方法通常用于将一个张量(tensor)转换成一个标准的Python数值。在PyTorch中,模型的输出通常是一个张量。如果你想要获取这个张量中的单个值,可以使用.item()方法。
    #给动作添加噪声,增加探索
    action +=random.normalvariate(mu=0,sigma=0.01)#高斯随机噪声
    return action

更新样本池函数,准备离线学习

#样本池
datas = []

#向样本池中添加N条数据,删除M条最古老的数据
def update_data():
    #初始化游戏
    state = env.reset()
    
    #玩到游戏结束为止
    over = False
    while not over:
        #根据当前状态得到一个动作
        action = get_action(state)
        
        #执行动作,得到反馈
        next_state,reward,over,_ = env.step([action])
        
        #记录数据样本
        datas.append((states,action,reward,next_state,over))
        
        #更新游戏状态,开始下一个当作
        state = next_state
    #数据上限,超出时从最古老的开始删除
    while len(datas)>10000:
        datas.pop(0)
        

env.step(action) 是一个常用方法,用于执行给定的动作 action 并与之环境交互。

数据采样函数

#获取一批数据样本
def get_sample():
    samples = random.sample(datas,64)
    #[b,4]
    state = torch.FloatTensor([i[0]for i in samples]).reshape(-1,3)
    #[b,1]
    action = torch.LongTensor([i[1]for i in samples]).reshape(-1,1)
    #[b,1]
    reward = torch.FloatTensor([i[2]for i in samples]).reshape(-1,1)
    #[b,4]
    next_state = torch.FloatTensor([i[3]for i in samples]).reshape(-1,3)
    #[b,1]
    over = torch.LongTensor([i[4]for i in samples]).reshape(-1,1)

    return state,action,reward,next_state,over

state,action,reward,next_state,over=get_sample()

state[:5],action[:5],reward[:5],next_state[:5],over[:5]

这些数据通常用于训练强化学习模型,其中状态 statenext_state 被用来输入到价值函数或策略网络中,action 是模型选择的动作,reward 是环境对动作的反馈,over 表示游戏是否结束,通常用于确定奖励的折扣因子。

测试函数

from IPython import display

def test(play):
    #初始化游戏
    state = env.reset()#重置环境状态
    
    #记录反馈值的和,这个值越大越好
    reward_sum= 0
    
    #玩到游戏结束为止
    over =False
    while not over:
        #根据当前状态得到一个动作
        action = get_action(state)
        
        #执行动作,得到反馈
        staet,reward,over,_=env.step(action)
        reward_sum+=reward
        
        #打印动画
        if play and random.random()<0.2: #用于清除先前在输出区域的显示内容
            display.clear_output(wait=True)  #wait:设置为 True 时,clear_output 将等待所有异步输出完成之后再清除输出区域。这可以确保在清除之前所有输出都已经显示。    
            show()
    return  reward_sum
def get_value(state,action):
    #直接评估综合了state和action的value
    input = torch.cat([state,action],dim=1) #torch.cat 函数用于连接多个张量,dim 参数指定了沿着哪个维度进行连接。
    
    return model_value(input)

def get_target(next_state,reward,over):
    #对next_state评估需要先把它对应的当作计算出来
    action = model_action_next(next_state)
    
    #和value的计算一样,action拼合进next_state里综合计算
    input = torch.act([next_state,action],dim=1)
    target = model_value_next(input)*0.98
    target *=(1-over)
    target +=reward
    return target

action模型的loss

def get_loss_action(state):
    #首先把动作计算出来
    action  = model_action(state)
    
    #像value计算那里一样,拼合state和action综合计算
    input = torch.cat([state,action],dim = 1)
    
    #使用value网络评估动作的价值,价值越高越好
    #因为这里是在计算loss,loss是越小越好,所以符号取反
    loss =-model_value(input).mean()
    
    return loss

软更新函数,DQN使用硬更新

软更新(Soft Update)是深度强化学习中用于更新目标网络参数的一种技术。在某些强化学习算法,如深度确定性策略梯度(DDPG)算法中,会使用两个相似的网络:一个用于生成当前策略或价值函数的“主网络”(online network),以及一个“目标网络”(target network)。

目标网络的参数是主网络参数的慢速更新版本,这样做的目的是增加训练过程的稳定性。软更新的具体步骤如下:

1. **初始化**:开始时,目标网络的参数被复制或初始化为主网络参数的副本。

2. **慢速更新**:在每次训练迭代中,以一个小的比例(通常是一个小于1的因子,如0.001或0.005)更新目标网络的参数。这个更新过程可以表示为:
 
   其中, 是目标网络的参数, 是主网络的参数,而是更新比例(tau 系数)。

3. **逐步逼近**:通过这种方式,目标网络的参数会逐步逼近主网络的参数,但不会立即完全同步。这有助于减少训练过程中的震荡。

软更新的优点包括:

- **稳定性**:由于目标网络参数更新得更慢,它为训练过程提供了一定程度的稳定性。
- **减少震荡**:软更新减少了目标值的突然变化,这有助于避免训练中的大振荡。
- **平滑学习**:它允许模型在更新过程中保持平滑的学习曲线。

软更新通常在算法的每次迭代或每隔几个步骤执行一次,具体取决于算法的设计和所需的更新频率。
 

def train():
    model_action.train()  #设置模型为训练模式
    model_value.train()
    optimizer_action = torch.optim.Adan(model.parameters(),lr =5e-4) #创建优化器
    optimizer_value = torch.optim.Adam(model_td.parameters(),lr=5e-3)
    loss_fn = torch.nn.MSELoss()
    
    #玩N局游戏,每局训练一次
    for epoch in range(200):
        #更新N条数据
        update_data()
        
        for i in range(200):
            #玩一局游戏,得到数据
            states,rewards,actions,next_states,overs = get_sample()

            #计算values 和targets
            values= get_value(state,action)
            targets = get_target(next_states,reward,over)


            #两者求差,计算loss,更新参数
            loss_value= loss_fn(values,targets)

            #更新参数
            optimizer.zero_grad()   #作用是清除(重置)模型参数的梯度
            loss.backward()       #反向传播计算梯度的标准方法
            optimizer.step()     #更新模型的参数

            #使用value网络评估action网络的loss,更新参数
            loss_action = get_loss_action(state)

            optimizer_td.zero_grad()  
            loss_td.backward()       
            optimizer_td.step()  

            #以一个小的比例更新
            soft_update(model_action,model_action_next)
            soft_update(model_value,model_value_next)

        if i %20 ==0:
            test_result = sum([test(play=False)for _ in range(10)])/10
            print(epoch,len(datas),test_result)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2050697.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【RAG综述】北京大学检索增强技术综述

RAG for AIGC ​ 图 1 描述了一个典型的 RAG 过程。给定一个输入查询&#xff0c;检索器识别相关的数据源&#xff0c;检索到的信息与生成器交互以改进生成过程。根据检索结果如何增强生成&#xff0c;有几种基础范式&#xff08;简称基础&#xff09;&#xff1a;它们可以作为…

STM32的蜂鸣器

蜂鸣器分为有源蜂鸣器和无源蜂鸣器。 有源蜂鸣器&#xff1a;内部有震荡源&#xff0c;只要通电即可自动发出固定频率的声音。&#xff08;频率固定无 法控制音色&#xff09; 。 无源蜂鸣器&#xff1a;内部无震荡源&#xff0c;需要外部脉冲信号驱动发声&#xff0c;声音频…

《机器学习》 线性回归 一元、多元 推导 No.3

一、什么是线性回归 线性回归是一种用于预测连续数值的机器学习算法。它基于输入特征与目标变量之间的线性关系建立了一个线性模型。线性回归的目标是找到最佳拟合直线&#xff0c;以最小化预测值与实际值之间的误差。这个线性模型可以用来进行预测和推断。 线性回归的模型可以…

SpringBoot Profile多环境配置及配置优先级

【SpringBoot学习笔记 三】Profile多环境配置及配置优先级_profiles队列中的优先值-CSDN博客 Profile激活方式 但是我们发现一个问题&#xff0c;就是每次切换环境还需要去配置里指定&#xff0c;然后通过修改dev为test或prod来切换项目环境 , 这样做的话每次切换环境都要重新改…

前端面试——如何判断对象和数组

给你一个值&#xff0c;如何判断其是对象还是数组&#xff1f;&#xff1f;&#xff1f; 我们先给出数据 var lists [1,2,3,4,5]var objs {length:5 } 我们分别尝试如下五种方法 console.log((✘)使用length,lists.length,objs.length); console.log((✔)使用isArray,Arr…

【已成功EI检索】第三届机电一体化技术与航空航天工程国际学术会议(ICMTAE 2023)

重要信息 大会官网&#xff1a;www.icmtae.org 大会时间&#xff1a;2023年9月15-17日 大会地点&#xff1a;中国-江西南昌理工学院&#xff08;南昌市青山湖区经济技术开发区英雄大道901号&#xff09; 接受/拒稿通知&#xff1a;投稿后1周内 收录检索&#xff1a;EI 和 …

Vulkan 学习(4)---- Vulkan 逻辑设备

目录 Vulkan Logical Device OverView逻辑设备创建VkDeviceQueueCreateInfoDeviceExtension获取DeviceQueue参考代码 Vulkan Logical Device OverView 在 Vulkan 中&#xff0c;逻辑设备(Logical Device)是与物理设备(Physical Device)交互的接口,它抽象了对特定GPU(物理设备)…

CDD数据库文件制作(八)——服务配置(0x85)

目录 1.子功能创建2.会话切换配置/安全等级配置2.1.根据诊断调查表进行信息提取2.2.会话转换配置/安全等级配置3.寻址方式信息提取/禁止肯定响应位(SPRMIB)信息3.1.寻址方式/禁止肯定响应位(SPRMIB)配置4.否定响应码信息提取4.1.否定响应码配置按照诊断调查表中对0x85服务的…

PX30 Android8.1适配AIC8800 wifi

wifi驱动生成ko文件 生成后 通过wpa_supplicant加载参数 external/wpa_supplicant_8/wpa_supplicant/main.c int main(int argc, char *argv[]) {int ret -1;char module_type[20]{0};wpa_printf(MSG_INFO,"argc %d\n",argc);if(argc 2) {if (wifi_type[0] 0) …

【MySQL】数据库基础(表的操作)

目录 一、创建表 二、查看表结构 三、修改表 3.1 添加新列 3.2 修改列属性 3.3 删除列属性 3.4 修改表名 3.5 向表中插入 3.6 修改列名 四、删除表 一、创建表 语法&#xff1a; CREATE TABLE table_name ( field1 datatype, field2 datatype, field3 datatype ) …

docker容器安全加固参考建议——筑梦之路

这里主要是rootless的方案。 在以 root 用户身份运行 Docker 会带来一些潜在的危害和安全风险&#xff0c;这些风险包括&#xff1a; 容器逃逸&#xff1a;如果一个容器以 root 权限运行&#xff0c;并且它包含了漏洞或者被攻击者滥用&#xff0c;那么攻击者可能会成功逃出容器…

车载camera avm框图

一、关键词介绍: POC: power on coax LVDS: Low-Voltage Differential Signaling GMSL:Gigabit Multimedia Serial Link AVM: Around View Monitor Serdes:DeSerializer、Serializer DVP:Interface with ISP and Sensor: DVP(Digital Video Port) 二、车载camera avm…

书籍推荐:大数据之路 阿里巴巴大数据实践

书籍推荐&#xff1a;大数据之路 阿里巴巴大数据实践 这本书侧重于理论知识&#xff0c;并结合了阿里大数据发展的过程&#xff0c;将知识总结起来。总的来所&#xff0c;书中的有些章节个人感觉非常不错&#xff0c;比如&#xff1a;数据仓库建模&#xff1b;但是大部分章节都…

性能优化理论篇 | 如何保证数据安全落盘,5分钟彻底弄懂 一次write中的各种缓冲区 !

性能优化系列目录&#xff1a; 性能优化理论篇 | 彻底弄懂系统平均负载 性能优化理论篇 | swap area是个什么东西 性能优化理论篇 | Cache VS Buffer&#xff0c;傻傻分不清 &#xff1f; 在很多IO场景中&#xff0c;我们经常需要确保数据已经安全的写到磁盘上&#xff0c;以便…

xss之DOM破坏

文章目录 DOM破坏漏洞的复现https://xss.pwnfunction.com/基于bp学院DOM破坏漏洞复现思路分析实现 常见的xss触发的标签没有过滤的情况存在过滤的情况 DOM破坏 DOM破坏就是⼀种将 HTML 代码注⼊⻚⾯中以操纵 DOM 并最终更改⻚⾯上 JavaScript ⾏为的技术。 在⽆法直接 XSS的情…

Linux·权限与工具-make

1. Makefile/makefile工具 首先展示一下&#xff0c;makefile工具如何使用。我们先写一个C语言程序 然后我们建立一个Makefile/makefile文件&#xff0c;m大小写均可。我们在文件中写入这样两行 wq保存退出后&#xff0c;我们使用 make 命令 可以看到生成了可执行程序&#xff…

无人机模拟训练室技术详解

无人机模拟训练室作为现代无人机技术培训的重要组成部分&#xff0c;集成了高精度模拟技术、先进的数据处理能力及高度交互的操作界面&#xff0c;为无人机操作员提供了一个安全、高效、接近实战的训练环境。以下是对无人机模拟训练室技术的详细解析&#xff0c;涵盖系统基础概…

为TI的 AM355移植uboot和linux内核

一、uboot移植 在移植之前要先对uboot的源码结构有一定熟悉 1.uboot源码顶层目录下各源码文件夹的作用 2.编译后生成的uboot.xxx 各文件后缀含义 关于以上两点社区已经有很多前辈总结的很详细&#xff0c;这里不做赘述。 对于uboot源码分析韦东山老师b站上有免费的课程&#x…

QT中Charts基本用法

QT中Charts基本用法 第一步:创建工程,添加Charts库 第二步:添加charts视图 注意要打上对钩 第三步:添加所需成员 第四步:编写初始化函数 第五步:添加测试数据

C++学习笔记之算法模板

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、双指针1.1 有序数组的合并1.2 快慢指针/删除有序数组中的重复项1.3 求和 二、动态规划2.1 自底向上和自顶向下&#xff08;带备忘录&#xff09;2.2 带有当前状…