基于pytorch使用仿真数据集来训练一个深度学习模型进行相位解包裹

news2024/11/29 13:42:04

使用 PyTorch 来训练一个深度学习模型进行相位解包裹是一种常见的方法。下面是一个详细的示例,展示如何生成仿真数据集并在 PyTorch 中训练模型。

1. 生成仿真数据集

首先,我们生成一些仿真数据集,包含多个包裹相位图和对应的解包裹相位图。

import numpy as np
import matplotlib.pyplot as plt

# 参数设置
nx, ny = 128, 128  # 图像尺寸
num_samples = 1000  # 生成样本数量

# 生成仿真数据集
def generate_dataset(num_samples, nx, ny):
    X, Y = np.meshgrid(np.linspace(-1, 1, nx), np.linspace(-1, 1, ny))
    true_phases = []
    wrapped_phases = []
    for _ in range(num_samples):
        # 生成真实相位分布
        phi_true = 3 * np.exp(-(X**2 + Y**2) / 0.2**2) + 2 * np.random.randn(nx, ny)
        # 生成包裹相位分布
        phi_wrapped = np.angle(np.exp(1j * phi_true))
        true_phases.append(phi_true)
        wrapped_phases.append(phi_wrapped)
    return np.array(true_phases), np.array(wrapped_phases)

# 生成数据集
true_phases, wrapped_phases = generate_dataset(num_samples, nx, ny)

# 显示仿真数据
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(true_phases[0], cmap='viridis')
plt.title('真实相位分布')
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(wrapped_phases[0], cmap='viridis')
plt.title('包裹相位分布')
plt.colorbar()

plt.show()

2. 数据预处理

将生成的数据集转换为 PyTorch 的 Tensor 格式,并创建数据加载器(DataLoader)。

import torch
from torch.utils.data import Dataset, DataLoader

# 自定义数据集类
class PhaseUnwrappingDataset(Dataset):
    def __init__(self, wrapped_phases, true_phases):
        self.wrapped_phases = wrapped_phases
        self.true_phases = true_phases

    def __len__(self):
        return len(self.wrapped_phases)

    def __getitem__(self, idx):
        wrapped_phase = self.wrapped_phases[idx]
        true_phase = self.true_phases[idx]
        wrapped_phase = torch.tensor(wrapped_phase, dtype=torch.float32).unsqueeze(0)  # (1, nx, ny)
        true_phase = torch.tensor(true_phase, dtype=torch.float32).unsqueeze(0)  # (1, nx, ny)
        return wrapped_phase, true_phase

# 创建数据集和数据加载器
dataset = PhaseUnwrappingDataset(wrapped_phases, true_phases)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

3. 构建深度学习模型

使用 PyTorch 构建一个卷积神经网络(CNN)模型,包含几个卷积层和反卷积层。

import torch.nn as nn

class PhaseUnwrappingNet(nn.Module):
    def __init__(self):
        super(PhaseUnwrappingNet, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 1, kernel_size=3, padding=1),
            nn.Tanh()  # 使用 Tanh 激活函数将输出限制在 [-1, 1] 范围内
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

4. 训练模型

编写训练代码,使用生成的数据集训练模型。

import torch.optim as optim

# 创建模型实例
model = PhaseUnwrappingNet()
model = model.to('cuda' if torch.cuda.is_available() else 'cpu')

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        wrapped_phase, true_phase = batch
        wrapped_phase = wrapped_phase.to(device)
        true_phase = true_phase.to(device)

        # 前向传播
        outputs = model(wrapped_phase)
        loss = criterion(outputs, true_phase)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}')

# 保存模型
torch.save(model.state_dict(), 'phase_unwrapping_model.pth')

5. 测试模型

从数据集中选择一个包裹相位图进行测试,预测其解包裹相位图。

# 加载模型
model.load_state_dict(torch.load('phase_unwrapping_model.pth'))
model.to(device)
model.eval()

# 选择一个测试样本
with torch.no_grad():
    test_wrapped_phase = wrapped_phases[0].unsqueeze(0).to(device)
    test_true_phase = true_phases[0].to(device)
    test_unwrapped_phase = model(test_wrapped_phase).squeeze(0).cpu().numpy()

# 显示结果
plt.figure(figsize=(12, 6))
plt.subplot(1, 3, 1)
plt.imshow(true_phases[0], cmap='viridis')
plt.title('真实相位分布')
plt.colorbar()

plt.subplot(1, 3, 2)
plt.imshow(wrapped_phases[0], cmap='viridis')
plt.title('包裹相位分布')
plt.colorbar()

plt.subplot(1, 3, 3)
plt.imshow(test_unwrapped_phase, cmap='viridis')
plt.title('恢复的相位分布')
plt.colorbar()

plt.show()

详细步骤解释

  1. 生成仿真数据集

    • 使用 generate_dataset 函数生成真实相位分布和对应的包裹相位分布。
    • phi_true 是真实相位分布,phi_wrapped 是包裹相位分布。
  2. 自定义数据集类

    • PhaseUnwrappingDataset 类继承自 PyTorch 的 Dataset 类,用于加载和预处理数据。
    • __getitem__ 方法返回一个包裹相位图和对应的真实相位图,并将它们转换为 PyTorch 的 Tensor 格式。
  3. 创建数据加载器

    • DataLoader 用于批量加载数据,并在训练过程中对数据进行随机打乱。
  4. 构建模型

    • PhaseUnwrappingNet 是一个包含编码器和解码器的卷积神经网络模型。
    • 编码器和解码器分别使用卷积层和反卷积层,中间使用 ReLU 激活函数。
    • 最后一层使用 Tanh 激活函数,将输出限制在 [-1, 1] 范围内。
  5. 训练模型

    • 使用均方误差(MSE)作为损失函数,Adam 优化器进行优化。
    • 训练过程中,模型在每个 epoch 的损失会被记录并打印出来。
  6. 测试模型

    • 从数据集中选择一个包裹相位图进行测试,预测其解包裹相位图。
    • 使用 imshow 函数显示真实相位分布、包裹相位分布和恢复的相位分布。

注意事项

  1. 数据集大小:生成的数据集大小需要根据你的硬件资源进行调整。如果内存不足,可以减少 num_samples 或使用更小的图像尺寸。
  2. 模型架构:上述模型是一个简单的 CNN 模型,你可以根据具体任务的复杂性调整模型的层数和参数。
  3. 损失函数:除了 MSE 损失函数,还可以尝试其他损失函数,如 L1 损失或自定义的损失函数,以提高模型的性能。
  4. 数据增强:为了提高模型的泛化能力,可以在训练数据集上应用数据增强技术,如旋转、平移等。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2249783.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

OGRE 3D----2. QGRE + QQuickView

将 OGRE(面向对象图形渲染引擎)集成到使用 QQuickView 的 Qt Quick 应用程序中,可以在现代灵活的 UI 框架中提供强大的 3D 渲染功能。本文将指导您如何在 QQuickView 环境中设置 OGRE。 前提条件 在开始之前,请确保您已安装以下内容: Qt(版本 5.15 )OGRE(版本14.2.5)…

丹摩 | 利用 CogVideoX 生成视频

声明:非广告,纯用户体验 1. CogVideoX CogVideoX 是智谱 AI 推出的一款极具创新性与突破性的视频生成产品。它在技术层面展现出诸多卓越特性,例如其采用的 Diffusion Transformer(DiT)架构奠定了强大的生成能力基础…

本地化部署 私有化大语言模型

本地化部署 私有化大语言模型 本地化部署 私有化大语言模型Anaconda 环境搭建运行 代码概述环境配置安装依赖CUDA 环境配置 系统设计与实现文件处理与加载文档索引构建模型加载与推理文件上传与索引更新实时对话与文档检索Gradio 前端设计 主要功能完整代码功能说明运行示例文件…

05_JavaScript注释与常见输出方式

JavaScript注释与常见输出方式 JavaScript注释 源码中注释是不被引擎所解释的,它的作用是对代码进行解释。lavascript 提供两种注释的写法:一种是单行注释,用//起头:另一种是多行注释,放在/*和*/之间。 //这是单行注释/* 这是 多行 注释 *…

python常见问题-pycharm无法导入三方库

1.运行环境 python版本:Python 3.9.6 需导入的greenlet版本:greenlet 3.1.1 2.当前的问题 由于需要使用到greenlet三方库,所以进行了导入,以下是我个人导入时的全过程 ①首先尝试了第1种导入方式:使用pycharm进行…

vue3实现自定义导航菜单

一、创建项目 1. 打开HBuilder X 图1 2. 新建一个空项目 文件->新建->项目->uni-app 填写项目名称:vue3demo 选择项目存放目录:D:/HBuilderProjects 一定要注意vue的版本,当前选择的版本为vue3 图2 点击“创建”之后进入项目界面 图…

多模态图像生成模型Qwen2vl-Flux,利用Qwen2VL的视觉语言理解能力增强FLUX,可集成ControlNet

Qwen2vl-Flux 是一种先进的多模态图像生成模型,它利用 Qwen2VL 的视觉语言理解能力增强了 FLUX。该模型擅长根据文本提示和视觉参考生成高质量图像,提供卓越的多模态理解和控制。让 FLUX 的多模态图像理解和提示词理解变得很强。 Qwen2vl-Flux有以下特点…

原生html+css+ajax+php图片压缩后替换原input=file上传

当前大部分照片尺寸大于5MB&#xff0c;而50MB限制的PHP通常上传4MB左右 于是就需要压缩后上传&#xff0c;上5代码使用后筛选的代码 <?php if ($_SERVER[REQUEST_METHOD] POST) { $uploadDir uploads/ . date(Ymd) . /; if (!is_dir($uploadDir)) { mkdir($uploadDir, …

1 ISP一键下载

BOOT0BOOT1启动模式说明0X用户Flash用户闪存存储器&#xff0c;也就是Flash启动10系统存储器系统存储器启动&#xff0c;串口下载11SRAM启动SRAM启动&#xff0c;用于在SRAM中调试代码 闪存存储器 是STM32 的内置FLASH,一般使用JTAG或者SWD模式下载程序时&#xff0c;就是下载…

泷羽sec学习打卡-shell命令4

声明 学习视频来自B站UP主 泷羽sec,如涉及侵权马上删除文章 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都 与本人无关,切莫逾越法律红线,否则后果自负 关于shell的那些事儿-shell4 如何在shell脚本与用户进行交互&#xff1f;如何限制用户输入的字符个数呢…

电子应用设计方案-27:智能淋浴系统方案设计

智能淋浴系统方案设计 一、系统概述 本智能淋浴系统旨在为用户提供舒适、便捷、个性化的淋浴体验&#xff0c;通过集成多种智能技术&#xff0c;实现水温、水流、淋浴模式的精准控制以及与其他智能家居设备的联动。 二、系统组成 1. 喷头及淋浴杆 - 采用可调节角度和高度的设计…

Spring系列之批处理Spring Batch介绍

概述 官网&#xff0c;GitHub A lightweight, comprehensive batch framework designed to enable the development of robust batch applications vital for the daily operations of enterprise systems. 执行流程 实战 假设有个待处理的任务&#xff0c;如文件batch-tes…

内存共享模型和Actor 模型

内存共享模型&#xff1a; 典型代表&#xff1a;java Actor 模型&#xff1a; 典型代表&#xff1a;HamnoyOS API 13

机器学习期末复习笔记

markdown文件下载&#xff1a;https://github.com/1037827920/SCUT-Notes/tree/main/%E6%9C%BA%E5%99%A8%E5%AD%A6%E4%B9%A0 机器学习期末复习笔记 1. 机器学习简介 1.1 什么是机器学习 如图所示&#xff1a; 几乎所有的机器学习都包括以下三个部分&#xff1a;数据、模型和…

Spring Boot拦截器(Interceptor)详解

拦截器Interceptor 拦截器我们主要分为三个方面进行讲解&#xff1a; 介绍下什么是拦截器&#xff0c;并通过快速入门程序上手拦截器拦截器的使用细节通过拦截器Interceptor完成登录校验功能 1. 快速入门 什么是拦截器&#xff1f; 是一种动态拦截方法调用的机制&#xff…

Python基础学习-12匿名函数lambda和map、filter

目录 1、匿名函数&#xff1a; lambda 2、Lambda的参数类型 3、map、 filter 4、本节总结 1、匿名函数&#xff1a; lambda 1&#xff09;语法&#xff1a; lambda arg1, arg2, …, argN : expression using arg 2&#xff09; lambda是一个表达式&#xff0c;而不是一个语…

【SpringBoot】Spring Data Redis的环境搭建(win10)

启动redis服务 进入redis安装目录&#xff0c;启动cmd Redis客户端连接redis服务 我用的redis客户端是github上一个大佬写的&#xff0c;叫 Another Redis Desktop Manager Java框架操作Redis 框架有很多&#xff0c;比如Jedis&#xff0c;Spring Data Redis&#xff0c;Let…

联想品牌的电脑 Bios 快捷键是什么?如何进入 Bios 设置?

在某些情况下&#xff0c;您可能需要通过U盘来安装操作系统或进行系统修复。对于联想电脑用户来说&#xff0c;了解如何设置U盘作为启动设备是非常有用的技能之一。本文简鹿办公将指导您如何使用联想电脑的 U 盘启动快捷键来实现这一目标。 联想笔记本 对于大多数联想笔记本电…

51单片机教程(九)- 数码管的动态显示

1、项目分析 通过演示数码管动态显示的操作过程。 2、技术准备 1、 数码管动态显示 4个1位数码管和单片机如何连接 a、静态显示的连接方式 优点&#xff1a;不需要动态刷新&#xff1b;缺点&#xff1a;占用IO口线多。 b、动态显示的连接方式 连接&#xff1a;所有位数码…

windows下安装node.js和pnpm

首先&#xff0c;一定要powershell右键选择管理员身份运行&#xff0c;否则第三个命令报错。 # 安装 fnm (快速 Node 管理器) winget install Schniz.fnm# 配置 fnm 环境 fnm env --use-on-cd | Out-String | Invoke-Expression# 下载并安装 Node.js fnm use --install-if-mis…