【深度学习笔记】计算机视觉——风格迁移

news2024/11/26 2:44:34

风格迁移

摄影爱好者也许接触过滤波器。它能改变照片的颜色风格,从而使风景照更加锐利或者令人像更加美白。但一个滤波器通常只能改变照片的某个方面。如果要照片达到理想中的风格,可能需要尝试大量不同的组合。这个过程的复杂程度不亚于模型调参。

本节将介绍如何使用卷积神经网络,自动将一个图像中的风格应用在另一图像之上,即风格迁移(style transfer) :cite:Gatys.Ecker.Bethge.2016
这里我们需要两张输入图像:一张是内容图像,另一张是风格图像
我们将使用神经网络修改内容图像,使其在风格上接近风格图像。
例如, :numref:fig_style_transfer中的内容图像为本书作者在西雅图郊区的雷尼尔山国家公园拍摄的风景照,而风格图像则是一幅主题为秋天橡树的油画。
最终输出的合成图像应用了风格图像的油画笔触让整体颜色更加鲜艳,同时保留了内容图像中物体主体的形状。

在这里插入图片描述

🏷fig_style_transfer

方法

:numref:fig_style_transfer_model用简单的例子阐述了基于卷积神经网络的风格迁移方法。
首先,我们初始化合成图像,例如将其初始化为内容图像。
该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数。
然后,我们选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新。
这个深度卷积神经网络凭借多个层逐级抽取图像的特征,我们可以选择其中某些层的输出作为内容特征或风格特征。
以 :numref:fig_style_transfer_model为例,这里选取的预训练的神经网络含有3个卷积层,其中第二层输出内容特征,第一层和第三层输出风格特征。

在这里插入图片描述

🏷fig_style_transfer_model

接下来,我们通过前向传播(实线箭头方向)计算风格迁移的损失函数,并通过反向传播(虚线箭头方向)迭代模型参数,即不断更新合成图像。
风格迁移常用的损失函数由3部分组成:

  1. 内容损失使合成图像与内容图像在内容特征上接近;
  2. 风格损失使合成图像与风格图像在风格特征上接近;
  3. 全变分损失则有助于减少合成图像中的噪点。

最后,当模型训练结束时,我们输出风格迁移的模型参数,即得到最终的合成图像。

在下面,我们将通过代码来进一步了解风格迁移的技术细节。

[阅读内容和风格图像]

首先,我们读取内容和风格图像。
从打印出的图像坐标轴可以看出,它们的尺寸并不一样。

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('../img/rainier.jpg')
d2l.plt.imshow(content_img);


在这里插入图片描述

style_img = d2l.Image.open('../img/autumn-oak.jpg')
d2l.plt.imshow(style_img);


在这里插入图片描述

[预处理和后处理]

下面,定义图像的预处理函数和后处理函数。
预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。
后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。
由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

[抽取图像特征]

我们使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征 :cite:Gatys.Ecker.Bethge.2016

pretrained_net = torchvision.models.vgg19(pretrained=True)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /home/ci/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth



  0%|          | 0.00/548M [00:00<?, ?B/s]

为了抽取图像的内容特征和风格特征,我们可以选择VGG网络中某些层的输出。
一般来说,越靠近输入层,越容易抽取图像的细节信息;反之,则越容易抽取图像的全局信息。
为了避免合成图像过多保留内容图像的细节,我们选择VGG较靠近输出的层,即内容层,来输出图像的内容特征。
我们还从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层
正如 :numref:sec_vgg中所介绍的,VGG网络使用了5个卷积块。
实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。
这些层的索引可以通过打印pretrained_net实例获取。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或风格层之间的所有层。
下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。
由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

下面定义两个函数:get_contents函数对内容图像抽取内容特征;
get_styles函数对风格图像抽取风格特征。
因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。
由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

[定义损失函数]

下面我们来描述风格迁移的损失函数。
它由内容损失、风格损失和全变分损失3部分组成。

内容损失

与线性回归中的损失函数类似,内容损失通过平方误差函数衡量合成图像与内容图像在内容特征上的差异。
平方误差函数的两个输入均为extract_features函数计算所得到的内容层的输出。

def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

风格损失

风格损失与内容损失类似,也通过平方误差函数衡量合成图像与风格图像在风格上的差异。
为了表达风格层输出的风格,我们先通过extract_features函数计算风格层的输出。
假设该输出的样本数为1,通道数为 c c c,高和宽分别为 h h h w w w,我们可以将此输出转换为矩阵 X \mathbf{X} X,其有 c c c行和 h w hw hw列。
这个矩阵可以被看作由 c c c个长度为 h w hw hw的向量 x 1 , … , x c \mathbf{x}_1, \ldots, \mathbf{x}_c x1,,xc组合而成的。其中向量 x i \mathbf{x}_i xi代表了通道 i i i上的风格特征。

在这些向量的格拉姆矩阵 X X ⊤ ∈ R c × c \mathbf{X}\mathbf{X}^\top \in \mathbb{R}^{c \times c} XXRc×c中, i i i j j j列的元素 x i j x_{ij} xij即向量 x i \mathbf{x}_i xi x j \mathbf{x}_j xj的内积。它表达了通道 i i i和通道 j j j上风格特征的相关性。我们用这样的格拉姆矩阵来表达风格层输出的风格。
需要注意的是,当 h w hw hw的值较大时,格拉姆矩阵中的元素容易出现较大的值。
此外,格拉姆矩阵的高和宽皆为通道数 c c c
为了让风格损失不受这些值的大小影响,下面定义的gram函数将格拉姆矩阵除以了矩阵中元素的个数,即 c h w chw chw

def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)

自然地,风格损失的平方误差函数的两个格拉姆矩阵输入分别基于合成图像与风格图像的风格层输出。这里假设基于风格图像的格拉姆矩阵gram_Y已经预先计算好了。

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

全变分损失

有时候,我们学到的合成图像里面有大量高频噪点,即有特别亮或者特别暗的颗粒像素。
一种常见的去噪方法是全变分去噪(total variation denoising):
假设 x i , j x_{i, j} xi,j表示坐标 ( i , j ) (i, j) (i,j)处的像素值,降低全变分损失

∑ i , j ∣ x i , j − x i + 1 , j ∣ + ∣ x i , j − x i , j + 1 ∣ \sum_{i, j} \left|x_{i, j} - x_{i+1, j}\right| + \left|x_{i, j} - x_{i, j+1}\right| i,jxi,jxi+1,j+xi,jxi,j+1

能够尽可能使邻近的像素值相似。

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

损失函数

[风格转移的损失函数是内容损失、风格损失和总变化损失的加权和]。
通过调节这些权重超参数,我们可以权衡合成图像在保留内容、迁移风格以及去噪三方面的相对重要性。

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

[初始化合成图像]

在风格迁移中,合成的图像是训练期间唯一需要更新的变量。因此,我们可以定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可。

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

下面,我们定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

[训练模型]

在训练模型进行风格迁移时,我们不断抽取合成图像的内容特征和风格特征,然后计算损失函数。下面定义了训练循环。

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

现在我们[训练模型]:
首先将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)


在这里插入图片描述

我们可以看到,合成图像保留了内容图像的风景和物体,并同时迁移了风格图像的色彩。例如,合成图像具有与风格图像中一样的色彩块,其中一些甚至具有画笔笔触的细微纹理。

小结

  • 风格迁移常用的损失函数由3部分组成:(1)内容损失使合成图像与内容图像在内容特征上接近;(2)风格损失令合成图像与风格图像在风格特征上接近;(3)全变分损失则有助于减少合成图像中的噪点。
  • 我们可以通过预训练的卷积神经网络来抽取图像的特征,并通过最小化损失函数来不断更新合成图像来作为模型参数。
  • 我们使用格拉姆矩阵表达风格层输出的风格。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1493147.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

2024-阿里巴巴灵犀互娱校招内推

灵犀互娱是阿里集团旗下研运一体游戏品牌&#xff0c;在业务模式上&#xff0c;灵犀互娱面向全球&#xff0c;研运一体&#xff0c;坚持精品&#xff0c;打造爆款&#xff0c;重视服务玩家。 访问链接即刻开启内推&#xff1a;https://talent.lingxigames.com/campus/qrcode/…

第十二篇:学习python数据清洗

文章目录 一、啥是数据清洗二、将表格数据导入pandas中1. 准备工作2. 引入csv文件2.1 引入pandas库2.2 读取文件/修改名称3.2 快速浏览数据2.4 修改名字2.5 查找缺失值2.6 删除缺失值 3. 引入Excel文件3.1 引入pandas库3.2 读取Excel文件的人均GDP数据3.3 查看数据类型和non-nu…

速卖通关键字搜索API接口实战:Python代码与搜索策略解析

一、速卖通关键字搜索API简介 速卖通&#xff08;AliExpress&#xff09;作为阿里巴巴旗下的国际电商平台&#xff0c;为卖家和买家提供了便捷的交易渠道。其开放平台提供的API接口允许开发者集成速卖通的各种功能&#xff0c;其中之一就是关键字搜索API。通过这个API&#xf…

备考2024年北京高考数学:20114~2023十年选择题练习和解析

距离2024年高考还有三个月的时间&#xff0c;如何用三个月的时间再提高北京数学高考的成绩&#xff1f;吃透历年真题以及背后的知识点是行之有效的方法 之一。 今天我们来看一下2014-2023年的北京市高考数学的选择题&#xff0c;从过去十年&#xff08;2014-2023&#xff09;的…

【JSON2WEB】09 Amis-editor的代码移植到json2web

【JSON2WEB】01 WEB管理信息系统架构设计 【JSON2WEB】02 JSON2WEB初步UI设计 【JSON2WEB】03 go的模板包html/template的使用 【JSON2WEB】04 amis低代码前端框架介绍 【JSON2WEB】05 前端开发三件套 HTML CSS JavaScript 速成 【JSON2WEB】06 JSON2WEB前端框架搭建 【J…

大语言模型的Scaling laws(尺度定律)的正确认识

源自&#xff1a;人工智能前沿讲习 “人工智能技术与咨询” 发布 实验一 声明:公众号转载的文章及图片出于非商业性的教育和科研目的供大家参考和探讨&#xff0c;并不意味着支持其观点或证实其内容的真实性。版权归原作者所有&#xff0c;如转载稿涉及版权等问题&#xff0c;…

数据分析案例-二手车用户数据可视化分析(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

为什么被蜜蜂蛰了会肿得像馒头

有的人却只是一点点小鼓包。 病情分析&#xff1a;蜜蜂体内存在一种有毒物质&#xff0c;其主要成分是蚁酸&#xff0c;这种成分进入人体后&#xff0c;会和血液发生反应&#xff0c;导致皮肤表现出红肿和瘙痒的症状。一些人群还会对蜜蜂表现出过敏反应&#xff0c;此类人群在…

活动策划整体流程需要考虑哪些要素

传媒如春雨&#xff0c;润物细无声&#xff0c;大家好&#xff0c;我是51媒体网胡老师。 活动策划整体流程中需要考虑的要素非常多&#xff0c;这些要素通常涵盖从策划前的准备到活动结束后的总结&#xff0c;以下是一些关键的考虑要素&#xff1a; 活动目标&#xff1a;确定活…

单片机为什么需要时钟?2种时钟电路对比?

目录 一、晶体振荡器&#xff08;Crystal Oscillator&#xff09;的核心知识 二、单片机为什么需要时钟电路&#xff1f; 三、单片机的时钟电路方案 01、外部晶振方案 02、内部晶振方案 四、总结 单片机研发设计的项目中&#xff0c;它的最小电路系统包含 电源电路复位…

QT:颜色选择器

普通 Qt提供了一个现成的QColorDialog类。 用法: #include <QColorDialog>QColor color QColorDialog::getColor(Qt::white, this); if(!color.isValid()){//点击 关闭 或 cancel 颜色无效 }else {ui->text->setText(color.name());//类似##ffffQRgb rgb colo…

Android9-W517-使用NotificationListenerService监听通知

目录 一、前言 二、前提 三、方案 方案一 方案二 方案三 方案四 方案五 方案六 方案七 四、关于NotificationListenerService类头注释 五、结论 一、前言 NotificationListenerService可以让应用监听所有通知&#xff0c;但是无法获得监听通知的权限&#xff0c;如…

x86 Ubuntu上编译eudev给龙芯loongarch64架构主机使用

1、下载eudev库eudev-master.zip&#xff0c;链接&#xff1a;eudev库官方地址 2、下载龙芯的交叉编译工具&#xff1a;loongson-gnu-toolchain-8.3-x86_64-loongarch64-linux-gnu-rc1.2.tar.xz&#xff0c;链接&#xff1a;龙芯交叉编译官方地址 3、交叉编译器环境搭建 (1)、…

Spring-Cloud中服务发现是什么?干什么的?怎么用?

&#x1f413; 是什么 Spring Cloud通过Eureka或Consul等服务注册与发现组件来实现微服务间的相互感知。服务提供者将自己的服务信息注册到注册中心&#xff0c;服务消费者从注册中心获取服务提供者的信息&#xff0c;从而进行服务调用。 &#x1f413; 干什么 在Spring Cloud…

nodejs安装教程(及过程中的易错)

nodejs&#xff1a;Nodejs 是基于 Chrome 的 V8 引擎开发的一个 C 程序&#xff0c;目的是提供一个 JS 的运行环境。 npm&#xff1a;npm 是 Node Package Manager 的缩写&#xff0c;意思是 Node 的包管理系统&#xff0c;是最大的软件包仓库 下载nodejs 首先我们需要在node…

VNC 与 虚拟机 保姆级 快速入门图文指导

Time: 2024年3月5日22:31:49 By[ V ]: MemoryErHero 重要的事情先说三遍: 1 虚拟机内无需安装 VNC-Viewer-7.0.1-Windows 2 虚拟机内无需安装 VNC-Viewer-7.0.1-Windows 3 虚拟机内无需安装 VNC-Viewer-7.0.1-Windows 1 VNC 图文安装 流程 ① VNC-Viewer-7.0.1-Windows.e…

【Python】Python注册微服务到nacos

Python注册微服务到Nacos 1.Nacos部署 github 的nacos项目的发布页&#xff08;Releases alibaba/nacos GitHub &#xff09;&#xff0c;选择所要下载的nacos版本&#xff0c;在nacos下方的assets中选择安装包进行下载。 解压nacos安装包到指定目录。 tar -zxvf nacos-ser…

SpringCloud-MQ消息队列

一、消息队列介绍 MQ (MessageQueue) &#xff0c;中文是消息队列&#xff0c;字面来看就是存放消息的队列。也就是事件驱动架构中的Broker。消息队列是一种基于生产者-消费者模型的通信方式&#xff0c;通过在消息队列中存放和传递消息&#xff0c;实现了不同组件、服务或系统…

微信小程序触屏事件_上划下划事件

一、微信小程序触屏事件 bindtouchstart&#xff1a;手指触摸动作开始 bindtouchmove&#xff1a;手指触摸后移动 bindend&#xff1a;手指触摸动作结束 属性类型说明touchesArray触摸事件&#xff0c;当前停留在屏幕中的触摸点信息的数组 Touch 对象 属性类型说明identi…

【数据结构和算法初阶(C语言)】顺序表+单链表经典例题图文详解(题解大合集,搭配图文演示详解,一次吃饱吃好)

目录 1.移除链表元素 1.1思路1&#xff1a;遍历删除 1. 2 思路2&#xff1a;尾插法 2.反转链表 3.链表的中间节点 3.1解题思想及过程 3.2快慢指针思想解题---变式&#xff1a;返回链表的倒数第K个节点 4.合并两个有序链表 4.1解题思想 1取小的尾插 5.反转链表 6…