一、查漏补缺、熟能生巧:
主要是mu均值 和 logvar对数标准差 std标准差的 处理方面不熟练
二、代码解读:
1.sample_code中提供了3种model:fcn_model 、 conv_model 和 vae_model:
(1)fcn_model的结构非常好理解:
就是通过全连接层进行降维后,又重新升到原来的维度:(当然图像数据首先需要进行flatten)
#定义module的结构
class fcn_autoencoder(nn.Module):
def __init__(self):
super(fcn_autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(64 * 64 * 3, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 64 * 64 * 3),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
(2)conv_model:
和fcn_model的想法一样,只是改用convolution罢了:
class conv_autoencoder(nn.Module):
def __init__(self):
super(conv_autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 12, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(12, 24, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(24, 48, 4, stride=2, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),
nn.Tanh(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
(3)VAE:Variation Auto_Encoder,这个不是很好理解
所以这里会涉及到均值mu,对数标准差logvar的概念:
class VAE(nn.Module):
def __init__(self):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 12, 4, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(12, 24, 4, stride=2, padding=1),
nn.ReLU(),
)
self.enc_out_1 = nn.Sequential(
nn.Conv2d(24, 48, 4, stride=2, padding=1),
nn.ReLU(),
)
self.enc_out_2 = nn.Sequential(
nn.Conv2d(24, 48, 4, stride=2, padding=1),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(48, 24, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(24, 12, 4, stride=2, padding=1),
nn.ReLU(),
nn.ConvTranspose2d(12, 3, 4, stride=2, padding=1),
nn.Tanh(),
)
def encode(self, x):
h1 = self.encoder(x)
return self.enc_out_1(h1), self.enc_out_2(h1)
def reparametrize(self, mu, logvar): #这个函数
std = logvar.mul(0.5).exp_()
if torch.cuda.is_available():
eps = torch.cuda.FloatTensor(std.size()).normal_()
else:
eps = torch.FloatTensor(std.size()).normal_()
eps = Variable(eps) #这一句其实也可以不要
return eps.mul(std).add_(mu)
def decode(self, z):
return self.decoder(z)
def forward(self, x):
mu, logvar = self.encode(x) #这里的初始mu,logvar都是从encode中获取的
z = self.reparametrize(mu, logvar) #通过reparamizer 得到重建图像z,和对应的mu 和 logvar
return self.decode(z), mu, logvar