样式迁移及代码

news2024/9/21 22:44:11

一、定义

1、使用卷积神经网络,自动将一个图像中的风格应用在另一图像之上,即风格迁移;两张输入图像:一张是内容图像,另一张是风格图像

2、训练一些样本使得样本在一些cnn的特征上跟样式图片很相近,在一些cnn的特征上跟内容图片很像,关键在于该怎么样定义内容是一样的,样式是一样的

3、方法:

        (1)我们初始化合成图像,例如将其初始化为内容图像

        (2)该合成图像是风格迁移过程中唯一需要更新的变量,即风格迁移所需迭代的模型参数

        (3)选择一个预训练的卷积神经网络来抽取图像的特征,其中的模型参数在训练中无须更新(VGG系列对于抽取特征效果不错)

4、风格迁移常用的损失函数组成:

        (1)内容损失使合成图像与内容图像在内容特征上接近;

        (2)风格损失使合成图像与风格图像在风格特征上接近;

        (3)全变分损失则有助于减少合成图像中的噪点。

5、样式就是通道里面的像素的统计信息和通道之间的统计信息。假设两张图片它的样式是一样的,那么他们卷积层的一些输出,他们通道之间的统计分布和通道里面统计分布是差不多匹配的;

6、样式损失是指匹配通道内的统计信息和通道之间的统计信息(可以理解为匹配一阶二阶三阶的统计信息,一阶可以是均值,二阶可以是协方差),把feature map里通道作为一维,然后将通道里面的像素拉成一个向量,然后做协方差矩阵,去计算一个多维度的随机变量的二阶信息去匹配分布。

二、代码

1、内容及样式读取

d2l.set_figsize()
#读取内容图片(风景照片)
content_img = d2l.Image.open('../img/rainier.jpg')
#读取样式图片(油画)
style_img = d2l.Image.open('../img/autumn-oak.jpg')

2、预处理与后处理

        预处理函数preprocess对输入图像在RGB三个通道分别做标准化,并将结果变换成卷积神经网络接受的输入格式。 后处理函数postprocess则将输出图像中的像素值还原回标准化之前的值。

#将一张图片转化为可以训练的tensor
def preprocess(img, image_shape):
    transforms = torchvision.transforms.Compose([
        torchvision.transforms.Resize(image_shape),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])
    #对输入图像应用之前定义的变换序列,得到一个张量,并在0维度上添加一个新的维度(通道)
    return transforms(img).unsqueeze(0)

#将一个训练好的tensor转化为图片
def postprocess(img):
    img = img[0].to(rgb_std.device)
    img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)
    return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))

3、抽取图像特征

#使用基于ImageNet数据集预训练的VGG-19模型来抽取图像特征
pretrained_net = torchvision.models.vgg19(pretrained=True)

#浅层是局部,越接近输出层越全局;样式既想要一些局部的信息,也想要一些全局的信息
style_layers, content_layers = [0, 5, 10, 19, 28], [25]

#要的是28层以内的层(为啥老师写的是(max(content_layers + style_layers))
net = nn.Sequential(*[pretrained_net.features[i] for i in
                      range(max(content_layers , style_layers) + 1)])

def extract_features(X, content_layers, style_layers):
    contents = []
    styles = []
    for i in range(len(net)):
        X = net[i](X)
        if i in style_layers:
            styles.append(X)
        if i in content_layers:
            contents.append(X)
    return contents, styles

#对内容图像抽取内容特征
def get_contents(image_shape, device):
    content_X = preprocess(content_img, image_shape).to(device)
    contents_Y, _ = extract_features(content_X, content_layers, style_layers)
    return content_X, contents_Y

#对风格图像抽取风格特征
def get_styles(image_shape, device):
    style_X = preprocess(style_img, image_shape).to(device)
    _, styles_Y = extract_features(style_X, content_layers, style_layers)
    return style_X, styles_Y

4、定义损失函数

# 内容损失
def content_loss(Y_hat, Y):
    # 我们从动态计算梯度的树中分离目标:
    # 这是一个规定的值,而不是一个变量。
    return torch.square(Y_hat - Y.detach()).mean()

#风格损失
def gram(X):
    # X图像的特征,形状为(batch_size, num_channels, height, width);n特征图的元素总数除以通道数(高度和宽度的乘积)
    num_channels, n = X.shape[1], X.numel() // X.shape[1]
    #将特征图的每个通道展开成一个一维的向量
    X = X.reshape((num_channels, n))
    #通过计算特征图中所有通道对的内积,每个元素表示两个通道的相关性
    return torch.matmul(X, X.T) / (num_channels * n)

def style_loss(Y_hat, gram_Y):
    return torch.square(gram(Y_hat) - gram_Y.detach()).mean()

# 全变分损失
def tv_loss(Y_hat):
    return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +
                  torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())

#损失函数
content_weight, style_weight, tv_weight = 1, 1e3, 10

def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):
    # 分别计算内容损失、风格损失和全变分损失
    contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(
        contents_Y_hat, contents_Y)]
    styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(
        styles_Y_hat, styles_Y_gram)]
    tv_l = tv_loss(X) * tv_weight
    # 对所有损失求和
    l = sum(10 * styles_l + contents_l + [tv_l])
    return contents_l, styles_l, tv_l, l

5、初始化合成图像

        将合成的图像视为模型参数。模型的前向传播只需返回模型参数即可

class SynthesizedImage(nn.Module):
    def __init__(self, img_shape, **kwargs):
        super(SynthesizedImage, self).__init__(**kwargs)
        self.weight = nn.Parameter(torch.rand(*img_shape))

    def forward(self):
        return self.weight

        创建了合成图像的模型实例,并将其初始化为图像X,风格图像在各个风格层的格拉姆矩阵styles_Y_gram将在训练前预先计算好。

def get_inits(X, device, lr, styles_Y):
    #生成图像的起始状态就是 X 的值,这样训练快一点
    gen_img = SynthesizedImage(X.shape).to(device)
    gen_img.weight.data.copy_(X.data)
    #Adam 优化器,用于优化生成图像的参数,目的是最小化损失函数,从而生成符合目标风格和内容的图像
    trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)
    #对每张样式图像 Y 计算 Gram 矩阵,用于衡量风格特征
    styles_Y_gram = [gram(Y) for Y in styles_Y]
    return gen_img(), styles_Y_gram, trainer

6、训练模型

def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):
    X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)
    scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)
    animator = d2l.Animator(xlabel='epoch', ylabel='loss',
                            xlim=[10, num_epochs],
                            legend=['content', 'style', 'TV'],
                            ncols=2, figsize=(7, 2.5))
    for epoch in range(num_epochs):
        trainer.zero_grad()
        contents_Y_hat, styles_Y_hat = extract_features(
            X, content_layers, style_layers)
        contents_l, styles_l, tv_l, l = compute_loss(
            X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)
        l.backward()
        trainer.step()
        scheduler.step()
        if (epoch + 1) % 10 == 0:
            animator.axes[1].imshow(postprocess(X))
            animator.add(epoch + 1, [float(sum(contents_l)),
                                     float(sum(styles_l)), float(tv_l)])
    return X

device, image_shape = d2l.try_gpu(), (300, 450)
net = net.to(device)
content_X, contents_Y = get_contents(image_shape, device)
_, styles_Y = get_styles(image_shape, device)
output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)

三、总结

1、风格迁移常用的损失函数由3部分组成:

        (1)内容损失使合成图像与内容图像在内容特征上接近;

        (2)风格损失令合成图像与风格图像在风格特征上接近;

        (3)全变分损失则有助于减少合成图像中的噪点。

2、我们可以通过预训练的卷积神经网络来抽取图像的特征,并通过最小化损失函数来不断更新合成图像来作为模型参数。

3、我们使用格拉姆矩阵表达风格层输出的风格。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1945114.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

PHP教程002:PHP变量介绍

文章目录 一、PHP程序1、PHP标记2、PHP代码3、语句结束符;4、注释 二、PHP变量2.1 声明变量2.2 赋值运算符3、变量命名规则 一、PHP程序 PHP文件的默认扩展名是".php"PHP文件可以包含html、css、js 序号组成描述1<?php ... ?>PHP标记2PHP代码函数、数组、流…

二、原型模式

文章目录 1 基本介绍2 实现方式深浅拷贝目标2.1 使用 Object 的 clone() 方法2.1.1 代码2.1.2 特性2.1.3 实现深拷贝 2.2 在 clone() 方法中使用序列化2.2.1 代码 2.2.2 特性 3 实现的要点4 Spring 中的原型模式5 原型模式的类图及角色5.1 类图5.1.1 不限制语言5.1.2 在 Java 中…

Java之集合底层-数据结构

Java集合之数据结构 1 概述 数据结构是计算机科学中研究数据组织、存储和操作的一门学科。它涉及了如何组织和存储数据以及如何设计和实现不同的数据操作算法和技术。常见的据结构有线性数据结构&#xff08;含数组、链表、栈和队列等&#xff09;&#xff0c;非线性数据结构…

四、GD32 MCU 常见外设介绍(1)RCU 时钟介绍

系统架构 1.RCU 时钟介绍 众所周知&#xff0c;时钟是MCU能正常运行的基本条件&#xff0c;就好比心跳或脉搏&#xff0c;为所有的工作单元提供时间 基数。时钟控制单元提供了一系列频率的时钟功能&#xff0c;包括多个内部RC振荡器时钟(IRC)、一个外部 高速晶体振荡器时钟(H…

Meta发布最强AI模型,扎克伯格公开信解释为何支持开源?

凤凰网科技讯 北京时间7月24日&#xff0c;脸书母公司Meta周二发布了最新大语言模型Llama 3.1&#xff0c;这是该公司目前为止推出的最强大开源模型&#xff0c;号称能够比肩OpenAI等公司的私有大模型。与此同时&#xff0c;Meta CEO马克扎克伯格(Mark Zuckerberg)发表公开信&a…

力扣1792.最大平均通过率

力扣1792.最大平均通过率 每个班级加上一个人以后得通过率增量不同 将优先级最高的班级放队列顶&#xff0c;每次操作即可 class Solution {public:struct Radio{int pass;int total;//满足该条件 oth的优先级更高bool operator < (const Radio& oth)const{return (l…

【中项】系统集成项目管理工程师-第4章 信息系统架构-4.1架构基础

前言&#xff1a;系统集成项目管理工程师专业&#xff0c;现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试&#xff0c;全称为“全国计算机与软件专业技术资格&#xff08;水平&#xff09;考试”&…

Java习题二

一题目要求&#xff1a; 二具体代码&#xff1a; package three;import sun.util.resources.LocaleData;import java.time.LocalDate; import java.time.format.DateTimeFormatter; import java.util.*;public class test {public static void main(String[] args) {String us…

php--高级文件绕过

&#x1f3bc;个人主页&#xff1a;金灰 &#x1f60e;作者简介:一名简单的大一学生;易编橙终身成长社群的嘉宾.✨ 专注网络空间安全服务,期待与您的交流分享~ 感谢您的点赞、关注、评论、收藏、是对我最大的认可和支持&#xff01;❤️ &#x1f34a;易编橙终身成长社群&#…

UCOS-III 任务调度接口(OSSched)详解

在实时操作系统uC/OS-III中&#xff0c;调度器是核心组件之一&#xff0c;它负责管理任务的执行顺序和优先级。本文将详细解析uC/OS-III内核中的调度函数OSSched。 OSSched函数简介 OSSched函数用于检查并确定是否有更高优先级的任务需要运行。该函数通常在任务级别代码中调用…

【快速逆向四/无过程/有源码】浙江工商职业技术学院 统一身份认证

逆向日期&#xff1a;2024.07.23 使用工具&#xff1a;Node.js 加密方法&#xff1a;RSAUtils 文章全程已做去敏处理&#xff01;&#xff01;&#xff01; 【需要做的可联系我】 AES解密处理&#xff08;直接解密即可&#xff09;&#xff08;crypto-js.js 标准算法&#xf…

使用 Socket和动态代理以及反射 实现一个简易的 RPC 调用

使用 Socket、动态代理、反射 实现一个简易的 RPC 调用 我们前面有一篇 socket 的文章&#xff0c;再之前&#xff0c;还有一篇 java动态代理的文章&#xff0c;本文用到了那两篇文章中的知识点&#xff0c;需要的话可以回顾一下。 下面正文开始&#xff1a; 我们的背景是一个…

掌握Rust:函数、闭包与迭代器的综合运用

掌握Rust&#xff1a;函数、闭包与迭代器的综合运用 引言&#xff1a;解锁 Rust 高效编程的钥匙函数定义与模式匹配&#xff1a;构建逻辑的基石高阶函数与闭包&#xff1a;代码复用的艺术迭代器与 for 循环&#xff1a;高效数据处理的引擎综合应用案例&#xff1a;构建一个简易…

最新App崩溃率出炉!这样的行业均值水平如何?

前不久发布的《2024 Q1 移动应用性能体验报告》(以下简称报告),公布了最新的App崩溃率行业均值。基于友盟覆盖的终端设备,观测启动次数和崩溃次数,《报告》综合计算得出iOS APP崩溃率0.21%,Android Java崩溃率0.22%、native 0.16%、ANR 0.53%。 作为国内领先的第三方全域数据智…

PyMol在Windows系统上的免费安装指南

PyMOL是一个强大的分子可视化工具&#xff0c;广泛应用于生物化学、分子生物学和材料科学等领域。对于需要在Windows系统上进行分子结构分析和可视化的用户来说&#xff0c;安装一个免费版本的PyMol至关重要。本文将提供详细的步骤&#xff0c;指导如何在Windows系统上免费安装…

有哪些好用的 AI 学术研究工具和科研工具?

AI视频生成&#xff1a;小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ AI 应用其实分两个层面&#xff0c;第一是模型&#xff0c;第二是应用。现在很多模型厂家都是既做 toC 的对话应用&#xff0c;也做 t…

Jmeter性能测试进行参数化操作

在使用Jmeter进行性能测试中,Jmeter的基本操作是肯定要会的,除此之外,还需要会多并发压测配置线程,多用户并发参数化的设置等等技术.下面就给大家介绍一下这个方面的内容: 一.Jmeter单个请求部署 1.设置线程组 注意:如果要使用调度器,那么循环次数的”永远”的选项一定要记得…

MySQL-视图、存储过程和触发器

一、视图的定义和使用 视图是从一个或者几个基本表&#xff08;或视图&#xff09;导出的表。它与基本表不同&#xff0c;是一个虚表,视图只能用来查询。不能做增删改查(虚拟的表) 1.视图的作用 简化查询重写格式化数据频繁访问数据库过滤数据 2.创建视图 -- 创建视图 -- 语法…

计算机三级嵌入式笔记(二)——嵌入式处理器

目录 考点1 嵌入式处理器的结构类型 考点2 嵌入式处理器简介 考点3 ARM处理器概述 考点4 处理器和处理器核 考点5 ARM 处理器的分类 考点6 经典 ARM 处理器 考点7 ARM Cortex 嵌入式处理器 考点8 ARM Cortex实时嵌入式处理器 考点9 ARM Cortex 应用处理器 考点10 AR…

Python群体趋向性潜关联有向无向多图层算法

&#x1f3af;要点 &#x1f3af;算法模型图层节点和边数学定义 | &#x1f3af;算法应用于贝叶斯推理或最大似然优化概率建模的多图层生成模型 | &#x1f3af;算法结合图结构边和节点属性 | &#x1f3af;对比群体关联预测推理生成式期望最大化多图层算法 | &#x1f3af;使…