【MARL】MADDPG + attention 实现(+论文解读)

news2024/12/26 9:27:47

文章目录

  • 前言
  • 注意力机制
  • 论文里的attention
    • 回顾知识-MADDPG
    • 讲解
      • 1.Q的定义
      • 2.Q的恒等式
      • 3.论文里的attention
      • 4.好处
  • 实现 和 修改
    • 结果展示
    • 原论文代码 翻改版
    • 修改后
    • 原maddpg代码


前言

导师让在MADDPG上加一个注意力机制,试了很多种,下面的参考的论文的效果最好,先把其思路记录下来。
之后有时间再试试自注意力机制。

参考论文:
Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG
论文代码:github

注意力机制

注意力机制是什么?最初是在NLP领域中为了解决lstm中序列的前后文本因为位置跨度大而不能理解正确输出后文的问题。
著名文章:Attention is all your need
讲的比较好的文章:
1.Transformer:注意力机制(attention)和自注意力机制(self-attention)的学习总结
2.详解Transformer中Self-Attention以及Multi-Head Attention

我这里的个人理解:

注意力机制:Attention(Q, K, V) = softmax(Q*K^T/sqrt(d_k)) * V (以缩放点积注意力机制举例)


Q :查询(自主性提示,有主观意识成分) K : 键(非自主性提示,客观存在的特征) V: 值 (K中实际信息,客观存在)
针对两条序列,Q 和 K (假设Q:5x1x6 ,K:5x3x6 ,V:5x3x6)
注:batch_size x len x hidden_dim
具象理解:其中Q中的1可以具象化为我要找的一件东西(这个东西有6个特征),K中的3为我已有的三件物品的线索(也有6个特征),我拿我心中这个东西的特征和线索特征去匹配。
5可以表示为我找了5次东西。


注意力分数 (打分器 ) attention_score:Q*K^T/sqrt(d_k) 对Q和K做一个加权求和,输出一个值,这个值可以说是Q和K的相似程度,两者相似程度越高,这个值就越大,并且其他不相似的元素也对其有贡献(即软注意力,硬注意力是0,1 关系)。
注:这里使用了缩放点积的方法:缩放:sqrt(d_k) (防止后续softmax过大以丧失梯度)点积:Q*K^T
关键疑问:为什么这里明明是矩阵乘法(行列元素对应相乘并相加:实际就是加权求和),确说成是点积(对应元素相乘)?
回答:矩阵乘法当的第一步是行列元素相乘,可以看作是对应行元素与对应列元素的点积。
此时:得到5x1x3
具象理解:其中1x3 可以看作我得到了这3个线索的相似度分数,假设分数[3,4,5]。


注意力权重 attention_weight :softmax(Q*K^T/sqrt(d_k)) 对分数进行一个softmax函数,实际上就是把各个分数转换成(0,1)之间的概率分布,且其概率分布和为1。
此时:得到5x1x3
具象理解:其中1x3为输出的是我每个线索我要给予多少关注程度。假设是[0.2,0.3,0.5]。


注意力值 atten_value :softmax(Q*K^T/sqrt(d_k)) * V 将关注度与V还是进行一个加权求和(矩阵乘法)
此时:5x1x6
具象理解:我第一个线索中对应的东西的6个特征均给予0.2程度的关心,第2个线索对应的东西的6个特征均给予0.3程度的关心,第3个是均给予0.5程度的关心。最后找到了这个‘东西’。
最后找到的东西和原来的东西不一样但相似度高,比方说:我想找青苹果,最后找到了红苹果(即本质一样)。
在翻译领域上的具体实现是:翻译:apple 为 苹果。


总结
输入:Q,K序列,返回:在不同K‘地址’的V‘物品’里‘合成的像Q的’物品’
该模型可以动态地决定在不同位置上分配多少注意力,从而关注更需要关注的特征。
K和Q的匹配过程决定了“在哪里看”,而V则决定了“看到什么”。


补充:K-head注意力机制:只是在注意力机制基础上,用到的hidden层变成了K个
比如:原来Q:5x1x6 -> Kx5x1x6 同理K,V也是。
最后输出前再堆叠(stack)一下->5x1x (6xK)
其他常见的做法:未来避免隐层变多造成的计算过慢,将原来的hidden切成(view)K个,比如说三个,5x1x6 -> 5x1x(2x3) 。
好处是让注意力不只是注意一个敌方,注意多个地方。


注:由于注意力机制的第一步和最后一步其实都是矩阵相乘(加权求和),也就是matmul()或者mul().sum(-1),所以会见到这两种代码的排列组合,不要怀疑,这两种写法都是对的。

自注意力机制:(施工中)
简单理解:
与注意力机制的区别在于:只输入一条序列,输入的Q,K一样,关注自己相似的地方。
具体:The animal didn’t cross the street,because it was too tired。这里animal 和 it 是同一个。

论文里的attention

参考论文:
Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG

回顾知识-MADDPG

1.MADDPG采用了集中式训练,分布式执行(centralized training with decentralized execution,CTDE)的框架 集中式含义:在训练时使用所有智能体的状态和动作集合,并不是指训练出的critic网络一样。
疑问:为何训练出的critic网络不一样? 答:每个智能体的奖励函数不一样,或者有可能done也不一样。
分布式含义:在执行时只使用当前智能体的状态,所以训练出的actor肯定不一样的。

2.采用Actor-Critic框架,如何训练critic和actor?
critic网络:用深度神经网络拟合动作价值函数。
与DDPG一样,使用TD(0)算法来迭代实现贝尔曼状态动作价值函数。
在这里插入图片描述
这里的TD目标实际上代替了总回报值Gt。
于是自然推出损失函数为当前动作价值和TD目标的均方差->以实现估计状态动作价值函数
在这里插入图片描述
具体理解:拟合的Q(s,a)为:当前长期回报总值,y为当前奖励值和下一步的未来长期回报值。
本质上是对未来回报的预测,告诉智能体在某个状态下采取某个动作的潜在好处。
TD目标值并不总是大于当前Q值估计,大于时,则说明某个状态好,小于时,则说明某个状态差。
TD目标的计算是为了提供一个更接近真实长期回报的估计,目标是正确认识到可以达到的Q值。
具象理解:相当于是训练一个评论家,告诉你,我最大可以拿到多少奖励。


actor网络:用深度神经网络拟合动作值/动作概率分布/动作均值和方差。
在这里插入图片描述
一般来说:即使用当前策略下的动作(替换经验池中抽取当前智能体的动作为当前状态下的动作),损失函数为-Q(s,a)->加上负号以使用梯度下降。实际实现Q值最大化,即回报值最大化。

具象理解:训练一个玩家,使奖励达到最大。

其他:使用A-C算法,取代了传统DQN算法中显示使用max a′Q(s t+1,a ′) 选取最大Q值的策略,而是学习了一个策略。
注:先更新critic,后更新actor有助于学习到更好的策略

讲解

没按文章顺序解读,按自己的理解解读

1.Q的定义

和实际上的attention 不一样,论文里并不是直接加进去,而是巧妙利用了注意力机制里的一些特性,重新定义了Q函数,达到了神奇的效果。

MADDPG论文中定义的Q 为
Q = Q i u ( s , a ∣ a i = u i ( o i ) ) Q = Q_i^u(s,a|a_i=u_i(o_i)) Q=Qiu(s,aai=ui(oi))
其中u为actor的策略,s为所有智能体的状态,a为所有智能体的动作,u_i为当前智能体的策略。
这里的Q即在更新actor时要最大化的Q。
更新critic时的Q为 Q = Q i u ( s , a ) Q = Q_i^u(s,a) Q=Qiu(s,a) ,即没有后续的条件( a i = u i ( o i ) a_i=u_i(o_i) ai=ui(oi))。

论文里定义的Q,将动作价值函数Q(s,a)定义为:(更新critic时的Q)
Q = Q i u i ∣ u − i ( s , a i ) Q = Q_i^{u_i|u_{-i}}(s,a_i) Q=Qiuiui(s,ai)
其中 u i u_i ui为当前actor的策略, u − i u_{-i} ui表示其他智能体策略。
更新actor时和上述MADDPG一样,使用 a i = u i ( o i ) a_i=u_i(o_i) ai=ui(oi)

论文是这样阐述的:
在这里插入图片描述
回想一下智能体的环境,都是被a(所有智能体的动作)来影响的,
那么,从智能体i的角度来看,即就是当前智能体的角度来看,我在s(所有智能体的状态)下做出自己动作(a_i)的结果,取决于其他智能体的动作。

论文作者意思就是说,这个环境都是被动作影响的,那么我的动作是在其他智能体的动作影响下的环境下做出的动作,那么我的动作实际上取决于其他智能体的动作。

因此将Q定义为如上形式。

2.Q的恒等式

此时,我们的目标是也就是最大化这个定义的Q值,也就是论文(下图)这个argmax的形式。
在这里插入图片描述
从数学上,因为在其他智能体采取其他动作的条件下 ∣ u − i |u_{-i} ui,意味着我们要考虑所有可能的动作组合的概率分布,故可以显式的写成上述(6)、(7)式。
Σ a ⃗ − i ∈ A ⃗ − i [ π ⃗ − i ( a ⃗ − i ∣ s ) Q i π i ( s , a i , a ⃗ − i ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[\vec{\pi}_{-i}(\vec{a}_{-i}|s)Q_{i}^{\pi_{i}}(s,a_{i},\vec{a}_{-i})] Σa iA i[π i(a is)Qiπi(s,ai,a i)]
A为动作空间。

由于要估计每个其他智能体的动作(即实现 Σ a ⃗ − i ∈ A ⃗ − i \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}} Σa iA i) 论文作者运用了K-head模块来估计。

即: Σ a ⃗ − i ∈ A ⃗ − i [ Q i π i ( s , a i , a ⃗ − i ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[Q_{i}^{\pi_{i}}(s,a_{i},\vec{a}_{-i})] Σa iA i[Qiπi(s,ai,a i)]
可以写成(约等于)
∑ k = 1 K Q i k ( s , a i ∣ a ⃗ − i ; w i ) \sum_{k=1}^KQ_i^k(s,a_i|\vec{a}_{-i};w_i) k=1KQik(s,aia i;wi)
wi为critic网络的参数。

作者这里说
这里由于在生成 Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi)时输入只有s和a_i,而a_-i,用一个额外的隐层h来实现,所以并没有写成 Q i k ( s , a i , a ⃗ − i ; w i ) Q_i^k(s,a_i,\vec{a}_{-i};w_i) Qik(s,ai,a i;wi)

然而我查看了代码,感觉写成后者也没问题,因为在生成时,两者都输入了,且分别用了H和h的隐层。不过我支持作者这样的写法,原因稍后说。

而关于 Σ a ⃗ − i ∈ A ⃗ − i [ π ⃗ − i ( a ⃗ − i ∣ s ) ] \Sigma_{\vec{a}_{-i}\in\vec{A}_{-i}}[\vec{\pi}_{-i}(\vec{a}_{-i}|s)] Σa iA i[π i(a is)]的估计
则是作者富有想象力的体现了。

在这里插入图片描述
但是这里作者并没写为什么,所以一直很难搞懂。

作者使用 W i ( w i ) W_i(w_i) Wi(wi)来近似所有其他智能体的动作概率分布,即 π ⃗ − i ( A ⃗ − i ∣ s ) \vec{\pi}_{-i}(\vec{A}_{-i}|s) π i(A is),此时问题变为如何近似 W i ( w i ) W_i(w_i) Wi(wi)。而注意力机制天然适合生成概率分布。

最后Q估计为如下的形式:

在这里插入图片描述
等式左端的Q其就是在Q的定义里描述的Q。

3.论文里的attention

回顾注意力机制中,注意力权重的部分,即找出Q和K序列的相似的权重的序列。
其概率分布和为1,且能动态调整。

假设3个同质智能体,状态维度均为12,动作维度均为1。
此时作者让Q的查询的部分为其余智能体的动作组合(2),K的键的部分为所有智能体状态和当前智能体动作(37)( Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi))。(这里键是客观实在,最后输出时也是只用到了键的s,a_i,所以我支持作者上述写法。)

按照本文上述注意力机制的描述,此时得到了一个2x37(即 W i ( w i ) W_i(w_i) Wi(wi))。即两个1x37,每个(1)其他智能体动作有37个线索,每个线索应该给多少关注[0.11,…,0.2]。

所以此时得到的权重,正好可以看成在维度(2x37)和后面 Q i k Q_i^k Qik(37xhidden)维度相同的 其他智能体动作的概率值。
在这里插入图片描述
注:这里的K-head 和k-head注意力机制不同,这里仅仅是为了拟合多种不同的动作,所以只在键上做了K-head。

至于为什么这里可以拟合成 其他智能体动作的概率(比较难理解,作者也没解释,只是说正好可以这样生成,生成了可以动态调整。)
个人理解:
当前智能体的动作 取决于其他智能体的动作(和环境状态)(1.Q的定义上讲到)。
反过来,
所以其他智能体的动作 取决于当前智能体的动作和所有状态,
也就是说和当前智能体的动作和所有状态相关,
而其他智能体的动作的概率值本来也是和当前智能体的动作和所有状态相关。
两者均相关于同一个东西,那么就可以理解成: A~C B~C => A~B
那么注意力权重就是给出的概率分布,就可以用来近似。
再进一步解释,这里作者在传入Q和K前,分别将数据先进行一个全连接,再进行一个激活函数(hi)处理,可能就是让神经网络尽可能的学习到这个近似。

最后一步,将两者加权求和,这一步和注意力机制的最后一步竟也是一样。
会有一种让人觉得这就是简单在用注意力机制的感觉,不过这里有理有据。

还有一点与attention不一样,注意力机制最后是权重与V加权求和,这里还是和键K相乘,因为这里的K的含义不一样,这里的论文里的K是 Q i k ( s , a i ∣ a ⃗ − i ; w i ) Q_i^k(s,a_i|\vec{a}_{-i};w_i) Qik(s,aia i;wi)

在这里插入图片描述

4.好处

至此,关键创新点,解读完了,其他基本和MADDPG一致,除了代码那里在 更新critic那块还用了和MAAC一样的方法,(此论文发表比MAAC早,应该是早已有的方法)即求损失函数之和来更新critic。

此外,论文还解释了K在动作是离散空间下,不必是|A_-i|,因为只有一小部分的动作是重要的。
在这里插入图片描述

还研究了在连续空间下,K-head也是可以有效作用的。
在这里插入图片描述
最终的好处。
1.关注了其他智能体的动作来更新critic,缓解了环境非稳态的问题。
理由是在原来的算法下,不管其他队友的情况下,总有确定性的概率会导致相同的奖励和相同的下一个状态。

在这里插入图片描述
2.可以动态调节,动作的概率分布,更容易适用于不同策略(由于attention的性质),
即使其他智能体的策略,已经改变,当前的动作价值也不需要改变(因为已经学到了),提供了一个稳定的良好的Q值。
在这里插入图片描述
3.这种方法的近似,相当于以往Q= V+A的近似,比单单的全连接要好。
在这里插入图片描述
之后就是训练曲线展示,在合作导航和捕食者两个基准环境测试(maddpg论文里用的环境)下,证明了比MADDPG好。

实现 和 修改

当然,肯定是好的,因为我也用过了。
两者除了critic网络架构上的区别,其他参数均一致的情况下,且都调整为个人认为理想的网络层数的情况下,效果如下:(论文里的critic网路结构被我修改后的结果,修改前实验效果稍许不如)

结果展示

在这里插入图片描述
黄色为maddpg+attention
红色为maddpg。
可以看出黄色明显优于红色。

原论文代码:github
为方便理解代码,我将其先修改为一般的attention作为block参数共享的代码。
我这里的critic是单个单个更新的。环境是参照prttingzoo搭建的。

注意力机制 基本上是可以即插即用的,x是cat(s,a),只要把自己的network替换成如下形式就行。

一些得根据自己修改的地方
智能体数目:3
状态维度:12
动作维度:1
且这里的三个都是同样的状态维度和动作维度

论文版本的代码,
至于agent_id,agents,只是为了得到当前智能体是第几个智能体agent_id_index,
我这里是agent_id 是智能体的名字’Red-0’
agents是keys是agent_id,values是agent类的字典。

原论文代码 翻改版

class Attention2(nn.Module):
    def __init__(self, encoder_input_dim, decoder_input_dim, hidden_dim, head_count):
        super(Attention2, self).__init__()
        self.fc_encoder_input = nn.Linear(encoder_input_dim, hidden_dim)
        self.fc_encoder_heads = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(head_count)])
        self.fc_decoder_input = nn.Linear(decoder_input_dim, hidden_dim)

    def forward(self, encoder_input, decoder_input):
        ''' encoder_input 由所有智能体的状态和当前智能体动作组成,decoder_input 由其余智能体的动作组成'''
        # encoder_input shape: (batch_size, input_dim)
        encoder_h = F.relu(self.fc_encoder_input(encoder_input))
        # encoder_h shape: (batch_size, hidden_dim)

        encoder_heads = torch.stack([F.relu(head(encoder_h)) for head in self.fc_encoder_heads], dim=0)
        # encoder_heads shape: (head_count, batch_size, hidden_dim)

        # decoder_input shape: (batch_size, input_dim)
        decoder_H = F.relu(self.fc_decoder_input(decoder_input))
        # decoder_H shape: (batch_size, hidden_dim)

        ''' enocde_heads 用作键值对 decoder_H 用作查询 '''
        scores = torch.sum(torch.mul(encoder_heads, decoder_H), dim=2)
        # scores shape: (head_count, batch_size)

        attention_weights = F.softmax(scores.permute(1, 0), dim=1).unsqueeze(2)
        # attention_weights shape: (batch_size, head_count, 1)

        contextual_vector = torch.matmul(encoder_heads.permute(1, 2, 0), attention_weights).squeeze()
        # contextual_vector shape: (batch_size, hidden_dim)

        return contextual_vector

class MLPNetworkWithAttention2(nn.Module):
    def __init__(self, in_dim, out_dim,hidden_dim = 128 ,head_count = 8 ):
        super(MLPNetworkWithAttention2, self).__init__()
        #self.args = args # 3为智能体个数 12为状态维度 1为动作维度 
        self.fc_obs = nn.Linear(12, hidden_dim) 
        self.fc_action = nn.Linear(1, hidden_dim)
        self.attention_modules = Attention2(hidden_dim * (3 + 1), hidden_dim * (3 - 1),hidden_dim, head_count)  #3为智能体数量
        self.fc_qvalue = nn.Linear(hidden_dim, out_dim) 

    def forward(self, x, agent_id, agents):
        agent_id_list = list(agents.keys())
        agent_id_index = agent_id_list.index(agent_id) #获取agent_id在agents中的索引 按照顺序排
        agent_n = len(agent_id_list) #智能体数量3 #12为state_dim #3*12=36
        
        out_obs_list = [F.relu(self.fc_obs(x[:,:12])) , F.relu(self.fc_obs(x[:,12:24])) , F.relu(self.fc_obs(x[:,24:36]))]               
        # out_obs_list shape: [(batch_size, hidden_dim), ...] #即 batch_size * hidden_dim * agent_count

        out_action_list = [F.relu(self.fc_action(x[:,36:37])) , F.relu(self.fc_action(x[:,37:38])) , F.relu(self.fc_action(x[:,38:39]))]
        # out_action_list shape: [(batch_size, hidden_dim), ...]

        encoder_input = torch.cat(out_obs_list + [out_action_list[agent_id_index]], dim=1)
        # encoder_input shape: (batch_size, hidden_dim * (agent_count + 1))

        decoder_input = torch.cat(out_action_list[:agent_id_index] + out_action_list[agent_id_index+1:], dim=1)
        # decoder_input shape: (batch_size, hidden_dim * (agent_count - 1))

        contextual_vector = self.attention_modules(encoder_input, decoder_input)
        # contextual_vector shape: (batch_size, hidden_dim)

        qvalue = self.fc_qvalue(contextual_vector)
        # qvalue shape: (batch_size, 1)

        return qvalue

修改后

我修改后的,
由于,我发现这里注意力机制中的隐层维度会随着 智能体数量的提高 而变高,可能会造成过拟合的现象,以及认为传入Q,K的数据,不需要进行relu的操作,因为在attention机制里已有一层relu,故修改如下:

## 注意力机制改2_ --Modelling the Dynamic Joint Policy of Teammates with Attention Multi-agent DDPG 论文 改版
class Attention2_(nn.Module):
    def __init__(self, encoder_input_dim, decoder_input_dim, hidden_dim, head_count):
        super(Attention2_, self).__init__()
        self.fc_encoder_input = nn.Linear(encoder_input_dim, hidden_dim)
        self.fc_encoder_heads = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(head_count)]) ##
        self.fc_decoder_input = nn.Linear(decoder_input_dim, hidden_dim)

    def forward(self, encoder_input, decoder_input):
        ''' encoder_input 由所有智能体的状态和当前智能体动作组成,decoder_input 由其余智能体的动作组成'''
        # encoder_input shape: (batch_size, input_dim)
        encoder_h = F.relu(self.fc_encoder_input(encoder_input))
        # encoder_h shape: (batch_size, hidden_dim)

        encoder_heads = torch.stack([F.relu(head(encoder_h)) for head in self.fc_encoder_heads], dim=0)
        # encoder_heads shape: (head_count, batch_size, hidden_dim)

        # decoder_input shape: (batch_size, input_dim)
        decoder_H = F.relu(self.fc_decoder_input(decoder_input))
        # decoder_H shape: (batch_size, hidden_dim)
        ''' enocde_heads 用作键值对 decoder_H 用作查询 '''
        scores = torch.sum(torch.mul(encoder_heads, decoder_H), dim=2)
        # scores shape: (head_count, batch_size) <- before sum (head_count, batch_size, hidden_dim) 

        attention_weights = F.softmax(scores.permute(1, 0), dim=1).unsqueeze(2)
        # attention_weights shape: (batch_size, head_count, 1)

        contextual_vector = torch.matmul(encoder_heads.permute(1, 2, 0), attention_weights).squeeze()
        # contextual_vector shape: (batch_size, hidden_dim)

        return contextual_vector
        
class MLPNetworkWithAttention2_(nn.Module):
    def __init__(self, in_dim, out_dim , hidden_dim = 128 ,head_count = 8 ):
        '''
        在Attention2中 hidden_dim = 128 ,head_count = 8  效果最好
        '''
        super(MLPNetworkWithAttention2_, self).__init__()
        '''
        #self.args = args # 3为智能体个数 12为状态维度 1为动作维度 
        self.fc_obs = nn.Linear(12, hidden_dim) 
        self.fc_action = nn.Linear(1, hidden_dim)
        '''
        self.attention_modules = Attention2_(hidden_dim , hidden_dim ,hidden_dim, head_count) 
        self.fc_qvalue = nn.Linear(hidden_dim, out_dim) 

        #所有智能体的状态和当前智能体动作 维度
        self.fc1 = torch.nn.Linear(37, hidden_dim)
        #其余智能体的动作 维度
        self.fc2 = torch.nn.Linear(2, hidden_dim)
        self.fc3 = torch.nn.Linear(hidden_dim, hidden_dim)
        

    def forward(self, x,agent_id,agents):
        agent_id_list = list(agents.keys())
        agent_id_index = agent_id_list.index(agent_id) #获取agent_id在agents中的索引 按照顺序排
        agent_n = len(agent_id_list) #智能体数量 #12为state_dim #3*12=36

        '''改
        out_obs_list = [F.relu(self.fc_obs(x[:,:12])) , F.relu(self.fc_obs(x[:,12:24])) , F.relu(self.fc_obs(x[:,24:36]))]               
        # out_obs_list shape: [(batch_size, hidden_dim), ...] #即 batch_size * hidden_dim * agent_count
        out_action_list = [F.relu(self.fc_action(x[:,36:37])) , F.relu(self.fc_action(x[:,37:38])) , F.relu(self.fc_action(x[:,38:39]))]
        # out_action_list shape: [(batch_size, hidden_dim), ...]
        encoder_input = torch.cat(out_obs_list + [out_action_list[agent_id_index]], dim=1)
        # encoder_input shape: (batch_size, hidden_dim * (agent_count + 1))
        decoder_input = torch.cat(out_action_list[:agent_id_index] + out_action_list[agent_id_index+1:], dim=1)
        # decoder_input shape: (batch_size, hidden_dim * (agent_count - 1))
        '''

        # 所有智能体的动作对应列
        action_list = [x[:,36:37],x[:,37:38],x[:,38:39]]     
        encoder_input = self.fc1(torch.cat((x[:,:self.all_obs_dim],action_list[agent_id_index]),1)) #batch_size * 37 -> batch_size * hidden_dim
        decoder_input = self.fc2(torch.cat((action_list[:agent_id_index]+action_list[agent_id_index+1:]),1)) #batch_size * 2 -> batch_size * hidden_dim

        # 要满足 encoder_input shape: (batch_size, hidden_dim) decoder_input shape: (batch_size, hidden_dim) 
        contextual_vector = self.attention_modules(encoder_input, decoder_input)
        # contextual_vector shape: (batch_size, hidden_dim)
        t1 = F.relu(self.fc3(contextual_vector))

        qvalue = self.fc_qvalue(t1)
        # qvalue shape: (batch_size, 1)

        return qvalue

原maddpg代码

class MLPNetwork(nn.Module):
    def __init__(self, in_dim, out_dim,hidden_dim_1=256, hidden_dim_2=128,non_linear=nn.ReLU()):
        super(MLPNetwork, self).__init__()

        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim_1),
            non_linear,
            nn.Linear(hidden_dim_1, hidden_dim_2),
            non_linear,
            nn.Linear(hidden_dim_2, out_dim),
        ).apply(self.init) #apply(self.init)是在初始化模块的权重和偏置时调用init方法

    @staticmethod
    def init(m):
        """init parameter of the module"""
        gain = nn.init.calculate_gain('relu') #zh-cn:计算增益
        if isinstance(m, nn.Linear):
            torch.nn.init.xavier_uniform_(m.weight, gain=gain)#这行代码使用 Xavier 均匀分布初始化方法来初始化模块的权重(m.weight)。Xavier 初始化方法旨在使得网络各层的激活值和梯度的方差在传播过程中保持一致,有助于加速网络的收敛。gain 参数是根据 ReLU 激活函数的特性调整的。
            m.bias.data.fill_(0.01) #zh-cn:这行代码使用常数 0.01 来初始化模块的偏置(m.bias)。

    def forward(self, x):
        return self.net(x)

我是把上述代码替换为上述修改版,运行代码得到的结果展示。

效果确实是有的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1948446.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Maven的概念

1.什么是Maven 1.1.什么是Maven Maven是跨平台的项目管理工具&#xff0c;主要服务于基于Java平台的项目构建、依赖管理以及项目信息管理。 1.2.什么是理想的项目构建 高度自动化&#xff0c;标准化&#xff0c;跨平台&#xff0c;可重用的组件 1.3.什么是依赖&#xff0c…

wget下载github文件得到html文件

从github/gitee下载源文件&#xff0c;本来是22M下载下来只有11k 原因&#xff1a; Github会提供html页面&#xff0c;包括指定的文件、上下文与相关操作。通过wget或者curl下载时&#xff0c;会下载该页面 解决方式&#xff1a; github点击Code一栏的raw按钮&#xff0c;获得源…

HTTPS证书价格差异体现在哪?

HTTPS证书作为保障网站安全的重要工具&#xff0c;其类型、功能和费用差异成为用户选择时的关键考量因素。本文将深入探讨HTTPS证书的不同类型、功能以及费用差异&#xff0c;以帮助用户做出更合适的选择。 HTTPS证书的类型 HTTPS证书主要分为三种&#xff1a;DV&#xff08;D…

24证券从业考试报名『个人信息表』填写模板❗

24证券从业考试报名『个人信息表』填写模板❗ 1️⃣居住城市、通讯地址&#xff1a;写自己现居住的地址就可以。 2️⃣学历&#xff1a;需要注意的是学历填写的是考生已经取得的学历&#xff0c;在校大学生已经不具有报名资格&#xff0c;选择大专以上&#xff0c;或者是高中学…

【轨物方案】成套开关柜在线监测物联网解决方案

随着物联网技术的发展&#xff0c;电力设备状态监测技术也得到了迅速发展。传统的电力成套开关柜设备状态监测方法主要采用人工巡检和定期维护的方式&#xff0c;这种方法不仅效率低下&#xff0c;而且难以保证设备的实时性和安全性。因此&#xff0c;基于物联网技术的成套开关…

ARM架构(二)—— arm v7-a/v8/v9寄存器介绍

1、ARM v7-A寄存器 1.1 通用寄存器 V7 V8开始 FIQ个IRQ优先级一样&#xff0c; 通用寄存器&#xff1a;31个 1.2 程序状态寄存器 CPSR是程序状态毒存器&#xff0c;保存条件标志位&#xff0c;中断禁止位&#xff0c;当前处理器模式等控制和状态位。每种异常模式下还存在SPS…

MySQL之索引及简单运用

索引&#xff1a; 什么是索引 索引是数据库中一种非常重要的数据结构&#xff0c;用于帮助快速查询数据库表中的数据。它就像一本书的目录&#xff0c;能够让你快速定位到书中的某个具体章节或内容&#xff0c;而不需要一页一页地翻阅整本书。 在数据库管理系统中&#xff0c;…

Servlet 3.0的新特征

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhlServlet 3.0概述 Servlet 3.0规范是在2009年随着Java EE 6的发布而推出的。它引入了一系列新特性和改进,旨在简化Web应用的开发和部署过程,并提高Web应用的性能和可扩展性。Servlet 3.0的发布标…

C++ | Leetcode C++题解之第279题完全平方数

题目&#xff1a; 题解&#xff1a; class Solution { public:// 判断是否为完全平方数bool isPerfectSquare(int x) {int y sqrt(x);return y * y x;}// 判断是否能表示为 4^k*(8m7)bool checkAnswer4(int x) {while (x % 4 0) {x / 4;}return x % 8 7;}int numSquares(i…

河南萌新联赛2024第(二)场---G. lxy的通风报信

题目描述 在 nnn mmm 的平面星球里&#xff08; 5<n5 < n5<n&#xff0c; m<1000m < 1000m<1000 &#xff09;&#xff0c;存在着 aaa 支军队 ( 2<a<502 < a < 502<a<50 ) , 和 bbb 支驻扎在地图中的敌军 ( 0<b<n∗m−a0 < b &…

Python写UI自动化--playwright(通过UI文本匹配实现定位)

本篇简单拓展一下元素定位技巧&#xff0c;通过UI界面的文本去实现定位 目录 匹配XPath 匹配文本元素 .count()统计匹配数量 处理匹配文本返回多个元素 1、使用.nth(index)选择特定元素: 2、获取所有匹配的元素并遍历: 3、错误处理: 匹配XPath 比如我们要定位到下图的…

SpringBoot 项目配置文件注释乱码的问题解决方案

一、问题描述 在项目的配置文件中&#xff0c;我们写了一些注释&#xff0c;如下所示&#xff1a; 但是再次打开注释会变成乱码&#xff0c;如下所示&#xff1a; 那么如何解决呢&#xff1f; 二、解决方案 1. 点击” File→Setting" 2. 搜索“File Encodings”, 将框…

JCR一区级 | Matlab实现TTAO-Transformer-LSTM多变量回归预测

JCR一区级 | Matlab实现TTAO-Transformer-LSTM多变量回归预测 目录 JCR一区级 | Matlab实现TTAO-Transformer-LSTM多变量回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.【JCR一区级】Matlab实现TTAO-Transformer-LSTM多变量回归预测&#xff0c;三角拓扑聚合…

【C++进阶学习】第八弹——红黑树的原理与实现——探讨树形结构存储的最优解

二叉搜索树&#xff1a;【C进阶学习】第五弹——二叉搜索树——二叉树进阶及set和map的铺垫-CSDN博客 AVL树&#xff1a; ​​​​​​【C进阶学习】第七弹——AVL树——树形结构存储数据的经典模块-CSDN博客 前言&#xff1a; 在前面&#xff0c;我们已经学习了二叉搜索树和…

Sentinel限流规则详解

上一期教程讲解了 Sentinel 的快速入门&#xff1a;Sentinel快速入门&#xff0c;这一期主要讲述 Sentinel 的限流规则 簇点链路 簇点链路就是项目内的调用链路&#xff08;Controller -> Service -> Mapper&#xff09;&#xff0c;链路中被监控的每个接口就是一个资源…

论文阅读【检测】:商汤 ICLR2021 | Deformable DETR

文章目录 论文地址AbstractMotivation技术细节多尺度backbone特征MSDeformAttention 小结 论文地址 Deformable DETR 推荐视频&#xff1a;bilibili Abstract DETR消除对目标检测中许多手工设计的组件的需求&#xff0c;同时表现出良好的性能。然而&#xff0c;由于Transfor…

微服务案例搭建

案例搭建 使⽤微服务架构的分布式系统,微服务之间通过⽹络通信。我们通过服务提供者与服务消费者来描述微服 务间的调⽤关系。 服务提供者&#xff1a;服务的被调⽤⽅&#xff0c;提供调⽤接⼝的⼀⽅ 服务消费者&#xff1a;服务的调⽤⽅&#xff0c;依赖于其他服务的⼀⽅ 我…

pyenv-win | python版本管理,无需卸载当前版本

系统&#xff1a;windows&#xff0c;且已安装git。 使用 pyenv-win 在Windows中管理多个python版本&#xff0c;而无需卸载当前版本。安装步骤如下&#xff1a; 安装 pyenv-win 1. 安装 Git 和 pyenv-win: git clone https://github.com/pyenv-win/pyenv-win.git %USERPRO…

河南萌新联赛2024第(二)场:南阳理工学院

河南萌新联赛2024第&#xff08;二&#xff09;场&#xff1a;南阳理工学院 2024.7.24 13:00————15:00 过题数5/11 补题数6/11 国际旅行Ⅰ 国际旅行Ⅰ 小w和大W的决斗。 A*BBBB “好”字符 水灵灵的小学弟 lxy的通风报信 狼狼的备忘录 重生之zbk要拿回属于他的一切 这是签…

THS配置keepalive(yjm)

启动完THS管理控制台和THS后&#xff0c;登录控制台&#xff0c;进入实例管理》节点管理&#xff0c;可以分别使用界面配置和编辑配置设置长连接。 1、界面配置 点击界面配置》集群设置&#xff0c;启用长连接&#xff0c;设置长连接数、最大请求数和超时时间。 2、编辑配置 …