经典网络解(三) 生成模型VAE | 自编码器、变分自编码器|有监督,无监督

news2024/11/30 8:36:20

文章目录

  • 1 有监督与无监督
  • 2 生成模型
    • 2.1 重要思路
  • 3 VAE
    • 编码器怎么单独用?
    • 解码器怎么单独用?
    • 为什么要用变分
    • 变分自编码器推导
      • 高斯混合模型
  • 4 代码实现

之前我们的很多网络都是有监督的

生成网络都是无监督的(本质就是密度估计),我们首先来讲有监督学习,无监督学习

1 有监督与无监督

有监督学习

目标学习X到Y的映射,有正确答案标注

示例

分类回归

目标检测

语义分割

无监督学习

没有标记,找出隐含在数据里的模型或者结构

示例

聚类

降维 1 线性降维:PCA主成分分析 2 非线性降维:特征学习(自编码)

密度估计

当谈论有监督学习和无监督学习时,你可以将其比喻为烹饪和探险两种不同的方式:

有监督学习就像是在烹饪中的烹饪食谱。你有一本详细的烹饪书(类似于带标签的训练数据),书中告诉你每一步应该怎么做,包括每个食材的量和准备方式(就像标签指导模型的输出)。你只需按照指示的步骤执行,最终会得到一道美味的菜肴。在这个过程中,你不需要创造新的食谱,只需遵循已有的指导。

无监督学习则类似于一场探险,你被带到一个未知的地方,没有地图或导航,只有一堆不同的植物和动物(类似于未标记的数据)。你的任务是探索并发现任何可能的规律、相似性或特征,以确定它们之间的关系(就像从未标记的数据中发现模式)。在这个过程中,你可能会发现新的物种或新的地理特征,而无需事先知道要找什么。

2 生成模型

学习训练模型的分布,然后产生自己的模型!给定训练集,产生与训练集同分布的新样本!

生成模型应用

​ 图像合成 图像属性编辑 图片风格转移等

2.1 重要思路

显示密度估计

​ 显示定义并求解分布

​ 又可以分为

1 可以求解的

PixelRNN

2 不可以求解的

VAE

隐示密度估计:学习一个模型,而无需定义它

GAN

3 VAE

变分自编码器

我们先介绍自编码器和解码器

编码器

编码器的作用一般都是提取压缩特征,降低维度,保证数据里最核心最重要的信息被保留

解码器

但是只有编码器是不行的,我不知道编码器提取的特征怎么样,所以我们需要加上解码器,解码器就可以利用提取到的特征进行重构原始数据,这样的话重构出来的图像越像原图说明编码器越好

编码器怎么单独用?

做分类或其他有监督任务

对于输入数据,利用编码器提取特征,然后输出预测标签,根据真实标签进行计算损失函数,微调网络,这种情况可以适用于少量的数据标记情况

外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传

但是这样的效果往往不如在有监督网络微调的方式

解码器怎么单独用?

做图像生成

比如给一个随机的二维编码,我就可以生成一个真实图像样本

讲完了编码器解码器,我们想问为什么要用变分?变分是什么

为什么要用变分

上面的我的自编码器的思想太死板了!他只能学到一些离散的编码,学到自己见过的内容,无法组合创新

VAE引入了概率分布的概念,它假设数据的潜在表示(潜在空间)是连续的,并使用概率分布来建模这个潜在空间。具体来说,VAE假设潜在表示服从一个潜在空间的高斯分布,其中编码器学习生成均值和方差,而解码器从这个分布中采样。这种建模方式允许VAE学习数据的连续、平滑的表示,而不仅仅是对数据的离散编码。

在这里插入图片描述

自编码器输入图像后,编码器会生成一个编码

而如图所示变分自编码器是输入图片,编码器输出一个分布(均值和方差)

生成图像的时候,从这个分布中采样送入解码器即可

两个损失函数

一个最小化重构误差

一个尽可能使得潜在表示的概率分布接近标准正态分布使得0均值1方差(损失函数一方面可以避免退化成自编码器,另一方面保证采样简单)

变分自编码器推导

高斯混合模型

用很多个简单高斯逼近最后的比较复杂的分布

优化解码器参数使得似然函数L最大, 但是实际中由于有隐变量的存在而无法积分,所以我们只能通过近似的方式

具体推导可以查看如下这篇博客

从零推导:变分自编码器(VAE) - 知乎 (zhihu.com)

4 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc21 = nn.Linear(hidden_dim, latent_dim)
        self.fc22 = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        mu = self.fc21(h1)
        log_var = self.fc22(h1)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        recon_x = torch.sigmoid(self.fc4(h3))
        return recon_x
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        recon_x = self.decode(z)
        return recon_x, mu, log_var

# 定义损失函数,通常使用重建损失和KL散度
def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

# 创建VAE实例并进行训练
input_dim = 784  # 用于示例的MNIST数据集
hidden_dim = 400
latent_dim = 20
vae = VAE(input_dim, hidden_dim, latent_dim)

# 定义优化器
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

# 训练VAE
def train_vae(train_loader, vae, optimizer, num_epochs):
    vae.train()
    for epoch in range(num_epochs):
        for batch_idx, data in enumerate(train_loader):
            data = data.view(-1, input_dim)
            recon_batch, mu, log_var = vae(data)
            loss = loss_function(recon_batch, data, mu, log_var)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_loader)}], Loss: {loss.item()}')

# 使用MNIST数据集示例
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

num_epochs = 10
train_vae(train_loader, vae, optimizer, num_epochs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1051883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

5、Linux驱动开发:设备-设备注册

目录 🍅点击这里查看所有博文 随着自己工作的进行,接触到的技术栈也越来越多。给我一个很直观的感受就是,某一项技术/经验在刚开始接触的时候都记得很清楚。往往过了几个月都会忘记的差不多了,只有经常会用到的东西才有可能真正记…

12、Kubernetes中KubeProxy实现之iptables和ipvs

目录 一、概述 二、iptables 代理模式 三、iptables案例分析 四、ipvs案例分析 一、概述 iptables和ipvs其实都是依赖的一个共同的Linux内核模块:Netfilter。Netfilter是Linux 2.4.x引入的一个子系统,它作为一个通用的、抽象的框架,提供…

华为智能高校出口安全解决方案(3)

本文承接: https://qiuhualin.blog.csdn.net/article/details/133267254?spm1001.2014.3001.5502 重点讲解华为智能高校出口安全解决方案的攻击防御&安全运维&日志审计的部署流程。 华为智能高校出口安全解决方案(3) 课程地址攻击防…

git报错:Failed to connect to 127.0.0.1 port 1080

Bug描述 由于在试了网上的这条命令 git config --global http.proxy socks5 127.0.0.1:1080 git config --global https.proxy socks5 127.0.0.1:1080git config --global http.proxy 127.0.0.1:1080 git config --global https.proxy 127.0.0.1:1080Bug描述:Faile…

对负采样(negative sampling)的一些理解

负采样(negative sampling)通常用于解决在训练神经网络模型时计算softmax的分母过大、难以计算的问题。但在LightGCN模型论文的BPR LOSS中,负采样的概念可能与传统的softmax分母问题不完全一样。 在LightGCN模型中,不同于传统的协…

Spring结合自定义注解实现 AOP 切面功能【详解】

Spring结合自定义注解实现 AOP 切面功能 Spring AOP 注解概述Aspect 快速入门execution 切点表达式 拦截指定类的方法Pointcut("annotation(xx)") 拦截拥有指定注解的方法常用注解1.Before:在切点方法前执行2.After:在切点方法后执行3.Around&…

Python爬虫获取百度图片+重命名+帧差法获取关键帧

(清库存) 获取图片 重命名 帧差法 爬虫获取图片文件重命名帧差法获取关键帧 爬虫获取图片 # 图片在当前目录下生成import requests import renum 0 numPicture 0 file List []def dowmloadPicture(html, keyword):global num# t 0pic_url re.fin…

【JVM】运行时数据区之 堆——自问自答

Q:堆和栈,在设计上有何用义? 此处我们不说数据结构的概念。 堆本身是一种存储结构,在代码的内存层面来看,无论是c 操作的原生内存,还是Java 背后的JVM,堆的作用都是进行持久存储的。 这个持久存储并不是…

集合-Collection

系列文章目录 1.集合-Collection-CSDN博客 文章目录 目录 系列文章目录 文章目录 前言 一 . 集合的继承体系 二 . 什么是Collection? 三 . 常用方法 1.add(Object element): 将指定的元素添加到集合中。 2. remove(Object element): 从集合中移除指定的元素。 3. bo…

国庆day1---消息队列实现进程之间通信方式代码,现象

snd&#xff1a; #include <myhead.h>#define ERR_MSG(msg) do{\fprintf(stderr,"__%d__:",__LINE__);\perror(msg);\ }while(0)typedef struct{ long msgtype; //消息类型char data[1024]; //消息正文 }Msg;#define SIZE sizeof(Msg)-sizeof(long)int main…

HP E1740A 模拟量输入模块

HP&#xff08;惠普&#xff09;E1740A 模拟量输入模块是一种用于数据采集和测量的工控模块&#xff0c;通常用于各种自动化和监测应用中。以下是该模拟量输入模块的一些可能特点和功能&#xff1a; 多通道输入&#xff1a; E1740A 模块通常具有多个模拟量输入通道&#xff0c;…

windows的arp响应

1.原理‘ 2.场景 3.步骤

YOLOv8+swin_transfomerv2

测试环境&#xff1a;cuda11.3 pytorch1.11 rtx3090 wsl2 ubuntu20.04 踩了很多坑&#xff0c;网上很多博主的代码根本跑不通&#xff0c;自己去github仓库复现修改的 网上博主的代码日常出现cpu,gpu混合&#xff0c;或许是人家分布式训练了&#xff0c;哈哈哈 下面上干货…

Android回收视图

本文所有代码均存放于https://github.com/MADMAX110/BitsandPizzas 回收视图是列表视图的一个更高级也更灵活的版本。 回收视图比列表视图更加灵活&#xff0c;所以需要更多设置&#xff0c;回收视图使用一个适配器访问它的数据&#xff0c;不过与列表视图不同&#xff0c;回收…

[RCTF2015]EasySQL 二次注入 regexp指定字段 reverse逆序输出

第一眼没看出来 我以为是伪造管理员 就先去测试管理员账号 去register.php 注册 首先先注册一个自己的账号 我喜欢用admin123 发现里面存在修改密码的内容 那么肯定链接到数据库了 题目又提示是sql 那我们看看能不能修改管理员密码 首先我们猜测闭合 通过用户名 admin…

HTML,CSS,JavaScript知识点

HTML&#xff0c;CSS&#xff0c;JavaScript知识点 HTML篇 HTML是超文本标记语言。文件以.html结尾。 Hello,HTML。常用的工具: 标题: <h1>一级标题</h1><h2>二级标题</h2><h3>三级标题</h3><h4>四级标题</h4>无序列表和有…

YOLOv8+swin_transfomer

测试环境&#xff1a;cuda11.3 pytorch1.11 rtx3090 wsl2 ubuntu20.04 本科在读&#xff0c;中九以上老师或者课题组捞捞我&#xff0c;孩子想读书&#xff0c;求课题组师兄内推qaq 踩了很多坑&#xff0c;网上很多博主的代码根本跑不通&#xff0c;自己去github仓库复现修…

PHP免登录积分商城系统/动力商城/积分商城兑换系统源码Tinkphp

介绍&#xff1a; PHP免登录积分商城系统/动力商城/积分商城兑换系统源码Tinkphp&#xff0c;这个免登录积分商城系统是一种新型的电子商务模式&#xff0c;它通过省去麻烦的注册步骤&#xff0c;让用户能够很快又方便去积分兑换。这种商城系统具有UI干净整洁大方、运行顺畅的…

正点原子嵌入式linux驱动开发——STM32MP1启动详解

STM32单片机是直接将程序下载到内部 Flash中&#xff0c;上电以后直接运行内部 Flash中的程序。 STM32MP157内部没有供用户使用的 Flash&#xff0c;系统都是存放在外部 Flash里面的&#xff0c;比如 EMMC、NAND等&#xff0c;因此 STM32MP157上电以后需要从外部 Flash加载程序…

Mendix中的依赖管理:npm和Maven的应用

序言 在传统java开发项目中&#xff0c;我们可以利用maven来管理jar包依赖&#xff0c;但在mendix项目开发Custom Java Action时&#xff0c;由于目录结构有一些差异&#xff0c;我们需要自行配置。同样的&#xff0c;在mendix项目开发Custom JavaScript Action时&#xff0c;…