73. 风格迁移以及代码实现

news2025/1/20 15:45:31

摄影爱好者也许接触过滤波器。它能改变照片的颜色风格,从而使风景照更加锐利或者令人像更加美白。但一个滤波器通常只能改变照片的某个方面。如果要照片达到理想中的风格,可能需要尝试大量不同的组合。这个过程的复杂程度不亚于模型调参。

本节将介绍如何使用卷积神经网络,自动将一个图像中的风格应用在另一图像之上,即风格迁移(style transfer)

这里我们需要两张输入图像:一张是内容图像,另一张是风格图像。 我们将使用神经网络修改内容图像,使其在风格上接近风格图像。

例如, 图13.12.1中的内容图像是在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。 最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。

在这里插入图片描述

1. 基于CNN的风格迁移

图13.12.2 用简单的例子阐述了基于卷积神经网络的风格迁移方法。

首先,我们初始化合成图像,例如将其初始化为内容图像。 该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数

然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新

这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或风格特征。 以下图为例,这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出风格特征。

这个模型特别之处在于,不是训练卷积网络的权重,而是训练合成图片

在这里插入图片描述

接下来,我们通过前向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。 风格迁移常用的损失函数由3部分组成:

  1. 内容损失使合成图像与内容图像在内容特征上接近;

  2. 风格损失使合成图像与风格图像在风格特征上接近;

  3. 全变分损失则有助于减少合成图像中的噪点

最后,当模型训练结束时,我们输出风格迁移的模型参数,即得到最终的合成图像。

在下面,我们将通过代码来进一步了解风格迁移的技术细节。

2. 阅读内容和风格图像

首先,我们读取内容和风格图像。 从打印出的图像坐标轴可以看出,它们的尺寸并不一样。

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('drive/MyDrive/chapter13/img/rainier.jpg')
d2l.plt.imshow(content_img);

运行结果:

在这里插入图片描述

style_img = d2l.Image.open('drive/MyDrive/chapter13/img/autumn-oak.jpg')
d2l.plt.imshow(style_img);

运行结果:

在这里插入图片描述

3. 预处理和后处理

下面,定义图像的预处理函数和后处理函数。

预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。

后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape): # 将图片变成一个能训练的tensor
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape), # resize成我们要的样子
        torchvision.transforms.ToTensor(), # 转成tensor
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)]) # 做标准化
    return transforms(img).unsqueeze(0)

def postprocess(img): # 将tensor变回到图片
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

4. 抽取图像特征

我们使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征。

VGG系列对于抽取特征效果很好。

pretrained_net = torchvision.models.vgg19(pretrained=True)

为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。

一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。

为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。

VGG网络使用了5个卷积块。 实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或风格层之间的所有层,因此28层以后的层丢掉。下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。 由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X) # 计算每一层的输出
        if i in style_layers: # 如果是样式层,就add到styles
            styles.append(X)
        if i in content_layers: # 如果是内容层,就add到contents
            contents.append(X)
    return contents, styles

下面定义两个函数:get_contents函数对内容图像抽取内容特征; get_styles函数对风格图像抽取风格特征。 因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。 由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device) # 将内容图像转为tensor形式
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device) # 将风格图像转为tensor形式
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

5. 定义损失函数

下面我们来描述风格迁移的损失函数。 它由内容损失、风格损失和全变分损失3部分组成。

5.1 内容损失

与线性回归中的损失函数类似,内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异。 平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。

def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

5.2 风格损失

风格损失与内容损失类似,也通过平方误差函数衡量合成图像与风格图像在风格上的差异。 为了表达风格层输出的风格,我们先通过extract_features函数计算风格层的输出。

假设该输出的样本数为1,通道数为 𝑐 ,高和宽分别为 ℎ 和 𝑤 ,我们可以将此输出转换为矩阵 𝐗 ,其有 𝑐 行和 ℎ𝑤 列。 这个矩阵可以被看作由 𝑐 个长度为 ℎ𝑤 的向量 𝐱1,…,𝐱𝑐 组合而成的。其中向量 𝐱𝑖 代表了通道 𝑖 上的风格特征

在这些向量的格拉姆矩阵 𝐗𝐗⊤∈ℝ𝑐×𝑐 中, 𝑖 行 𝑗 列的元素 𝑥𝑖𝑗 即向量 𝐱𝑖 和 𝐱𝑗 的内积。它表达了通道 𝑖 和通道 𝑗 上风格特征的相关性。我们用这样的格拉姆矩阵来表达风格层输出的风格。 需要注意的是,当 ℎ𝑤 的值较大时,格拉姆矩阵中的元素容易出现较大的值。 此外,格拉姆矩阵的高和宽皆为通道数 𝑐 。 为了让风格损失不受这些值的大小影响,下面定义的gram函数将格拉姆矩阵除以了矩阵中元素的个数,即 𝑐ℎ𝑤 。

def gram(X):
  # n等于高宽相乘,num_channels是通道数
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n) # 做normalization

自然地,风格损失的平方误差函数的两个格拉姆矩阵输入分别基于合成图像与风格图像的风格层输出。这里假设基于风格图像的格拉姆矩阵gram_Y已经预先计算好了。

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

5.3 全变分损失

有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。 一种常见的去噪方法是全变分去噪(total variation denoising): 假设 𝑥𝑖,𝑗 表示坐标 (𝑖,𝑗) 处的像素值,降低全变分损失

在这里插入图片描述
能够尽可能使邻近的像素值相似。

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

5.4 损失函数

风格转移的损失函数是内容损失、风格损失和总变化损失加权和。 通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

6. 初始化合成图像

在风格迁移中,合成的图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可。

# 这么定义之后,就可以对weight算梯度,对其进行更新
class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self): # 返回模型参数
        return self.weight

下面,我们定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, device, lr, styles_Y):
    # 获取到图片的形状,但是权重参数暂时是随机初始化的,再将整体移动到gpu上
    gen_img = SynthesizedImage(X.shape).to(device)
    # 再用X图片本身的模型参数去重写
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

7. 训练模型

在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数。下面定义了训练循环。

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

现在我们训练模型: 首先将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
# 预先抽取内容特征
content_X, contents_Y = get_contents(image_shape, device)
# 预先抽取风格特征
_, styles_Y = get_styles(image_shape, device)
# 第一个参数是content_X,可知是把内容图片作为初始图片传入
# (当然,也能用风格图片做初始化)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

运行结果:

在这里插入图片描述

我们可以看到,合成图像保留了内容图像的风景和物体,并同时迁移了风格图像的色彩。例如,合成图像具有与风格图像中一样的色彩块,其中一些甚至具有画笔笔触的细微纹理。

8. Q&A

Q1: 越靠近输出端内容还原度越高还是越靠近输入端内容还原度越高?

A1: 越靠近输出端内容还原度越高。

Q2: TV损失可以理解为图像平滑技术吗?

A2:可以

Q3:为什么卷积层的kernel不需要参加训练?

A3: 因为不需要更新它,卷积层在这里只是被用来抽取特征的,我只需要更新合成图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/161351.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

数据导入导出(POI以及easyExcel)

一.概念: 1.场景需求 将一些数据库信息导出为Excel表格 将Excel表格数据导入数据库 大量数据的导入导出操作 常⽤的解决⽅案为:Apache POI与阿⾥巴巴easyExcel2.Apache POI介绍 Apache POI 是基于Office Open XML 标准(OOXML)和M…

96、【树与二叉树】leetcode ——404. 左叶子之和:递归法[先序+后序]+迭代法[先序+层次](C++版本)

题目描述 原题链接:404. 左叶子之和 解题思路 一、递归法 (1)设置一个布尔变量判定(先序遍历) 左孩子一定在每个子树的最左侧,因此设置一个局部参数flag,当每次遍历的是左子树时&#xff0c…

记录一次Oracle Linux7上安装RDBMS 11.2.0.4的问题

参考文档: 文档1 OHASD fails to start on SuSE 11 SP2 on IBM: Linux on System z (Doc ID 1476511.1) As part of the root.sh, roothas.pl or rootcrs.pl is called and an entry is placed in /etc/inittab h1:35:respawn:/etc/init.d/init.ohasd run >/de…

cmake入门之二:调用外部共享库和头文件

cmake入门之二:调用外部共享库和头文件1.创建外部共享库1.1 创建相关文件或文件夹1.1.1 ext目录下的CMakeLists.txt1.1.2 ext目录lib文件夹下1.1.2.1 hello.h1.1.2.2 hello.c1.1.2.3 CMakeLists.txt1.2 编译、安装2.创建当前项目2.1 创建相关文件或文件夹2.1.1 proj…

为什么这么努力,还是赚不到钱?你不能不知道赚钱的三个模型

大部分人赚钱,都是通过能力努力运气,这种需要付出自己的大量时间和精力,并且赚到的钱也是有上限的。小部分人赚钱是通过,个人IP的商业模式来赚钱,并且跟我们传统的生意不一样的是,这个亏钱是有下限的&#…

OpenHarmony轻量级设备xts认证大致流程

因为最近公司在做openharmony开发板xts认证,这里对认证过程和过程中遇到的坑做下记录,也给大家探探路。 1. 开发板适配 OpenHarmony轻量系统的移植比较简单,代码中解耦做得非常好。从代码的设计理念上来看,移植主要是3部分的内容…

.vscode/extensions.json和setting.json 是项目用到的插件推荐列表和设置

文章目录前言一、extensions.json安装推荐插件编辑推荐插件二、setting.json总结前言 在前端项目,文件目录中存在.vscode文件夹,文件夹下一般存在两个文件extensions.json和setting.json。作用是保持所有开发者安装了相同的插件和相同的配置&#xff0c…

访问学者J1签证通常准备那些材料?

访问学者J1签证通常准备那些材料?知识人网小编马上整理一下分享出来作为参考:材料准备1、VISA部分:护照,护照照片,160确认页,签证费收据两联都带,DS2019,D7002,sevis费收…

Linux应用基础——监控与管理进程

目录 一、进程 1.定义 2.组成 3.进程环境包括 4.进程的生命周期 二、描述进程状态 三、相关命令 1.top命令 2.ps命令 二、中断进程 1.基本进程管理信号 2.每个信号的默认操作 3.相关命令 (1)kill命令 (2)killall命…

助力不文明行文识别,基于YOLOv7融合RepVGG的遛狗牵绳行为检测识别分析系统

不知道大家平时在路上走的时候或者在小区的时候有没有遇上过遛狗不牵绳子的行为,我在实际生活里面可是没少遇到过,有时候特别大的一只狗就这么冲过来,主人却还无动于衷,揍他的心都有了,这种行为的确是很不文明&#xf…

Java项目:仓库管理系统设计和实现(java+ssm+springboot+layui)

源码获取:博客首页 "资源" 里下载! 主要功能模块 1.用户模块管理:用户登录、用户注册、用户的查询、添加、删除操作、 2.客户信息管理:.客户列表的展示、添加、修改、删除操作、 3.供应商管理:供应商详情…

Android MVVM之ViewModel的详解与使用

一、介绍 ViewModel 类是一种业务逻辑或屏幕级状态容器。它用于将状态公开给界面,以及封装相关的业务逻辑。 它的主要优点是,它可以缓存状态,并可在配置更改后持久保留相应状态。这意味着在 activity 之间导航时或进行配置更改后(…

【UE4 第一人称射击游戏】45-使用线追踪进行破坏

上一篇:【UE4 第一人称射击游戏】44-瞄准时的武器线追踪步骤:打开“Weapon_Base”删除打印节点添加如下节点,表示追踪线命中目标时执行的逻辑对上面逻辑的解释:首先追踪线命中目标后,显示红色的那个准心然后让目标的健…

阿里云 - MaxCompute研究

一、官方介绍MaxCompute是适用于数据分析场景的企业级SaaS(Software as a Service)模式云数据仓库,提供离线和流式数据的接入,支持大规模数据计算及查询加速能力。MaxCompute适用于100 GB以上规模的存储及计算需求,最大…

全国青少年软件编程(Scratch)等级考试一级考试真题2022年12月——持续更新.....

1.小明想在开始表演之前向大家问好并做自我介绍,应运行下列哪个程序?( ) A. B. C. D. 正确答案:D 答案解析: 外观积木配合显示时间,才能看清楚内容。 2.舞台有两个不同的背景,小猫角色的哪个积木能够切换舞台背景?( ) A.<

UVC静态杀菌模组的工作原理及应用

现代紫外线消毒技术是基于现代防疫学、光学、生物学和物理化学的基础上&#xff0c;利用特殊设计的高效率&#xff0c;高强度和长寿命的C波段紫外光发生装置&#xff0c;产生的强紫外C光照射空气或物体表面&#xff0c;当空气或固体表面中的各种细菌、病毒、寄生虫、水藻以及其…

C/C++ 三维数组和二维数组指针的结合

示例程序&#xff1a;#include <iostream> #include <stdio.h> int main() {int a[3][4] {{1,2,3,4},{2,3,4,5},{3,4,5,6}};int b[3][4] {{10,11,12,13},{11,12,13,14},{12,13,14,15}};int(*aa[2])[4] { a,b };int* p1[3] {a[0],a[1],a[2]};int* p2[3] {b[0],…

小学三年级奥数(和差倍问题)

例题5&#xff1a;学校合唱团成员中,女生人数是男生的3倍,而且女生比男生多80人&#xff0c;合唱团里男生有多少人&#xff1f;女生有多少人&#xff1f;思路分析&#x1f604;&#xff1a;抓住关键语句&#xff0c;女生人数是男生的3倍&#xff0c;那么把男生看成1份&#xff…

《图机器学习》-Graph as Matrix:Page Rnak,

Graph as Matrix一、Graph as Matrix二、PageRank三、PageRank&#xff1a;How to solve&#xff1f;四、Random Walk with Restarts and Personalized PageRank五、Matrix Factorization and Node Embedding一、Graph as Matrix 本小节将从矩阵的角度研究图形分析和学习。 把…

centos 一个ip绑定双网卡

nmcli con show (绿正常&#xff0c;黄白不正常) nmcli con del uuid &#xff08;eg&#xff1a;nmcli con del 585bdacc-314f-423e-a935-18295d0fb48b&#xff09; nmcli con add type bond ifname bond0 mode active-backup &#xff08;bond0只是一个名称&#xff0c;可以…