【PyTorch】神经风格迁移项目

news2024/9/21 11:03:31

神经风格迁移中,取一个内容图像和一个风格图像,综合内容图像的内容和风格图像的艺术风格生成新的图像。

 

目录

准备数据

处理数据 

神经风格迁移模型

加载预训练模型 

定义损失函数

定义优化器

运行模型 


准备数据

创建data文件夹,放入一张内容图片(左),一张风格图片(右),分别命名为content和style

from PIL import Image
path2content= "./data/content.jpg"
path2style= "./data/style.jpg"
content_img = Image.open(path2content)
style_img = Image.open(path2style)

 

 

 

处理数据 

调用torchvision.transforms包中Resize、ToTensor和Normalize对图像进行预处理

import torchvision.transforms as transforms

h, w = 256, 384 
mean_rgb = (0.485, 0.456, 0.406)
std_rgb = (0.229, 0.224, 0.225)
transformer = transforms.Compose([
                    # 将图像缩放到指定大小
                    transforms.Resize((h,w)),  
                    # 将图像转换为张量
                    transforms.ToTensor(),
                    # 对图像进行标准化处理
                    transforms.Normalize(mean_rgb, std_rgb)])  

content_tensor = transformer(content_img)
print(content_tensor.shape, content_tensor.requires_grad)

style_tensor = transformer(style_img)
print(style_tensor.shape, style_tensor.requires_grad)

 

# 克隆content_tensor作为输入图像,并设置requires_grad为True,表示需要计算梯度
input_tensor = content_tensor.clone().requires_grad_(True)
print(input_tensor.shape, input_tensor.requires_grad)

import torch
from torchvision.transforms.functional import to_pil_image
# 将图像张量转换为所需PIL图像
def imgtensor2pil(img_tensor):
    # 克隆并分离图像张量
    img_tensor_c = img_tensor.clone().detach()
    # 将图像张量乘以标准RGB值
    img_tensor_c*=torch.tensor(std_rgb).view(3,1,1)
    # 将图像张量加上均值RGB值
    img_tensor_c+=torch.tensor(mean_rgb).view(3,1,1)
    # 将图像张量限制在0到1之间
    img_tensor_c = img_tensor_c.clamp(0,1)
    # 将图像张量转换为PIL图像
    img_pil=to_pil_image(img_tensor_c)
    # 返回PIL图像
    return img_pil

import matplotlib.pylab as plt
%matplotlib inline

plt.imshow(imgtensor2pil(content_tensor))
plt.title("content image");
plt.imshow(imgtensor2pil(style_tensor))
plt.title("style image");

 

 

神经风格迁移模型

 保持模型参数不变,更新模型的输入

加载预训练模型 

import torchvision.models as models
# 检查是否有可用的GPU,如果没有则使用CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载预训练的VGG19模型,并将其特征提取部分移动到指定的设备上,并将其设置为评估模式
model_vgg = models.vgg19(pretrained=True).features.to(device).eval()
# 将模型的所有参数设置为不需要梯度,即不进行反向传播
for param in model_vgg.parameters():
    param.requires_grad_(False)   
print(model_vgg)

 

定义损失函数

# 定义函数,获取模型中指定层的特征
def get_features(x, model, layers):
    # 创建一个空字典,用于存储特征
    features = {}
    # 遍历模型的所有子层
    for name, layer in enumerate(model.children()):
        # 将输入数据传入子层,得到输出数据
        x = layer(x)
        # 如果子层的名称在指定的层列表中
        if str(name) in layers:
            # 将输出数据存储到字典中,键为子层的名称
            features[layers[str(name)]] = x
    # 返回字典
    return features

# 定义函数,于计算gram矩阵
def gram_matrix(x):
    # 获取输入张量的维度
    n, c, h, w = x.size()
    # 将输入张量展平
    x = x.view(n*c, h * w)
    # 计算gram矩阵
    gram = torch.mm(x, x.t())
    return gram

import torch.nn.functional as F

# 定义函数,获取内容损失
def get_content_loss(pred_features, target_features, layer):
    # 获取目标特征
    target= target_features[layer]
    # 获取预测特征
    pred = pred_features [layer]
    # 计算均方误差损失
    loss = F.mse_loss(pred, target)
    return loss

# 定义函数,获取风格损失
def get_style_loss(pred_features, target_features, style_layers_dict):  
    # 初始化损失为0
    loss = 0
    # 遍历style_layers_dict中的每一层
    for layer in style_layers_dict:
        # 获取预测特征
        pred_fea = pred_features[layer]
        # 计算预测特征的gram矩阵
        pred_gram = gram_matrix(pred_fea)
        # 获取预测特征的shape
        n, c, h, w = pred_fea.shape
        # 获取目标特征的gram矩阵
        target_gram = gram_matrix (target_features[layer])
        # 计算当前层的损失
        layer_loss = style_layers_dict[layer] *  F.mse_loss(pred_gram, target_gram)
        # 将当前层的损失加到总损失中
        loss += layer_loss/ (n* c * h * w)
    # 返回总损失
    return loss
# 定义特征层字典,用于存储不同层的特征
feature_layers = {'0': 'conv1_1',
                  '5': 'conv2_1',
                  '10': 'conv3_1',
                  '19': 'conv4_1',
                  '21': 'conv4_2',  
                  '28': 'conv5_1'}

# 将内容张量增加一个维度,并将其移动到指定设备上
con_tensor = content_tensor.unsqueeze(0).to(device)

sty_tensor = style_tensor.unsqueeze(0).to(device)

# 获取内容张量的特征
content_features = get_features(con_tensor, model_vgg, feature_layers)

style_features = get_features(sty_tensor, model_vgg, feature_layers)
# 遍历content_features字典中的所有key
for key in content_features.keys():
    # 打印每个key对应的值的形状
    print(content_features[key].shape)

 

定义优化器

from torch import optim

# 克隆con_tensor,并设置requires_grad_为True,表示需要计算梯度
input_tensor = con_tensor.clone().requires_grad_(True)
# 使用Adam优化器,优化input_tensor,学习率为0.01
optimizer = optim.Adam([input_tensor], lr=0.01)

运行模型 

# 定义训练的轮数
num_epochs = 300
# 定义内容损失的权重
content_weight = 1e1
# 定义风格损失的权重
style_weight = 1e4
# 定义内容层
content_layer = "conv5_1"
# 定义风格层及其权重
style_layers_dict = { 'conv1_1': 0.75,
                      'conv2_1': 0.5,
                      'conv3_1': 0.25,
                      'conv4_1': 0.25,
                      'conv5_1': 0.25}

# 遍历每一轮
for epoch in range(num_epochs+1):
    # 梯度清零
    optimizer.zero_grad()
    # 获取输入特征
    input_features = get_features(input_tensor, model_vgg, feature_layers)
    # 获取内容损失
    content_loss = get_content_loss (input_features, content_features, content_layer)
    # 获取风格损失
    style_loss = get_style_loss(input_features, style_features, style_layers_dict)
    # 计算神经损失
    neural_loss = content_weight * content_loss + style_weight * style_loss
    # 反向传播
    neural_loss.backward(retain_graph=True)
    # 更新参数
    optimizer.step()
    
    # 每隔100轮打印一次损失
    if epoch % 100 == 0:
        print('epoch {}, content loss: {:.2}, style loss {:.2}'.format(
          epoch,content_loss, style_loss))

打印输出图片(左),对比原始内容图片(右)

plt.imshow(imgtensor2pil(input_tensor[0].cpu()));

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1977406.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

人工智能与大数据的融合:驱动未来的力量

人工智能与大数据的融合:驱动未来的力量 一、人工智能与大数据的概述二、人工智能与大数据在数据库中的融合三、实际应用案例四、未来发展方向总结 【纪录片】中国数据库前世今生 在数字化潮流席卷全球的今天,数据库作为IT技术领域的“活化石”&#xff…

16进制转换-系统架构师(三十九)

1、(软件架构设计->构件与中间件技术->构件标准)对象管理组织(OMG)基于CORBA基础设施定义了四种构件标准。其中,()状态信息是构件自身而不是由容器维护的。 A实体构件 B加工构件 C服务…

C++中lambda使用mutable关键字详解

C中lambda使用mutable关键字详解 在《C初学者指南-5.标准库(第二部分)–更改元素算法》中&#xff0c;讲“generate”算法时有下面这段代码&#xff1a; auto gen [i0]() mutable { i 2; return i; }; std::vector<int> v; v.resize(7,0); generate(begin(v)1, begin…

(南京观海微电子)——LCD OTP(烧录)介绍

OTP OTP只是一种存储数据的器件&#xff0c;全写:ONETIMEPROGRAM。 OTP目的&#xff1a;提高产品的一致性 客户端的接口不支持和我们自己的产品IC之间通信&#xff0c;即不支持写初始化&#xff0c;所以产品的电学功能以及光学特性需要固化在IC中&#xff0c;所以需要我们来进行…

青甘环线游记|day(1)|兰州

出发 下午1点&#xff0c;登机。航班经停万州&#xff0c;再到兰州。下图为飞机上拍的照片&#xff0c;不知道为什么窗户上有结晶的东西&#xff08;&#xff1f;&#xff09; 在飞机上拍的航线图&#xff0c;但是有点模糊。飞机上有提供午餐。4点左右到达万州 在飞机上好像…

08 Redis Set类型操作与使用场景

Redis Set类型操作与使用场景 一、Set类型操作 ​ Redis的Set结构与Java中的HashSet类似&#xff0c;可以看做是一个value为null的HashMap。因为也是一个hash表&#xff0c;因此具备与HashSet类似的特征&#xff1a; ​ 无序 ​ 元素不可重复 ​ 查找快 ​ 支持交集、并集…

Tomcat 8.5 下载、安装、启动及各种问题

&#x1f970;&#x1f970;&#x1f970;来都来了&#xff0c;不妨点个关注叭&#xff01; &#x1f449;博客主页&#xff1a;欢迎各位大佬!&#x1f448; 本期内容主要介绍 Tomcat 8 的安装&#xff0c;以及可能会遇到的问题 文章目录 1. Tomcat 安装2. 可能会遇到的问题2.…

pip‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。

重新设置一下环境变量。 注意&#xff0c;这里后面没有斜杠 我之前就是因为环境变量中&#xff0c;这两行最后都有斜杠&#xff0c;导致提示pip‘ 不是内部或外部命令,也不是可运行的程序 或批处理文件。

Multi-AP

1. Multiple-BSSID 和Multi-VAP Multiple-BSSID 和Multi-VAP差异&#xff1a; Multi-VAP&#xff1a; 每个AP独自发送beacon帧&#xff1b; Multiple-BSSID&#xff1a; 所有AP公用一个beacon帧。 1.1 Multi-VAP 如果您使用过 Wi-Fi &#xff08;2.4/5.0GHz&#xff09;&am…

著名ROM修改社区停止运营 管理员与继任者互相指责

运营近二十年的知名ROM修改社区网站Romhacking.net即将关闭新内容的提交和更新。网站创始人Nightcrawler表示&#xff0c;网站已经“几乎完成了最初设定的所有目标&#xff0c;并且远远超出了预期。”然而&#xff0c;根据其他网站工作人员的说法&#xff0c;事情似乎没那么简单…

C++ 重要特性探究

shared_from_this 使用分析 场景 类的成员函数需要获取指向自身的shared_ptr的时候类成员函数传递shared_ptr给其他函数或者对象的时候&#xff0c;目的是为了管理对象生命周期使用方法 首先类必须继承 std::enable_shared_from_this<T>必须使用 shared_from_this 获取指…

Arduino PID库 (2) –微分导致的过冲

Arduino PID库 &#xff08;2&#xff09; – Derivative Kick 参考&#xff1a;手把手教你看懂并理解Arduino PID控制库——微分冲击 pid内容索引-CSDN博客 Arduino PID库 &#xff08;1&#xff09;– 简介 问题 此修改将稍微调整derivative term。目标是消除一种称为“…

RocketMQ消息汇总

当物理文件删除了 队列中的下标的消息也被删除了 但是即使物理删除了 队列中的偏移量还是会持续上升每天凌晨4点 定时清理 在 RocketMQ 中&#xff0c;消息的物理删除是通过定期清理 CommitLog 文件来实现的。CommitLog 文件中存储的是所有主题和队列的消息&#xff0c;一旦这…

关于图片导入Eagle弹出“抱歉,eagle发生了一些问题”的解决办法 | 如何查看Eagle调试报告查询错误文件方法

教程不易&#xff0c;希望得到关注 先说解决办法 使用格式工厂将所有图片或报错图片文件再次转为JPG文件&#xff0c;即可正常导入。 官网入口 http://www.pcgeshi.com/ 吐槽一下现在搜索软件搜“格式工厂官网”第一页全是盗版软件和流氓网页&#xff0c;什么什么金X 风X格式…

使用 Streamlit 和 Python 构建 Web 应用程序

一.介绍 在本文中&#xff0c;我们将探讨如何使用 Streamlit 构建一个简单的 Web 应用程序。Streamlit 是一个功能强大的 Python 库&#xff0c;允许开发人员快速轻松地创建交互式 Web 应用程序。Streamlit 旨在让 Python 开发人员尽可能轻松地创建 Web 应用程序。以下是一些主…

TCP/UDP Socket 测试小工具,作为网工不可以不知道

背景 阿祥今天推荐一款TCP/UDP Socket 测试工具&#xff0c;所谓TCP/IP调试工具是用于在TCP/UDP的应用层上进行通信连接、数据传输的Windows工具。所谓应用层上就是说&#xff0c;TCP调试工具是不涉及TCP/IP协议层实现的问题&#xff0c;而只是利用TCP/IP进行数据传输的工具。 …

建模杂谈系列246 数据模型

说明 如果说微服务化(API接口、Web页面、Docker镜像)是架构方面的基准&#xff0c;那么数据模型就是逻辑处理方面的基准 内容 以下是一个样例&#xff1a; import redef extract_utf8_chars(input_string None):# 定义一个正则表达式&#xff0c;用于匹配所有的UTF-8字符utf…

OpenStack Yoga版安装笔记(十一)nova安装(上)

1、官方文档 OpenStack Installation Guidehttps://docs.openstack.org/install-guide/ 本次安装是在Ubuntu 22.04上进行&#xff0c;基本按照OpenStack Installation Guide顺序执行&#xff0c;主要内容包括&#xff1a; 环境安装 &#xff08;已完成&#xff09;OpenStack…

一文详解大模型蒸馏工具TextBrewer

原文&#xff1a;https://zhuanlan.zhihu.com/p/648674584 本文分享自华为云社区《TextBrewer&#xff1a;融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度&#xff0c;减少内存占用》&#xff0c;作者&#xff1a;汀丶。 TextBre…

谷粒商城实战笔记-122~124-全文检索-ElasticSearch-分词

文章目录 一&#xff0c;122-全文检索-ElasticSearch-分词-分词&安装ik分词二&#xff0c;124-全文检索-ElasticSearch-分词-自定义扩展词库1&#xff0c;创建nginx容器1.1 创建nginx文件夹1.2 创建nginx容器获取nginx配置1.3 复制nginx容器配置文件1.4 删除临时的nginx容器…