一、VAE
背景:VAE什么变分自编码器,听起来起名都头大,用大白话告诉你。
把一个复杂图片压缩成两个参数,用这个参数采样再复原。
这个简单的东西是两个参数,均值和方差,用(0,1)随机取个值a z = μ + σ * a 这个Z可以是变化的,代码你的图片
二、代码
极简代码和公式最靠近本质,VAE如下
class Encoder(nn.Module):
def __init__(self, input_dim, hidden_dim, latent_dim):
super(Encoder, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc_mu = nn.Linear(hidden_dim, latent_dim)
self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
def forward(self, x):
x = torch.relu(self.fc1(x))
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
class Decoder(nn.Module):
def __init__(self, latent_dim, hidden_dim, output_dim):
super(Decoder, self).__init__()
self.fc1 = nn.Linear(latent_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)
def forward(self, z):
z = torch.relu(self.fc1(z))
x_hat = torch.sigmoid(self.fc2(z))
return x_hat
三、训练和损失
训练就是用GPU给你干活。你告诉GPU方向,就是用损失告诉GPU,给我降低点损失
损失就是给点数据,让数据沿着路(神经网络)往前跑的结果,这个结果输入一个损失函数,得到就是损失。
VAE的损失 包括 重建损失,KL损失(也就是分布之间的距离)
衡量原始输入 x 和重构输出 x' 之间的差异。通常使用均方误差(MSE)或交叉熵损失。
VAE在文生图例如 Stable disfussion 1.5 XL 、DIT等起的左右是把图片压缩到一个隐向量。