目录
写在前面
一、VAE结构
二、损失函数
三、代码实现
1.训练代码
2.推理生成图片
3.插值编辑图片
四、总结
写在前面
论文地址:https://arxiv.org/abs/1312.6114
大模型已经有了突破性的进展,图文的生成质量都越来越高,可控性也越来越强。很多阅读大模型源码的小伙伴会发现,大部分大模型,尤其是CV模型都会用到一个子模型:变分自编码器(VAE),这篇文章就以图像生成为例介绍一下VAE,并且解释它问什么天生适用于图像生成。配合代码尽量做到通俗易懂。
变分自编码器(VAE)是一种生成模型,旨在通过学习数据的潜在表示(Latent)来生成新样本。VAE 的训练目标是最大化变分下界,这意味着在学习潜在空间时,保持生成样本与真实数据的相似性,并尽量让潜在变量的分布接近标准正态分布。这样一来,模型就能有效地生成多样化的新图像。
上面那段话似乎不容易理解,我用白话解释一遍。VAE 的最大作用是尽量简单的生成“能看的”图片。现在达到的效果是输入一段标准高斯分布的Latent,就能生成自然连贯的图像。而且生成的图像有如下三个特点:
1.这个图像是全新的(也许跟某些训练数据相似);
2.通过编辑Latent可以一定程度上控制生成图像中的内容;
3.Latent空间中的结构化使得生成的图像自然且连贯,也就是说输入虽然是随机的,但输出是“能看的”,不是无意义的图像。
一、VAE结构
VAE由如下三块组成:
1.编码器(Encoder):输入数据通过编码器转换为潜在空间的分布。编码器通常由几层神经网络组成,输出潜在变量的均值和方差(其实是对数方差)。
2.重参数化层(Reparameterize):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。这一过程使得模型能够在训练时进行反向传播。
3.解码器(Decoder):解码器接收潜在变量并将其转换回原始数据的分布。解码器同样由神经网络组成,目的是重构输入数据。
可以看到和AE相比,VAE的结构差别主要集中在编码器和潜在空间的处理。编码器有两个输出均值和方差(其实是对数方差);中间的重参数化层根据均值和方差重采样得到Latent,我们一般管他叫做z。
下面我们使用MNIST数据集模拟一个VAE的结构,编码器和解码器使用最简单的全连接,Hidden维度400,Latent维度20,batch_size=128。
可以看到,编码器的输出是两个128x20的特征图,用于重参数化;重参数化的输出是128x20,也就是每一个点都根据对应的均值和方差采样得来。
二、损失函数
(VAE)的损失函数主要由两部分组成:
1.重构损失(Reconstruction Loss):衡量模型生成的样本与原始输入之间的差异,通常使用均方误差(MSE)或二元交叉熵(Binary Cross-Entropy)作为度量。这部分确保生成的样本尽量忠实于输入数据。
2.KL散度(Kullback-Leibler Divergence):衡量编码器输出的潜在分布与先验分布(通常是标准正态分布)之间的差异。目标是使得 逼近标准正态分布,使得采样变得更加合理。
重构损失没什么可说的,下面给出KL散度的公式:
KL散度代码实现:在代码实现的时候编码器的输出其实是均值mu和对数方差log_var,这一点在上图也能看出来:
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
其中log_var
对应对数方差,使用对数方差的形式可以保证数值稳定性、避免负值以及计算便利性,这种做法在许多深度学习模型中都得到了广泛应用,尤其是在处理概率分布时。;mu
是均值;,在代码中就是log_var.exp()。
KL散度会在下一篇文章详细介绍,这里到此为止。
三、代码实现
1.训练代码
下面是训练的全部代码,很简单,没什么可说的,重点是重参数化层和损失函数中的KL散度。
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义 VAE 模型
class VAE(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # 输出均值和对数方差
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid() # 输出为 [0, 1]
)
def encode(self, x):
"""
编码器
:param x:
:return:
"""
h = self.encoder(x)
mu, log_var = h.chunk(2, dim=-1)
return mu, log_var
@staticmethod
def reparameterize(mu, log_var):
"""
重参数化
:param mu:
:param log_var:
:return:
"""
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
"""
解码器
:param z:
:return:
"""
return self.decoder(z)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
def loss_function(recon_x, x, mu, log_var):
"""
重构损失和 KL 散度
:param recon_x:
:param x:
:param mu:
:param log_var:
:return:
"""
BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return BCE + KLD
def train(model, train_loader, optimizer, epoch):
"""
训练模型
:param model:
:param train_loader:
:param optimizer:
:param epoch:
:return:
"""
model.train()
for batch_idx, (data, _) in enumerate(train_loader):
data = data.view(data.size(0), -1) # 展平输入
optimizer.zero_grad()
recon_batch, mu, log_var = model(data)
loss = loss_function(recon_batch, data, mu, log_var)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}]: Loss: {loss.item()}')
# 超参数
input_dim = 28 * 28 # MNIST
hidden_dim = 400
latent_dim = 20
batch_size = 128
learning_rate = 1e-3
num_epochs = 200
# 数据加载
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # 展平
])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化模型和优化器
model = VAE(input_dim, hidden_dim, latent_dim)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(1, num_epochs + 1):
train(model, train_loader, optimizer, epoch)
if epoch % 20 == 0:
# 保存模型
torch.save(model.state_dict(), 'model_data/vae_mnist_{}.pth'.format(epoch))
2.推理生成图片
下面是推理代码,理论上一个训练好的解码器,只需要标准高斯分布的随机噪声作为输入即可。我们来试一下,只使用解码器,输入是标准高斯分布的采样数据,输出是数字图片。
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 定义 VAE 模型
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid() # 输出为 [0, 1]
)
def decode(self, z):
return self.decoder(z)
def forward(self, x):
pass
def ran_demo():
with torch.no_grad():
z = torch.randn(64, latent_dim).to(device) # 随机采样
sample = model.decode(z).cpu()
# 绘制生成的样本
fig, axes = plt.subplots(8, 8, figsize=(8, 8))
for i in range(64):
axes[i // 8, i % 8].imshow(sample[i].view(28, 28), cmap='gray')
axes[i // 8, i % 8].axis('off')
plt.show()
if __name__ == '__main__':
# 超参数
input_dim = 28 * 28 # MNIST
hidden_dim = 400
latent_dim = 20
# hidden_dim = 1024
# latent_dim = 128
batch_size = 128
learning_rate = 1e-3
num_epochs = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型和优化器
model = VAE()
# 加载模型并生成图像
model.load_state_dict(torch.load('model_data/vae_mnist_1000.pth', map_location=torch.device('cpu')))
# model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
model.eval()
# 随机输入
ran_demo()
输出结果如下:大部分是能看出来的数字的。毕竟只是一个简单的demo,就不要在意细节了。(#^.^#)
3.插值编辑图片
下面玩一个有意思的,既然不同的Latent分布控制着不同的图像特征,那么我们试试把一个数字的Latent通过插值慢慢混入另一个数字的Latent,看看会发生什么。我们在数字6的Latent中慢慢混入7.
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
# 定义 VAE 模型
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, latent_dim * 2) # 输出均值和对数方差
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, input_dim),
nn.Sigmoid() # 输出为 [0, 1]
)
def encode(self, x):
h = self.encoder(x)
mu, log_var = h.chunk(2, dim=-1)
return mu, log_var
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
print(eps)
return mu + eps * std
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, log_var = self.encode(x)
z = self.reparameterize(mu, log_var)
return self.decode(z), mu, log_var
def interpolate_demo(from_num, to_num):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.view(-1)) # 展平
])
dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
data_loader = DataLoader(dataset, batch_size=1, shuffle=True)
def interpolate(z1, z2, num_steps=10):
return [(1 - alpha) * z1 + alpha * z2 for alpha in np.linspace(0, 1, num_steps)]
# 找到数字“1”和“7”的潜在向量
def get_latent_vector(digit):
model.eval()
with torch.no_grad():
for data, labels in data_loader:
if labels[0] == digit:
data = data.to(device)
mu, log_var = model.encode(data.view(-1, input_dim))
return mu.mean(0).cpu().numpy() # 返回均值作为潜在向量
# 获取两个数字的向量
latent_1 = get_latent_vector(from_num)
latent_7 = get_latent_vector(to_num)
# 计算插值向量
interpolated_latents = interpolate(latent_1, latent_7)
# 使用解码器生成图像
with torch.no_grad():
generated_images = [model.decode(torch.tensor(latent).float().to(device)).view(28, 28).cpu().numpy() for latent
in interpolated_latents]
# 可视化生成的图像
fig, axs = plt.subplots(1, len(generated_images), figsize=(15, 3))
for i, img in enumerate(generated_images):
axs[i].imshow(img, cmap='gray')
axs[i].axis('off')
plt.show()
if __name__ == '__main__':
# 超参数
input_dim = 28 * 28 # MNIST
hidden_dim = 400
latent_dim = 20
# hidden_dim = 1024
# latent_dim = 128
batch_size = 128
learning_rate = 1e-3
num_epochs = 500
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化模型和优化器
model = VAE()
# 加载模型并生成图像
model.load_state_dict(torch.load('model_data/vae_mnist_200.pth', map_location=torch.device('cpu')))
model.eval()
# 插值demo
interpolate_demo(6, 7)
可以看到数字6慢慢变成了数字7,中间的几张图既有6的特征又有7的特征。通过控制Latent确实可以控制输出图像的特征。那是不是也可以把一个人的脸慢慢变成另一个人的脸呢,我感觉可以试试。
四、总结
1.与AE模型相比,VAE主要有两处修改:
(1)编码器输出均值和方差(对数方差),经过重参数化层重采样后得到Latent,再进行解码;
(2)损失函数加入了KL散度,衡量编码器输出的Latent分布与先验分布(通常是标准正态分布)之间的差异,同时起到正则化的目的,使码器输出的Latent分布尽量符合标准高斯分布。
2.为什么VAE适合用在生成任务?
(1)容易生成的“能看的”图像:解码器只需接受标准高斯分布的采样数据就能生成自然连贯的图像,这意味着我们不再为生成的图像过于抽象而烦恼;
(2)生成图像的属性可以编辑:图像的各种属性特征都蕴含在Latent里,只要找到方法对齐并组合这些特征,我们就能控制输出图像的内容,比如:长着牛头的企鹅。这就是为什么当今很多生成模型吧VAE作为一个模块来使用,同时还需要配合其它模型来完成特定的生成任务,这点今天不做过多讨论。
总之VAE极大推动了生成任务,是很有研究价值的,小伙伴们快玩起来吧。
VAE就介绍到这,关注不迷路(*^__^*)
关注订阅号了解更多精品文章
交流探讨请加微信