Glow: Generative Flow with Invertible 1×1 Convolutions
代码github: https://github.com/rosinality/glow-pytorch添加链接描述
1 模型架构如下
1.1 左边图flow模型
Flow model
① ActNorm
② InvConv2dLU
③ AffineCoupling
1.2 右边模型结构Glow模型
Glow Model
Block (L-1)个
Flow K
2 Flow层
2.1 ActNorm
1)ActNorm 就是一个线性函数:
# y=a*x +b
self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1)) # b
self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1)) # a
2) 参数初始化:
取第一批数据,计算均值方差,初始化 a 和b
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
mean = ( flatten.mean(1) .unsqueeze(1) .unsqueeze(2).unsqueeze(3)
.permute(1, 0, 2, 3) )
std = ( flatten.std(1) .unsqueeze(1) .unsqueeze(2) .unsqueeze(3)
.permute(1, 0, 2, 3))
self.loc.data.copy_(-mean)
self.scale.data.copy_(1 / (std + 1e-6))
3)log deteminant
log_abs = logabs(self.scale)
logdet = height * width * torch.sum(log_abs)
- 反函数
output / self.scale - self.loc
2.2 Invertible 1*1 convolution
2.2.1使用LU 分解计算行列式
1)权重计算
weight = (
self.w_p
@ (self.w_l * self.l_mask + self.l_eye)
@ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))
)
2)前向计算
out = F.conv2d(input, weight)
3) 行列式计算
logdet = height * width * torch.sum(self.w_s)
4)反函数计算
F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
2.2.2 行列式直接计算方法
2.3 AffineCoupling
1)网络结构
self.net = nn.Sequential(
nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(filter_size, filter_size, 1),
nn.ReLU(inplace=True),
ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),
)
2)前向公式
in_a, in_b = input.chunk(2, 1)
log_s, t = self.net(in_a).chunk(2, 1)
# s = torch.exp(log_s)
s = F.sigmoid(log_s + 2)
# out_a = s * in_a + t
out_b = (in_b + t) * s
torch.cat([in_a, out_b], 1)
3)log 行列式
logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
3)反函数
out_a, out_b = output.chunk(2, 1)
log_s, t = self.net(out_a).chunk(2, 1)
# s = torch.exp(log_s)
s = F.sigmoid(log_s + 2)
# in_a = (out_a - t) / s
in_b = out_b / s - t
torch.cat([out_a, in_b], 1)
2.4 Flow 模型
1)前向计算及log行列式
out, logdet = self.actnorm(input)
out, det1 = self.invconv(out)
out, det2 = self.coupling(out)
logdet = logdet + det1+det2
2)反函数
input = self.coupling.reverse(output)
input = self.invconv.reverse(input)
input = self.actnorm.reverse(input)
3 Glow 层
Flow * k
Block * (L-s)
3.1 Block
3.1.1 squeeze
squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2)
squeezed = squeezed.permute(0, 1, 3, 5, 2, 4)
out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2)
3.1.2 step of flow
for i in range(n_flow):
self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))
3.1.3 split
out, z_new = out.chunk(2, 1)
mean, log_sd = self.prior(out).chunk(2, 1)
log_p = gaussian_log_p(z_new, mean, log_sd)
log_p = log_p.view(b_size, -1).sum(1)
3.1.4block 层
1)反函数
input = output
if reconstruct:
if self.split:
input = torch.cat([output, eps], 1)
else:
input = eps
else:
if self.split:
mean, log_sd = self.prior(input).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = torch.cat([output, z], 1)
else:
zero = torch.zeros_like(input)
# zero = F.pad(zero, [1, 1, 1, 1], value=1)
mean, log_sd = self.prior(zero).chunk(2, 1)
z = gaussian_sample(eps, mean, log_sd)
input = z
for flow in self.flows[::-1]:
input = flow.reverse(input)
b_size, n_channel, height, width = input.shape
unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
unsqueezed = unsqueezed.contiguous().view(
b_size, n_channel // 4, height * 2, width * 2
)
return unsqueezed
3.2 Glow层
1)前向计算
self.blocks = nn.ModuleList()
n_channel = in_channel
for i in range(n_block - 1):
self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
n_channel *= 2
self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))
2)反函数
for i, block in enumerate(self.blocks[::-1]):
if i == 0:
input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
else:
input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)
return input