李沐49_样式迁移——自学笔记

news2025/1/11 22:54:50

样式迁移

将样式图片中的样式迁移到内容图片上,合成图片,例如将照片转换成漫画形式或者是油画风。

基于CNN的样式迁移

读取图片和样式风格

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

d2l.set_figsize()
content_img = d2l.Image.open('rainier.jpg')
d2l.plt.imshow(content_img);

在这里插入图片描述

style_img = d2l.Image.open('autumn-oak.jpg')
d2l.plt.imshow(style_img);

在这里插入图片描述

预处理和后处理

预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。

后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。 由于图像打印函数要求每个像素的浮点数值在0~1之间,我们对小于0和大于1的值分别取0和1。

rgb_mean = torch.tensor([0.485, 0.456, 0.406])
rgb_std = torch.tensor([0.229, 0.224, 0.225])

def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    return transforms(img).unsqueeze(0)

def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

抽取图像特征

基于ImageNet数据集预训练的VGG-19模型来抽取图像特征

pretrained_net = torchvision.models.vgg19(pretrained=True)
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.
  warnings.warn(
/usr/local/lib/python3.10/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.
  warnings.warn(msg)
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [00:03<00:00, 155MB/s]

从VGG中选择不同层的输出来匹配局部和全局的风格,这些图层也称为风格层。VGG网络使用了5个卷积块。

实验中,我们选择第四卷积块的最后一个卷积层作为内容层,选择每个卷积块的第一个卷积层作为风格层。 这些层的索引可以通过打印pretrained_net实例获取。

style_layers, content_layers = [0, 5, 10, 19, 28], [25]

使用VGG层抽取特征时,我们只需要用到从输入层到最靠近输出层的内容层或风格层之间的所有层。 下面构建一个新的网络net,它只保留需要用到的VGG的所有层。

net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers + style_layers) + 1)])

给定输入X,如果我们简单地调用前向传播net(X),只能获得最后一层的输出。 由于我们还需要中间层的输出,因此这里我们逐层计算,并保留内容层和风格层的输出。

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

get_contents函数对内容图像抽取内容特征;

get_styles函数对风格图像抽取风格特征。

因为在训练时无须改变预训练的VGG的模型参数,所以我们可以在训练开始之前就提取出内容特征和风格特征。由于合成图像是风格迁移所需迭代的模型参数,我们只能在训练过程中通过调用extract_features函数来抽取合成图像的内容特征和风格特征。

def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

损失函数

内容损失、风格损失和全变分损失3部分组成

1.内容损失

def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

2.风格损失

def gram(X):
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    X = X.reshape((num_channels, n))
    return torch.matmul(X, X.T) / (num_channels * n)
def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

3.全变分损失

def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

损失函数

风格转移的损失函数是内容损失、风格损失和总变化损失的加权和。

content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

初始化合成图像

定义一个简单的模型SynthesizedImage,并将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

定义get_inits函数。该函数创建了合成图像的模型实例,并将其初始化为图像X。风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, device, lr, styles_Y):
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

训练模型

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

我们训练模型: 首先将内容图像和风格图像的高和宽分别调整为300和450像素,用内容图像来初始化合成图像。

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1623984.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Facebook的魅力魔法:探访数字社交的奇妙世界

1. 社交媒体的演变与Facebook的角色 在数字化时代&#xff0c;社交媒体已经成为我们日常生活中不可或缺的一部分。而在众多的社交媒体平台中&#xff0c;Facebook 以其深厚的历史和广泛的影响力&#xff0c;成为了全球数亿用户沟通、分享和互动的主要场所。从其初创之时起&…

【学习AI-相关路程-自我总结-相关入门-自我学习-NVIDIA-Jetson】

【学习AI-相关路程-自我总结-相关入门-自我学习】 1、前言2、思考前进方向3、学习路线1、基础知识阶段2、初级准备阶段3、中级学习阶段4、高级实战阶段 4、自我的努力5、学习平台6、自己总结 1、前言 最近AI相关比较火的&#xff0c;对于程序员&#xff0c;或者走这行的人来说…

Flutter开发好用插件url_launcher详解-启动 URL

文章目录 url_launcher介绍安装用法错误处理自定义行为其他功能 url_launcher介绍 url_launcher 是一个 Flutter 插件&#xff0c;用于启动 URL。它支持网络、电话、短信和电子邮件方案。您可以使用它从您的 Flutter 应用程序中打开网站、拨打号码、发送短信或撰写电子邮件。 …

群组分析方法

目录 1.什么是群组分析方法 2.基本原理 3.群组分析方法分类 3.1.层次方法 3.2.划分方法 3.3.密度基方法 ​​​​​​​3.4.模型基方法 4.群组评估 5.应用步骤 1.什么是群组分析方法 群组分析&#xff08;Cluster Analysis&#xff09;是数据分析中的一种重要方法&…

git lab 2.7版本修改密码命令

1.gitlab-rails console -e production Ruby: ruby 2.7.5p203 (2021-11-24 revision f69aeb8314) [x86_64-linux] GitLab: 14.9.0-jh (51fb4a823f6) EE GitLab Shell: 13.24.0 PostgreSQL: 12.7 2根据用户名修改密码 user User.find_by(username: ‘username’) # 替换’use…

ABAP 遗传算法求解

本文无文本解析&#xff0c;结尾处有简单装箱问题的示例&#xff0c;该算法收敛结果较慢&#xff0c;仅供ABAP爱好者参考&#xff0c;实践&#xff0c;实际应用建议使用线性规划。可直接复制后在系统中使用。 对象自定义逻辑版本-截图 对象自定义逻辑版本-对象描述 INIT I…

Navicat连接SQLSever报错:[08001] MicrosoftTCP Provider 远程主机强迫关闭了一个现有的连接

Navicat连接SQLSever报错&#xff1a;[08001] [Microsoft][SQL Server Native Client 10.0]TCP Provider: 远程主机强迫关闭了一个现有的连接 问题分析 旧版的MSSQL 如果不是最新版的&#xff0c;可以去这安装以下即可。 最新版的MSSQL 如果是安装最新版的MSSQL连接不上很正…

DFS和回溯专题:全排列 II

DFS和回溯专题&#xff1a;全排列 II 题目链接: 全排列 II 参考题解 代码随想录 题目描述 代码纯享版 class Solution {public List<List<Integer>> list_all new ArrayList();public List<Integer> list new ArrayList();public int[] res;public Lis…

探索 Python 的动态类型系统:变量引用、不可变性及高效内存管理与垃圾回收机制的深入分析

文章目录 1. 动态类型及其内存管理解析1.1 变量与对象的引用关系1.2 对象的不可变性和内存地址的变化 2. 垃圾回收与内存优化策略2.1 动态内存分配的基础2.2 Python 的垃圾回收 Python作为一种流行的高级编程语言&#xff0c;以其代码的易读性和简洁性著称。尤其是它的动态类型…

开源数据集分享———猫脸码客

猫脸码客作为一个专注于开源数据集分享的公众号&#xff0c;致力于为广大用户提供丰富、优质的数据资源。我们精心筛选和整理各类开源数据集&#xff0c;涵盖机器学习、深度学习、自然语言处理等多个领域&#xff0c;以满足不同用户的需求。 (https://img-blog.csdnimg.cn/d98…

找不到vcruntime140_1.dll,无法继续执行代码的多种解决方法

在启动电脑并着手进行日常工作的过程中&#xff0c;当我尝试运行一款至关重要的软件时&#xff0c;系统突然弹出一个令人困扰的错误提示&#xff1a;“由于找不到vcruntime140_1.dll&#xff0c;无法继续执行代码”&#xff0c;这个错误信息明确指出&#xff0c;由于缺失了vcru…

SpringCloud系列(16)--将服务提供者Provider注册进Zookeeper

前言&#xff1a;在上一章节中我们说明了一些关于Eureka自我保护模式&#xff0c;而且自上一章节起关于Eureka的知识已经讲的差不多了&#xff0c;不过因为Eureka已经停更了&#xff0c;为了安全考虑&#xff0c;我们要用还在更新维护的注册中心来取代Eureka&#xff0c;而本章…

【Java--数据结构】链表经典OJ题详解(上)

欢迎关注个人主页&#xff1a;逸狼 创造不易&#xff0c;可以点点赞吗~ 如有错误&#xff0c;欢迎指出~ 目录 谈谈头插、头删、尾插、头插的时间复杂度 反转一个单链表 链表的中间结点 返回倒数第k个结点 合并两个链表 谈谈头插、头删、尾插、头插的时间复杂度 头插和头删的时…

echerts横向一根柱子不同参数不同色的斜切式柱状图的做法

思路&#xff1a;网上搜寻了许久&#xff0c;最终参照官网的斜切样式制作出来了&#xff0c;这种方法采用的是遮盖原有柱状图顶端造成视觉上看起来是斜切的效果。目前这是我能想到最好的办法了&#xff0c;也欢迎大家提供其他的方法。 参照echerts图表网&#xff1a; echarts…

数据结构四:线性表之带头结点的单向循环循环链表的设计

前面两篇介绍了线性表的顺序和链式存储结构&#xff0c;其中链式存储结构为单向链表&#xff08;即一个方向的有限长度、不循环的链表&#xff09;&#xff0c;对于单链表&#xff0c;由于每个节点只存储了向后的结点的地址&#xff0c;到了尾巴结点就停止了向后链的操作。也就…

2023平航杯——手机取证复现

手机最近连接的wifi"只有红茶可以吗"的密码是&#xff1f;【标准格式&#xff1a;ABCabc123!#】 手机上安装了某个运动软件&#xff0c;它的包名是&#xff1f;【标准格式&#xff1a;com.baidu.gpt】 com.dizhisoft.changdongli 该运动软件中最近一次运动记录的起点…

【产品经理修炼之道】- 如何从0到1搭建B端产品

随着数字化转型的不断深化,B端产品也面临着升级。本文总结分析了如何从0到1搭建B端产品,希望对你有所帮助。 背景 随着公司数字化转型的不断的推进和实施,数字化转型成功越来越明显的体现在财务报上,这也增强了管理层对数字转型的信心,在推进中我们也发现几年建设的系统的…

【继承和多态】

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、pandas是什么&#xff1f;二、使用步骤 1.引入库2.读入数据总结 前言 提示&#xff1a;这里可以添加本文要记录的大概内容&#xff1a; 例如&#xff1a;…

ios微信小程序禁用下拉上拉

第一步&#xff1a; page.json配置页面的"navigationStyle":"custom"属性&#xff0c;禁止页面滑动 "navigationStyle":"custom" 第二步&#xff1a; 页面里面使用scroll-view包裹内容&#xff0c;内容可以内部滑动 <view class&…

Java中的 JDK环境变量配置详解

JDK安装后&#xff0c;接下我们来学习一个补充知识&#xff0c;叫做Path环境变量 什么是Path环境变量&#xff1f; Path环境变量是让系统程序的路径&#xff0c;方便程序员在命令行窗口的任意目录下启动程序&#xff1b; 如何配置环境变量呢&#xff1f; 比如把QQ的启动程序&a…