ActionNet
类是一个表示代理策略的神经网络模型。该模型使用多个图神经网络层来处理输入数据,主要用于强化学习或图数据中的任务。在前向传播过程中,模型会对输入特征和边的属性进行一系列的图卷积操作,并逐层对输出进行激活和 Dropout 操作,最终返回一个预测结果。
from models.action import ActionNet
import torch.nn as nn
from torch_geometric.typing import Adj, OptTensor
from torch import Tensor
from helpers.classes import ActionNetArgs
class ActionNet(nn.Module):
def __init__(self, action_args: ActionNetArgs):
"""
Create a model which represents the agent's policy.
"""
super().__init__()
self.num_layers = action_args.num_layers
self.net = action_args.load_net()
self.dropout = nn.Dropout(action_args.dropout)
self.act = action_args.act_type.get()
def forward(self, x: Tensor, edge_index: Adj, env_edge_attr: OptTensor, act_edge_attr: OptTensor) -> Tensor:
edge_attrs = [env_edge_attr] + (self.num_layers - 1) * [act_edge_attr]
for idx, (edge_attr, layer) in enumerate(zip(edge_attrs[:-1], self.net[:-1])):
x = layer(x=x, edge_index=edge_index, edge_attr=edge_attr)
x = self.dropout(x)
x =