Pytorch Advanced(三) Neural Style Transfer

news2024/11/16 18:28:10

神经风格迁移在之前的博客中已经用keras实现过了,比较复杂,keras版本。

这里用pytorch重新实现一次,原理图如下:


from __future__ import division
from torchvision import models
from torchvision import transforms
from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

加载图像

def load_image(image_path, transform=None, max_size=None, shape=None):
    """Load an image and convert it to a torch tensor."""
    image = Image.open(image_path)
    
    if max_size:
        scale = max_size / max(image.size)
        size = np.array(image.size) * scale
        image = image.resize(size.astype(int), Image.ANTIALIAS)
    
    if shape:
        image = image.resize(shape, Image.LANCZOS)
    
    if transform:
        image = transform(image).unsqueeze(0)
    
    return image.to(device)

这里用的模型是 VGG-19,所要用的是网络中的5个卷积层

class VGGNet(nn.Module):
    def __init__(self):
        """Select conv1_1 ~ conv5_1 activation maps."""
        super(VGGNet, self).__init__()
        self.select = ['0', '5', '10', '19', '28'] 
        self.vgg = models.vgg19(pretrained=True).features
        
    def forward(self, x):
        """Extract multiple convolutional feature maps."""
        features = []
        for name, layer in self.vgg._modules.items():
            x = layer(x)
            if name in self.select:
                features.append(x)
        return features

 模型结构如下,可以看到使用序列模型来写的VGG-NET,所以标号即层号,我们要保存的是['0', '5', '10', '19', '28'] 的输出结果。

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU(inplace)
    (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace)
    (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (24): ReLU(inplace)
    (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (26): ReLU(inplace)
    (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace)
    (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (31): ReLU(inplace)
    (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (33): ReLU(inplace)
    (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (35): ReLU(inplace)
    (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace)
    (2): Dropout(p=0.5)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace)
    (5): Dropout(p=0.5)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)

 训练:

接下来对训练过程进行解释:

1、加载风格图像和内容图像,我们在之前的博客中使用的一幅加噪图进行训练,这里是用的内容图像的拷贝。

2、我们需要优化的就是作为目标的内容图像拷贝,可以看到target需要求导。

3、VGGnet参数是不需要优化的,所以设置为验证状态。

4、将3幅图像输入网络,得到总共15个输出(每个图像有5层的输出)

5、内容损失:这里是遍历5个层的输出来计算损失,而在keras版本中只用了第4层的输出计算损失

6、风格损失:同样计算格拉姆风格矩阵,将每一层的风格损失叠加,得到总的风格损失,计算公式同样和keras版本有所不一样

7、反向传播

def main(config):
    
    # Image preprocessing
    # VGGNet was trained on ImageNet where images are normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
    # We use the same normalization statistics here.
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.485, 0.456, 0.406), 
                             std=(0.229, 0.224, 0.225))])
    
    # Load content and style images
    # Make the style image same size as the content image
    content = load_image(config.content, transform, max_size=config.max_size)
    style = load_image(config.style, transform, shape=[content.size(2), content.size(3)])
    
    # Initialize a target image with the content image
    target = content.clone().requires_grad_(True)
    
    optimizer = torch.optim.Adam([target], lr=config.lr, betas=[0.5, 0.999])
    vgg = VGGNet().to(device).eval()
    
    for step in range(config.total_step):
        
        # Extract multiple(5) conv feature vectors
        target_features = vgg(target)
        content_features = vgg(content)
        style_features = vgg(style)

        style_loss = 0
        content_loss = 0
        for f1, f2, f3 in zip(target_features, content_features, style_features):
            # Compute content loss with target and content images
            content_loss += torch.mean((f1 - f2)**2)

            # Reshape convolutional feature maps
            _, c, h, w = f1.size()
            f1 = f1.view(c, h * w)
            f3 = f3.view(c, h * w)

            # Compute gram matrix
            f1 = torch.mm(f1, f1.t())
            f3 = torch.mm(f3, f3.t())

            # Compute style loss with target and style images
            style_loss += torch.mean((f1 - f3)**2) / (c * h * w) 
        
        # Compute total loss, backprop and optimize
        loss = content_loss + config.style_weight * style_loss 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (step+1) % config.log_step == 0:
            print ('Step [{}/{}], Content Loss: {:.4f}, Style Loss: {:.4f}' 
                   .format(step+1, config.total_step, content_loss.item(), style_loss.item()))

        if (step+1) % config.sample_step == 0:
            # Save the generated image
            denorm = transforms.Normalize((-2.12, -2.04, -1.80), (4.37, 4.46, 4.44))
            img = target.clone().squeeze()
            img = denorm(img).clamp_(0, 1)
            torchvision.utils.save_image(img, 'output-{}.png'.format(step+1))

写在if __name__=="__main__"后面的语句只会在本脚本中才能被执行,被调用时是不会被执行的。 

python的命令行工具:argparse,很优雅的添加参数

但是由于jupyter不支持添加外部参数,所以使用了外部博客的方法来支持(记住更改读取图片的位置)

import sys
if __name__ == "__main__":
    
    #解决方案来自于博客
    if '-f' in sys.argv:
        sys.argv.remove('-f')
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--content', type=str, default='png/content.png')
    parser.add_argument('--style', type=str, default='png/style.png')
    parser.add_argument('--max_size', type=int, default=400)
    parser.add_argument('--total_step', type=int, default=2000)
    parser.add_argument('--log_step', type=int, default=10)
    parser.add_argument('--sample_step', type=int, default=500)
    parser.add_argument('--style_weight', type=float, default=100)
    parser.add_argument('--lr', type=float, default=0.003)
    #config = parser.parse_args()
    config = parser.parse_known_args()[0]   #参考博客 https://blog.csdn.net/ken_for_learning/article/details/89675904
    print(config)
    main(config)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1002354.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

金蝶云星空和四化智造MES(WEB)单据接口对接

金蝶云星空和四化智造MES(WEB)单据接口对接 接入系统:四化智造MES(WEB) MES建立统一平台上通过物料防错防错、流程防错、生产统计、异常处理、信息采集和全流程追溯等精益生产和精细化管理,帮助企业合理安排…

Linux中安装MySQL_图解_2023新

1.卸载 为了避免不必要的错误发生,先将原有的文件包进行查询并卸载 // 查询 rpm -qa | grep mysql rpm -qa | grep mari// 卸载 rpm -e 文件名 --nodeps2.将安装包上传到指定文件夹中 这里采用的是Xftp 3.将安装包进行解压 tar -zxvf 文件名 -C 解压路径4.获取解压的全路…

春秋云镜 CVE-2015-9331

春秋云镜 CVE-2015-9331 wordpress插件 WordPress WP All Import plugin v3.2.3 任意文件上传 靶标介绍 wordpress插件 WordPress WP All Import plugin v3.2.3 存在任意文件上传,可以上传shell。 启动场景 漏洞利用 exp #/usr/local/bin/python3 # -*-coding:…

基础设施SIG月度动态:「龙蜥大讲堂」基础设施系列专题分享完美收官,容器镜像构建 2.0 版本上线

基础设施 SIG(OpenAnolis Infra SIG)目标:负责 OpenAnolis 社区基础设施工程平台的建设,包括官网、Bugzilla、Maillist、ABS、ANAS、CI 门禁以及社区 DevOps 相关的研发工程系统。 01 SIG 整体进展 1. 龙蜥大讲堂 - 基础设施系…

mac 本地运行 http-proxy-middleware ,请求超时

const http require(http)"/customer": {target: "http://10.10.111.192:8080/",// target: "http://user.jinfu.baohan.com/",changeOrigin: true, // 是否启用跨域// 解决mac 代理超时问题headers: {Connection: "keep-alive"},// …

机器学习(10)---特征选择

文章目录 一、概述二、Filter过滤法2.1 过滤法说明2.2 方差过滤2.3 方差过滤对模型影响 三、相关性过滤3.1 卡方过滤3.2 F检验3.3 互信息法3.4 过滤法总结 四、Embedded嵌入法4.1 嵌入法说明4.2 以随机森林为例的嵌入法 五、Wrapper包装法5.1 包装法说明5.2 以随机森林为例的包…

事件处理机制

前面介绍了如何放置各种组件,从而得到了丰富多彩的图形界面,但这些界面还不能响应用户的任何操作。比如单击前面所有窗口右上角的“X”按钮,但窗口依然不会关闭。因为在 AWT 编程中 ,所有用户的操作,都必须都需要经过一…

025-从零搭建微服务-文件服务(一)

写在最前 如果这个项目让你有所收获,记得 Star 关注哦,这对我是非常不错的鼓励与支持。 源码地址(后端):https://gitee.com/csps/mingyue 源码地址(前端):https://gitee.com/csps…

thinkphp5.0 composer 安装oss提示php版本异常

场景复现: 本地 phpstudy 环境,安装的有7.0到7.3三个版本,首先确认composer已经安装 composer安装阿里云oss的命令为:composer require aliyuncs/oss-sdk-php 运行报错: Problem 1- Root composer.json requires php…

电机故障数据集

1.电机常见的故障类型有以下几种: 轴承故障:轴承是电机运转时最容易受损的部件之一。常见故障包括磨损、疲劳、过热和润滑不良,这些问题可能导致噪音增加和电机性能下降。 绝缘老化:电机绝缘材料随着使用时间的增加会老化&#x…

微服务·数据一致-seata

微服务数据一致-seata 概述 Seata(Simple Extensible Autonomous Transaction Architecture)是一个开源的分布式事务解决方案,旨在帮助应用程序分布式事务管理的挑战。Seata提供了一套全面的工具和框架,可用于实现跨多个数据库和…

Nginx+Tomcat(多实例)实现动静分离和负载均衡

一、Tomcat 多实例部署 1.在安装好jdk环境后,添加两例tomcat服务 #解压安装包 cd /opt tar zxvf apache-tomcat-9.0.16.tar.gz#移动并复制一例 mkdir /usr/local/tomcat mv apache-tomcat-9.0.16 /usr/local/tomcat/tomcat1 cp -a /usr/local/tomcat/tomcat1 /usr…

常用JVM配置参数

在IDE的后台打印GC日志: 既然学习JVM,阅读GC日志是处理Java虚拟机内存问题的基础技能,它只是一些人为确定的规则,没有太多技术含量。 既然如此,那么在IDE的控制台打印GC日志是必不可少的了。现在就告诉你怎么打印。 …

Django03_Django基本配置

Django03_Django基本配置 3.1 整体概述 django项目创建后,在主应用中,会有一个settings.py文件,这个就是该项目的配置文件 settings文件包含Django安装的所有配置settings文件是一个包含模块级变量的python模块,所以该模块本身必…

解决nbsp;不生效的问题

代码块 {{title}} title:附 \xa0\xa0\xa0件,//或者 <span v-html"title"></span> title:附 件&#xff1a;,效果图

青骨申报|CSC管理信息平台使用指南

2023年青年骨干教师出国研修项目于9月10-25日网上报名&#xff0c;为此知识人网小编特转载最新版本的国家留学基金委&#xff08;CSC&#xff09;国家公派留学管理信息平台使用指南&#xff08;国内申请访学类&#xff09;&#xff0c;以方便申报者查阅。 提示&#xff1a;国家…

静态代理和动态代理笔记

总体分为: 1.静态代理: 代理类和被代理类需要实现同一个接口.在代理类中初始化被代理类对象.在代理类的方法中调 用被代理类的方法.可以选择性的在该方法执行前后增加功能或者控制访问 2.动态代理: 在程序执行过程中,实用JDK的反射机制,创建代理对象,并动态的指定要…

数字档案管理系统单机版功能

nhdeep数字档案管理系统&#xff0c;简化了档案库配置过程&#xff0c;内置标准著录项&#xff0c;点击创建新档案库后选择档案库类型为案卷库或一文一件库后&#xff0c;可立即使用此档案库&#xff1b; 支持添加额外的自定义著录项&#xff0c;支持批量数据导入&#xff0c;…

读高性能MySQL(第4版)笔记06_优化数据类型(上)

1. 良好的逻辑设计和物理设计是高性能的基石 1.1. 反范式的schema可以加速某些类型的查询&#xff0c;但同时可能减慢其他类型的查询 1.2. 添加计数器和汇总表是一个优化查询的好方法&#xff0c;但它们的维护成本可能很 1.3. 将修改schema作为一个常见事件来规划 2. 让事情…

JVM GC垃圾回收

一、GC垃圾回收算法 标记-清除算法 算法分为“标记”和“清除”阶段&#xff1a;标记存活的对象&#xff0c; 统一回收所有未被标记的对象(一般选择这种)&#xff1b;也可以反过来&#xff0c;标记出所有需要回收的对象&#xff0c;在标记完成后统一回收所有被标记的对象 。它…