Progressive Feature Fusion Framework Based on Graph Convolutional Network

news2024/11/28 8:23:10

以Resnet50作为主干网络,然后使用GCN逐层聚合多级特征,逐级聚合这种模型架构早已不新鲜,这篇文章使用GCN的方式对特征进行聚合,没有代码。这篇文章没有过多的介绍如何构造的节点特征和邻接矩阵,我觉得对于图卷积来说,最重要的一点就是确定那些特征作为图节点以及节点直接的连接关系。

很多方法是直接将特征图的每个像素作为一个节点,那这样的话怎么确定每个像素之间的连接关系呢?

对于邻接矩阵来说,两个节点相连置为一,两个节点不相连置为零,通过将节点矩阵和邻接矩阵进行相乘来进行节点之间的信息交互。这种交互是只要两个节点之间相连就将两个节点的特征值进行相加。

这种直接相加的方式忽略了节点与节点之间的重要程度,可以使用图注意力来给图的节点与节点之间施加一个权重,这个权重可以通过自注意力的方式得到,也可以通过图注意力网络中的计算方式得到节点与节点之间的权重关系。图注意力网络的代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
 
 
def get_weights(size, gain=1.414):
    weights = nn.Parameter(torch.zeros(size=size))
    nn.init.xavier_uniform_(weights, gain=gain)
    return weights
 
class GraphAttentionLayer(nn.Module):
    '''
    Simple GAT layer 图注意力层 (inductive graph)
    '''
    def __init__(self, in_features, out_features, dropout, alpha, concat = True, head_id = 0):
        ''' One head GAT '''
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features  #节点表示向量的输入特征维度
        self.out_features = out_features    #节点表示向量的输出特征维度
        self.dropout = dropout  #dropout参数
        self.alpha = alpha  #leakyrelu激活的参数
        self.concat = concat    #如果为true,再进行elu激活
        self.head_id = head_id  #表示多头注意力的编号
 
        self.W_type = nn.ParameterList()
        self.a_type = nn.ParameterList()
        self.n_type = 1 #表示边的种类
        for i in range(self.n_type):
            self.W_type.append(get_weights((in_features, out_features)))
            self.a_type.append(get_weights((out_features * 2, 1)))
 
        #定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.zeros(size = (in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain = 1.414)  #xavier初始化
        self.a = nn.Parameter(torch.zeros(size = (2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain = 1.414)  #xavier初始化
 
        #定义dropout函数防止过拟合
        self.dropout_attn = nn.Dropout(self.dropout)
        #定义leakyrelu激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
 
    def forward(self, node_input, adj, node_mask = None):
        '''
        node_input: [batch_size, node_num, feature_size] feature_size 表示节点的输入特征向量维度
        adj: [batch_size, node_num, node_num] 图的邻接矩阵
        node_mask:  [batch_size, node_mask]
        '''
 
        zero_vec = torch.zeros_like(adj)
        scores = torch.zeros_like(adj)
 
        for i in range(self.n_type):
            h = torch.matmul(node_input, self.W_type[i])
            h = self.dropout_attn(h)
            N, E, d = h.shape   # N == batch_size, E == node_num, d == feature_size
 
            a_input = torch.cat([h.repeat(1, 1, E).view(N, E * E, -1), h.repeat(1, E, 1)], dim = -1)
            a_input = a_input.view(-1, E, E, 2 * d)     #([batch_size, E, E, out_features])
 
            score = self.leakyrelu(torch.matmul(a_input, self.a_type[i]).squeeze(-1))   #([batch_size, E, E, 1]) => ([batch_size, E, E])
            #图注意力相关系数(未归一化)
 
            zero_vec = zero_vec.to(score.dtype)
            scores = scores.to(score.dtype)
            scores += torch.where(adj == i+1, score, zero_vec.to(score.dtype))
 
        zero_vec = -1*30 * torch.ones_like(scores)  #将没有连接的边置为负无穷
        attention = torch.where(adj > 0, scores, zero_vec.to(scores.dtype))    #([batch_size, E, E])
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,则该位置的注意力系数保留;否则需要mask并置为非常小的值,softmax的时候最小值不会被考虑
 
        if node_mask is not None:
            node_mask = node_mask.unsqueeze(-1)
            h = h * node_mask   #对结点进行mask
 
        attention = F.softmax(attention, dim = 2)   #[batch_size, E, E], softmax之后形状保持不变,得到归一化的注意力权重
        h = attention.unsqueeze(3) * h.unsqueeze(2) #[batch_size, E, E, d]
        h_prime = torch.sum(h, dim = 1)             #[batch_size, E, d]
 
        # h_prime = torch.matmul(attention, h)    #[batch_size, E, E] * [batch_size, E, d] => [batch_size, N, d]
 
        #得到由周围节点通过注意力权重进行更新的表示
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
 
class GAT(nn.Module):
    def __init__(self, in_dim, hid_dim, dropout, alpha, n_heads, concat = True):
        '''
        Dense version of GAT
        in_dim输入表示的特征维度、hid_dim输出表示的特征维度
        n_heads 表示有几个GAL层,最后进行拼接在一起,类似于self-attention从不同的子空间进行抽取特征
        '''
        super(GAT, self).__init__()
        assert hid_dim % n_heads == 0
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat
 
        self.attn_funcs = nn.ModuleList()
        for i in range(n_heads):
            self.attn_funcs.append(
                #定义multi-head的图注意力层
                GraphAttentionLayer(in_features = in_dim, out_features = hid_dim // n_heads,
                                    dropout = dropout, alpha = alpha, concat = concat, head_id = i)
            )
 
        self.dropout = nn.Dropout(self.dropout)
 
    def forward(self, node_input, adj, node_mask = None):
        '''
        node_input: [batch_size, node_num, feature_size]    输入图中结点的特征
        adj:    [batch_size, node_num, node_num]    图邻接矩阵
        node_mask:  [batch_size, node_num]  表示输入节点是否被mask
        '''
        hidden_list = []
        for attn in self.attn_funcs:
            h = attn(node_input, adj, node_mask = node_mask)
            hidden_list.append(h)
 
        h = torch.cat(hidden_list, dim = -1)
        h = self.dropout(h) #dropout函数防止过拟合
        x = F.elu(h)     #激活函数
        return x
 
 
#特征矩阵
x = torch.randn((2, 4, 8))
#邻接矩阵
adj = torch.tensor([[[0, 1, 0, 1],
                    [1, 0, 1, 0],
                    [0, 1, 0, 1],
                    [1, 0, 1, 0]]])
adj = adj.repeat(2, 1, 1)
#mask矩阵
node_mask = torch.Tensor([[1, 0, 0, 1],
                          [0, 1, 1, 1]])
 
 
gat_layer = GraphAttentionLayer(in_features = 8, out_features = 8, dropout = 0.1, alpha = 0.2, concat = True)  #输入特征维度8, 输出特征维度8, 使用多头注意力机制
gat_ = GAT(in_dim = 8, hid_dim = 8, dropout = 0.1, alpha = 0.2, n_heads = 2, concat = True)    #输入特征维度8, 输出特征维度8, 使用多头注意力机制
 
output_ = gat_(x, adj, node_mask)
print(output_.shape)  
 
output_ = gat_(x, adj, node_mask)
print(output_.shape)
 
 
#输出:
torch.Size([2, 4, 8])
torch.Size([2, 4, 8])

自注意力和图注意力在计算节点之间权重的方式稍有不同,在自注意力的计算方式中之进行了矩阵相乘并没有可训练的参数。在图注意力计算节点之间权重时,采用了线性映射的方式,这两种权重计算方式那个更好一点还要通过实验来进行验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1802626.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Swift 序列(Sequence)排序面面俱到 - 从过去到现在(三)

概述 在上一篇 Swift 序列(Sequence)排序面面俱到 - 从过去到现在(二) 博文中,我们介绍了如何构建一个自定义类型中“多属性”排序的通用实现。 而在本课中我们将再接再厉介绍 iOS 15+ 中新的排序机制,并简要剖析就地排序(In-place sorting)对运行性能有着怎样的显著影…

数据挖掘--数据预处理

数据挖掘--引论 数据挖掘--认识数据 数据挖掘--数据预处理 数据挖掘--数据仓库与联机分析处理 数据挖掘--挖掘频繁模式、关联和相关性:基本概念和方法 数据挖掘--分类 数据挖掘--聚类分析:基本概念和方法 数据清理 缺失值 如果数据集含有分类属性…

阿里云(域名解析) certbot 证书配置

1、安装 certbot ubuntu 系统: sudo apt install certbot 2、申请certbot 域名证书,如申请二级域名aa.example.com 的ssl证书,同时需要让 bb.aa.example.com 也可以使用此证书 1、命令:sudo certbot certonly -d “域名” -d “…

聪明人社交的基本顺序:千万别搞反了,越早明白越好

聪明人社交的基本顺序:千万别搞反了,越早明白越好 国学文化 德鲁克博雅管理 2024-03-27 17:00 作者:方小格 来源:国学文化(gxwh001) 导语 比一个好的圈子更重要的,是自己优质的能力。 唐诗宋…

c++编译器在什么情况下会提供类的默认构造函数等,与析构函数

我们都知道,在 c 里,编写的简单类,若没有自己编写构造析构函数与 copy 构造函数 与 赋值运算符函数,那么编译器会提供这些函数,并实现简单的语义,比如成员赋值。看 源码时,出现了下图类似的情形…

《C++避坑神器·二十七》VS中release打断点方法,#undef作用

1、release打断点方式 2、#undef作用 #undef指令用于”取消“已定义的#define指令 案例:

小主机折腾记24

好久不更新,最近折腾的事如下 1.10块钱自提买了个半高机箱,15086140,把之前拆机的H61m-A/M32AA/DP_MB与200w航嘉电源装了进去,额外买了半高pcie转接了个m2位,江波龙64g安装了win10专业版,最后卖了176块钱&a…

连山露【诗词】

连山露 雾隐黄山路,十步一松树。 树上惊松鼠,松子衔木屋。 松子青嫩芽,尖尖头探出。 卷挂白露珠,装映黄山雾。

UML实战-BUG管理系统

概述 根据 UML建模的过程来进行一个完整系统的设计–Bug 管理系统。下面是一个标注 UML 设计过程的参考。 需求分析:用例图。系统分析:分析业务规则–状态图。系统分析:分析业务流程–活动图。系统设计:设计静态结构–类图和包图。系统设计:Action类被调用关系–序列图。…

检测五个数是否一样的算法

目录 算法算法的输出与打印效果输出输入1输入2 打印打印1打印2 算法的流程图总结 算法 int main() {int arr[5] { 0 };int i 0;int ia 0;for (i 0; i < 5; i) { scanf("%d", &arr[i]); }for (i 1; i < 5; i) {if (arr[0] ! arr[i]) {ia 1;break;} }…

Linux-常用命令-常用设置

1.帮助类命令 1.man命令-获得帮助信息 man [命令或配置文件]例&#xff1a;查看ls命令的帮助信息 man ls输入 ZZ 退出帮助2.服务管理类命令 1.centos7语法 1.1 临时开关服务命令 开启服务&#xff1a; systemctl start 服务名 关闭服务&#xff1a; systemctl stop 服务…

Javaweb---HTTPS

题记 为了保护数据的隐私性我们引入了HTTPS 加密的方式都有那些呢? 1.对称加密: 加密和解密使用的密钥是同一个密钥 2.非对称加密:有两个密钥(一对),分为公钥和私钥(公钥是公开的,私钥是要藏好的) HTTPS的工作过程(旨在对body和header进行加密) 1.对称加密 上述引出的…

两张图片进行分析

两张图片进行分析&#xff0c;可以拖动左边图片进行放大、缩小查看图片差异 底图 <template><div class"box_container"><section><div class"" v-for"item in imgData.imgDataVal" :key"item.id"><img :s…

Kafka监控系统efak的安装

下载地址Kafka Eaglehttp://download.kafka-eagle.org/下载地址连接不稳定&#xff0c;可以多次尝试直到成功连接下载 1.解压安装包并重命名 tar -zxvf kafka-eagle-bin-3.0.1.tar.gz 查看到解压后包含一个安装包&#xff0c;再解压 tar -zxvf efak-web-3.0.1-bin.tar.gz 移…

小程序简单版录音机

先来看看效果 结构 先来看看页面结构 <!-- wxml --><view class"wx-container"><view id"title">录音机</view><view id"time">{{hours}}:{{minute}}:{{second}}</view><view class"btngroup"…

【JavaSE】面向对象---多态

前言 本篇以Java初学者视角写下&#xff0c;难免有不足&#xff0c;或者术语不严谨之处。如有错误&#xff0c;欢迎评论区指正。本篇说明多态相关的知识。若本文无法解决您的问题&#xff0c;可以去最下方的参考文献出&#xff0c;找出想要的答案。 多态概念 多态&#xff08…

【Ardiuno】实验使用ESP32连接Wifi(图文)

ESP32最为精华和有特色的地方当然是wifi连接&#xff0c;这里我们就写程序实验一下适使用ESP32主板连接wifi&#xff0c;为了简化实验我们这里只做了连接部分&#xff0c;其他实验在后续再继续。 由于本实验只要在串口监视器中查看结果状态即可&#xff0c;因此电路板上无需连…

最短路径——迪杰斯特拉与弗洛伊德算法

一.迪杰斯特拉算法 首先对于最短路径来说&#xff1a;从vi-vj的最短路径&#xff0c;不用非要经过所有的顶点&#xff0c;只需要找到路径最短的路径即可&#xff1b; 那么迪杰斯特拉的算法&#xff1a;其实也就与最小生成树的思想类似&#xff0c;找到较小的&#xff0c;然后…

在网上赚钱,可以自由掌控时间,灵活的兼职副业选择

朋友们看着周围的人在网上赚钱&#xff0c;自己也会为之心动&#xff0c;随着电子设备的普及&#xff0c;带动了很多的工作、创业以及兼职副业选择的机会&#xff0c;作为普通人的我们&#xff0c;如果厌倦了世俗的朝九晚五&#xff0c;想着改变一下自己的生活&#xff0c;可以…

STM32 printf 重定向到CAN

最近在调试一款电机驱动板 使用的是CAN总线而且板子上只有一个CAN 想移植Easylogger到上面试试easylogger的效果&#xff0c;先实现pritnf的重定向功能来打印输出 只需要添加以下代码即可实现 代码 #include <stdarg.h> uint8_t FDCAN_UserTxBuffer[512]; void FDCAN_p…