变分自编码器(VAE, Variational Autoencoder)

news2024/11/18 3:27:22

代码说明

VAE 模型结构:
编码器将输入数据(如 MNIST 图像)映射到潜在空间,生成均值 (mu) 和对数方差 (logvar)。
通过重新参数化技巧 (reparameterize) 从正态分布中采样潜在向量 z。
解码器将潜在向量 z 映射回原始空间,生成重构数据。
损失函数:
重构误差(BCE):衡量重构数据和原始数据的差异。
KL 散度(KLD):衡量潜在向量分布与标准正态分布的接近程度。
数据加载:
MNIST 数据集被用作示例,图像被标准化为 [0, 1] 范围。
生成结果:
测试阶段通过潜在空间随机采样生成新样本,并用 Matplotlib 可视化。
在这里插入图片描述

代码

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 超参数
input_dim = 784  # 输入维度 (28x28 图像展开为向量)
hidden_dim = 400  # 隐藏层维度
latent_dim = 20   # 潜在空间维度
batch_size = 128  # 批量大小
num_epochs = 20   # 训练轮数
learning_rate = 1e-3  # 学习率

# 数据加载
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
transform = transforms.Compose([
    transforms.ToTensor()  # 将像素值直接归一化到 [0, 1]
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# VAE 模型定义
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        # 编码器
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        # 解码器
        self.fc2 = nn.Linear(latent_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, input_dim)

    def encode(self, x):
        h1 = torch.relu(self.fc1(x))
        mu = self.fc_mu(h1)
        logvar = self.fc_logvar(h1)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h2 = torch.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h2))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# 构造模型、损失函数和优化器
model = VAE(input_dim, hidden_dim, latent_dim)
criterion = nn.BCELoss(reduction='sum')  # 二元交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练
def loss_function(recon_x, x, mu, logvar):
    BCE = criterion(recon_x, x)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for data, _ in train_loader:
        data = data.view(-1, input_dim)  # 展平输入图像
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {train_loss / len(train_loader.dataset):.4f}")

# 测试(生成样本)
model.eval()
with torch.no_grad():
    z = torch.randn(16, latent_dim)  # 随机采样潜在向量
    samples = model.decode(z).view(-1, 1, 28, 28)  # 生成样本

# 可视化生成结果
plt.figure(figsize=(8, 8))
for i in range(16):
    plt.subplot(4, 4, i + 1)
    plt.imshow(samples[i][0].numpy(), cmap='gray')
    plt.axis('off')
plt.suptitle('Generated Samples from VAE')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2242556.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

DataWorks on EMR StarRocks,打造标准湖仓新范式

在大数据领域,数据仓库和实时分析系统扮演着至关重要的角色。DataWorks 基于大数据引擎,为数据仓库/数据湖/湖仓一体等解决方案提供统一的全链路大数据开发治理平台,为用户带来智能化的数据开发和分析体验。而阿里云提供的 EMR Serverless St…

Redis系列之底层数据结构ZipList

Redis系列之底层数据结构ZipList 实验环境 Redis 6.0 什么是Ziplist? Ziplist,压缩列表,这种数据结构会根据存入数据的类型和大小,分配大小不同的空间,所以是为了节省内存而采用的。因为这种数据结构是一种完整连续…

【FPGA开发】AXI-Stream总线协议解读

文章目录 AXI-Stream概述协议中一些定义字节定义流的定义 数据流类别字节流连续对齐流连续不对齐流稀疏流 协议的信号信号列表 文章为个人理解整理,如有错误,欢迎指正! 参考文献 ARM官方手册 《IHI0051B》 AXI-Stream概述 协议中一些定义 A…

c# 调用c++ 的dll 出现找不到函数入口点

今天在调用一个设备的dll文件时遇到了一点波折,因为多c 不熟悉,调用过程张出现了找不到函数入口点,一般我们使用c# 调用c 文件,还是比较简单。 [DllImport("AtnDll2.dll",CharSet CharSet.Ansi)]public static extern …

L11.【LeetCode笔记】有效的括号

目录 1.题目 2.分析 理解题意 解决方法 草稿代码 ​编辑 逐一排错 1.当字符串为"["时,分析代码 2.当字符串为"()]"时,分析代码 正确代码(isValid函数部分) 提交结果 3.代码优化 1.题目 https://leetcode.cn/problems/valid-parentheses/descri…

Unity类银河战士恶魔城学习总结(P129 Craft UI 合成面板UI)

【Unity教程】从0编程制作类银河恶魔城游戏_哔哩哔哩_bilibili 教程源地址:https://www.udemy.com/course/2d-rpg-alexdev/ 本章节实现了合成面板的UI设置 UI_CraftWindow.cs 字段作用: UI 组件: itemName / itemDescription / icon&#…

Python酷库之旅-第三方库Pandas(221)

目录 一、用法精讲 1036、pandas.DatetimeIndex.to_pydatetime方法 1036-1、语法 1036-2、参数 1036-3、功能 1036-4、返回值 1036-5、说明 1036-6、用法 1036-6-1、数据准备 1036-6-2、代码示例 1036-6-3、结果输出 1037、pandas.DatetimeIndex.to_series方法 10…

通过JS实现下载图片到本地教程分享

今天分享的这个方法我之前自己试了一下&#xff0c;感觉还行&#xff0c;原理就是通过<a>标签的新增属性实现的&#xff0c;然后可以强制触发下载功能&#xff0c;废话不多说&#xff0c;直接上教程。 首先在HTML写下面的代码: <a href"img.jpg" download…

二、神经网络基础与搭建

神经网络基础 前言一、神经网络1.1 基本概念1.2 工作原理 二、激活函数2.1 sigmoid激活函数2.1.1 公式2.1.2 注意事项 2.2 tanh激活函数2.2.1 公式2.2.2 注意事项 2.3 ReLU激活函数2.3.1 公式2.3.2 注意事项 2.4 SoftMax激活函数2.4.1 公式2.4.2 Softmax的性质2.4.3 Softmax的应…

Unreal engine5实现类似鬼泣5维吉尔二段跳

系列文章目录 文章目录 系列文章目录前言一、实现思路二、具体使用蓝图状态机蓝图接口三、中间遇到的问题 前言 先看下使用Unreal engine5实现二段跳的效果 一、实现思路 在Unreal Engine 5 (UE5) 中使用蓝图系统实现类似于《鬼泣5》中维吉尔的二段跳效果&#xff0c;可以通…

RAG经验论文《FACTS About Building Retrieval Augmented Generation-based Chatbots》笔记

《FACTS About Building Retrieval Augmented Generation-based Chatbots》是2024年7月英伟达的团队发表的基于RAG的聊天机器人构建的文章。 这篇论文在待读列表很长时间了&#xff0c;一直没有读&#xff0c;看题目以为FACTS是总结的一些事实经验&#xff0c;阅读过才发现FAC…

Java 简单家居开关系统

1.需求&#xff1a; 面向对象编程实现智能家居控制系统&#xff08;简单的开关&#xff09; 2.实现思路 1.定义设备类&#xff1a;创建设备对象代表家里的设备 JD类&#xff1a; import lombok.AllArgsConstructor; import lombok.Data; import lombok.NoArgsConstructor;D…

任务调度工具Spring Test

Spring Task 是Spring框架提供的任务调度工具&#xff0c;可以按照约定的时间自动执行某个代码逻辑。 作用&#xff1a;定时自动执行某段Java代码 应用场景&#xff1a; 信用卡每月还款提醒 银行贷款每月还款提醒 火车票售票系统处理未支付订单 入职纪念日为用户发送通知 一.…

SAFETY LAYERS IN ALIGNED LARGE LANGUAGEMODELS: THE KEY TO LLM SECURITY

目录 概要 背景 大语言模型对齐 对齐大语言模型中的过度拒绝 微调攻击 研究设置 问题定义 对齐的大语言模型 大语言模型的提示模板 安全层的存在和定位 安全层的存在性 1.从余弦相似度说明 2.从向量之间角度差异说明 3.与预训练LLM对比说明 安全层的定位 1.推理…

netcore Kafka

一、新建项目KafakDemo <ItemGroup><PackageReference Include"Confluent.Kafka" Version"2.6.0" /></ItemGroup> 二、Program.cs using Confluent.Kafka; using System; using System.Threading; using System.Threading.Tasks;names…

在k8s上部署Crunchy Postgres for Kubernetes

目录 一、前言二、安装Crunchy Postgres for Kubernetes三、部署一个简单的postgres集群四、增加pgbouncer五、数据备份六、备份恢复七、postgres配置参数八、数据导入九、权限管理 一、前言 Crunchy Postgres可以帮助我们在k8s上快速部署一个高可用、具有自动备份和恢复功能的…

函数指针示例

目录&#xff1a; 代码&#xff1a; main.c #include <stdio.h> #include <stdlib.h>int Max(int x, int y); int Min(int x, int y);int main(int argc, char**argv) {int x,y;scanf("%d",&x);scanf("%d",&y);int select;printf(&q…

计算机网络:运输层 —— 运输层端口号

文章目录 运输层端口号的分类端口号与应用程序的关联应用举例发送方的复用和接收方的分用 运输层端口号的分类 端口号只具有本地意义&#xff0c;即端口号只是为了标识本计算机网络协议栈应用层中的各应用进程。在因特网中不同计算机中的相同端口号是没有关系的&#xff0c;即…

牛客挑战赛77

#include <iostream>// 函数 kXOR&#xff1a;计算两个数在 k 进制下的异或和 // 参数&#xff1a; // a: 第一个正整数 // b: 第二个正整数 // k: 进制基数 // 返回值&#xff1a; // 两数在 k 进制下的异或和&#xff08;十进制表示&#xff09; long long kXO…

大数据-225 离线数仓 - 目前需求分析 指标口径 日志数据采集 taildir source HDFS Sink Agent Flume 优化配置

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; 目前已经更新到了&#xff1a; Hadoop&#xff08;已更完&#xff09;HDFS&#xff08;已更完&#xff09;MapReduce&#xff08;已更完&am…