[图神经网络]ViG(Vision GNN)网络代码实现

news2025/1/31 8:04:46

论文解读:

[图神经网络]视觉图神经网络ViG(Vision GNN)--论文阅读https://blog.csdn.net/weixin_37878740/article/details/130124772?spm=1001.2014.3001.5501代码地址:

ViGhttps://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch

一、网络结构

        ViG可堆叠为各向同性结构(isotropic architecture)(类似于ViT)和金字塔结构(pyramid architecture)(类似于ResNet)。本文主要解析金字塔结构PyramidViG-B为例。涉及的代码是git中的pyramid.py和gcn_lib文件夹在的三个文件。

         如上图所示,通过不同规格的ViG Block的堆叠,可以构造出具有4个Stage的金字塔行网络。经过移植,可以取代Resnet50在Faster RCNN中担任主干网络(但直接移植效果并不理想)。

        网络定义代码:

def pvig_b_224_gelu(num_classes =1000,pretrained=False, **kwargs):
    class OptInit:
        # 参数列表
        def __init__(self, num_classes=1000, drop_path_rate=0.0, **kwargs):
            self.k = 9 # 邻居节点数,默认为9
            self.conv = 'mr' # 图卷积层类型,可选 {edge, mr}
            self.act = 'gelu' # 激活层类型,可选 {relu, prelu, leakyrelu, gelu, hswish}
            self.norm = 'batch' # 归一化方式,可选 {batch, instance}
            self.bias = True # 卷积层是否使用偏置
            self.dropout = 0.0 # dropout率
            self.use_dilation = True # 是否使用扩张knn
            self.epsilon = 0.2 # gcn的随机采样率
            self.use_stochastic = False # gcn的随机性
            self.drop_path = drop_path_rate
            self.blocks = [2,2,18,2] # 各层的block个数
            self.channels = [128, 256, 512, 1024] # 各层的通道数
            self.n_classes = num_classes # 分类器输出通道数
            self.emb_dims = 1024 # 嵌入尺寸

    opt = OptInit(**kwargs)
    model = DeepGCN(opt)    #构造gcn
    model.default_cfg = default_cfgs['vig_b_224_gelu']    #注入参数
    return model
#  网络参数计算代码
class DeepGCN(torch.nn.Module):
    def __init__(self, opt):
        super(DeepGCN, self).__init__()
        # ...
        #  参数赋值省略
        # ...

        blocks = opt.blocks            # 获取各层block个数列表[2,2,18,2]
        self.n_blocks = sum(blocks)    # 获取block层数总数
        channels = opt.channels        # 获取输出通道数(用于分类器赋值)
        reduce_ratios = [4, 2, 1, 1]   # 下采样率
        #  获取FFN的随机深度衰减规律
        dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]
        # 获取各层knn的数量
        num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]
        max_dilation = 49 // max(num_knn)    #最大相关数目
        HW = 224 // 4 * 224 // 4

 二、ViG模块

      实际网络构造时使用ViG Block进行堆叠,ViG Block由GCN模块和FFN模块个组成,构造使用代码循环堆叠ViG Block

# 构造骨干网络
self.backbone = nn.ModuleList([])
        idx = 0
        for i in range(len(blocks)):
            if i > 0:
                #  如果不是第一层需要额外在层间添加下采样
                self.backbone.append(Downsample(channels[i-1], channels[i]))
                HW = HW // 4
            for j in range(blocks[i]):
                self.backbone += [
                    # 构造GCN
                    Seq(Grapher(channels[i], num_knn[idx], min(idx // 4 + 1, max_dilation), conv, act, norm,
                                    bias, stochastic, epsilon, reduce_ratios[i], n=HW, drop_path=dpr[idx],
                                    relative_pos=True),
                    # 构造FFN
                          FFN(channels[i], channels[i] * 4, act=act, drop_path=dpr[idx])
                         )]
                idx += 1
        self.backbone = Seq(*self.backbone)
        # 构造分类器
        self.prediction = Seq(nn.Conv2d(channels[-1], 1024, 1, bias=True),
                              nn.BatchNorm2d(1024),
                              act_layer(act),
                              nn.Dropout(opt.dropout),
                              nn.Conv2d(1024, opt.n_classes, 1, bias=True))
        self.model_init()

        网络的前向传递函数,可以看到图片在进入图网络之前先进行了stem(就是ViT里的切patch操作)和位置编码(位置对应的矩阵)

    def forward(self, inputs):
        x = self.stem(inputs) + self.pos_embed    #patch分割和位置嵌入
        B, C, H, W = x.shape
        for i in range(len(self.backbone)):
            x = self.backbone[i](x)

        x = F.adaptive_avg_pool2d(x, 1)
        return self.prediction(x).squeeze(-1).squeeze(-1)

        stem操作和位置嵌入如下:

self.stem = Stem(out_dim=channels[0], act=act)
#返回整数部分
self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224//4, 224//4))

        1.Grapher模块

                首先看Grapher的前向传递函数

def forward(self, x):
        _tmp = x
        x = self.fc1(x)
        B, C, H, W = x.shape
        relative_pos = self._get_relative_pos(self.relative_pos, H, W)
        x = self.graph_conv(x, relative_pos)
        x = self.fc2(x)
        x = self.drop_path(x) + _tmp
        return x

                可以看到,对于每个Grapher模块而言,基本的处理流程是:

                ①全连接层fc1

# 由一个1x1Conv和一个BatchNorm组成
self.fc1 = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )

                ②由_get_relative_pos(.)函数更新关联位置

                其实从代码来看是用来匹配下采样带来的尺寸变化(调整尺寸)

    def _get_relative_pos(self, relative_pos, H, W):
        if relative_pos is None or H * W == self.n:
            return relative_pos
        else:
            N = H * W
            N_reduced = N // (self.r * self.r)
            return F.interpolate(relative_pos.unsqueeze(0), size=(N, N_reduced), mode="bicubic").squeeze(0)

                 在block初始化时,由get_2d_relative_pos_embed(.)函数赋予初值(如不启用的话会直接置None);

# 获取位置嵌入
relative_pos_tensor = torch.from_numpy(np.float32(
        get_2d_relative_pos_embed(in_channels,int(n**0.5)))).unsqueeze(0).unsqueeze(1)
# 进行双线性插值
relative_pos_tensor = F.interpolate(relative_pos_tensor, size=(n, n//(r*r)), 
        mode='bicubic', align_corners=False)
# 转换为nn参数
self.relative_pos = nn.Parameter(-relative_pos_tensor.squeeze(1), requires_grad=False)

                        get_2d_relative_pos_embed(.)位置嵌入函数,位于gcn_lib/pos_embed.py。作用是构建一个grid,并获取位置嵌入(包含cls_token)

                ③图卷积(graph_conv )

self.graph_conv = DyGraphConv2d(in_channels, in_channels * 2, kernel_size,
                     dilation, conv, act, norm, bias, stochastic, epsilon, r)

                转到graph_conv ,查看其前向传递函数:

def forward(self, x, relative_pos=None):
    B, C, H, W = x.shape
    y = None
    if self.r > 1:    #  此参数为下采样率,金字塔池化情况下默认开启(始终大于1)
        y = F.avg_pool2d(x, self.r, self.r)
        y = y.reshape(B, C, -1, 1).contiguous()            
    x = x.reshape(B, C, -1, 1).contiguous()

    # 获取邻居节点的聚合信息(基于knn)
    edge_index = self.dilated_knn_graph(x, y, relative_pos)
    # 图卷积
    x = super(DyGraphConv2d, self).forward(x, edge_index, y)
    # 将tensor变形为四维并输出
    return x.reshape(B, -1, H, W).contiguous()

                其中self.dilated_knn_graph为DenseDilatedKnnGraph,来自gcn_lib/torch_edge.py,和大部分图网络算法一样采用torch.topk(.)来进行邻接矩阵稀疏。同时使用part_pairwise_distance函数从特征中提取x_square_part、x_inner、x_square三个值。

                ④全连接层fc2

        self.fc2 = nn.Sequential(
            nn.Conv2d(in_channels * 2, in_channels, 1, stride=1, padding=0),
            nn.BatchNorm2d(in_channels),
        )

                这个和前一个全连接层一样,只不过输入通道翻倍了而已。

                ⑤DropPath随机删除

self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

                用来防止过拟合,同时该网络中还具备类似残差的结构

x = self.drop_path(x) + _tmp

        2.FNN模块

                FNN模块是一个多层感知机,由两层全连接实现,同样具备残差结构

shortcut = x
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.drop_path(x) + shortcut
return x
        self.fc1 = nn.Sequential(
            nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0),
            nn.BatchNorm2d(hidden_features),
        )
        self.act = act_layer(act)
        self.fc2 = nn.Sequential(
            nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0),
            nn.BatchNorm2d(out_features),
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()

                这里的激活层act默认为relu激活函数。

三、网络的迁移

        得益于金字塔结构带来的多尺度特征,ViG可以像Swin Transfomer一样作为骨干网络用来特征提取,这里将其作为骨干网络移植到Faster RCNN中代替原本的ResNet50。卸掉prediction预测头和平均池化adaptive_avg_pool2d后,可以由一个224x224x3的输入得到一个7x7x1024的特征。

    def forward(self, inputs):
        x = self.stem(inputs) + self.pos_embed
        B, C, H, W = x.shape
        for i in range(len(self.backbone)):
            x = self.backbone[i](x)

        # x = F.adaptive_avg_pool2d(x, 1)
        return x

        经过测试,ViG可以在数据集上获得越70%的mAP,但是效果劣于resnet50和mobilenetv3,具体原因不明。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/521225.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Hive之DDL

目录 对数据库操作: 创建数据库: 查看数据库信息: 1.查看基本信息: 2.查看详尽信息: 删除数据库: 1.简单语法: 2.复杂语法: 对表操作: 创建表: 1.普…

JVM-内存结构

✅作者简介:热爱Java后端开发的一名学习者,大家可以跟我一起讨论各种问题喔。 🍎个人主页:Hhzzy99 🍊个人信条:坚持就是胜利! 💞当前专栏:JVM 🥭本文内容&…

《程序员的底层思维》读书笔记

人是能够习惯于任何环境的生物,之前你认为自己难以克服的困难,慢慢都会适应了。 维克多弗兰克《活出生命的意义》 文章目录 人是能够习惯于任何环境的生物,之前你认为自己难以克服的困难,慢慢都会适应了。 基础思维能力逻辑思维批…

每日学术速递5.12

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.CV 1.ImageBind: Holistic AI learning across six modalities 标题:ImageBind:跨六种模式的整体人工智能学习 作者:Mengyuan Yan Jessica Lin Mont…

支付系统设计三:渠道网关设计04-渠道数据补全

文章目录 前言一、交易信息准备1. MessageDescription内容2. 交易信息填充3. 开户机构信息填充4. 省市区域信息填充5. 银行信息填充 二、路由处理三、支付渠道数据补全1.服务端支付渠道获取2. 支付渠道通用数据补全2.1 支付渠道账户信息补全2.1 商户信息补全结束 3. 支付渠道差…

具有噪声标签的鲁棒医学图像分割的点类仿射损失校正

文章目录 Joint Class-Affinity Loss Correction for Robust Medical Image Segmentation with Noisy Labels摘要本文方法Differentiated Affinity Reasoning (DAR)Class-Affinity Loss Correction (CALC)Class-Level Loss CorrectionAffinity-Level Loss CorrectionClass-Affi…

AcWing算法提高课-1.3.4数字组合

宣传一下算法提高课整理 <— CSDN个人主页&#xff1a;更好的阅读体验 <— 本题链接&#xff08;AcWing&#xff09; 点这里 题目描述 给定 N N N 个正整数 A 1 , A 2 , … , A N A_1,A_2,…,A_N A1​,A2​,…,AN​&#xff0c;从中选出若干个数&#xff0c;使它们…

轻松搭建冒险岛服务器-冒险岛私服搭建详细教程

想要拥有一个属于自己的冒险岛世界吗&#xff1f;想要一步步学习如何架设冒险岛服务器吗&#xff1f;本文将从如何选择服务器、安装系统、配置环境、搭建数据库、部署网站、上传文件、启动服务等8个方面&#xff0c;一步步为大家详细讲解冒险岛架设教程。让你轻松打造属于自己的…

sql 性能优化基于explain调优

文章目录 Explain分析&#xff1f;问题描述解决方案 Explain分析&#xff1f; 关于Explain具体可以干什么&#xff0c;有哪些优缺点&#xff0c;本博主的文章有写到&#xff0c;这是链接地址: 点击这里查看. 下面来说下Explain在项目实战中&#xff0c;如何去进行优化。 问题…

7年老人,30岁的测试说辞就辞,“人员优化”4个字,泰裤辣...

前几天&#xff0c;一个认识了好几年在大厂做测试的程序员朋友&#xff0c;年近30了&#xff0c;在公司做了7年了&#xff0c;一直兢兢业业&#xff0c;最后还是却被大厂以“人员优化”的名义无情被辞&#xff0c;据他说&#xff0c;有一个月散伙饭都吃了好几顿…… 在很多企业…

【ChatGPT】国内免费使用ChatGPT镜像

Yan-英杰的主页 悟已往之不谏 知来者之可追 C程序员&#xff0c;2024届电子信息研究生 目录 什么是ChatGPT镜像&#xff1f; 亲测&#xff1a; 一、二狗问答(AI对话) 二、AiDuTu 三、WOChat 四、ChatGPT(个人感觉最好用) 我们可以利用ChatGPT干什么&#xff1f; 一、三分…

薪人薪事 java开发实习一面

目录 1.常用数据结构&#xff0c;区别及使用场景2.数组和链表在内存中数据的分布情况3.HashMap底层数据结构4.put操作5.JVM内存区域6.各个区域存放什么东西7.创建一个对象&#xff0c;内存怎么分配的8.堆中内存怎么划分&#xff0c;gc怎么回收9.IOC 原理10.Bean存放在哪里11.AO…

支付系统设计三:渠道网关设计05-交易持久化

文章目录 前言一、领域模型持久化服务工厂二、聚合创建工厂1. 模型创建1.1 获取域模型Class1.2 新建模型1.3 数据填充 2. 模型持久化2.1 获取域模型对应的仓储2.2 调用域模型仓储进行持久化 总结 前言 本篇将解析交易信息入库&#xff0c;即对上送的参数&#xff0c;在进行校验…

关于ASA广告归因接入方法

投放苹果ASA广告&#xff0c;提高 app 曝光率、下载量的增长&#xff0c;那么我们该如何从后台看到投放广告的效果呢&#xff1f; 我们可以借助Apple Ads归因API。那什么是归因&#xff1f;什么又是API呢&#xff1f; 归因&#xff1a;可以给用户打标签&#xff0c;然后看他在…

[GUET-CTF2019]encrypt 题解

本题是输入了一个字符串&#xff0c;进行了rc4加密&#xff0c;和魔改的base64加密 RC4算法初始化函数 RC4加密过程 魔改的base64加密 最后加密的字符串是byte_602080 我们可以将byte_602080提取出来&#xff0c;下面是提取数据的IDC脚本&#xff0c;得到了密文 #include<…

赫夫曼树和赫夫曼编码详解

目录 何为赫夫曼树&#xff1f; 赫夫曼树算法 赫夫曼编码 编程实现赫夫曼树 编程实现赫夫曼编码 编程实现WPL 总代码及分析 何为赫夫曼树&#xff1f; 树的路径长度&#xff1a;从树根到每一结点的路径长度之和 结点的带权路径长度&#xff1a;从树根到该结点的路径长度…

2023网络安全十大顶级工具

从事网络安全工作&#xff0c;手上自然离不开一些重要的网络安全工具。今天&#xff0c;分享10大网络安全工具。 一、Kali Linux Kali 是一个基于 Debian 的 Linux 发行版。它的目标就是为了简单&#xff1a;在一个实用的工具包里尽可能多的包含渗透和审计工具。Kali 实现了这…

【AI面试】CNN 和 transformer 的问题汇总

​ CNN卷积神经网络和transformer相关的知识&#xff0c;是AI的一个基础的&#xff0c;也是前言的知识点。一般面试官会从以下这些问题入手&#xff1a; 卷积神经网络&#xff0c;有什么特点&#xff1f;1*1卷积核有什么作用&#xff1f;计算经过卷积的输出尺寸大小空洞卷积你…

机器学习之朴素贝叶斯三、拉普拉斯平滑技术、优化改进情感分析

文章目录 一、前文问题1. 先看下改进前我们的代码计算部分2. 问题分析&#xff1a; 二、针对问题进行解决1. 什么是拉普拉斯平滑技术2. 拉普拉斯优化-下溢上溢问题3. 改进地方分析&#xff1a;4.改进优化1.优化一&#xff0c;对条件概率计算进行优化2.优化二&#xff0c;对后延…

从小白到黑客高手:一份全面详细的学习路线指南

前言 黑客从入门到精通需要经过深入的学习和实践&#xff0c;这是一个需要长时间投入和大量精力的过程。在这份学习路线中&#xff0c;我将为你介绍黑客学习的基本知识和技能&#xff0c;帮助你逐步掌握黑客技能。 黑客 一、入门阶段 1.了解计算机基础知识 学习计算机基础知…