经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

news2024/10/5 15:57:48

经典神经网络(10)PixelCNN模型、Gated PixelCNN模型及其在MNIST数据集上的应用

1 PixelCNN

  • PixelCNN是DeepMind团队在论文Pixel Recurrent Neural Networks (16.01)提出的一种生成模型,实际上这篇论文共提出了两种架构:PixelRNNPixelCNN,两者的主要区别是前者用LSTM来建模,而PixelCNN是基于CNN的,相比RNN,CNN计算更高效,我们这里只讨论PixelCNN。

  • PixelCNN借用了NLP里的方法来生成图像。对于自然图像,每个像素值的取值范围为0~255,共256个离散值。PixelCNN模型会根据前i - 1个像素输出第i个像素的概率分布。

  • 训练时,和多分类任务一样,要根据第i个像素的真值和预测的概率分布求交叉熵损失函数

  • 采样时(图像生成时),会根据前i - 1个像素直接从预测的概率分布(多项分布)里采样出第i个像素。

1.1 单通道PixelCNN

1.1.1 掩码卷积

我们现在知道了PixelCNN的大体思路,就是根据前i - 1个像素输出第i个像素的概率分布。我们现在只考虑单通道图像,每个像素的颜色取值只有256种,那么很容易想到下面的实现方式:

在这里插入图片描述

但是只输出一个像素的概率分布,这样训练效率太低了。

  • 在训练时,我们可以输入一幅图像,同时让模型输出图像每一点像素的概率分布(如下图所示),这样就能通过每个像素的真值和模型预测的概率分布求交叉熵损失函数,进行并行训练。
  • 我们能这么做的原因是:在训练时,整幅训练图像是已知的,因此我们可以在一次前向传播后得到图像每一处的概率分布。
  • 当然,我们需要找到每个像素都忽略后续像素的信息的方法,即论文中提出的掩码卷积机制,我们后面再讲。

在这里插入图片描述

但是在生成图像(采样)时,还是要一个像素一个像素的生成(如下所示)

  • 在采样时,我们会先根据前i - 1个像素输出第i个像素的概率分布。
  • 然后,我们会从第i个像素的概率分布中进行采样(如下面代码所示)
# 假设颜色取值范围为[0, 7],下面为概率分布
prob_dist = torch.tensor([[0.1347, 0.1356, 0.1048, 0.1314, 0.1329, 0.1256, 0.1326, 0.1025]])

# 我们并不是取概率最大的像素,而是从概率分布中采样(例如下面取像素值6)
# torch.multinomial会从input这个概率分布中,取num_samples个值
pixel = torch.multinomial(input=prob_dist, num_samples=1).float() # tensor([[6.]])

在这里插入图片描述

我们现在已经知道了训练及采样的大体过程。但是,我们现在还是有一个疑问,如何保证训练时候,每个像素都忽略后续像素的信息?

PixelCNN论文里提出了一种掩码卷积机制,这种机制可以巧妙地掩盖住每个像素右侧和下侧的信息。

  • 具体来说,PixelCNN使用了两类掩码卷积:
    • 我们把两类掩码卷积分别称为「A类」和「B类」。
    • 二者都是对卷积操作的卷积核做了掩码处理,使得卷积核的右下部分不产生贡献。
    • A类和B类的唯一区别在于:卷积核的中心像素是否产生贡献
    • CNN的第一个的卷积层使用A类掩码卷积,之后每一层的都使用B类掩码卷积

在这里插入图片描述

我们来分析下这样设计的优点:

  • 对于一个7x7的图像,我们先用1次3x3 A类掩码卷积,再用若干次3x3 B类掩码卷积。我们观察图像中心处的像素在每次卷积后的感受野(即输入图像中哪些像素的信息能够传递到中心像素上)
    • 经过了第一个A类掩码卷积后,每个像素就已经看不到自己位置上的输入信息了。
    • 再经过两次B类掩码卷积后,中心像素能够看到左上角大部分像素的信息(如下图所示,我们发现还是会看漏少部分的信息,后面的Gated PixelCNN对此进行了改进)。
    • 这满足PixelCNN的约束。

在这里插入图片描述

  • 如果一直使用A类掩码卷积,每次卷积后中心像素都会看漏一些信息,最终就会导致看漏很多信息

在这里插入图片描述

  • 如果第一层就使用B类卷积,中心像素还是能看到自己位置的输入信息。这打破了PixelCNN的约束。

总结如下:

  • 逐像素预测只依赖于前面的像素,因此在选择卷积核时要进行掩码操作避免看到未来的值,因此,在第一层预测时可采用掩码卷积A
  • 由于CNN的逐像素预测是多层卷积,所以当第一层结束后,图像缺失部分已经有了预测值,因此在进行下一次/层卷积操作时可以利用当前像素的预测值,因此采用下列掩码卷积B
  • 需要注意的是,这里只考虑了单通道,如果扩展到RGB三个通道时,该如何进行mask呢?

1.1.2 PixelCNN的网络架构

  • 利用两类掩码卷积,PixelCNN满足了每个像素只能接受之前像素的信息这一约束。
  • 我们可以用任意一种CNN架构来实现PixelCNN。
  • 下图红色框所示部分是PixelCNN的网络结构,其中,第一个7x7卷积层用了A类掩码卷积,之后所有3x3卷积都是B类掩码卷积。

在这里插入图片描述

1.1.3 PixelCNN在MNIST数据集上的应用

1.1.3.1 模型

实现PixelCNN,最重要的是实现掩码卷积。

  • 掩码卷积的实现思路就是在卷积核组上设置一个mask。在前向传播的时候,先让卷积核组乘mask,再做普通的卷积。
  • 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码。
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class MaskConv2d(nn.Module):
    """
        掩码卷积的实现思路:
            在卷积核组上设置一个mask,在前向传播的时候,先让卷积核组乘mask,再做普通的卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        # 由于输入输出都是单通道图像,我们只需要在卷积核的h, w两个维度设置掩码
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2] = 1
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        # 为了保证掩码能正确广播到4维的卷积核组上,我们做一个reshape操作
        mask = mask.reshape((1, 1, H, W))
        # register_buffer可以把一个变量加入成员变量的同时,记录到PyTorch的Module中
        # 每当执行model.to(device)把模型中所有参数转到某个设备上时,被注册的变量会跟着转。
        # 第三个参数表示被注册的变量是否要加入state_dict中以保存下来
        self.register_buffer(name='mask', tensor=mask, persistent=False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res

有了最核心的掩码卷积,我们来根据论文中的模型结构图把模型搭起来

在这里插入图片描述

  • 我们先实现残差块上图右部分的ResidualBlock,这里添加归一化
class ResidualBlock(nn.Module):
    """
        残差块ResidualBlock
    """
    def __init__(self, h, bn=True):
        super().__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(2 * h, h, 1)
        self.bn1 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv2 = MaskConv2d('B', h, h, 3, 1, 1)
        self.bn2 = nn.BatchNorm2d(h) if bn else nn.Identity()
        self.conv3 = nn.Conv2d(h, 2 * h, 1)
        self.bn3 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()

    def forward(self, x):
        # 1、ReLU + 1×1 Conv + bn
        y = self.relu(x)
        y = self.conv1(y)
        y = self.bn1(y)
        # 2、ReLU + 3×3 Conv(mask B) + bn
        y = self.relu(y)
        y = self.conv2(y)
        y = self.bn2(y)
        # 3、ReLU + 1×1 Conv + bn
        y = self.relu(y)
        y = self.conv3(y)
        y = self.bn3(y)
        # 4、残差连接
        y = y + x
        return y
  • 有了所有这些基础模块后,我们就可以拼出最终的PixelCNN了。
  • 注意,我们可以自己决定颜色有几个亮度级别。要修改亮度级别的数量,只需要修改softmax输出的通道数color_level。
class PixelCNN(nn.Module):
    def __init__(self, n_blocks, h, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.conv1 = MaskConv2d('A', 1, 2 * h, 7, 1, 3)
        self.bn1 = nn.BatchNorm2d(2 * h) if bn else nn.Identity()
        self.residual_blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.residual_blocks.append(ResidualBlock(h, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(2 * h, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        # 1、7 × 7 conv(mask A)
        x = self.conv1(x)
        x = self.bn1(x)
        # 2、Multiple residual blocks
        for block in self.residual_blocks:
            x = block(x)
        x = self.relu(x)
        # 3、1 × 1 conv
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x
1.1.3.2 数据集及训练

准备好了模型代码,我们可以编写训练脚本了:

  • PixelCNN有15个残差块,中间特征的通道数为128,输出前线性层的通道数为32
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成[0, color_level-1]之间的整型标签
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')

if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # 需要注意的是:MNIST数据集的大部分像素都是0和255
    color_level = 8  # or 256
    # 1、创建PixelCNN模型
    model = PixelCNN(n_blocks=15, h=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_pixelcnn_{color_level}.pth'
    train(model, device, model_path)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/pixelcnn_{color_level}.jpg')        
1.1.3.3 采样
  • 在采样时,我们把x初始化成一个0张量。
  • 之后,循环遍历每一个像素,输入x,把预测出的下一个像素填入x.
def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1】个[0, color_level-1]的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点[0, 1]
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入到生成图像中
                x[:, :, i, j] = pixel
    # 乘255变成一个用8位字节表示的图像
    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)

1.2 多通道PixelCNN

如下图所示,作者假设RGB三个通道之间存在相互影响

  • 其中红色预测不受蓝色和绿色通道的影响,只受上下文影响
  • 绿色红色通道和上下文影响,但不受蓝色通道影响;
  • 蓝色通道受上下文、红色通道、绿色通道影响

在这里插入图片描述

更具体地,我们规定一个子像素只由它之前的子像素决定,生成图像时,我们一个子像素一个子像素地生成

  • 如下图所示,对于RGB图像,R子像素由它之前所有像素决定
  • G子像素由它的R子像素和之前所有像素决定,
  • B子像素由它的R、G子像素和它之前所有像素决定。

在这里插入图片描述

如下图所示,由于现在要预测三个颜色通道,网络的输出应该是一个[256x3, H, W]形状的张量

  • 即每个像素输出三个概率分布,分别表示R、G、B取某种颜色的概率。
  • 同时,本质上来讲,网络是在并行地为每个像素计算3组结果。因此,为了达到同样的性能,网络所有的特征图的通道数也要乘3。

在这里插入图片描述

图像变为多通道后,A类卷积和B类卷积的定义也需要做出一些调整。我们不仅要考虑像素在空间上的约束,还要考虑一个像素内子像素间的约束。为此,我们要用不同的策略实现约束。为了方便描述,我们设卷积核组的形状为[o, i, h, w],其中o为输出通道数,i为输入通道数,h, w为卷积核的高和宽。

  • 对于通道间的约束,我们要在o, i两个维度上设置掩码,如下图左边所示。
    • 设输出通道可以被拆成三组o1, o2, o3,输入通道可以被拆成三组i1, i2, i3
      • o1 = 0:o/3, o2 = o/3:o*2/3, o3 = o*2/3:o
      • i1 = 0:i/3, i2 = i/3:i*2/3, i3 = i*2/3:i
      • 序号1, 2, 3分别表示这组通道是在维护R, G, B的计算。
    • 我们对输入通道组和输出通道组之间进行约束。
    • 对于A类卷积,我们令o1看不到i1, i2, i3o2看不到i2, i3o3看不到i3
    • 对于B类卷积,我们取消每个通道看不到自己的限制,即在A类卷积的基础上令o1看到i1o2看到i2o3看到i3
  • 如下图右边所示,对于空间上的约束,我们还是和之前一样,在h, w两个维度上设置掩码。由于「是否看到自己」的处理已经在o, i两个维度里做好了,我们直接在空间上用原来的B类卷积就行。

在这里插入图片描述

  • 下面给出三维掩码示意图方便理解:

在这里插入图片描述

2 Gated PixelCNN

2.1 Gated PixelCNN简述

  • 可以参考大神讲解:Gated PixelCNN (sergeiturukin.com)

  • PixelCNN的掩码卷积其实有一个重大漏洞:像素存在视野盲区。如下图所示,中心像素看不到右上角三个本应该能看到的像素。

在这里插入图片描述

  • 为此,PixelCNN论文的作者又发表了Conditional Image Generation with PixelCNN Decoders(16.06)。这篇论文提出了一种叫做Gated PixelCNN的改进架构。Gated PixelCNN使用了一种更好的掩码卷积机制,消除了原PixelCNN里的视野盲区。

在这里插入图片描述

  • 如下图所示,Gated PixelCNN使用了两种卷积,即垂直卷积和水平卷积,来分别维护一个像素上侧的信息和左侧的信息
    • 垂直卷积的结果只是一些临时量
    • 而水平卷积的结果最终会被网络输出
    • 使用这种新的掩码卷积机制后,每个像素能正确地收到之前所有像素的信息了。

在这里插入图片描述

  • Gated PixelCNN用下图的模块代替了原PixelCNN的普通残差模块。
  • 模块的输入输出都是两个量,左边的量是垂直卷积中间结果,右边的量是最后用来计算输出的量。
  • 垂直卷积的结果会经过偏移和一个1x1卷积,再加到水平卷积的结果上。
  • 两条计算路线在输出前都会经过门激活单元。所谓门激活单元,就是输入两个形状相同的量,一个做tanh,一个做sigmoid,两个结果相乘再输出。
  • 此外,模块右侧还有一个残差连接。

在这里插入图片描述

2.2 Gated PixelCNN在MNIST数据集上的应用

2.2.1 创建模型

  • 首先,实现垂直卷积和水平卷积
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import ToTensor
import time
import einops
import cv2
import numpy as np
import os


class VerticalMaskConv2d(nn.Module):
    """
        垂直卷积
    """
    def __init__(self, *args, **kwags):
        super().__init__()
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[0:H // 2 + 1] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res


class HorizontalMaskConv2d(nn.Module):
    """
        水平卷积
    """
    def __init__(self, conv_type, *args, **kwags):
        super().__init__()
        assert conv_type in ('A', 'B')
        self.conv = nn.Conv2d(*args, **kwags)
        H, W = self.conv.weight.shape[-2:]
        mask = torch.zeros((H, W), dtype=torch.float32)
        mask[H // 2, 0:W // 2] = 1
        if conv_type == 'B':
            mask[H // 2, W // 2] = 1
        mask = mask.reshape((1, 1, H, W))
        self.register_buffer('mask', mask, False)

    def forward(self, x):
        self.conv.weight.data *= self.mask
        conv_res = self.conv(x)
        return conv_res
# 垂直卷积
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [0., 0., 0.]]]])
# A类水平卷积
tensor([[[[0., 0., 0.],
          [1., 0., 0.],
          [0., 0., 0.]]]])
# B类水平卷积
tensor([[[[0., 0., 0.],
          [1., 1., 0.],
          [0., 0., 0.]]]])
  • 我们现在搭建Gated Block模块,这也是最难理解的一部分。
  • 可以参考的解释:https://segmentfault.com/a/1190000041189859?utm_source=sf-similar-article

在这里插入图片描述

  • # 这里比较难理解,通过对图像进行零填充并裁剪图像底部,可以确保垂直和水平堆栈之间的因果关系
    v_to_h = v[:, :, 0:-1]
    v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
    # 注意到,v和i相加的位置只差了一个单位。
    # 为了把相加的位置对齐,我们要把v往下移一个单位,把原来在i-1处的信息移到i上。
    # 这样,移动过后的v_to_h就能和h直接用向量加法并行地加到一起了。
    

在这里插入图片描述

  • 维护两个v, h两个变量,分别表示垂直卷积部分的结果和水平卷积部分的结果。
    • v会经过一个垂直掩码卷积和一个门激活函数。
    • h会经过一个类似于残差块的结构,只不过第一个卷积是水平掩码卷积、激活函数是门激活函数、进入激活函数之前会和垂直卷积的信息融合。
class GatedBlock(nn.Module):

    def __init__(self, conv_type, in_channels, p, bn=True):
        super().__init__()
        self.conv_type = conv_type
        self.p = p
        self.v_conv = VerticalMaskConv2d(in_channels, 2 * p, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.v_to_h_conv = nn.Conv2d(2 * p, 2 * p, kernel_size=1)
        self.bn2 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_conv = HorizontalMaskConv2d(conv_type, in_channels, 2 * p, 3, 1,
                                           1)
        self.bn3 = nn.BatchNorm2d(2 * p) if bn else nn.Identity()
        self.h_output_conv = nn.Conv2d(p, p, 1)
        self.bn4 = nn.BatchNorm2d(p) if bn else nn.Identity()

    def forward(self, v_input, h_input):
        # v代表垂直卷积部分的结果
        v = self.v_conv(v_input)
        v = self.bn1(v)
        # Note: 重点代码
        # 为了把v的信息贴到h上,我们并不是像前面的示意图所写的令v上移一个单位
        # 而是用下面的代码令v下移了一个单位(下移即去掉最下面一行,往最上面一行填0)
        v_to_h = v[:, :, 0:-1]
        v_to_h = F.pad(v_to_h, (0, 0, 1, 0))
        # 和h相加前,先经过 1×1 conv
        v_to_h = self.v_to_h_conv(v_to_h)
        v_to_h = self.bn2(v_to_h)
        # 分为两份,经过tanh 和 sigmoid
        v1, v2 = v[:, :self.p], v[:, self.p:]
        v1 = torch.tanh(v1)
        v2 = torch.sigmoid(v2)
        v = v1 * v2

        # h代表水平卷积部分的结果
        h = self.h_conv(h_input)
        h = self.bn3(h)
        h = h + v_to_h
        # 分为两份,经过tanh 和 sigmoid
        h1, h2 = h[:, :self.p], h[:, self.p:]
        h1 = torch.tanh(h1)
        h2 = torch.sigmoid(h2)
        h = h1 * h2
        h = self.h_output_conv(h)
        h = self.bn4(h)
        # 在网络的第一层,每个数据是不能看到自己的。
        # 所以,当GatedBlock发现卷积类型为A类时,不应该对h做残差连接。
        if self.conv_type == 'B':
            h = h + h_input
        return v, h
  • 最后,我们来用GatedBlock搭出Gated PixelCNN
  • Gated PixelCNN和PixelCNN的结构非常相似,只是把ResidualBlock替换成了GatedBlock而已。
class GatedPixelCNN(nn.Module):

    def __init__(self, n_blocks, p, linear_dim, bn=True, color_level=256):
        super().__init__()
        self.block1 = GatedBlock('A', 1, p, bn)
        self.blocks = nn.ModuleList()
        for _ in range(n_blocks):
            self.blocks.append(GatedBlock('B', p, p, bn))
        self.relu = nn.ReLU()
        self.linear1 = nn.Conv2d(p, linear_dim, 1)
        self.linear2 = nn.Conv2d(linear_dim, linear_dim, 1)
        self.out = nn.Conv2d(linear_dim, color_level, 1)

    def forward(self, x):
        v, h = self.block1(x, x)
        for block in self.blocks:
            v, h = block(v, h)
        x = self.relu(h)
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.out(x)
        return x

2.2.2 数据集、训练及采样

  • 数据集、训练及采样和PixelCNN一模一样,不再赘述。
def get_dataloader(batch_size: int):
    dataset = torchvision.datasets.MNIST(root='/root/autodl-fs/data/minist',
                                         train=True,
                                         transform=ToTensor()
                                         )
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)


def train(model, device, model_path, batch_size=128, color_level=8, n_epochs=40):
    """训练过程"""
    dataloader = get_dataloader(batch_size)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)
    loss_fn = nn.CrossEntropyLoss()

    tic = time.time()
    for e in range(n_epochs):
        total_loss = 0
        for x, _ in dataloader:
            current_batch_size = x.shape[0]
            x = x.to(device)
            # 把训练集的浮点颜色值转换成0~color_level-1之间的整型标签的
            y = torch.ceil(x * (color_level - 1)).long()
            y = y.squeeze(1)
            predict_y = model(x)
            loss = loss_fn(predict_y, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * current_batch_size
        total_loss /= len(dataloader.dataset)
        toc = time.time()
        torch.save(model.state_dict(), model_path)
        print(f'epoch {e} loss: {total_loss} elapsed {(toc - tic):.2f}s')


def get_img_shape():
    return (1, 28, 28)


def sample(model, device, model_path, output_path, n_sample=1):
    """
        把x初始化成一个0张量。
        循环遍历每一个像素,输入x,把预测出的下一个像素填入x
    """
    model.eval()
    model.load_state_dict(torch.load(model_path))
    model = model.to(device)
    C, H, W = get_img_shape()  # (1, 28, 28)
    x = torch.zeros((n_sample, C, H, W)).to(device)
    with torch.no_grad():
        for i in range(H):
            for j in range(W):
                # 我们先获取模型的输出,再用softmax转换成概率分布
                output = model(x)
                prob_dist = F.softmax(output[:, :, i, j], -1)
                # 再用torch.multinomial从概率分布里采样出【1个】0~(color_level-1)的离散颜色值
                # 再除以(color_level - 1)把离散颜色转换成浮点颜色(因为网络是输入是浮点颜色)
                pixel = torch.multinomial(input=prob_dist, num_samples=1).float() / (color_level - 1)
                # 最后把新像素填入生成图像
                x[:, :, i, j] = pixel

    imgs = x * 255
    imgs = imgs.clamp(0, 255)
    imgs = einops.rearrange(imgs, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=int(n_sample**0.5))

    imgs = imgs.detach().cpu().numpy().astype(np.uint8)
    cv2.imwrite(output_path, imgs)


if __name__ == '__main__':
    os.makedirs('work_dirs', exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    color_level = 8  # or 256
    # 1、创建GatedPixelCNN模型
    model = GatedPixelCNN(n_blocks=15, p=128, linear_dim=32, bn=True, color_level=color_level)
    # 2、模型训练
    model_path = f'work_dirs/model_gatedpixelcnn_{color_level}.pth'
    train(model, device, model_path, batch_size=1)
    # 3、采样
    sample(model, device, model_path, f'work_dirs/gatedpixelcnn_{color_level}.jpg')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1798705.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

鸿蒙开发的南向开发和北向开发

鸿蒙开发主要分为设备开发和应用开发两个方向,也叫南向开发和北向开发: 鸿蒙设备开发(南向开发),要侧重于硬件层面的开发,涉及硬件接口控制、设备驱动开发、鸿蒙系统内核开发等,目的是使硬件设备能够兼容并…

Linux环境---在线安装MYSQL数据库

Linux环境—在线安装MYSQL数据库 一、使用步骤 1.安装环境 Mysql 驱动 8.0 需要 jdk1.8 才行。 JDK版本:1.8 参考文档 MYSQL版本:8.0.2 下载链接: https://pan.baidu.com/s/1MwXIilSL6EY3OuS7WtpySA?pwdg263 操作系统:CentOS 1.1 建立存…

Python数据分析II

目录 1.HS-排序返回前n行 2.HS-相关性 3.缺失值处理 4.时间 5.时间索引 6.分组聚合 7.离散分箱 8.Concat关联(索引关联) 9.Merge关联(字段关联) 10.join合并(左字段,右索引) 11.行列转置及透视表 12.数据可视化-面向过程 13.数据可视化-面向对象 14.快速生成柱状…

github有趣项目:Verilog在线仿真( DigitalJS+edaplayground)

DigitalJS https://github.com/tilk/digitaljs这个项目是一个用Javascript实现的数字电路模拟器。 它旨在模拟由硬件设计工具合成的电路 像 Yosys(这里是 Github 存储库),它有一个配套项目 yosys2digitaljs,它可以转换 Yosys 将文…

STCunio数字电源带PID数字闭环(带详细的代码说明文档)

STCunio,即 system on chip unusual i/o,采用类似 arduino 构架设计,即使没有单片机知 识的设计师和艺术家们能够很快地通过它学习电子和传感器的基础知识,并应用到他们的设 计当中。设计中所要表现的想法和创意才是最主要的,至于…

创新指南 | 5个行之有效的初创企业增长策略

本文探讨了五种初创企业实现快速增长的有效策略:利用网络效应通过激励和资本化用户增长;通过持续提供高质量内容建立信任和权威的内容营销;利用简单有效的推荐计划扩展用户群;采用敏捷开发方法快速适应市场变化和客户反馈&#xf…

基于springboot实现社区养老服务系统项目【项目源码+论文说明】计算机毕业设计

基于springboot实现社区养老服务系统演示 摘要 现代经济快节奏发展以及不断完善升级的信息化技术,让传统数据信息的管理升级为软件存储,归纳,集中处理数据信息的管理方式。本社区养老服务系统就是在这样的大环境下诞生,其可以帮助…

签名安全规范:解决【请求对象json序列化时,时间字段被强制转换成时间戳的问题】

文章目录 引言I 签名安全规范1.1 签名生成的通用步骤1.2 签名运算(加密规则)1.3 对所有传入参数按照字段名的 ASCII 码从小到大排序(字典序)1.4 允许的请求头字段1.5 签名校验工具II 注解校验签名2.1 获取请求数据,并校验签名数据2.2 解决时间格式被强制转换成时间戳的问题…

2024年数据防泄密软件精选,五款热门防泄密软件集锦

在信息爆炸的今天,企业数据的安全性已成为不可忽视的关键问题。 随着数字化转型的加速,数据泄露的风险也随之增加,这对企业的核心竞争力构成了严重威胁。 为了构建坚不可摧的数据防线,选择高效可靠的数据防泄密软件显得尤为重要…

爬取基金收盘价并用pyecharts进行展现

爬取基金收盘价并用pyecharts进行展现 一、用到的第三方包 因为使用到了一些第三方的包,包还是比较大的如果直接从社区下载比较费劲,所以建议配置国内镜像源,这里以清华的镜像源为例。 pip config set global.index-url https://pypi.tuna…

FastAdmin自定义滚动条

效果 实现过程 HTML代码 <style>.custom-scrollbar {position: fixed;/*bottom: 0px;*/height: 20px;width: 97.5%;overflow-y: scroll;overflow-x: scroll;z-index: 100;}#scrollDivTable{height: 20px;}/*原滚动条不显示*//*.fixed-table-body::-webkit-scrollbar {*/…

电脑知识 如何看懂串口通信协议(程序员视角)

目录 前言 一、串口文档 二、明确身份 三、串口设置 四、看懂命令格式 五、看懂发送命令的格式 1.帧头和帧尾 2.帧内数据长度 3.帧内数据/具体命令 4.整体命令 5.真正的命令字和命令值 六、第一个案例 1.发送命令 2.雷达的回答 七、作者的话 前言 用一个案例&#…

C++基础与深度解析 | 类与面向对象编程 | 数据成员 | 成员函数 | 访问限定符与友元 | 构造、析构成员函数 | 字面值类、成员指针与bind交互

文章目录 一、结构体与对象聚合二、成员函数&#xff08;方法&#xff09;三、访问限定符与友元1.访问限定符2.友元&#xff08;慎用&#xff09; 四、构造、析构与复制成员函数1.构造函数2.析构函数3.补充 五、字面值类&#xff0c;成员指针与bind交互1.字面值类2.成员指针3.b…

无线麦克风哪个牌子性价比高?一文告诉你无线领夹麦克风怎么挑选

​当我们谈论到演讲、表演或者录制视频时&#xff0c;一个高质量的无线麦克风能够使得整个体验提升至一个全新的水平。它不仅能够保证声音的清晰度和真实度&#xff0c;还能够让使用者在演讲或者表演时更加自信和舒适。基于对市场的深入研究和用户体验的考量&#xff0c;我挑选…

【css3】png图片实现动态动画

.border_style {width: 400px;height: 400px;background-color: black;margin: auto;}keyframes sprite-animation {0% {background-position: 0 0;}100% {background-position: 0 -2064px;/* 假设每个图像的宽度为100px */}}.wrj_box {width: 86px;height: 86px;background-im…

S3Dlib | 太炫酷!所有3D图形它都可以绘制...

前言 一、「s3dlib」-Python中王炸3D绘图神器 二、可视化学习圈子是干什么的&#xff1f; 三、系统学习可视化 四、猜你喜欢 前言 我们的数据可视化课程已经上线啦&#xff01;&#xff01;目前课程的主要方向是 科研、统计、地理相关的学术性图形绘制方法&#xff0c;后续…

Patchwork++:基于点云的快速、稳健的地面分割方法

1. 背景 论文发表在2022IROS&#xff0c;是Patchwork的改进版本。算法通过数学方法进行快速而鲁棒性很强的地面分割&#xff0c;在智能机器人上的可操作性非常强。通过微调算法&#xff0c;可以应用于16-beams等多种规格的激光雷达。由于激光雷达点云数据标注的难度非常大&…

数据泄露防护(DLP)系统有哪些?2024年数据泄露防护系统TOP5排名

数据泄露防护&#xff08;DLP&#xff09;系统是企业为确保敏感信息不被非法访问、使用或泄露而采用的重要安全策略。以下是一些常见的数据泄露防护系统&#xff0c;以及它们的功能和优点。 1、安企神 DLP 安企神 DLP是一款为企业研发的数据防泄漏系统&#xff0c;以强大的功能…

pxe自动装机

概念 pxe是c/s模式。允许客户端通过网络从远程服务器&#xff08;服务端&#xff09;下载引导镜像&#xff0c;加载安装文件&#xff0c;实现自动化安装操作系统。 无人值守&#xff1a;安装选项不需要人为干预&#xff0c;可以自动化实现。 pxe的优点&#xff1a;1.规模化&…

美琳莱卡:创新消费模式引领新零售时代

公司成立时间与定位 美琳莱卡自创立之初,便以独特的视角和前瞻性的战略定位,立足于消费市场的变革前沿。公司成立于2024年,正值全球数字化浪潮蓬勃兴起,消费升级趋势日益明显之际。美琳莱卡敏锐地捕捉到这一时代机遇,将自身定位为创新消费模式的引领者,致力于通过线上线下高度…