mindspore框架下Pix2Pix模型实现真实图到线稿图的转换|(三)Pix2Pix模型训练与模型推理

news2024/9/24 19:19:01

mindspore框架下Pix2Pix模型实现真实图到线稿图的转换

  1. mindspore框架下Pix2Pix模型实现真实图到线稿图的转换|(一)dataset_pix2pix数据集准备
  2. mindspore框架下Pix2Pix模型实现真实图到线稿图的转换|(二)Pix2Pix模型构建
  3. mindspore框架下Pix2Pix模型实现真实图到线稿图的转换|(三)Pix2Pix模型训练与模型推理
  4. mindspore框架下Pix2Pix模型实现真实图到线稿图的转换|(四)模型应用实践

Pix2Pix模型训练

训练分为两个主要部分:

  1. 训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。
  2. 训练生成器。训练生成器是希望能产生更好的虚假图像。
    在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计。
import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensor

epoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100

def get_lr():
    lrs = [lr] * dataset_size * n_epochs
    lr_epoch = 0
    for epoch in range(n_epochs_decay):
        lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decay
        lrs += [lr_epoch] * dataset_size
    lrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)
    return Tensor(np.array(lrs).astype(np.float32))

dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()

def forword_dis(reala, realb):
    lambda_dis = 0.5
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    pred1 = net_discriminator(reala, realb)
    loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))
    loss_dis = loss_d * lambda_dis
    return loss_dis

def forword_gan(reala, realb):
    lambda_gan = 0.5
    lambda_l1 = 100
    fakeb = net_generator(reala)
    pred0 = net_discriminator(reala, fakeb)
    loss_1 = loss_f(pred0, ops.ones_like(pred0))
    loss_2 = l1_loss(fakeb, realb)
    loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1
    return loss_gan

d_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),
                beta1=0.5, beta2=0.999, loss_scale=1)

grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())

def train_step(reala, realb):
    loss_dis, d_grads = grad_d(reala, realb)
    loss_gan, g_grads = grad_g(reala, realb)
    d_opt(d_grads)
    g_opt(g_grads)
    return loss_dis, loss_gan

if not os.path.isdir(ckpt_dir):
    os.makedirs(ckpt_dir)

g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)

for epoch in range(epoch_num):
    for i, data in enumerate(data_loader):
        start_time = datetime.datetime.now()
        input_image = Tensor(data["input_images"])
        target_image = Tensor(data["target_images"])
        dis_loss, gen_loss = train_step(input_image, target_image)
        end_time = datetime.datetime.now()
        delta = (end_time - start_time).microseconds
        if i % 2 == 0:
            print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))
        d_losses.append(dis_loss.asnumpy())
        g_losses.append(gen_loss.asnumpy())
    if (epoch + 1) == epoch_num:
        mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

Pix2Pix模型加载与推理

  1. 加载训练过程完成后的ckpt文件;
  2. 通过load_checkpoint和load_param_into_net将ckpt中的权重参数导入到模型中;
  3. 获取数据进行推理并对推理的效果图进行演示。
from mindspore import load_checkpoint, load_param_into_net

param_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):
    plt.subplot(2, 10, i + 1)
    plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
    plt.subplot(2, 10, i + 11)
    plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)
    plt.axis("off")
    plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1971489.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Google Gemma2 2B:语言模型的“小时代”到来?

北京时间8月1日凌晨(当地时间7月31日下午),Google发布了其Gemma系列开源语言模型的更新,在AI领域引发了巨大的震动。Google Developer的官方博客宣布,与6月发布的27B和9B参数版本相比,新的2B参数模型在保持…

python实现consul的服务注册与注销

我在使用consul的时候主要用于prometheus的consul服务发现,把数据库、虚拟机信息发布到consul,prometheus通过consul拿到数据库、虚拟机信息去采集指标信息。 此篇文章前提是已经安装好consul服务以后,安装consul请参考二进制方式部署consul…

Nat网络地址转换实验

一、实验拓扑 二、实验要求 三、实验思路 四、实验展示 1.接口IP配置 telnet路由器 r1 r2 r3 pc2 2.全网可达(给边界路由器,私家路由器写上缺省 ,还要用到nat地址转换,多对多一对多,端口映射)因为左右…

第22集《大佛顶首楞严经》

请大家打开讲义第四十九页,“寅三、大众茫然”。 我们要是读《金刚经》,就知道整个修学的方向。《金刚经》就是讲到,一个菩萨发了菩提心,心中有目标,要能够上求佛道,下化众生,但是他不知道“云…

探索味蕾新境界:嘴尚绝卤味,一口难忘的美味传奇

在美食的浩瀚星空中,总有一些味道能够穿越时光的长河,直击人心最柔软的部分,让人回味无穷。今天,就让我们一起走进“嘴尚绝”卤味的世界,感受那份独特而令人难以忘怀的口感之美。 一、卤味之魂,匠心独运 “…

CTF web bibibi题型

CTF web bibibi题型 1.进入网站 在kali中使用Dirsearch对地址进行目录扫描,发现robots.txt 网址内加入 /robots.txt 进入网址 /fl4gi5Here.php 找到flag

未来五年,网络安全有没有发展前途,零基础转行难不难?

在被新冠疫情常态化影响的今天,职场当中呈现出了严重的两极分化现象,具体的表现形式为: 一些人薪资翻倍、愈加繁忙,另一些人则加入了失业大军、不知所措;一些行业实现了井喷式增长,一些行业却不断裁员、随…

Apache解析漏洞

一、apache_parsing 在Apache1.x/2.x中Apache 解析文件的规则是从右到左开始判断解析,如果后缀名为不可识别文件解析,就再往左判断。如1.php.xxxxx 1、进入Vulhub靶场并执行以下命令启动靶场 2、只要一个文件含有.php后缀的文件即将被识别成PHP文件&am…

即时通讯和即时通信,即时通讯和实时通信

在当今数字化时代,即时通讯和实时通信已成为人们日常生活和工作中不可或缺的一部分。尽管这两个概念经常被混淆使用,但它们在本质和应用上存在一些区别和联系。同时,企业级即时通讯平台WorkPlus对于提升企业内部沟通和协作也有着重要的作用。…

Java面试八股之简述spring boot的目录结构

简述spring boot的目录结构 Spring Boot 项目遵循标准的 Maven 或 Gradle 项目布局,并且有一些约定的目录用于组织不同的项目组件。下面是一个典型的 Spring Boot 项目目录结构: src/main/java:包含所有的 Java 源代码,通常按包组…

8个高质量PPT模板网站,免费下载

演示文稿已经成为交流和展示想法的重要工具。而一个引人注目、内容精彩的PPT演示,不仅可以让观众留下深刻的印象,还能有效地传达信息和观点。分享八个备受推崇的高质量PPT模板网站,这些网站提供各种各样的模板,涵盖了不同主题、风…

史上最快在IDEA中创建类,只需要 ctrl + 鼠标左键双击 就可以调出创建类的窗口(全网首创)

文章目录 1、正常创建类的步骤2、改进方案 1、正常创建类的步骤 需要首先鼠标右键一次,点击新建,再点击 Java类,过于麻烦 2、改进方案 我们只需要自定义设置创建类的快捷键即可 找到设置、按键映射、主菜单、文件、文件打开操作、打开项目…

有哪些因素会影响谷歌ASO优化效果呢

目前在Google Play上,已超过5.3亿的移动应用。未来还会有更多的移动应用涌入。开发者都希望自己的应用,最具有竞争力,并且可以获得大量免费流量。ASO是Google Play最重要的策略之一,而影响谷歌ASO优化效果的因素有很多&#xff0c…

欧拉系统如果数据库忘记密码的解决办法

如果数据库忘记密码,该怎么办 systemctl stop mariadb #先关闭数据库 mysqld_safe --skip-grant-tables& #跳过权限表的检查 mysql #现在可以不通过密码就能进入mysql了 flush privileges; #刷新权限 alter user rootlocalhost ide…

【IEEE出版 | 连续五届稳定EI检索】第六届机器学习、大数据与商务智能国际会议(MLBDBI 2024)

IEEE出版 | MLBDBI 2023会后4个半月内完成EI检索 第六届机器学习、大数据与商务智能国际会议(MLBDBI 2024) 2024 6th International Conference on Machine Learning, Big Data and Business Intelligence 重要信息 大会官网: 会议时间&a…

二叉树链式结构的实现(递归的暴力美学!!)

前言 Hello,小伙伴们。你们的作者菌又回来了,前些时间我们刚学习完二叉树的顺序结构,今天我们就趁热打铁,继续我们二叉树链式结构的学习。我们上期有提到,二叉树的的底层结构可以选为数组和链表,顺序结构我们选用的数…

大数据Flink(一百零六):什么是阿里云实时计算Flink版

文章目录 什么是阿里云实时计算Flink版 一、产品概述 二、产品架构 三、产品优势 什么是阿里云实时计算Flink版 阿里云实时计算Flink版是一套基于Apache Flink构建的⼀站式实时大数据分析平台,提供端到端亚秒级实时数据分析能力,并通过标准SQL降低业…

openEuler 自定义ISO制作(logo,名称,ISO)

前言 oecustom (openEuler customize) 是一套关于 openEuler iso 格式光盘映像的定制工具集。 工具用途iso_custom用于定制 openEuler iso 镜像,可以定制 openEuler iso 镜像的系统名称和安装界面图标等iso_cut用于裁剪 openEuler iso 镜像,参考 oemak…

暴食之史莱姆(河南萌新2024)

思路&#xff1a;单调栈&#xff08;分别统计左边小于等于当前大小的数量&#xff09; #include <bits/stdc.h>using namespace std; typedef long long ll; typedef double db; typedef long double ldb; typedef pair<int, int> pii; typedef pair<ll, ll>…

【超强论文干货】教大家一个水论文最快的方法,一天能找20个创新点!本科生_研究生_博士生一定要收藏!

前言 都这个时候了&#xff0c;别告诉我你还没找到论文创新点。哈哈哈&#xff0c;如果你真的卡在这里了&#xff0c;那么这个文章绝对是你的救星&#xff0c;今天我就给大家分享找论文创新点这一块的方法&#xff0c;让你一天之内至少找到十几个。这个方法不仅适合那些真心想…