《Probing the 3D Awareness of Visual Foundation Models》论文解析——单图像表面重建

news2024/11/19 12:11:15

一、论文简介

        论文讨论了大规模预训练产生的视觉基础模型在处理任意图像时的强大能力,这些模型不仅能够完成训练任务,其中间表示还对其他视觉任务(如检测和分割)有用。研究者们提出了一个问题:这些模型是否能够表示物体的三维结构。他们通过一系列实验,使用特定任务的探针和零样本推理程序来分析这些模型的3D感知能力,并发现当前模型存在一些限制。这个实验旨在评估模型对图像中可见表面表示的能力,具体包括两个任务:深度估计(Monocular Depth Estimation)和表面法线估计(Surface Normal Estimation)。

二、深度估计(Monocular Depth Estimation)

        任务:预测图像中每个像素点的深度。

        数据集:使用NYUv2数据集评估场景级性能,NAVI数据集评估对象级性能。

        输入:单张RGB图像;输出:图像中每个像素点的深度

        网络结构:使用AdaBins的二进制预测结果,在模型的多层特征图基础上,构建一个类似于DPT解码器的多尺度探测器,用于密集预测。通过训练密集探针来预测每个像素点的深度。

        1.预训练模型特征提取

        对于一张待估计深度的图像,使用一个预训练的视觉模型(例如,一个视觉变换器或卷积神经网络)来提取图像的特征。这些特征通常在模型的中间层获得,以捕捉到图像的高层语义信息。

        2.深度预测网络(Dense Probe)

        设计一个密集探针(dense probe)网络,这个网络将从预训练模型中提取的特征映射到深度图。这个探针网络可以是一个简单的全连接层,或者是一个更复杂的网络结构,如多层感知机(MLP)或卷积层。使用AdaBins方法来训练这个探针网络(AdaBins是一种基于分箱的深度预测技术,它将深度范围划分为一系列离散的“bins”,并学习将图像特征映射到这些bins的概率分布),度量预测深度和真实深度之间的差距。

        3. 损失函数和优化

        AdaBins方法使用特定的损失函数来训练网络,这个损失函数同时考虑了深度值的回归和分类任务;使用AdamW优化器进行训练,这是一种带有权重衰减的随机梯度下降变体,有助于防止过拟合并提高训练稳定性;采用线性预热和余弦衰减学习率调度器进行学习率调度,这意味着在训练初期逐步增加学习率,然后在训练后期逐渐减小学习率,以促进模型收敛。

        4.深度图生成

        对于输入图像中的每个像素,探针网络预测一个深度值或一个深度bins的概率分布。并根据预测的概率分布,为每个像素选择最有可能的深度值,或者通过某种方式(如取期望值)从概率分布中得到一个单一的深度估计值。

        5.评估方法

        使用均方根预测误差(RMSE)和不同阈值下的召回率来评估深度估计的准确性,将预测的深度图与真实深度图(第二列)进行比较,以验证模型的性能。

三、表面法线估计(Surface Normal Estimation)

        任务:预测每个像素点的表面法线方向。

        数据集:NYUv2数据集:该数据集提供了与表面法线相关的注释,用于评估室内场景的表面法线估计性能。

        NAVI数据集:该数据集包含了对象实例在多种场景和方向中的表面法线注释,用于评估对象级别的表面法线估计性能。

        输入:单张RGB图像;输出:图像中每个像素点表面法线方向

        1.同深度估计进行预训练模型特征提取

        2.表面法线预测网络(Surface Normal Prediction Network)

        设计一个网络结构,将从预训练模型中提取的特征映射到表面法线的预测。这个网络可以是一个简单的全连接层,或者是一个更复杂的网络结构,如多层感知机(MLP)或卷积层。使用Bae等人提出的不确定性感知的角度损失函数来训练网络,以预测法线的方向。

        3.表面法线图生成

        对于输入图像中的每个像素,网络预测一个表面法线的方向向量。并将预测的法线向量归一化,以确保它们具有单位长度。

        4.评估方法

        使用均方根角度预测误差(RMSE)和不同角度阈值下的召回率来评估表面法线估计的准确性。将预测的表面法线图与真实表面法线图(如果有的话)进行比较,以验证模型的性能。

四、相关代码解析

        1.深度估计

        深度估计是一个复杂的计算机视觉任务,通常涉及到机器学习或深度学习技术。以下是一个简单的示例,使用Python和OpenCV库来从单个RGB图像中估计深度。

        

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

# 定义一个自定义的数据集
class DepthDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)
        
        # 假设我们有一个对应的深度图,这里我们随机生成一个作为示例
        depth = torch.rand(1, 1, image.size[1], image.size[0])
        
        return image, depth

# 定义CNN模型
class DepthCNN(nn.Module):
    def __init__(self):
        super(DepthCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(64*16*16, 1024)
        self.fc2 = nn.Linear(1024, 1)  # 假设深度图是单通道的

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化数据集和数据加载器
image_paths = ['path_to_your_image1.jpg', 'path_to_your_image2.jpg']  # 替换为实际图像路径
transform = transforms.Compose([transforms.ToTensor()])
dataset = DepthDataset(image_paths, transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 实例化模型
model = DepthCNN()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(10):  # 迭代10个epoch
    for images, depths in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, depths)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

        在这个示例中,我们定义了一个DepthDataset类来加载图像和对应的深度图。然后,我们定义了一个DepthCNN类来构建CNN模型。模型包含三个卷积层和两个全连接层。我们使用均方误差损失(MSELoss)作为损失函数,并使用Adam优化器来更新模型权重。

        2.表面法线估计

        表面法线分析是计算机视觉中的一个高级任务,通常涉及到从RGB图像中估计表面的法线向量。这通常需要复杂的深度学习模型,比如卷积神经网络(CNN)。以下是一个使用PyTorch框架的简化示例,展示了如何构建一个CNN模型来进行表面法线分析。

        

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image

# 定义一个自定义的数据集
class NormalDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path)
        
        if self.transform:
            image = self.transform(image)
        
        # 假设我们有一个对应的法线图,这里我们随机生成一个作为示例
        # 法线图通常有三个通道,分别对应x, y, z坐标
        normal = torch.rand(3, image.size[1], image.size[0])
        
        return image, normal

# 定义CNN模型
class NormalCNN(nn.Module):
    def __init__(self):
        super(NormalCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.fc1 = nn.Linear(64*16*16, 256)
        self.fc2 = nn.Linear(256, 3)  # 法线有三个分量

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 实例化数据集和数据加载器
image_paths = ['path_to_your_image1.jpg', 'path_to_your_image2.jpg']  # 替换为实际图像路径
transform = transforms.Compose([transforms.ToTensor()])
dataset = NormalDataset(image_paths, transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 实例化模型
model = NormalCNN()

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
for epoch in range(10):  # 迭代10个epoch
    for images, normals in dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, normals)
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

        在这个示例中,我们定义了一个NormalDataset类来加载图像和对应的法线图。然后,我们定义了一个NormalCNN类来构建CNN模型。模型包含三个卷积层和两个全连接层。我们使用均方误差损失(MSELoss)作为损失函数,并使用Adam优化器来更新模型权重。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2243376.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

泷羽sec学习打卡-云技术基础1-docker

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负 关于云技术基础的那些事儿-Base1 一、云技术基础什么是云架构?什么是云服务?什么…

03-axios常用的请求方法、axios错误处理

欢迎来到“雪碧聊技术”CSDN博客! 在这里,您将踏入一个专注于Java开发技术的知识殿堂。无论您是Java编程的初学者,还是具有一定经验的开发者,相信我的博客都能为您提供宝贵的学习资源和实用技巧。作为您的技术向导,我将…

Spring Boot 与腾讯云 MySQL 监听 Binlog 数据变化,并使用 UI 展示页面效果

引言 在现代的分布式系统和微服务架构中,数据同步和变更监控是保证系统一致性和实时性的核心问题之一。MySQL 数据库的 binlog(二进制日志)功能能够记录所有对数据库的修改操作,如插入(INSERT)、更新&…

Spring Boot汽车资讯:科技与速度的新纪元

摘要 随着信息技术在管理上越来越深入而广泛的应用,管理信息系统的实施在技术上已逐步成熟。本文介绍了汽车资讯网站的开发全过程。通过分析汽车资讯网站管理的不足,创建了一个计算机管理汽车资讯网站的方案。文章介绍了汽车资讯网站的系统分析部分&…

thinkphp6模板调用URL方法生成的链接异常

var uul params.url ;console.log(params.url);console.log("{:Url(UserLog/index)}");console.log("{:Url("uul")}"); 生成的链接地址 UserLog/index /jjg/index.php/Home/UserLog/index.html /jjg/index.php/Home/Index/UserLog/index.html…

NodeJS 百度智能云文本转语音(实测)

现在文本转语音的技术已经非常完善了,尽管网络上有许多免费的工具,还是测试了专业的服务,选择了百度的TTS服务。 于是,在百度智能云注册和开通了文本转语音的服务,尝试使用NodeJS 实现文本转语音服务。但是百度的文档实…

UML 类图讲解

UML 类图符号含义 在 UML 类图中,每个符号都有其特定的含义。以下是常见符号的解释: : Public(公共访问权限)-: Private(私有访问权限)#: Protected(受保护访问权限)~: Package&…

【GAT】 代码详解 (1) 运行方法【pytorch】可运行版本

GRAPH ATTENTION NETWORKS 代码详解 前言0.引言1. 环境配置2. 代码的运行2.1 报错处理2.2 运行结果展示 3.总结 前言 在前文中,我们已经深入探讨了图卷积神经网络和图注意力网络的理论基础。还没看的同学点这里补习下。接下来,将开启一个新的阶段&#…

远程控制步骤

当远在千里之外的朋友想求助你帮他找到他电脑上的文件、或者是给他安装软件时。但是你给他说了他又找不到,那么这时你就可以通过控制对方的电脑去做一系列的操作。 如何远程控制对方的电脑非常关键。 方法一(Windows自带远程桌面功能)&#…

C指针之舞——指针探秘之旅

❤博客主页:折枝寄北-CSDN博客 ❤专栏内容:C语言学习专栏https://blog.csdn.net/2303_80170533/category_12794764.html?spm1001.2014.3001.5482 指针基础学习 在之前的博客文章中,简单总结了指针的基础概念 我们知道了指针的概念&#xf…

前端 JS 浅拷贝与深拷贝

目录 一、问题引出 二、浅拷贝 1、通过解构重构实现浅拷贝 三、深拷贝 1、自定义实现深拷贝 2、JSON实现深拷贝 四、总结 一、问题引出 基础类型的数据存放: let a 100let b aconsole.log("a:" a, "b:" b)a 50console.log("a…

72项!湖北省2024年度第二批省级科技计划项目拟立项项目公示!

本期精选 SCI&EI ●IEEE 1区TOP 计算机类(含CCF); ●EI快刊:最快1周录用! 知网(CNKI)、谷歌学术期刊 ●7天录用-检索(100%录用),1周上线; 免费稿件评估 免费匹配…

uniapp微信小程序转发跳转指定页面

onShareAppMessage 是微信小程序中的一个重要函数,用于自定义转发内容。当用户点击右上角的菜单按钮,并选择“转发”时,会触发这个函数。开发者可以在这个函数中返回一个对象,用于定义分享卡片的标题、图片、路径等信息。 使用场…

[N1CTF 2018]eating_cms

打开题目 只有个登录框,其他什么都没有,尝试了一下弱口令,没能成功 尝试访问一下register.php,看看能不能注册个账号 注册页面,随便注册个账号登陆一下 url中感觉是个注入点,尝试使用file伪协议读取一下us…

PMP–一、二、三模、冲刺–分类–5.范围管理–技巧–引导

文章目录 技巧一模5.范围管理--3.定义范围--工具与技术--引导--在研讨会和座谈会中使用引导技能来协调具有不同期望或不同专业知识的关键干系人,使他们就项目可交付成果以及项目和产品边界达成跨职能的共识。引导:题干关键词 “需求不同、需求差异、需求…

C语言-字符串指针及多变的访问方式

1、字符串指针 示例;输出字符串数组 1. #include <stdio.h> 2. #include <string.h> 3. 4. int main(){ 5. char str[] "<http://baidu.com>"; 6. int len strlen(str), i; 7. //直接输出字符串 8. printf("%s\\n", str); 9. //每次…

Linux之vim模式下全选命令

在Linux系统中&#xff0c;使用Vim编辑器进行全选操作可以通过以下几种方式实现&#xff1a; 1.使用键盘快捷键 按下 ”ggVG”&#xff08;先按下”g”&#xff0c;再按下”g”&#xff0c;再按下”V”&#xff0c;最后按下”G”&#xff09;可以全选当前文件内容。其中 ”g…

解决虚拟机未被自动分配ip

文章目录 1. 背景2. 解决步骤 1. 背景 从vulnhub下载的靶场文件&#xff0c;网络适配器模式设置为nat模式之后&#xff0c;启动虚拟机之后发现没有成功分配动态ip。推测是虚拟机分配的网卡名称和原先靶机作者设置网络配置文件 网络接口名称不一致导致。 2. 解决步骤 解决办法就…

【数据结构与算法】排序

文章目录 排序1.基本概念2.分类2.存储结构 一.插入排序1.1直接插入排序1.2折半插入排序1.3希尔排序 二.选择排序2.1简单选择排序2.2堆排序 三.交换排序3.1冒泡排序3.2快速排序 四.归并排序五.基数排序**总结** 排序 1.基本概念 排序&#xff08;sorting&#xff09;又称分类&…

5. ARM_指令集

概述 分类 汇编中的符号&#xff1a; 指令&#xff1a;能够编译生成一条32位机器码&#xff0c;并且能被处理器识别和执行伪指令&#xff1a;本身不是指令&#xff0c;编译器可以将其替换成若干条指令伪操作&#xff1a;不会生成指令&#xff0c;只是在编译阶段告诉编译器怎…