[oneAPI] Neural Style Transfer

news2024/12/24 10:14:05

[oneAPI] Neural Style Transfer

  • oneAPI
  • Neural Style Transfer
    • 特殊环境
    • 定义使用包
    • 加载数据
    • Neural Style Transfer模型与介绍
    • 训练过程
    • 结果

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

oneAPI

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
vgg = VGGNet().to(device).eval()

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
使用Intel Extension for PyTorch中, 实际上在推理模式下是不需要优化器的
'''
vgg = ipex.optimize(vgg)

Neural Style Transfer

Neural Style Transfer是一种使用 CNN 将一幅图像的内容与另一幅图像的风格相结合的算法。给定内容图像和风格图像,目标是生成最小化与内容图像的内容差异和与风格图像的风格差异的目标图像。
在这里插入图片描述
内容丢失

为了最小化内容差异,我们将内容图像和目标图像分别前向传播到预训练的VGGNet,并从多个卷积层中提取特征图。然后,更新目标图像以最小化内容图像的特征图与其特征图之间的均方误差。

风格丧失

与计算内容损失一样,我们将风格图像和目标图像前向传播到 VGGNet 并提取卷积特征图。为了生成与风格图像的风格相匹配的纹理,我们通过最小化风格图像的 Gram 矩阵和目标图像的 Gram 矩阵之间的均方误差来更新目标图像(特征相关性最小化)。请参阅此处了解如何计算风格损失。

特殊环境

本实验:借助PyTorch以及Intel® Optimization for PyTorch,对PyTorch进行了精心的优化与扩展,极大地提升了其性能,特别是在英特尔硬件上的表现更加卓越。这一优化策略使得我们的模型在训练和推断过程中变得更加迅捷高效,显著缩短了计算时间,提升了整体效率。并通过深度融合硬件与软件的精巧设计,有效地解锁了硬件潜力,让模型的训练和应用变得更加快速高效,为人工智能应用带来了全新的可能性。
在这里插入图片描述

数据集使用自己收集的一些数据

在这里插入图片描述
在这里插入图片描述

content.png 表面原始的图片
在这里插入图片描述
style.png表示需要将原始图片转化为的风格

定义使用包

from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

import intel_extension_for_pytorch as ipex

# Device configuration
device = torch.device('xpu' if torch.cuda.is_available() else 'cpu')

加载数据

def load_image(image_path, transform=None, max_size=None, shape=None):
    """Load an image and convert it to a torch tensor."""
    image = Image.open(image_path)

    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.LANCZOS)

    if shape:
        image = image.resize(shape, Image.LANCZOS)

    if transform:
        image = transform(image).unsqueeze(0)

    return image.to(device)

Neural Style Transfer模型与介绍

VGGNet是一个经典的深度卷积神经网络架构,由牛津大学的研究团队提出,用于图像分类和识别任务。VGGNet以其简单而有效的结构在计算机视觉领域取得了显著的成就,成为了深度学习研究的重要里程碑之一。

VGGNet的特点在于其深层的网络结构,通过多个小尺寸的卷积核和池化层的堆叠,达到了很强的特征提取能力。其标准结构包括数个卷积层,之后是池化层,最后是全连接层。

VGGNet在图像分类竞赛中取得了优异的表现,其简单的结构和深层次的特征提取使得它成为了其他网络架构的基础。然而,由于其深层次的结构,VGGNet在计算资源和训练时间上需要较大代价,后续的研究逐渐提出了更加高效的网络架构,如ResNet和Inception等。

而对于本任务,我们使用原图,目标图,风格图的’0’, ‘5’, ‘10’, ‘19’, '28’等层的特征进行对比,最后让原图和风格图内容对应,而目标图与风格图的风格相似

class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28']
        self.vgg = models.vgg19(pretrained=True).features

    def forward(self, x):
        """Extract multiple convolutional feature maps."""
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

训练过程

def main(config):
    # Image preprocessing
    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # We use the same normalization statistics here.
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406),
                             std=(0.229, 0.224, 0.225))])

    # Load content and style images
    # Make the style image same size as the content image
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])

    # Initialize a target image with the content image
    target = content.clone().requires_grad_(True)

    optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()

    '''
    Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
    使用Intel Extension for PyTorch中, 实际上在推理模式下是不需要优化器的
    '''
    vgg = ipex.optimize(vgg)

    for step in range(config.total_step):

        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss with target and content images
            content_loss += torch.mean((f1 - f2) ** 2)

            # Reshape convolutional feature maps
            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            # Compute gram matrix
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            # Compute style loss with target and style images
            style_loss += torch.mean((f1 - f3) ** 2) / (c * h * w)

            # Compute total loss, backprop and optimize
        loss = content_loss + config.style_weight * style_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step + 1) % config.log_step == 0:
            print('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}'
                  .format(step + 1, config.total_step, content_loss.item(), style_loss.item()))

        if (step + 1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{}.png'.format(step + 1))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='png/content.png')
    parser.add_argument('--style', type=str, default='png/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=2000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)
    config = parser.parse_args()
    print(config)
    main(config)

结果

在这里插入图片描述

迭代结果图

在这里插入图片描述

训练过程图

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/890768.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

1609.奇偶数

目录 一、题目 二、代码 三、完整测试代码 一、题目 1609. 奇偶树 - 力扣(LeetCode) 二、代码 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0),…

【大数据Hive】hive 事务表使用详解

目录 一、前言 二、Hive事务背景知识 hive事务实现原理 hive事务原理之 —— delta文件夹命名格式 _orc_acid_version 说明 bucket_00000 合并器(Compactor) 二、Hive事务使用限制 参数设置 客户端参数设置 客户端参数设置 三、Hive事务使用操作演示 操作步骤 客…

电脑msvcr120.dll丢失怎么修复,msvcr120.dll怎么安装?

msvcr120.dll是Microsoft Visual C Redistributable的一部分,它是Windows操作系统中的一个动态链接库文件。这个文件包含了一些用于C编程的函数和资源,它们被许多应用程序用于提供特定的功能和服务。如果你在运行某个程序时遇到了缺少msvcr120.dll的错误…

AlexNet中文翻译

ImageNet classification with deep convolutional neural networks 原文链接:https://dl.acm.org/doi/abs/10.1145/3065386 目录 使用深度卷积神经网络进行 ImageNet 分类 摘要 1 简介 2 数据集 3 架构 3.1 ReLU非线性 3.2 多GPU上的训练 3.3 局部响应标准化 3.4 重…

centos安装elasticsearch7.9

安装es 下载elasticsearch安装包解压安装包,并修改配置文件解压进入目录修改配置文件 添加用户,并修改所有者切换用户,运行es如何迁移旧版本的数据 下载elasticsearch安装包 下载地址如下,版本号可以替换成自己想要的。 这里需要注意一点&am…

讯飞星火、文心一言和通义千问同时编“贪吃蛇”游戏,谁会胜出?

同时向讯飞星火、文心一言和通义千问三个国产AI模型提个相同的问题: “python 写一个贪吃蛇的游戏代码” 看哪一家AI写的程序直接能用,谁就胜出! 讯飞星火 讯飞星火给出的代码: import pygame import sys import random# 初…

上海亚商投顾盘:沪指震荡反弹 机器人概念股掀涨停潮

上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。 市场情绪 三大指数今日震荡反弹,科创50盘中涨超1%。机器人概念股掀涨停潮,通力科技、昊志机电、哈焊…

java接口导出csv

1、背景介绍 项目中需要导出数据质检结果,本来使用Excel,但是质检结果数据行数过多,导致用hutool报错,因此转为导出csv格式数据。 2、参考文档 https://blog.csdn.net/ityqing/article/details/127879556 工程环境:…

Spring Clould 网关 - Gateway

视频地址:微服务(SpringCloudRabbitMQDockerRedis搜索分布式) Gateway网关-网关作用介绍(P35) Spring Cloud Gateway 是 Spring Cloud 的一个全新项目,该项目是基于 Spring 5.0,Spring Boot 2…

The coming up production issues

Introduction Ladies and gentlemen, give it up for the wonderful world of software production ! Now, I know that what youre thinking. "Software production!?" That sounds exciting, well, let me tell you, its a rollercoaster(过山车、剧烈起伏的事物…

【第六讲---非线性优化】

优化与优化库 优化问题 👉优化问题组成 优化对象目标函数/损失函数/评价函数约束条件 👉分类 可以分为凸优化和非凸优化 什么是凸优化呢? 目标函数是凸的(有单一极值点称为是凸的)不等式约束是凸的所在的空间是凸…

容器和云原生(二):Docker容器化技术

目录 Docker容器的使用 Docker容器关键技术 Namespace Cgroups UnionFS Docker容器的使用 首先直观地了解docker如何安装使用,并快速启动mysql服务的,启动时候绑定主机上的3306端口,查找mysql容器的ip,使用mysql -h contain…

【脚踢数据结构】图(纯享版)

(꒪ꇴ꒪ ),Hello我是祐言QAQ我的博客主页:C/C语言,Linux基础,ARM开发板,软件配置等领域博主🌍快上🚘,一起学习,让我们成为一个强大的攻城狮!送给自己和读者的…

Pytest和Unittest测试框架的区别?

如何区分这两者,很简单unittest作为官方的测试框架,在测试方面更加基础,并且可以再次基础上进行二次开发,同时在用法上格式会更加复杂;而pytest框架作为第三方框架,方便的地方就在于使用更加灵活&#xff0…

Python Django 模型概述与应用

今天来为大家介绍 Django 框架的模型部分,模型是真实数据的简单明确的描述,它包含了储存的数据所必要的字段和行为,Django 遵循 DRY Principle 。它的目标是你只需要定义数据模型,然后其它的杂七杂八代码你都不用关心,…

龙迅LT9711 2PORT MIPI或者LVDS转TYPE-C

LT9711 1.描述: Lontium LT9711是双端口MIPI/LVDS到DP1.2转换器,内部有c型替代模式开关和PD控制器。MIPI DSI/CSI输入具有可配置的单端口或双端口,具有1个时钟通道,1个~4个数据通道,最大运行2Gbps/通道,可…

2023 年值得关注的 8 个最佳免费开发者工具

开发者工具对开发人员的重要性不言而喻,保持最新工具的更新可以显著提高你的工作效率并简化您的工作流程。随着技术的快速发展,新的开发工具不断被引入市场。今天,我们将分享 2023 年你值得关注的最新开发者工具。 1.Plaky Plaky 是一种基于…

JVM——JDK 监控和故障处理工具总结

文章目录 JDK 命令行工具jps:查看所有 Java 进程jstat: 监视虚拟机各种运行状态信息 jinfo: 实时地查看和调整虚拟机各项参数jmap:生成堆转储快照**jhat**: 分析 heapdump 文件**jstack** :生成虚拟机当前时刻的线程快照 JDK 可视化分析工具JConsole:Java 监视与管理控制台连接…

司徒理财:8.17黄金反弹遇阻,1900现价空!

黄金趋势下跌,现在反弹遇阻,继续空!除非行情强势站上1907位置,否则还是空头下跌走势,反弹遇阻直接空!      黄金从走势上看,一直阴跌,并且在昨日加速下行!现在黄金受…

ubuntu18.04从0到1在ros上跑yolo5

ubuntu篇---ubuntu20.04安装cuda和cudnn_ubuntu安装cudann_心惠天意的博客-CSDN博客操作系统环境:Ubuntu 20.041. 安装N卡驱动首先我们需要添加源,sudo add-apt-repository ppa:graphics-drivers/ppasudo apt update然后检查可以安装的驱动版本&#xff…