Visual Commonsense R-CNN 实现和代码

news2024/11/25 5:03:09

这篇文章比较早,但是对于因果介绍的比较详细,很值得学习。
代码:https://github.com/Wangt-CN/VC-R-CNN
代码花了挺长时间总算跑通了,在 3080 上调真是错误不断,后来换到 2080 又是一顿调才好。这里跑通的主要环境为 ubuntu,2080,cuda 11.3, torch ‘1.10.1+cu113’ 。一些配置如下

  • 安装 conda 后,conda create --name vc_rcnn python=3.7
  • conda activate vc_rcnn pip install ninja yacs cython matplotlib tqdm opencv-python h5py lmdb -i https://pypi.mirrors.ustc.edu.cn/simple/
  • pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 -i https://pypi.mirrors.ustc.edu.cn/simple/ (高于 1.10版本的话安装 vc-r-cnn 会有问题)
  • 参考代码里的 install.md 安装 coco 和 apex 以及 vc-rcnn 即可。

作者大大有个博客对于因果学习和这篇文章介绍的都很详细,见 https://zhuanlan.zhihu.com/p/111306353
因此这里具体的介绍就不多写了,可参考原始文章和博客。只说实现部分。
在这里插入图片描述

  • 首先根据数据集构建 confounder Z Z Z 和先验数据分布 P ( z ) P(z) P(z)。 具体而言,confunder 指的是特定类别的固定表征,利用的是 gt 获得的,将特定类别的所有目标加一起取平均,对于coco而言获得 80 × 1024 80 \times 1024 80×1024 的 confounder,每个行都对应一个类别。而先验分布 P ( z ) P(z) P(z) 维度为 1 × 80 1 \times 80 1×80 ,表示特定类别出现的频率(所有值加一起和为1),应该是特定类别出现的频数除以总的目标个数得到。文章定义如下,是预先获得直接加载的。

    self.dic = torch.tensor(np.load(cfg.DIC_FILE)[1:], dtype=torch.float) # [80,1024]
    self.prior = torch.tensor(np.load(cfg.PRIOR_PROB), dtype=torch.float) # [80]

  • 模型使用的是 ResNet+FPN, 也会生成不止一个 proposals。之后就是设计两个分类器。Self Predictor 和一般使用的基本没什么区别,重点关注 Context Predict。公式如下,代码看后边。 80 × 1024 80 \times 1024 80×1024 的 confounder 首先和 ROI 特征 y y y 计算attention,具体先经过全连接,然后点乘,再通过 softmax,获得 4 × 80 4 \times 80 4×80 的attention (4 表示该样本有 4 个 proposals)。按照我们一般的思路,这个 attention 可以直接聚合对应的 80 个特征向量。然而这里再和 confounder 的先验分布 P ( z ) P(z) P(z) 相乘,很有意思,在 attention 中强行加入 confunder(即各个类别)的出现频率(ps,感觉那么像解决长尾分布的问题)。最后对这些特征聚合得到 4 × 1024 4 \times 1024 4×1024 的特征。这些特征和模型获得的 ROI 特征 x x x (公式是 x x x,但是代码中表示的应该就是前面所说的 y y y,是一样的 ROI 特征,不知道是不是表示错误还是我理解错误)进行拼接(公式里是相加,可能也表示拼接)。这样就能获得一个包含真实特征和 confounter 的特征。最后用 Context Predict (它的输入是Self Predictor 输入维度的两倍)进行预测。
    q = W 3 y , K = W 4 Z T E z [ g y ( z ) ] = ∑ z [ Softmax ⁡ ( q T K / σ ) ⊙ Z ] P ( z ) E z [ f y ( x , z ) ] = W 1 x + W 2 ⋅ E z [ g y ( z ) ] \boldsymbol{q}=\boldsymbol{W}_3 \boldsymbol{y}, \boldsymbol{K}=\boldsymbol{W}_4 \boldsymbol{Z}^T \\ \mathbb{E}_{\boldsymbol{z}}\left[g_y(\boldsymbol{z})\right]=\sum_z\left[\operatorname{Softmax}\left(\boldsymbol{q}^T \boldsymbol{K} / \sqrt{\sigma}\right) \odot \boldsymbol{Z}\right] P(\boldsymbol{z}) \\ \mathbb{E}_{\boldsymbol{z}}\left[f_y(\boldsymbol{x}, \boldsymbol{z})\right]=\boldsymbol{W}_1 \boldsymbol{x}+\boldsymbol{W}_2 \cdot \mathbb{E}_{\boldsymbol{z}}\left[g_y(\boldsymbol{z})\right] \\ q=W3y,K=W4ZTEz[gy(z)]=z[Softmax(qTK/σ )Z]P(z)Ez[fy(x,z)]=W1x+W2Ez[gy(z)]

  • 最后能让训练整个的特征提取器减少模型受这种这种 bias (共现的bias,位置的bias等)的干扰。

def z_dic(self, y, dic_z, prior):
    """
    Please note that we computer the intervention in the whole batch rather than for one object in the main paper.
    """
    length = y.size(0) # proposals 的数量 torch.Size([4, 1024])
    if length == 1:
        print('debug')
    # torch.mm(self.Wy(y), self.Wz(dic_z).t()) --> torch.Size([4, 80])
    attention = torch.mm(self.Wy(y), self.Wz(dic_z).t()) / (self.embedding_size ** 0.5)
    attention = F.softmax(attention, 1)
    z_hat = attention.unsqueeze(2) * dic_z.unsqueeze(0) # torch.Size([4, 80, 1024])
    z = torch.matmul(prior.unsqueeze(0), z_hat).squeeze(1) # [1, 80], torch.Size([4, 80, 1024]) --> torch.Size([4, 1, 1024]) --> torch.Size([4, 1024]) 
    xz = torch.cat((y.unsqueeze(1).repeat(1, length, 1), z.unsqueeze(0).repeat(length, 1, 1)), 2).view(-1, 2*y.size(1)) # y [4,1024]->[4,4,1024], z [4,1024]->[4,4,1024] => cat [4,4, 2048] ==> [16, 2048]
    # detect if encounter nan
    if torch.isnan(xz).sum():
        print(xz)
    return xz

总的来说,实现上很有意思的,不知道是先有了因果的思考才有的实现思路,还是先有实现的方法再有的因果的角度hh。张老师组这几年在因果学习上发表了相当多的文章,很多领域都有涉及,很有启发性,值得学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/336832.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

代理模式详解

本文首更于《从零开始手把手教你实现一个简单的RPC框架》 。 1. 代理模式2. 静态代理3. 动态代理 3.1. JDK 动态代理机制 3.1.1. 介绍3.1.2. JDK 动态代理类使用步骤3.1.3. 代码示例 3.2. CGLIB 动态代理机制 3.2.1. 介绍3.2.2. CGLIB 动态代理类使用步骤3.2.3. 代码示例 3.3. …

win10系统安装Nginx

Nginx是一款自由的、开源的、高性能的HTTP服务器和反向代理服务器,同时也提供了IMAP/POP3/SMTP服务。 Nginx可以进行反向代理、负载均衡、HTTP服务器(动静分离)、正向代理等操作。因为最近在公司使用到了Nginx 第一步:下载Nginx …

想找大数据工作需要学些什么

大数据开发做什么? 大数据开发分两类,编写Hadoop、Spark的应用程序和对大数据处理系统本身进行开发。大数据开发工程师主要负责公司大数据平台的开发和维护、相关工具平台的架构设计与产品开发、网络日志大数据分析、实时计算和流式计算以及数据可视化等…

12_FreeRTOS任务相关API函数

目录 FreeRTOS任务相关API函数介绍 获取任务优先级函数 设置任务优先级函数 获取任务数量函数 获取所有任务状态信息 获取指定的单个任务的状态信息 获取当前任务句柄 通过任务名获取任务句柄 获取任务栈历史最小剩余推栈 以表格的形式获取系统中任务信息 实验源码 …

【虹科】防止PCB组装过程出现质量错误的5种方法

质量问题和错误时有发生,尤其是在涉及PCB和电子产品制造的复杂人为操作任务中。通常情况下,企业可能会配备自动光学检测(AOI)等系统,这些系统通常用于制造过程中“中间”阶段的检测。尽管AOI系统为质量控制创造价值&am…

Jmeter in Linux - 如何在Linux系统使用Jmeter压测?

Jmeter in Linux - 如何在Linux系统使用Jmeter压测?Jmeter in Linux系列目录:1. 在windows创建好一个测试计划:2. 保存后,将jmx后缀的文件上传至Linux服务器3. 执行jmeter命令4. 根据执行日志分析压测报告5. 解析压测报告Jmeter i…

有效的括号-力扣20-java

一、题目描述给定一个只包括 (,),{,},[,] 的字符串 s ,判断字符串是否有效。有效字符串需满足:左括号必须用相同类型的右括号闭合。左括号必须以正确的顺序闭合。每个右括号都有一个对应的相同类…

【huggingface系列学习】Using Transformers

文章目录前言Using Transformers使用tokenizer预处理Tokenizer详解Loading and saving加载保存EncodingDecodingModel创建一个Transformer不同的加载方法模型保存使用模型进行推理前言 因实验中遇到很多 huggingface-transformers 模型和操作,因此打算随着 course …

剖析字节案例,火山引擎 A/B 测试 DataTester 如何“嵌入”技术研发流程

更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 日前,在 WOT 全球创新技术大会上,火山引擎 DataTester 技术负责人韩云飞做了关于字节跳动 A/B 测试平台的分享。DataTester 是字节跳动内部应…

Roboguide与TIA V16通讯

软件需求:1. roboguide;2. TIA V16;3. KEPServer; 在之前的文章中介绍过KEPServer与TIA V16的通讯,此处不再介绍。接下来,介绍roboguide与KEPServer的仿真通讯。 创建一个roboguide项目。选择【外部设备】➡【添加外部设备】 选择【OPC Server】➡【OK】 OPC服务器名称命…

linux安装并配置nginx

菜鸟教程 一 . Nginx安装和部署 1.输入指令,下载相关的依赖包 yum -y install gcc zlib zlib-devel pcre-devel openssl openssl-develYUM(Yellow dog Updater, Modified)为多个Linux发行版的前端软件包管理器 -y 是参数,默认不要确认, rp…

对话 ChatGPT:现象级 AI 应用,将如何阐释「研发效能管理」?

ChatGPT 已然是 2023 开年至今,互联网上最热的话题没有之一。从去年的 AI 图片生成,到 ChatGPT,再到现在各种基于大模型的应用如雨后春笋般出现……在人们探讨技术无限可能的同时,另一个更深刻的命题也不可回避地浮现出来&#xf…

汽摩仪表快检盒

不怕失业 ​ ​最近大火的ChatGPT说要取代程序员,老婆子惊慌失措,跟着糟老头憋屈,咸鱼想靠软件翻身,这下白瞎了。 ​温州寄来了汽车燃油预热控制板,绍兴又寄来了发动机仪表,昆山的尾门在路上,都…

如何成为java架构师?2023版Java架构师学习路线总结完成,真实系统有效,一切尽在其中

导读 从初级Java工程师成长为Java架构师,你需要走很长的路,很多有计划的人在学习之初就在做准备。你知道Java架构师学习路线该怎么走吗?成为一个优秀的Java架构师究竟需要学什么?接下来就跟小编一起揭晓答案。 架构师是一个充满挑战的职业&#xff0…

Python自定义模块

到目前为止,读者已经掌握了导入 Python 标准库并使用其成员(主要是函数)的方法,接下来要解决的问题是,怎样自定义一个模块呢?Python 模块就是 Python 程序,换句话说,只要是 Python 程…

Swagger自动生成api文档

Swagger自动生成api文档Swagger是什么Swagger底层原理使用方式1修改pom文件2启动类中加入注解EnableSwagger23加入SpringFoxConfig.java4加入WebMvcConfig.java文件5 给Web 服务的接口加注解访问可视化页面Swagger是什么 Swagger 是一个规范和完整的框架,用于生成、…

C经典小游戏之扫雷

编译环境:VS022 目录 1.算法思路 2.代码模块 2.1 game.h 2.2 game.cpp 2.3 test.cpp 3.重点分析 4.金句省身 1.算法思路 主要采用二维数组进行实现,设置两个二维数组,一个打印结果,即为游戏界面显示的效果,一个用…

值类型和引用类型

一、值类型和引用类型示例: 值类型:基本数据类型系列,如:int,float,bool,string,数组和结构体等。 引用类型:如:指针,slice切片,map&a…

windows wireshark抓到未加入组的组播消息

现象 在Windows上开启wireshark,抓到了大量地址为239.255.255.251的组播包。 同时,根据组播相关命令,调用netsh interface ipv4 show joins,显示当前并没加入 239.255.255.251 组播组。 解决 根据IGMP Snooping,I…

《机器学习》学习笔记

第 2 章 模型评估与选择 2.1 经验误差与过拟合 精度:精度1-错误率。如果在 mmm 个样本中有 aaa 个样本分类错误,则错误率 Ea/mEa/mEa/m,精度 1−a/m1-a/m1−a/m。误差:一般我们把学习器的实际预测输出与样本的真实输出之间的差…