吴恩达教授深度学习--神经风格转换算法

news2025/2/4 23:58:55

什么是神经风格迁移?

假设你有一张内容图片C(Content)和一张具有独特风格S(Style)的图片,神经风格迁移可以让这两张图片结合,让原始图片具有图片S的风格。所以神经风格迁移可以解决的问题是:生成一张同时具有图片C的内容和图片S的风格的新图片G。比如下图:

要学习神经风格迁移算法,首先我们要知道什么是迁移学习。深度学习最强大的理念之一就是,神经网络可以从一个认为中习得知识,并将其应用或迁移到另一个独立的任务中,这就是迁移学习。神经风格迁移算法是在预训练好的模型上完成的,本文使用的是VGG-19网络结构,是一个已经训练好的图像分类网络,已经能够很好的识别图像的低级特征和高级特征。

代价函数

定义代价函数J(G)可以用来评定生成图片的好坏。代价函数由两部分组成,一个是内容代价函数J_{content}(C, G),用于衡量生成图片和原始图片的相似程度,另一个是风格代价函数J_{style}(S, G),用于衡量生成图片的风格和风格图片的风格相似程度,使用​超参数来确定内容代价和风格代价的权重,代价函数的公式如下:

J(G) = \alpha J_{content}(C, G) + \beta J_{style}(S, G)

神经风格迁移算法的步骤:

  • 随机初始化特定尺寸的生成图像G

  • 使用梯度下降算法最小化J(G),以找到最合适的图像G

  • 更新图像G:G = G - \frac{\partial}{\partial(G)}J(G)

内容代价函数

确定一个隐层l,使用预训练好的卷积神经网络,分别输入内容图像和生成图像,得到两个第隐层的激活块,然后使用L2范数的平方计算两个激活块中对应项的距离并求和。对使用梯度下降法时,会激励这个算法找到合适的图像G,使得图像G在第​l层的激活值和内容图像相似。

风格代价函数

图片风格定义为:第​l层各通道间激活项的相关系数,而相关系数是用于测量不同特征在图片中各位置同时出现或不同时出现的概率。用相关系数描述激活块中不同通道的风格

如何计算图像各通道的相关系数?

  • 取第l隐层的激活矩阵,有如下定义:a_{i,j,k}^{[l]}表示第k个通道的(i,j)激活项,G^{[l]}n_c^{[l]} * n_c^{[l]}大小的风格矩阵
  • 求风格图像的风格矩阵: G_{kk^\prime}^{(S)} = \sum_{i=1}^{n_H^{[l]}}\sum_{j=1}^{n_W^{[l]}}a_{i,j,k}^{(S)}a_{i,j,k^{\prime}}^{(S)}
  • 求生成图像的风格矩阵:G_{kk^\prime}^{(G)} = \sum_{i=1}^{n_H^{[l]}}\sum_{j=1}^{n_W^{[l]}}a_{i,j,k}^{(G)}a_{i,j,k^{\prime}}^{(G)}

风格代价函数,使用L2范数的平方计算两个风格矩阵对应项的距离:

J_{style}^{[l]} = ||G^{(S)}-G^{(G)}||_2^{2}

代码实现

请先导入一下模块或函数:

from keras.models import Model
from keras.applications import VGG19
from keras.preprocessing import image
from keras.optimizers import Adam

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import time

确定内容输出和风格输出的中间层

首先导入VGG-19网络:

from keras.applications import VGG19

vgg = VGG19(include_top=False)

查看网络所有层的名字

for layer in vgg.layers:
    print(layer.name)


输出内容:
input_2
block1_conv1
block1_conv2
block1_pool
block2_conv1
block2_conv2
block2_pool
block3_conv1
block3_conv2
block3_conv3
block3_conv4
block3_pool
block4_conv1
block4_conv2
block4_conv3
block4_conv4
block4_pool
block5_conv1
block5_conv2
block5_conv3
block5_conv4
block5_pool

选择网络某些中间层作为内容输出和风格输出

content_layers = ['block5_conv2'] # 被选择作为内容输出的中间层

style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1',
                'block4_conv1',
                'block5_conv1'] # 被选择为风格输出的中间层

构建模型

创建VGG19模型,返回以中间层作为输出的模型

def vgg_layers(layer_names):
    '''
        创建VGG19模型,返回以中间层作为输出的模型
    :param layer_names: 层名字列表
    :return:
    '''
    vgg = VGG19(include_top=False)

    # 中间层的输出结构列表
    outputs = [vgg.get_layer(name).output for name in layer_names]

    model = Model([vgg.input], outputs)

    return model

提取内容和特征

先定义一个函数,用于求特征矩阵的风格矩阵:

def gram_matrix(input):
    '''

    :param input: 特征矩阵
    :return: 返回风格矩阵
    '''
    result = tf.linalg.einsum("bijc,bijd -> bcd", input, input)
    return result

构建内容和风格提取模型,调用该模型的实例,可以返回图像在VGG19网络中间层的风格矩阵和特征矩阵: 

class StyleContentModel(Model):
    def __init__(self, style_layers, content_layers):
        '''
        :param style_layers: 被选择的中间层,作为风格图像输出
        :param content_layers: 被选择的中间层,作为内容图像的输出
        '''
        super(StyleContentModel, self).__init__()
        self.vgg = vgg_layers(style_layers + content_layers) # 这样模型的输出成为包含所有被选择的中间层输出的列表。
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)

    def call(self, inputs):
        '''
            在类实例被调用时,该函数会执行
        :param inputs: 输入模型的图像数据
        :return:  返回字典,包含图像内容输出的列表和风格输出的列表
        '''
        # 输入图像给模型,得到被选择的中间层输出
        outputs = self.vgg(inputs)

        # 分割输出值,得到风格输出值和内容输出值
        style_outputs, content_outputs = outputs[:self.num_style_layers], outputs[self.num_style_layers:]

        # 根据风格输出值计算其风格矩阵
        style_outputs = [gram_matrix(output) for output in style_outputs]

        # 对于内容输出,将层的名字及其对应的输出封装成字典
        content_dict = {content_name: value for (content_name, value) in zip(self.content_layers, content_outputs)}

        # 对于风格输出,将层的名字及其对应的输出封装成字典
        style_dict = {style_name: value for (style_name, value) in zip(self.style_layers, style_outputs)}

        return {"style": style_dict, "content": content_dict}

计算总代价

计算总代价包括计算内容代价和风格代价

def style_content_loss(outputs, style_target, content_target, style_weight, content_weight):
    style_outputs = outputs["style"]
    content_outputs = outputs["content"]

    # 计算风格损失
    style_loss = tf.add_n(tf.reduce_sum(tf.square(style_outputs[name] - style_target[name])) * style_weight for name in style_layers)
    style_loss = style_loss / len(style_layers)

    # 计算内容损失
    content_loss = tf.add_n(tf.reduce_sum(tf.square(content_outputs[name] - content_target[name])) * content_weight for name in content_layers)
    content_loss = content_loss / len(content_layers)

    # 计算总体损失
    loss = style_loss + content_loss

    return loss

使用优化算法寻找最优图像G

定义模型,并获取内容图像的输出值和风格图像的风格矩阵:

# 定义模型
    model = StyleContentModel(style_layers, content_layers)

    # 获取内容图像的内容输出和风格图像的风格输出
    content_target = model(content)["content"]
    style_target = model(style)["style"]

初始化生成图像,尺寸要和原图像一致:

# 初始化生成图像,尺寸和原图像要一致
generate = tf.Variable(content)

前向传播计算输出值,反向传播计算梯度值:

# 进行epochs次训练
    for epoch in range(epochs):

        with tf.GradientTape() as tape:
            outputs = model(generate)
            loss = style_content_loss(outputs, style_target, content_target, style_weight, content_weight)
        print("第{}代, 损失值为:{}".format(str(epoch), str(loss)))
        # 定义Adam优化算法
        optimizer = Adam(learning_rate, beta1, beta2, epsilon)

        # 使用优化算法最小化成本函数,以找到最优生成图像G
        grad = tape.gradient(loss, generate)
        optimizer.apply_gradients([(grad, generate)])

        # 限制图像的像素值在[0,1]
        generate.assign(tf.clip_by_value(generate, clip_value_min=0.0, clip_value_max=1.0))

将上诉操作封装成一个函数,方便同时执行:

def model(content, style, style_weight=1e-2, content_weight=1e4, learning_rate=0.05, beta1=0.9, beta2=0.999, epsilon=1e-8, epochs=10):
    '''

    :param content: 内容图像
    :param style: 风格图像
    :param style_weight: 风格损失函数权重
    :param content_weight: 内容损失函数权重
    :param learning_rate: 学习率
    :param beta1: 梯度的加权平均
    :param beta2: 梯度的平方加权平均
    :param epsilon: 防止除0而加上的参数
    :param epochs: 迭代次数
    :return:  返回生成图像 generate
    '''
    # 定义模型
    model = StyleContentModel(style_layers, content_layers)

    # 获取内容图像的内容输出和风格图像的风格输出
    content_target = model(content)["content"]
    style_target = model(style)["style"]

    # 初始化生成图像,尺寸和原图像要一致
    generate = tf.Variable(content)

    # 进行epochs次训练
    for epoch in range(epochs):

        with tf.GradientTape() as tape:
            outputs = model(generate)
            loss = style_content_loss(outputs, style_target, content_target, style_weight, content_weight)
        print("第{}代, 损失值为:{}".format(str(epoch), str(loss)))
        # 定义Adam优化算法
        optimizer = Adam(learning_rate, beta1, beta2, epsilon)

        # 使用优化算法最小化成本函数,以找到最优生成图像G
        grad = tape.gradient(loss, generate)
        optimizer.apply_gradients([(grad, generate)])

        # 限制图像的像素值在[0,1]
        generate.assign(tf.clip_by_value(generate, clip_value_min=0.0, clip_value_max=1.0))

    return generate

参考:

  1. 【中英】【吴恩达课后编程作业】Course 4 -卷积神经网络 - 第四周作业_hsck.cc_何宽的博客-CSDN博客
  2. 神经风格迁移  |  TensorFlow Core

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/761727.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kafka 入门到起飞系列 - 磁盘存储 -零拷贝

Redis 是 在内存存储数据的,数据读取时不要经过磁盘的IO,只需要内存的操作,这也是redis访问速度快的原因 Kafka背道而驰,Kafka 是在磁盘存储数据的,发送过来的数据交给Kafka后会落盘,消费者读取数据时&…

【C++11】function包装器和bind包装器的简单使用

function function 包装器一些场景下模板的低效性包装器 function 修复问题包装成员函数的注意事项一道例题function包装器的意义 bind 包装器bind 包装器介绍bind 包装器可调整传参顺序bind 包装器可绑定固定参数bind 包装器的意义 C11提供了多个包装器(wrapper,也…

BYOVD!干掉EDR/XDR/AVs进程工具

工具介绍 利用gmer驱动程序有效地禁止使用或杀死EDR和AV,它可以流畅地绕过HVCI;该样本来自 loldrivers:https://www.loldrivers.io/drivers/7ce8fb06-46eb-4f4f-90d5-5518a6561f15/ 关注【Hack分享吧】公众号,回复关键字【230614…

docker安装mariadb,并在宿主机连接docker中启动的mariadb

这篇文章主要介绍怎么在docker中安装一个mariadb数据库,然后在我们的电脑本机上连接虚拟机上docker运行的mariadb数据库。 首先,需要安装一个虚拟机软件,通过虚拟机软件安装一个linux操作系统,本篇文章安装的是ubuntu&#xff0c…

一、基础-3、MySQL卸载

1.、停止MySQL服务 winR 打开运行,输入 services.msc 点击 "确定" 调出系统服务。 2. 卸载MySQL相关组件 打开控制面板 ---> 卸载程序 ---> 卸载MySQL相关所有组件。 3. 删除MySQL安装目录 4. 删除MySQL数据目录 数据存放目录是在 C:\ProgramDat…

No.185# 技术管理框架知识点随记

引言 陆续参加了公司组织的两场关于技术管理的培训,时间一长也快忘的七七八八了。本文以刘建国《执行技术人管理之路》为基础框架,将知识点做了整理,在需要的时候翻翻。本文主要内容有: 技术管理之角色认知技术管理之管理规划技术…

【技能实训】DMS数据挖掘项目-Day11

文章目录 任务12【任务12.1】创建用户信息表【任务12.2】在com.qst.dms.entity下创建用户实体类User,以便封装用户数据【任务12.3】在com.qst.dms.service下创建用户业务类UserService【任务12.4】在项目根目录下创建图片文件夹images,存储dms.png【任务…

了解数据科学中的异常检测

大家好,本文将简要介绍一下异常检测,并指导通过不同的技术来识别异常。 如果你正在处理数据,那么无论是现在还是将来,都可能会遇到一项非常重要的任务 —— 异常检测。它在许多领域中都有很大的应用,如制造业、金融和…

visual studio 2017直接打开文件夹时,选择当前项目或者整个解决方案时,按快捷键查找时显示未找到以下指定文本

有的时候只想要打开一整个文件夹来看里面的代码,平时一般用Qt,但是感觉在打开整个文件夹看代码方面,Qt没有VS方便,于是选择了VS,安装的是VS2017,然后发现有个问题,CtrlF查找时,如果选…

报错:Invalid bound statement (not found): com.web.sysmgr.mapper.UserMapper.login

报错:Invalid bound statement (not found): com.web.sysmgr.mapper.UserMapper.login 原因: 确认是否在扫描Mapper接口时指定了正确的包路径。检查 MapperScan 注解中的包路径是否正确,确保只扫描到需要的Mapper接口。 如果在配置类中去配置…

JQuery 实现点击按钮添加 input 框

前言 用于记录开发中常用到的&#xff0c;快捷开发 需求 比如说&#xff0c;我台设备可以设置一个或多个秘钥&#xff0c;有时候我配置一个秘钥时&#xff0c;就不需要多个输入框&#xff0c;当我想配置多个秘钥时&#xff0c;就需要添加多个输入框。 实现 HTML <div…

Hadoop 之 HDFS 伪集群模式配置与使用(二)

HDFS 配置与使用 一.HDFS配置二.HDFS Shell1.默认配置说明2.shell 命令 三.Java 读写 HDFS1.Java 工程配置2.测试 一.HDFS配置 ## 基于上一篇文章进入 HADOOP_HOME 目录 cd $HADOOP_HOME/etc/hadoop ## 修改文件权限 chown -R root:root /usr/local/hadoop/hadoop-3.3.6/* ## …

JVM 运行流程、类加载、垃圾回收

一、JVM 简介 1、JVM JVM 是 Java Virtual Machine 的简称&#xff0c;意为 Java 虚拟机。 虚拟机是指通过软件模拟的具有完整硬件功能的、运行在一个完全隔离的环境中的完整计算机系统。 常见的虚拟机&#xff1a;JVM、VMwave、Virtual Box。 JVM 和其他两个虚拟机的区别…

Android Profiler 内存分析器使用

Android Profiler是Android Studio的一部分&#xff0c;提供了一个集成的性能分析工具套件&#xff0c;包括内存分析。Android Profiler 工具可提供实时数据&#xff0c;帮助您了解应用的 CPU、内存、网络和电池资源使用情况。 在Android Profiler中&#xff0c;您可以查看内存…

赋能安防“新视界”!智汇云舟亮相中国安防工程商集成商大会

7月14日&#xff0c;备受业界关注的中国安防工程商&#xff08;系统集成商&#xff09;大会暨第67届中国安防新产品、新技术成果展示在上海盛大开幕。来自上海、苏州、南京、无锡等城市的200余位行业领导、嘉宾莅临参会&#xff0c;智汇云舟副总裁陈虹旭受邀出席活动并发表《视…

2、Redis高级特性和应用(发布 订阅、Stream)

Redis高级特性和应用(发布 订阅、Stream) 发布和订阅 Redis提供了基于“发布/订阅”模式的消息机制&#xff0c;此种模式下&#xff0c;消息发布者和订阅者不进行直接通信,发布者客户端向指定的频道( channel)发布消息&#xff0c;订阅该频道的每个客户端都可以收到该消息。 …

【云原生|Docker系列第3篇】Docker镜像的入门实践

欢迎来到Docker入门系列的第三篇博客&#xff01;在前两篇博客中&#xff0c;我们已经了解了什么是Docker以及如何安装和配置它。本篇博客将重点介绍Docker镜像的概念&#xff0c;以及它们之间的关系。我们还将学习如何拉取、创建、管理和分享Docker镜像&#xff0c;这是使用Do…

链表OJ(LeetCode)

文章目录 1.移除链表元素2.反转链表3.链表的中间结点4.倒数第k个结点5.合并两个有序链表6.链表分割7.链表的回文结构8.相交链表9.环形链表10.环形链表Ⅱ1.常规思路2.新型思路【无码】 1.移除链表元素 法一&#xff1a;遍历删除 struct ListNode {int val;struct ListNode* nex…

采集极验4滑块验证码图片数据

在网络安全领域&#xff0c;验证码是一种常见的用于验证用户身份或防止恶意机器人攻击的技术。而极验4滑块验证码作为一种广泛应用的验证码形式&#xff0c;其具有较高的安全性和防御能力。本文将以获取极验4滑块验证码图片数据为主题&#xff0c;介绍相关技术和方法。 一、极…

【Jenkins入门到实战】忽如一夜春风来,千树万树梨花开

自动化运维之Jenkins 前提条件&#xff1a;安装好jdk &#xff08;版本要求11-17&#xff09;并配置好环境变量 一、Jenkins 1、Jenkins是什么 Jenkins是一个开源的持续集成服务&#xff0c;用于实施软件开发和发布流程。它帮助软件开发和运维团队在构建、测试和部署软件上实…