基于卷积变分自动编码器的3D数据处理与重建【CVAE】

news2024/11/24 4:37:16

在这个项目中,我们将学习如何使用卷积变分自动编码器 (CVAE) 来处理和重建 3D 湍流数据。

我们使用计算流体动力学 (CFD) 方法生成 3D 湍流立方体,每个 3D 立方体沿着三个速度分量携带物理信息(与图像数据类似,被视为单独的通道)。

在这里插入图片描述

推荐:用 NSDT设计器 快速搭建可编程3D场景。

作为 3D CFD 数据预处理的一部分,我们编写了一个自定义 pytorch 数据加载器,用于对数据集执行标准化和批处理操作。

CVAE 对预处理后的数据实施 3D 卷积 (3DConvs) 来执行重建。

通过微调超参数和操纵我们的模型架构,我们在 3D 重建方面取得了显着的进步。

可以在这个 Github 存储库中找到项目代码。

1、数据说明

我们的数据集是使用 CFD 模拟方法生成的,它包含从供暖通风和空调 (HVAC) 管道中提取的立方体。

每个立方体代表特定时间携带物理信息的湍流的三维时间快照。 从模拟中提取的信息基于两个流动分量:速度场 U 和静压 p。 U 场 (x, y, z) 和标量 p 基于流的方向(立方体的法线方向)。

我们使用体素将 3D 立方体表示为尺寸为 21 × 21 × 21 x 100(x_coord、y_coord、z_coord、timestep)的数组。 下图显示了一个立方体数据样本,我们使用热图可视化每个速度分量。

总的来说,该数据集由 96 个模拟组成,每个模拟有 100 个时间步长,总共 9600 个立方体(对于每个速度分量)。
在这里插入图片描述

注意:由于保密限制,我们不会公开我们的数据,你可以使用脚本并将其改编为自己的 3D 数据。

2、数据预处理

下面的脚本显示了为预处理 3D 数据而编写的自定义 pytorch 数据加载器。 以下是一些亮点:

  • 立方体速度通道的加载和串联
  • 数据标准化
  • 数据缩放

请参阅存储库中的 dataloader.py 脚本以了解完整的实现。

3、模型架构

下图显示了已实现的 CVAE 架构。 在本例中,为了清晰起见,显示了 2DConv,但实现的架构使用 3DConv。

CVAE 由编码器网络(上部)、变分层(mu 和 sigma)(中右部分)和解码器网络(底部)组成。

编码器对输入立方体执行下采样操作,解码器对它们进行上采样以恢复原始形状。 变分层尝试学习数据集的分布,该层稍后可用于生成。

在这里插入图片描述

编码器网络由四个 3D 卷积层组成,每层的卷积滤波器数量是前一层的两倍(分别为 32、64、128 和 256),这使得模型能够学习更复杂的流特征。

密集层用于组合从最后一个编码器层获得的所有特征图,该层连接到计算后验流数据分布的参数(mu和sigma,这些参数使用重新定义概率分布)的变分层。 -[1]中描述的参数化技巧。这种概率分布允许我们从中采样,以生成尺寸为 8 × 8 x 8 的合成 3D 立方体。

解码器网络采用潜在向量并应用四个 3D 转置卷积层来恢复(重建)原始数据维度,每一层的卷积滤波器数量是前一层的一半(分别为 256、128、64 和 32)。

CVAE 使用两个损失函数进行训练:用于重建的均方误差 (MSE) 和用于潜在空间正则化的 Kullback-Leibler 散度 (KLB)。

我们将[2]中提出的架构和[3]中的超参数作为基线架构。

下面的脚本显示了 pytorch 中的一个示例,其中编码器和解码器都是使用 3D 卷积层 (Conv3d) 定义的:

self.encoder = nn.Sequential(
            nn.Conv3d(in_channels=image_channels, out_channels=16, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=16),
            nn.ReLU(),
            nn.Conv3d(in_channels=16, out_channels=32, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=32),
            nn.ReLU(),
            nn.Conv3d(in_channels=32, out_channels=64, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(),
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(),
            nn.Conv3d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(),
            Flatten()
        )

self.decoder = nn.Sequential(
            UnFlatten(),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(),
            nn.ConvTranspose3d(in_channels=128, out_channels=128, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=128),
            nn.ReLU(),
            nn.ConvTranspose3d(in_channels=128, out_channels=64, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=64),
            nn.ReLU(),
            nn.ConvTranspose3d(in_channels=64, out_channels=32, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=32),
            nn.ReLU(),
            nn.ConvTranspose3d(in_channels=32, out_channels=16, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm3d(num_features=16),
            nn.ReLU(),
            nn.ConvTranspose3d(in_channels=16, out_channels=image_channels, kernel_size=4, stride=1, padding=0), # dimensions should be as original
            nn.BatchNorm3d(num_features=3))

4、设置环境

克隆此存储库:

git clone git@github.com:agrija9/Convolutional-VAE-for-3D-Turbulence-Data

建议使用虚拟环境来运行本项目:

  • 可以安装Anaconda并在系统中创建环境
  • 可以使用 pip venv 创建环境

在 pip/conda 环境中安装以下依赖项:

  • NumPy (>= 1.19.2)
  • Matplotlib (>= 3.3.2)
  • PyTorch (>= 1.7.0)
  • Torchvision (>= 0.8.1)
  • scikit-learn (>= 0.23.2)
  • tqdm
  • tensorboardX
  • torchsummary
  • PIL
  • collections

5、模型训练

要训练模型,请打开终端,激活 pip/conda 环境并输入:

cd /path-to-repo/Convolutional-VAE-for-3D-Turbulence-Data
python main.py --test_every_epochs 3 --batch_size 32 --epochs 40 --h_dim 128 --z_dim 64

以下是一些可以修改的超参数来训练模型

  • –batch_size 每个补丁要处理的立方体数量
  • –epochs 训练纪元数
  • –h_dim 隐藏密集层的维度(连接到变分层)
  • –z_dim 潜在空间维度

main.py 脚本调用模型并根据 3D CFD 数据对其进行训练。 使用 NVIDIA Tesla V100 GPU 训练 100 个 epoch 大约需要 2 小时。 在本例中,模型训练了 170 个 epoch。

请注意,在训练 3DConvs 模型时,与 2DConvs 模型相比,学习参数的数量呈指数级增长,因此,3D 数据的训练时间要长得多。

6、模型输出

训练完pytorch模型后,会生成一个包含训练后的权重的文件checkpoint.pkl。

在训练过程中,每隔 n 个时期根据测试数据对模型进行评估,脚本将重建的立方体与原始立方体进行比较,并将它们保存为图像。 此外,损失值被记录并放置在运行文件夹中,可以通过输入以下内容使用张量板可视化损失曲线:

cd /path-to-repo/Convolutional-VAE-for-3D-Turbulence-Data
tensorboard --logdir=runs/

如果没有创建,文件夹runs/会自动生成。

7、3D 重建结果

下图中,我们展示了同一立方体样本每 n 个 epoch 的重建结果。

顶行包含原始立方体样本(对于每个速度通道)。 底行包含每 n 个时期的重建输出。

对于这个例子,我们展示了从 0 到 355 个时期的重建,间隔为 15 个时期。 请注意作为历元函数的重建的改进。
在这里插入图片描述


原文链接:3D VAE神经网络实战 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/715505.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

elk中kibana使用

1.前言 kibana是一款作为elasticsearch可视化的一款软件,将elasticsearch中的数据以可视化的状态展现出来,kibana也提供了查询、统计、修改索引等功能 2.kibana使用 索引管理 在索引管理中,可以看到所有索引的状态、运行状况、主分片、副本…

76-基于51单片机家庭红外人体检测震动报警系统(程序+原理图+元件清单全套资料)...

资料编号:076 功能介绍:采用51单片机作为主控CPU,采用红外接触传感器采集当前是否有人,采用震动传感器采集当前是否有震动,起到家庭防盗效果,采用按键设置当前布防/撤防状态,布防状态下&#xf…

Binder系列--获取ServiceManager

获取ServiceManager hongxi.zhu 2023-7-1 以SurfaceFlinger为例&#xff0c;分析客户端进程如何获取ServiceManager代理服务对象 主要流程 SurfaceFlinger中获取SM服务 frameworks/native/services/surfaceflinger/main_surfaceflinger.cpp // publish surface flingersp<…

适合初中生用的台灯有哪些?这样的台灯最适合学生!

对于学生而言台灯主要的点就是能够护眼、缓解眼睛疲劳&#xff0c;因为学生需要长时间的学习和用眼而且可以休息放松的时间比较少&#xff0c;导致眼睛过度疲劳&#xff0c;这也是为什么这么多中小学生近视的原因。那么我们应该怎么样选好一款台灯呢&#xff1f; 要想台灯能护眼…

Linux系统之dnf包管理器的基本使用

Linux系统之dnf包管理器的基本使用 一、dnf工具介绍1. dnf工具简介2. dnf的功能 二、DNF的安装1. 检查本地操作系统版本2. 安装epel3. 检查本地yum仓库状态4. 安装dnf包 三、dnf的使用帮助1. 查看dnf版本2. 查看dnf命令的帮助信息3. dnf命令的选项解释 四、dnf命令的基本使用1.…

浅谈Unicode与UTF-8

我们都知道&#xff0c;在Golang中字符都是以UTF-8编码的形式存储&#xff0c;当我们使用range遍历字符串的时候&#xff0c;go会为我们取出一个字符(rune)而不是一个byte&#xff0c;例如以下例子&#xff0c;我们使用range迭代取出第一个字符“你”&#xff0c;并且打印输出取…

TechSmith Camtasia for Mac 2023.0.3 中文破解版 Win/Mac上强大的屏幕录像工具

Camtasia 是Win/Mac上最强大的屏幕录像工具之一&#xff0c;该软件集成了视频录制、编辑、导出等一系列功能&#xff0c;支持鼠标光标样式、草绘示意插图、冰冻区域等实用的功能&#xff0c;还具有移动客户端让你录制视频&#xff0c;然后通过无线传输到 Camtasia 中进行编辑&a…

判断数组中所有元素是否均为实数对象 numpy.isrealobj()

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 判断数组中所有元素 是否均为实数对象 numpy.isrealobj() [太阳]选择题 请问关于以下代码的说法错误的是&#xff1f; import numpy as np a np.array([1, 2, 3]) b np.array([1, 1 2j, …

剑指 Offer 19: 正则表达式匹配

可能存在一个现象&#xff0c;就是aaab&#xff0c;然后a*ab&#xff0c;那么这样*只能代表一个a。 这道题可以使用动态规划的方式来解决。 这道题就是状态的判断&#xff1a;是否两个都为0&#xff1f;只有两个都为0才为true&#xff0c;并且判断*&#xff0c;有两个情况&…

Docker WebRTC容器部署方案

文章目录 WebRTC简介WebRTC Docker容器部署优势方案&#xff08;mpromonet/webrtc-streamer&#xff09;步骤 WebRTC简介 WebRTC&#xff08;Web Real-Time Communication&#xff09;是一种开放的实时通信技术&#xff0c;它允许浏览器之间进行音频、视频和数据的实时传输。W…

从古代八卦探究计算机的八进制

八进制&#xff0c;即八卦&#xff0c;是中国古代哲学体系中非常重要的一个概念&#xff0c;它被广泛应用于易经、道家、儒家等诸多领域。随着计算机科学的快速发展&#xff0c;人们开始思考&#xff1a;八进制是否可以应用到计算机上&#xff1f; 一、什么是八进制&#xff1…

Javaee技术目的总结

一.前节回顾 在前一节中&#xff0c;我们了解了&#xff1a; 1.将中央控制器中的Action容器&#xff0c;变成可控制! 2.针对于反射调用业务代码&#xff0c;最终页面跳转 3.jsp页面参数传递后台的代码优化字段太多有影响&#xff01; 二.项目部署前期准备工作 1.项目运行环境…

c语言 va_start/va_end函数

c语言 va_satrt和va_end函数介绍 头文件&#xff1a;#include <stdarg.h> 函数原型&#xff1a;void va_start(va_list ap, last) 和 void va_end(va_list ap); 可以被参数数量和类型可变的函数调用。 可变参数用…&#xff08;3个省略号表示可变参数列表&#xff09; …

深入理解 http 反向代理

要理解什么是 反向代理(reverse proxy) , 自然你得先知道什么是 正向代理(forward proxy). 另外需要说的是, 一般提到反向代理, 通常是指 http 反向代理, 但反向代理的范围可以更大, 比如 tcp 反向代理, 在这里, 不打算讨论 tcp 之类的反向代理, 当文中说到反向代理时, 指的就是…

C++day5

2、 #include <iostream> using namespace std; static int blood 10000; class hero { protected:string name;int hp;int attack; public:hero(){}//无参构造hero(string name,int hp,int attack):name(name),hp(hp),attack(attack){}//有参构造virtual void Atk(){b…

使用supervisor管理进程

写目录 一、supervisor简介二 、supervisor安装2.1下载supervisor2.2配置文件详解2.3把squid服务加入到supervisor管理当中 一、supervisor简介 supervisor是Python开发的c/s服务&#xff0c;是Linux系统下的进程管理工具。可以监听、启动、停止、重启一个或多个进程用supervi…

卡尔曼滤波原理和使用

随着传感技术&#xff0c;机器人&#xff0c;自动驾驶等不断发展&#xff0c;对控制系统的精度以及稳定性要求越来越高。卡尔曼滤波作为一种状态最优估计方法&#xff0c;应用也越来越普遍、 对于Kalman Filter的理解&#xff0c;用过都知道“黄金五条”公式&#xff0c;且通过…

pytorch实战13:基于pytorch实现YOLOv1(长长文)

基于pytorch实现YOLOv1&#xff08;长长文&#xff09; 前言 ​ 本篇文章的目的是记录自己实现yolo v1的过程&#xff0c;在此过程中&#xff0c;参考了许多开源的代码和博客&#xff0c;赞美大佬们。 参考文献和代码 YOLO v1代码参考&#xff1a;&#xff08;读书人的事情&…

华为OD机试真题 Python 实现【猜字谜】【2023Q1 100分】,附详细解题思路

目录 一、题目描述二、输入描述三、输出描述补充说明:四、解题思路五、Python算法源码六、效果展示1、输入2、输出3、说明一、题目描述 小王设计了一人简单的清字谈游戏,游戏的迷面是一人错误的单词,比如nesw,玩家需要猜出谈底库中正确的单词。猜中的要求如 对于某个谜面和…

武汉抖音seo,抖音关键词排名

抖音seo怎么做 抖音作为一款热门的社交娱乐应用&#xff0c;其SEO关键词排名对于提升内容曝光和用户流量非常重要。 1. 关键词研究&#xff1a;在进行SEO关键词排名时&#xff0c;首先需要进行关键词研究&#xff0c;了解用户在抖音上搜索的热门关键词。可以通过使用相关的关…