[图神经网络]空间关系感知关系网络(SGRN)-代码解析

news2024/10/6 12:27:40

!!!这篇不涉及实现,仅从官方代码了解一下输出处理的思路,有机会的话会做实现,照例放出官方代码地址和之前写的论文解读:

SGRN网络github项目地址icon-default.png?t=N2N8https://github.com/simblah/SGRN_torch[图神经网络]空间关系感知关系网络(SGRN)-论文解读icon-default.png?t=N2N8https://blog.csdn.net/weixin_37878740/article/details/129837774?spm=1001.2014.3001.5501

一、网络框架回顾

         可以很明显的看到,论文提及的SGRN网络是在Faster R-CNN中嵌入了关系图学习器空间感知推理模块,且将RoI Pooling改为RoI Align,故我们重点来看这几个部分的代码。

二、代码解读

        所解读的代码位于项目的:lib->nets->network_gcn中。前向传递函数的结构如下:

    def forward(self, image, im_info, gt_boxes=None, mode='TRAIN'):
        #.....
        rois, cls_prob, bbox_pred = self._predict()
        #....

        前向传递函数中其余代码均为数据预处理和损失函数计算,其次rois,cls_prob和bbox_pred均由self._predict( )函数计算得出。故跳转到_predict( )函数

# This is just _build_network in tf-faster-rcnn
net_conv = self._image_to_head()          #骨干网络提取

# build the anchors for the image
self._anchor_component(net_conv.size(2), net_conv.size(3))

rois = self._region_proposal(net_conv)    #获取候选区域
if cfg.POOLING_MODE == 'align':           #兴趣域池化
    pool5 = self._roi_align_layer(net_conv, rois)
else:
    pool5 = self._roi_pool_layer(net_conv, rois)

fc7 = self._head_to_tail(pool5)

cls_prob, bbox_pred = self._region_classification(fc7)

         这是原属于Faster R-CNN的代码,修改了RoI Pooling,对应图示的这个部分:

        1.RoI Align

def _roi_align_layer(self, bottom, rois):
    return RoIAlign((cfg.POOLING_SIZE, cfg.POOLING_SIZE), 1.0 / 16.0,0)(bottom, rois)
def _roi_pool_layer(self, bottom, rois):
    return RoIPool((cfg.POOLING_SIZE, cfg.POOLING_SIZE),1.0 / 16.0)(bottom, rois)

        相较于RoI Pool,RoI Align最后多了一个参数,作用是控制双线性插值中采样点的个数,默认值为-1,代码中将其置为0。

        2.关系学习器

        用于从权重中构建图(邻接矩阵),对应图示这个部分:

num_rois = rois.shape[0]    #通道数
z = self.relation_fc_1(fc7)
z = F.relu(self.relation_fc_2(z))
eps = torch.mm(z, z.t())
_, indices = torch.topk(eps, k=32, dim=0)

        其中relation_fc_1()relation_fc_2()均为全连接层,两个全连接层叠加后跟一个relu激活函数,fc7(即建议区域)经过模块处理后得到长度为256的序列z

self.relation_fc_1 = nn.Linear(self._fc7_channels, 256)
self.relation_fc_2 = nn.Linear(256, 256)

        随后经过torch.mm( )<用途是将两个矩阵相乘>,将z与其转置相乘得到邻接矩阵eps

        邻接矩阵eps经过torch.topk( ),torch.topk的作用是找到序列中前k个元素进行排序,其返回值有两个:第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。这个函数的用途是用来构建稀疏图(对应原文取最大的3个值)

        排序规则为:该列最大的数据,以此类推:

tensor1=torch.tensor([  [9,1,2,1,9,1],
                        [3,4,5,1,1,1],
                        [7,8,9,1,1,1],
                        [1,4,7,1,1,2]])

values,indices=torch.topk(tensor1, k=3, dim=0)

print(values)
print(indices)

        实验结果为,可以看到每一排最大的k个数据被提到了前面。

tensor([[9, 8, 9, 1, 9, 2],
        [7, 4, 7, 1, 1, 1],
        [3, 4, 5, 1, 1, 1]])

tensor([[0, 2, 2, 0, 0, 3],
        [2, 1, 3, 1, 1, 0],
        [1, 3, 1, 2, 2, 1]])

        3.空间感知推理模块

                ①得到嵌入嵌入向量

                用于将图中的特征嵌入到图(邻接矩阵)中,对应图示这个部分:

                具体结构如下,由三个分支组成(图/邻接矩阵分类权重预测框,分别代表:连接关系,类别关系,位置关系)

cls_w = self.cls_score_net.weight
represent = torch.mm(cls_prob, cls_w)

                将cls_prob(分类器得出的类型预测权重)和cls_w(fc7经过一个线性分类器)相乘

self.cls_score_net = nn.Linear(self._fc7_channels, self._num_classes)

                 ②得到距离函数

cls_pred = torch.max(cls_prob, 1)[1]
bbox_pred_reshape = bbox_pred.view(-1, 1001, 4)
bbox_pred_cls = torch.zeros(num_rois, 4)
for i, cls in enumerate(cls_pred):
    bbox_pred_cls[i] = bbox_pred_reshape[i][cls]
bbox_pred_ctr = bbox_pred_cls[:, 0:2] + bbox_pred_cls[:, 2:4]

        torch.max()[1]返回的是最大值的索引,实验如下:

x = torch.max(tensor1,1)[1]    #tensor1同上例
print(x)
tensor([0, 2, 2, 2])    #得到每一列最大值的索引

        bbox_pred是由关系分类器处理fc7后得到的目标框预测,使用view对其进行重构,然后遍历cls_pred,取出每列最大的元素(四个坐标);再按照公式(如下)进行计算

                d=\sqrt{(c_a-c_b)^2+(y_a-y_b)^2} ,\theta=arctan(\frac{y_b-y_a}{c_b-c_a})

relation = torch.empty(2, 32 * num_rois, dtype=torch.long).to(self._device)
# U = torch.empty(32*128, 2).to(self._device)

relation[0] = torch.Tensor(list(range(num_rois)) * 32)  # , type=torch.long)
relation[1] = indices.view(-1)

coord_i = bbox_pred_ctr[relation[0]]
coord_j = bbox_pred_ctr[relation[1]]

d = torch.sqrt((coord_i[:, 0] - coord_j[:, 0]) ** 2 + (coord_i[:, 1] - coord_j[:, 1]) ** 2)
#距离
theta = torch.atan2((coord_j[:, 1] - coord_i[:, 1]), (coord_j[:, 0] - coord_i[:, 0]))
#角度

U = torch.stack([d, theta], dim=1).to(self._device)
#位置嵌入

                ③图卷积

                图卷积的定义为,参数分别为:输入图尺寸,输出图尺寸,深度,卷积核尺寸

self.gaussian = GMMConv(self._fc7_channels, self._fc7_channels, dim=2, kernel_size = 25)

                使用此图卷积处理: 图/邻接矩阵分类权重预测框(三个输入实际分别为:图(graph)、特征(feat )、伪坐标(pseudo),其中,特征/feat的尺寸为(N,D_{in})<分别表示:节点数量,输入特征的尺寸>;伪坐标/pseudo的尺寸为(E,D_u)<分别表示:边的条数,伪坐标的维数>)

f = self.gaussian(represent, relation, U)

                使用全连接和激活函数处理图嵌入:

f2 = F.relu(self.sg_conv_1(f))
h = F.relu(self.sg_conv_2(f2))

                其中2个全连接sg_conv为:

self.sg_conv_1 = nn.Linear(self._fc7_channels, 512)
self.sg_conv_2 = nn.Linear(512, 256)

        4.将图与候选区域融合

                将图得到的嵌入向量和候选区域进行融合后得到新的预测框和分类数据

new_f = torch.cat([fc7, h], dim=1)
new_cls_prob, new_bbox_pred = self._new_region_classification(new_f)

for k in self._predictions.keys():
    self._score_summaries[k] = self._predictions[k]

return rois, new_cls_prob, new_bbox_pred

                上面所用到的_new_region_classification( )的实现为:

def _new_region_classification(self, f):
    cls_score = self.new_cls_score_net(f)
    cls_pred = torch.max(cls_score, 1)[1]
    cls_prob = F.softmax(cls_score, dim=1)
    bbox_pred = self.new_bbox_pred_net(f)

    self._predictions["cls_score"] = cls_score
    self._predictions["cls_pred"] = cls_pred
    self._predictions["cls_prob"] = cls_prob
    self._predictions["bbox_pred"] = bbox_pred

    return cls_prob, bbox_pred

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/410951.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用三个云服务器,搭建MongoDB副本集模式(主从模式)

1. 下载安装mongoDB 首先我们需要在三台服务器上分别下载和安装mongoDB。 1.1. 打开服务器&#xff0c;创建目录 创建目录结构如下图所示&#xff1a;&#xff08;日志文件会自动创建&#xff09; 1.2. 下载mongoDB压缩包 把压缩包下载到指定目录&#xff08;便于后期维护…

ChatGPT大规模封号+停止注册?最火概念会凉吗?

一、背景 这个周末&#xff0c;先是意大利暂时封杀ChatGPT&#xff0c;限制OpenAI处理本国用户信息。 接着&#xff0c;据韩国媒体报道&#xff0c;三星导入ChatGPT不到20天&#xff0c;便曝出机密资料外泄&#xff08;涉及半导体设备测量资料、产品良率等内容&#xff0c;已…

微信小程序 | 秋招颗粒无收 ?快用ChatGPT做一款模拟面试小程序

Pre&#xff1a;效果预览 ① 选择职位进行面试 ② 根据岗位职责进行回答 一、需求背景 这两年IT互联网行业进入寒冬期&#xff0c;降本增效、互联网毕业、暂停校招岗位的招聘&#xff0c;各类裁员、缩招的情况层出不穷&#xff01;对于这个市场来说&#xff0c;在经历了互联网…

阿里云的客服 锻炼你心性的 一种方式 !!!

阿里云的产品&#xff0c;非常棒&#xff0c;开发的同学非常棒&#xff0c;专家们更棒&#xff0c;但&#xff0c;一切的开始就怕一个但字&#xff0c;但我还的说&#xff0c;但&#xff0c;阿里云的客服&#xff0c;OMG &#xff0c;我已经忍耐了 1年了&#xff0c;是在忍不住…

手麻系统源码,手术麻醉管理系统源码,维护方便,功能强大

手术麻醉管理系统源码&#xff0c;手麻系统源码&#xff0c;C# .net 桌面软件 C/S版 文末获取联系&#xff01; 手术麻醉管理系统采用下拉式汉化菜单&#xff0c;界面友好&#xff0c;实用性强&#xff0c;设有与住院、病区、药房等系统的软件接口。 开发语言&#xff1a;C# …

代码随想录算法训练营第五十三天 | 1143. 最长公共子序列、1035. 不相交的线、53. 最大子数组和

1143. 最长公共子序列 动规五部曲 1、确定dp数组&#xff08;dp table&#xff09;以及下标的含义 dp[i][j]&#xff1a;长度为[0, i - 1]的字符串text1与长度为[0, j - 1]的字符串text2的最长公共子序列为dp[i][j] 2、确定递推公式 主要就是两大情况&#xff1a; text1[i…

vue+ts+vite+pinia+element plus+i18n国际化

第一步&#xff0c;安装vue-i18n&#xff08;我这里版本是9.2.2&#xff09; npm install vue-i18n element-plus --save 第二步&#xff0c;src文件夹下创建language文件夹&#xff0c;目录如下 第三步&#xff0c;定义本地中文英文 en.ts // en.ts export default {message…

UE DTCmd 插件说明

Exec CMD Exec CMD (Have Process) 在蓝图非阻塞的运行Windows命令行并输出返回值&#xff0c;而且可以时时监听输出内容。 可以直接运行某个程序&#xff08;输入程序完整路径&#xff09; 可以直接运行bat脚本&#xff0c;并在bat脚本里面运行你任何想做的操作。 Cmd : 需要…

花30分钟,我用ChatGPT写了一篇2000字文章(内附实操过程)

有了ChatGPT之后&#xff0c;于我来说&#xff0c;有两个十分明显的变化&#xff1a; 1. 人变的更懒 因为生活、工作中遇到大大小小的事情&#xff0c;都可以直接找ChatGPT来寻求答案。 2. 工作产出量更大 之前花一天&#xff0c;甚至更久才能写一篇原创内容&#xff0c;现…

【MySQL--04】数据类型

文章目录1.数据类型1.1数据类型分类1.2数值类型1.2.1tinyint类型1.2.2bit类型1.2.3小数类型1.2.3.1 float1.2.3.2 decimal1.3字符串类型1.3.1 char1.3.2 varchar1.3.3char和varchar的比较1.4日期和时间类型1.5 enum和set1.5.1 enum1.5.2 set1.5.3 示例1.数据类型 1.1数据类型分…

试题E:蜂巢 ——蓝桥杯第十三届省赛Java 大学A组

试题E&#xff1a;蜂巢 解析 很明显的一道坐标计算问题&#xff0c;只是通过看似比较复杂的描述而已。 题目定义了一种行走方向&#xff0c;大概就是一共六种行走方向&#xff0c;如果以o为原点&#xff0c;建立坐标系&#xff0c;那么方向0和3就是x轴。其他方向为分力即可&am…

【微信小程序】免费的高德地图api——获取天气(全过程)

介绍 这里是小编成长之路的历程&#xff0c;也是小编的学习之路。希望和各位大佬们一起成长&#xff01; 以下为小编最喜欢的两句话&#xff1a; 要有最朴素的生活和最遥远的梦想&#xff0c;即使明天天寒地冻&#xff0c;山高水远&#xff0c;路远马亡。 一个人为什么要努力&a…

硬件工程师需要掌握的PCB设计常用知识点

一个优秀的硬件工程师设计的产品一定是既满足设计需求又满足生产工艺的&#xff0c;某个方面有瑕疵都不能算是一次完美的产品设计。规范产品的电路设计&#xff0c;工艺设计&#xff0c;PCB设计的相关工艺参数&#xff0c;使得生产出来的实物产品满足可生产性、可测试性、可维修…

Windows 安装 Go1.20.3 顺便了解 go env 环境变量

文章目录1.下载与安装2.GOROOT3.Go 的包管理3.1 GOPATH 模式3.2 Go Modules 模式4.GOPATH5.GO111MODULE6.GOPROXY7.GOSUMDB8.GONOPROXY/GONOSUMDB/GOPRIVATE9.GOMODCACHE10.GOCACHE11.GOENV12.GOBIN13.参考资料1.下载与安装 参考文章&#xff1a;Golang V1.19.1 安装配置 (win…

Vue3带来了什么

目录性能方面的优化更好的TypeScript集成用于处理大规模用例的新API分层内部模块CompositionAPI更多RFC提供的两个新功能proxy代替defineProperty双向绑定性能方面的优化 首先是相对Vue2的一些性能改进&#xff1a; 通过摇树(减轻了多达41&#xff05;的资源大小)初始渲染&am…

Hadoop安装Hbase启动失败报错解决方法

先进入hbase文件目录里看日志文件看看报什么错再具体解决&#xff1a; vim /opt/module/hbase-1.3.3/logs/hbase-root-master-hadoop-single.log 1.报错org.apache.hadoop.security.AccessControlException: Permission denied: user异常解决方法 1、第一种 在hdfs的配置文件…

3.2 二维随机变量的边缘分布

思维导图&#xff1a; 学习目标&#xff1a; 要学习二维随机变量的边缘分布&#xff0c;我可能会按照以下步骤进行学习&#xff1a; 理解二维随机变量的概念和表示方法&#xff0c;包括联合分布函数和联合分布律等概念。理解二维随机变量的边缘分布的概念和意义&#xff0c;即…

2023年4月份北京/广州/深圳DAMA-CDGP数据治理专家证书收益

DAMA认证为数据管理专业人士提供职业目标晋升规划&#xff0c;彰显了职业发展里程碑及发展阶梯定义&#xff0c;帮助数据管理从业人士获得企业数字化转型战略下的必备职业能力&#xff0c;促进开展工作实践应用及实际问题解决&#xff0c;形成企业所需的新数字经济下的核心职业…

Mysql的学习与巩固:一条SQL查询语句是如何执行的?

前提 我们经常说&#xff0c;看一个事儿千万不要直接陷入细节里&#xff0c;你应该先鸟瞰其全貌&#xff0c;这样能够帮助你从高维度理解问题。同样&#xff0c;对于MySQL的学习也是这样。平时我们使用数据库&#xff0c;看到的通常都是一个整体。比如&#xff0c;你有个最简单…

【华为机试真题详解JAVA实现】—MP3光标位置

目录 一、题目描述 二、解题代码 一、题目描述 MP3 Player因为屏幕较小,显示歌曲列表的时候每屏只能显示几首歌曲,用户要通过上下键才能浏览所有的歌曲。为了简化处理,假设每屏只能显示4首歌曲,光标初始的位置为第1首歌。 现在要实现通过上下键控制光标移动来浏览歌曲列…