输入image,label分别经过生成器和判别器。
经过生成器计算的是损失和产生的图片。并且在内部损失进行反向传播,优化器进行更新。
在pix2pix_model内部:首先对输入数据进行预处理。
def preprocess_input(self, data):
# move to GPU and change data types
data['label'] = data['label'].long()
if self.use_gpu():
data['label'] = data['label'].cuda(non_blocking=True)
data['instance'] = data['instance'].cuda(non_blocking=True)
data['image'] = data['image'].cuda(non_blocking=True)
# create one-hot label map
label_map = data['label']
bs, _, h, w = label_map.size()
nc = self.opt.label_nc + 1 if self.opt.contain_dontcare_label \
else self.opt.label_nc
input_label = self.FloatTensor(bs, nc, h, w).zero_()
input_semantics = input_label.scatter_(1, label_map, 1.0)
# concatenate instance map if it exists
if not self.opt.no_instance:
inst_map = data['instance']
instance_edge_map = self.get_edges(inst_map)
input_semantics = torch.cat((input_semantics, instance_edge_map), dim=1)
return input_semantics, data['image']
在预处理代码中首先将image,label传到cuda上,接着将label转换为one-hot编码。原本的(4,1,256,256)转变成了(4,19,256,256)。最后输出编码后的label和原始的RGB。
接着计算生成器损失:
在函数内部,将RGB和编码后label输入到生成器中,产生fake image。
在SPADEGenerator函数中,首先将标签图下采样到(1,19,8,8),然后经过一个卷积输出为(1,1024,8,8)。
接着将输入的RGB和编码后标签输入到Zencoder产生ST。
RGB经过model,在代码(1)中讲述了Zencoder的组成,最终输出大小为(1,512,128,128)。将语义图也下采样到(1,19,128,128)。
首先生成一个大小为(4,19,512)向量,4是batchsize。
向code_vector里面添值。
首先b_size为4,s_size为19。
下面计算分割图中不为0的像素总数。
接着使用segmap的掩膜选取RGB经过模型的输出,再求均值,填入到code_vector直至遍历结束。
具体来说:
1:i=0,表示第一个batch,j=0表示第一个通道。segmap.bool()表示对segmap所有值进行布尔操作。
2:segmap.bool()[0,0],表示取第一个batch的第一个通道值。然后将所有为Ture的像素个数汇总,长宽为128的图片,为true的共有4387个像素。
3:code[i]等于选择code第几个batch的特征。
4:假设选择第一个batch的特征,将segmap.bool()[i, j]即segmap第一个batch的第一个通道值作为mask放在code[i]上。输出全为true的值。
5:将为true的值reshape为(512,xxxx)。在沿着行维度求均值。最终大小就是512.
6:将512个值填入codes_vector中。
6:首先内层遍历19次将第一个通道填满,外层再遍历4次,将4个通道填满。得到Zencoder最终输出style_codes。
接着输入到SPADEResnetBlock中:
在ACE中首先添加噪声:
噪声大小为(4,8,8,1),热后和噪声的方差(1024)相乘,将1024广播到1024x1024,(4,8,8,1)广播到(4,8,8,1024),相乘后为(4,8,8,1024)经过转换为(4,1024,8,8)。
将segmap下采样。
生成一个全零矩阵。
这一部分的for循环和前面的一样,首先遍历batch,再遍历通道。
1:首先获得segmap中不等于0的像素总数。
2:self.getattr(‘fc_mu’ + str(j))是实例的fc_mu0属性,对应于:
3:这里是求的style_codes[i][j],即列对应的值。
4:将512rshape为(512,1)。在扩充到(self.style_length, component_mask_area),这里的component_mask_area是下采样后的segmap。
5:将segmap对应false的值即0用来替换掉component_mu的值。
就这样执行19次,再执行外部的batch循环。
最后的SEAN可以表示为:
用公式表示为:
将输出的结果经过leakrelu和卷积并再执行一次。
最后执行一个跳连接操作。对应于文中:
head0执行完毕。将输出结果上采样并在执行一次SAPDERES。
最后输出一个大小为(4,3,256,256)大小的RGB图。
这就是生成器的全部代码。