AIGC实战——像素卷积神经网络(PixelCNN)

news2024/10/5 13:51:36

AIGC实战——像素卷积神经网络

    • 0. 前言
    • 1. PixelCNN 工作原理
      • 1.1 掩码卷积层
    • 1.2 残差块
    • 2. 训练 PixelCNN
    • 3. PixelCNN 分析
    • 4. 使用混合分布改进 PixelCNN
    • 小结
    • 系列链接

0. 前言

像素卷积神经网络 (Pixel Convolutional Neural Network, PixelCNN) 是于 2016 年提出的一种图像生成模型,其根据前面的像素预测下一个像素的概率来逐像素地生成图像,模型可以通过自回归的方式进行训练以生成图像。在本节中,将使用 Keras 实现 PixelCNN 模型并将其应用于图像数据生成中。

1. PixelCNN 工作原理

为了理解 PixelCNN,我们需要介绍两个关键技术:掩码卷积层 (Masked Convolutional Layer) 和残差块 (Residual Block)。

1.1 掩码卷积层

我们已经知道,卷积层可以通过应用一系列卷积核从图像中提取特征。在特定像素点处,卷积层的输出是卷积核权重与以该像素为中心的区域上的值的加权和。通过应用一系列卷积层可以检测到图像中的边缘、纹理以及在更深层的形状和高级特征。
虽然卷积层在特征检测中十分有效,但无法直接用于自回归模型,因为像素之间没有明确的顺序关系。在图像中所有像素均会被平等对待,没有像素会被视为图像的起始或结束点,这与文本数据不同,文本数据中的符号具有明确的顺序性,因此可以方便地应用循环模型,如长短期记忆网络 (Long Short-Term Memory Network, LSTM)。
为了能够以自回归的方式下将卷积层应用于图像生成,我们首先必须将像素进行排序,并确保卷积核只能看到前面的像素。然后,通过将由 10 组成的掩码与卷积核权重矩阵相乘,使得在每个像素处,层的输出仅受到前面像素值的影响,从而逐像素地生成图像,通过将卷积卷积核应用于当前图像来预测下一个像素的值。
首先,需要选择像素的排序方式,一种可行的方法是从左上到右下对像素进行排序,首先沿行移动,然后沿列移动。
然后,我们对卷积核进行掩码处理,以使得每个像素处的层的输出仅受到前面的像素值的影响。为此,我们将由 10 组成的掩码与卷积核权重矩阵相乘,将目标像素后面的其余像素的值置零。
PixelCNN 中实际上有两种不同类型的掩码:

  • A 型,中心像素的值为掩码像素
  • B 型,中心像素的值不为掩码像素

掩码卷积层

初始的掩码卷积层(即直接应用于输入图像的层)不能使用中心像素,因为这恰是我们希望网络预测的像素,而后续的层可以使用中心像素,因为它已经由初始输入图像之前的像素信息计算出来。
使用 Keras 构建掩码卷积层 (MaskedConvLayer):

class MaskedConv2D(layers.Layer):
    def __init__(self, mask_type, **kwargs):
        super(MaskedConv2D, self).__init__()
        self.mask_type = mask_type
        # 掩码卷积层 (MaskedConvLayer) 基于普通的 Conv2D 层
        self.conv = layers.Conv2D(**kwargs)

    def build(self, input_shape):
        # 初始化卷积核
        self.conv.build(input_shape)
        # 创建掩码
        kernel_shape = self.conv.kernel.get_shape()
        # 掩码初始化为全零向量
        self.mask = np.zeros(shape=kernel_shape)
        # 前面行中的像素使用 1 解除掩码
        self.mask[: kernel_shape[0] // 2, ...] = 1.0
        # 同一行中前面列中的像素使用 1 解除掩码
        self.mask[kernel_shape[0] // 2, : kernel_shape[1] // 2, ...] = 1.0
        # 如果掩码类型为 B,则中心像素使用 1 解除掩码
        if self.mask_type == "B":
            self.mask[kernel_shape[0] // 2, kernel_shape[1] // 2, ...] = 1.0

    def call(self, inputs):
        # 掩码与卷积核权重相乘
        self.conv.kernel.assign(self.conv.kernel * self.mask)
        return self.conv(inputs)

需要注意的是,我们假设使用灰度图像(即只有一个通道)。如果我们使用彩色图像,则可以对三个颜色通道进行排序,例如红色通道在蓝色通道之前,蓝色通道在绿色通道之前。

1.2 残差块

我们已经学习了如何对卷积层进行掩码处理,接下来开始构建 PixelCNN,我们将使用残差块 (Residual Block) 作为核心构建块。

残差块是一组网络层,包含两个主要部分:

  • 主路径 (Main Path):由一系列卷积层和激活函数构成,用于学习特征表示
  • 跳跃连接 (Skip Connection):直接将输入信息绕过一部分主路径,与输出相加。这样可以确保输入信息更容易传播到后续层,并且有助于避免梯度消失问题

也就是说,在残差块中,输入有一条捷径连接到输出,而无需经过中间层。跳跃连接的理论基础可以描述为,如果最优的变换是保持输入不变,那么通过简单地将中间层的权重置零就可以实现;如果没有跳跃连接,网络就必须通过中间层找到一个恒等映射,这显然更加困难。PixelCNN 中的残差块的结构如下图所示。

残差块

使用 Keras 构建残差块 (ResidualBlock):

class ResidualBlock(layers.Layer):
    def __init__(self, filters, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        # 初始的 Conv2D 层将通道数量减半
        self.conv1 = layers.Conv2D(filters=filters // 2, kernel_size=1, activation="relu")
        # 类型 B 的 MaskedConv2D 层,使用尺寸为 3 的卷积核,仅利用五个像素的信息,即上方行的三个像素、左边一个像素以及中心像素本身
        self.pixel_conv = MaskedConv2D(
            mask_type="B",
            filters=filters // 2,
            kernel_size=3,
            activation="relu",
            padding="same",
        )
        # 最后的 Conv2D 层将通道数量加倍,以再次匹配输入形状
        self.conv2 = layers.Conv2D(filters=filters, kernel_size=1, activation="relu")

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pixel_conv(x)
        x = self.conv2(x)
        # 将卷积层的输出与输入相加,即跳跃连接
        return layers.add([inputs, x])

2. 训练 PixelCNN

PixelCNN 中,输出层是一个具有 256 个卷积核的 Conv2D 层,使用 softmax 激活函数。换句话说,网络通过预测正确的像素值来尝试重新创建其输入,类似于编码器。不同之处在于,网络采用了 MaskedConv2D 层,像素预测使用的像素信息并不相同。
使用这种方法,PixelCNN 必须独立学习每个像素的输出值,但像素值 220221 的差异并不明显,这意味着即使对于最简单的数据集,训练速度也可能非常慢。因此,我们需要简化输入,使得每个像素值只有四种输出类型。这样,我们可以使用一个具有 4 个卷积核的 Conv2D 输出层,而不是 256 个卷积核。

# 模型的输入是一个尺寸为 16×16×1 的灰度图像,输入值缩放为 0 到1之间
inputs = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))
# 首先使用 A 类型的 MaskedConv2D 层,卷积核大小为 7,使用了 24 个像素的信息——中点像素上方三行中的 21 个像素以及左边3个像素(中心像素本身没有使用)
x = MaskedConv2D(
    mask_type="A",
    filters=N_FILTERS,
    kernel_size=7,
    activation="relu",
    padding="same",
)(inputs)
# 连续堆叠五个 ResidualBlock 块
for _ in range(RESIDUAL_BLOCKS):
    x = ResidualBlock(filters=N_FILTERS)(x)
# 使用两个 B 类型的 MaskedConv2D 层,卷积核大小为 1
for _ in range(2):
    x = MaskedConv2D(
        mask_type="B",
        filters=N_FILTERS,
        kernel_size=1,
        strides=1,
        activation="relu",
        padding="valid",
    )(x)
# 最后的 Conv2D 层将通道数量减少为 4,即像素级别数
out = layers.Conv2D(
    filters=PIXEL_LEVELS,
    kernel_size=1,
    strides=1,
    activation="softmax",
    padding="valid",
)(x)
# 构建模型,其接收一张图像并输出相同尺寸的图像
pixel_cnn = models.Model(inputs, out)
print(pixel_cnn.summary())

adam = optimizers.Adam(learning_rate=0.0005)
pixel_cnn.compile(optimizer=adam, loss="sparse_categorical_crossentropy")
# 训练模型时,输入数据被缩放到 [0, 1] 的范围内(浮点数);输出数据为 [0, 3] 的范围内的整数
pixel_cnn.fit(input_data, output_data, batch_size=BATCH_SIZE, epochs=EPOCHS)

3. PixelCNN 分析

使用 Fashion-MNIST 数据集来训练 PixelCNN 模型,为了生成新图像,我们需要让模型根据前面的所有像素逐个像素地预测下一个像素。与变分自编码器等模型相比,生成过程非常缓慢。对于一个尺寸 32×32 的灰度图像,我们需要使用模型进行 1024 次连续预测,而对于变分自编码器 (Variational Autoencoder, VAE),我们只需要进行一次预测。这是自回归模型(如 PixelCNN )的一个主要缺点,由于采样过程的顺序性,进行采样的速度较慢。
因此,为了加快新图像的生成速度,我们将图像尺寸缩放为 16×16

class ImageGenerator(callbacks.Callback):
    def __init__(self, num_img):
        self.num_img = num_img

    def sample_from(self, probs, temperature):  # <2>
        probs = probs ** (1 / temperature)
        probs = probs / np.sum(probs)
        return np.random.choice(len(probs), p=probs)

    def generate(self, temperature):
        generated_images = np.zeros(shape=(self.num_img,) + (pixel_cnn.input_shape)[1:])
        # 从空白图像(全零)开始
        batch, rows, cols, channels = generated_images.shape

        for row in range(rows):
            for col in range(cols):
                for channel in range(channels):
                    # 遍历当前图像的行、列和通道,预测下一个像素值的分布
                    probs = self.model.predict(generated_images, verbose=0)[:, row, col, :]
                    # 从预测的分布中抽样一个像素级别(在本例中像素级别范围为 [0, 3 ])
                    generated_images[:, row, col, channel] = [self.sample_from(x, temperature) for x in probs]
                    # 将像素级别转换到[0, 1]范围内,并覆盖当前图像中的像素值,准备进行下一次循环迭代
                    generated_images[:, row, col, channel] /= PIXEL_LEVELS

        return generated_images

    def on_epoch_end(self, epoch, logs=None):
        generated_images = self.generate(temperature=1.0)
        for i in range(10):
            plt.subplot(2,5,i+1)
            plt.imshow(generated_images[i], cmap='gray')
        plt.tight_layout()
        plt.savefig("./output/generated_img_{}.png".format(epoch))

img_generator_callback = ImageGenerator(num_img=10)

在下图中,我们对比了来自原始训练集和由 PixelCN 生成的几张图像。可以看到,PixelCNN 模型能够成功的学习到原始图像的整体形状和风格。因此,我们可以将图像视为一系列符号(像素值),并应用如 PixelCNN 之类的自回归模型生成逼真的样本。

生成结果

自回归模型的一个缺点是在进行采样时速度较慢,这就是为什么本节我们对输入和输出进行缩放。然而,但我们也可以使用更复杂形式的自回归模型以生成逼真图像,在这种情况下,生成速度缓慢是为了获得卓越质量输出所必须进行的取舍。
接下来,使用混合分布对 PixelCNN 的架构和训练过程进行改进,并使用内置的 TensorFlow 函数来训练改进的 PixelCNN 模型。

4. 使用混合分布改进 PixelCNN

在上一节构建的 PixelCNN 模型中,为了避免训练速度过慢的问题,令网络不必学习 256 个独立像素值上的分布,我们将 PixelCNN 的输出减少到 4 个像素级别。但是,这种方法并非最佳解决方案,对于彩色图像,我们不希望仅使用少数几种颜色。
为了解决这一问题,我们可以将网络的输出设为混合分布 (Mixture Distribution),而不是对 256 个离散像素值使用 softmax,混合分布简单来说就是两个或多个其他概率分布的混合。例如,我们可以使用一个由五个逻辑分布组成的混合分布,每个分布都有不同的参数。混合分布还需要离散分类分布,用于指示选择混合中包含的每个分布的概率。
要从混合分布中进行采样,我们首先从分类分布中进行采样,选择一个特定的子分布,然后按照正常的方式从该子分布中进行采样。这样,我们可以用相对较少的参数创建复杂的分布。例如,上图中的混合分布仅需要八个参数——两个用于分类分布,其余六个参数用于每个正态分布都有均值和方差,这与在整个像素范围内定义分类分布所需的 255 个参数相比要少得多。
TensorFlow Probability 库提供了一个能够创建具有混合分布输出的 PixelCNN 函数,使用此函数构建 PixelCNN

# 定义 PixelCNN 模型
dist = tfp.distributions.PixelCNN(
    image_shape=(IMAGE_SIZE, IMAGE_SIZE, 1),
    num_resnet=1,
    num_hierarchies=2,
    num_filters=32,
    num_logistic_mix=N_COMPONENTS,
    dropout_p=0.3,
)
# 输入是尺寸为 32×32×1 的灰度图像
image_input = layers.Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 1))
# 模型以灰度图像作为输入,并输出在 PixelCNN 计算的混合分布下图像的对数似然
log_prob = dist.log_prob(image_input)
pixelcnn = models.Model(inputs=image_input, outputs=log_prob)
# 损失函数是批输入图像的平均负对数似然
pixelcnn.add_loss(-tf.reduce_mean(log_prob))

模型的训练方式与普通 PixelCNN 相同,但其接受整数像素值作为输入,取值范围为 [0, 255]。可以使用 sample 函数从分布中生成输出:

dist.sample(10).numpy()

生成的示例图像如下图所示,与普通 PixelCNN 不同的是,此模型利用了完整的像素值范围。

生成结果

小结

在本节中,介绍了如何使用 PixelCNN 以自回归的方式生成图像,使用 Keras 构建 PixelCNN 模型,实现掩码卷积层和残差块,以便信息可以在网络中传递,只有前面的像素可以用于生成当前的像素。最后,使用 TensorFlow Probability 库提供的 PixelCNN 函数,该函数使用混合分布作为输出层,从而能够进一步改善学习过程。

系列链接

AIGC实战——生成模型简介
AIGC实战——深度学习 (Deep Learning, DL)
AIGC实战——卷积神经网络(Convolutional Neural Network, CNN)
AIGC实战——自编码器(Autoencoder)
AIGC实战——变分自编码器(Variational Autoencoder, VAE)
AIGC实战——使用变分自编码器生成面部图像
AIGC实战——生成对抗网络(Generative Adversarial Network, GAN)
AIGC实战——WGAN(Wasserstein GAN)
AIGC实战——条件生成对抗网络(Conditional Generative Adversarial Net, CGAN)
AIGC实战——自回归模型(Autoregressive Model)
AIGC实战——改进循环神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1391261.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

礼贺新春,徐坊大曲新品【中国红】

梁山徐坊大曲新推出中国风礼盒&#xff0c;以中国红为主题&#xff0c;为即将到来的新春佳节增添了浓厚的节日气氛。为您呈现一场视觉与味觉的盛宴。从礼盒的颜色到图案设计&#xff0c;无不体现出中国红的热情与活力&#xff0c;象征着吉祥、喜庆与团圆。梁山徐坊大曲&#xf…

ubuntu qt 运行命令行

文章目录 1.C实现2.python实现 1.C实现 下面是封装好的C头文件&#xff0c;直接调用run_cmd_fun()即可。 #ifndef GET_CMD_H #define GET_CMD_H#endif // GET_CMD_H #include <iostream> #include<QString> using namespace std;//system("gnome-terminal -…

USB8814动态信号采集卡——声音振动类信号处理的理想之选!

背景介绍&#xff1a; 科技的发展在一定程度上依赖于对信号的处理&#xff0c;信号处理技术的先进性在很大程度上决定了科技发展的速度和方向。数字信号处理技术的崛起&#xff0c;彻底改变了传统的信息与信号处理方式&#xff0c;使得数据采集这一前期工作在数字系统中发挥着…

FTP文件传输协议 、多种方式安装yum仓库

一、网络文件共享服务 1.存储类型分三种&#xff1a; 直连式存储&#xff1a;Direct-Attached Storage&#xff0c;简称DAS 存储区域网络&#xff1a;Storage Area Network&#xff0c;简称SAN&#xff08;可以使用空间&#xff0c;管理也是你来管理&#xff09; 网络附加存储…

ImageNet Classification with Deep Convolutional 论文笔记

✅作者简介&#xff1a;人工智能专业本科在读&#xff0c;喜欢计算机与编程&#xff0c;写博客记录自己的学习历程。 &#x1f34e;个人主页&#xff1a;小嗷犬的个人主页 &#x1f34a;个人网站&#xff1a;小嗷犬的技术小站 &#x1f96d;个人信条&#xff1a;为天地立心&…

Leetcode23-数组能形成多少数对(2341)

1、题目 给你一个下标从 0 开始的整数数组 nums 。在一步操作中&#xff0c;你可以执行以下步骤&#xff1a; 从 nums 选出 两个 相等的 整数 从 nums 中移除这两个整数&#xff0c;形成一个 数对 请你在 nums 上多次执行此操作直到无法继续执行。 返回一个下标从 0 开始、长…

SpringMVC参数接收见解4

# 4.参数接收Springmvc中&#xff0c;接收页面提交的数据是通过方法形参来接收&#xff1a; 处理器适配器调用springmvc使用反射将前端提交的参数传递给controller方法的形参 springmvc接收的参数都是String类型&#xff0c;所以spirngmvc提供了很多converter&#xff08;转换…

第二证券:大盘探底回升走出底部还看成交量配合

持续震动数日后&#xff0c;大盘再现探底上升走势。 上证指数周二小幅低开后窄幅震动&#xff0c;午后快速回落改写本轮回调新低后&#xff0c;有资金开始出手介入&#xff0c;尾盘指数翻红。深证成指同样是在午后呈现探底上升走势&#xff0c;最终重回5日均线上方。截至收盘&…

mysql 下载和安装和修改MYSQL8.0 数据库存储文件的路径

一、第一步:下载步骤 下载链接&#xff1a;MySQL :: Download MySQL Installer 选择版本8.0.35&#xff0c;社区版&#xff0c; 点击 Download 下载 安装包 二、第二步:安装步骤 添加环境变量&#xff0c;C:\Program Files\MySQL\MySQL Server 8.0\bin 可以点开MySQL 8.0 Co…

如何用AI提高论文阅读效率?

已经2024年了&#xff0c;该出现一个写论文解读AI Agent了。 大家肯定也在经常刷论文吧。 但真正尝试过用GPT去刷论文、写论文解读的小伙伴&#xff0c;一定深有体验——费劲。其他agents也没有能搞定的&#xff0c;今天我发现了一个超级厉害的写论文解读的agent &#xff0c…

HNU-模式识别-作业2-面向应用分类系统

模式识别-作业2 计科210X 甘晴void 202108010XXX 【具体实现思路是按照去年数学建模国赛题来做的&#xff0c;就放个思路&#xff0c;完整不放全了】 题目&#xff1a; 查阅文献资料&#xff0c;构建一个面向应用的分类系统。 要求&#xff1a; 至少3页A4纸&#xff0c;文…

机器人制作开源方案 | AI校园服务机器人

作者&#xff1a;李强、李振宁、毛维雷、李文文、张奥 单位&#xff1a;山西能源学院 指导老师&#xff1a;姚志广、程晟 在这个科技飞速发展的时代&#xff0c;在工业智造、人工智能的飞速发展中&#xff0c;出现了越来越多的智能化机械装置&#xff0c;也有许多创新类的比赛…

STM32F103标准外设库—— 新建工程与库函数(四)

个人名片&#xff1a; &#x1f981;作者简介&#xff1a;一名喜欢分享和记录学习的在校大学生 &#x1f42f;个人主页&#xff1a;妄北y &#x1f427;个人QQ&#xff1a;2061314755 &#x1f43b;个人邮箱&#xff1a;2061314755qq.com &#x1f989;个人WeChat&#xff1a;V…

力扣刷MySQL-第二弹(详细解析)

&#x1f389;欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克&#x1f379; ✨博客主页&#xff1a;小小恶斯法克的博客 &#x1f388;该系列文章专栏&#xff1a;力扣刷题讲解-MySQL &#x1f379;文章作者技术和水平很有限&#xff0c;如果文中出…

基于Yolov5+Deepsort+SlowFast算法实现视频目标识别、追踪与行为实时检测

前言 前段时间打算做一个目标行为检测的项目&#xff0c;翻阅了大量资料&#xff0c;也借鉴了不少项目&#xff0c;最终感觉Yolov5DeepsortSlowfast实现实时动作检测这个项目不错&#xff0c;因此进行了实现。 一、核心功能设计 总的来说&#xff0c;我们需要能够实现实时检测视…

SAP PI之Rest adapter

一&#xff0c;简介 REST风格接口是以http为传输协议&#xff0c;以xml或json或text为有效负载。下图展示了REST到XI再返回的一个过程&#xff0c;一个REST接口包含的信息有&#xff1a;服务URL、URL中带的参数、http方法(post/get/put等)、http头部、body部分的有效载荷。而X…

debian12部署Gitea服务之二——部署git-lfs

Debian安装gitlfs: 先更新下软件包版本 sudo apt update 安装 sudo apt install git-lfs 验证是否安装成功 git lfs version cd到Gitea仓库目录下 cd /mnt/HuHDD/Git/Gitea/Repo/hu/testrepo.git 执行lfs的初始化命令 git lfs install客户机Windows端在官网下载并安装Git-Lfs 再…

编译原理1.3习题 程序设计语言的发展历程

图源&#xff1a;文心一言 编译原理习题整理~&#x1f95d;&#x1f95d; 作为初学者的我&#xff0c;这些习题主要用于自我巩固。由于是自学&#xff0c;答案难免有误&#xff0c;非常欢迎各位小伙伴指正与讨论&#xff01;&#x1f44f;&#x1f4a1; 第1版&#xff1a;自…

内存操作函数

一、memcpy函数 memcpy(void * destination, const void * source, num)表示将source中的前num个字符复制到destination中&#xff0c;不允许source和destination的内存区域重叠memcpy(a, bn1, n2)表示从b中第n11个字符开始复制&#xff0c;复制n2个字符到a中为了防止溢出&…

框架基础-Maven+SpringBoot入门

框架基础 Maven基础 Maven概述 Maven是为Java项目提供项目构建和依赖管理的工具 Maven三大功能 - 项目构建构建&#xff1a;是一个将代码从开发阶段到生产阶段的一个过程&#xff1a;清理&#xff0c;编译&#xff0c;测试&#xff0c;打包&#xff0c;安装&#xff0c;部署…