《昇思25天学习打卡营第24天》

news2024/9/19 10:52:57

接续上一天的学习任务,我们要继续进行下一步的操作

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

接下来了解一下其他内容

生成器

生成器G的功能是将隐向量z映射到数据空间。实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

代码实现

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal

weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)

class Generator(nn.Cell):
    """DCGAN网络生成器"""

    def __init__(self):
        super(Generator, self).__init__()
        self.generator = nn.SequentialCell(
            nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf, gamma_init=gamma_init),
            nn.ReLU(),
            nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.Tanh()
            )

    def construct(self, x):
        return self.generator(x)

generator = Generator()

判别器

判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。

代码实现

class Discriminator(nn.Cell):
    """DCGAN网络判别器"""

    def __init__(self):
        super(Discriminator, self).__init__()
        self.discriminator = nn.SequentialCell(
            nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),
            nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),
            nn.LeakyReLU(0.2),
            nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),
            )
        self.adv_layer = nn.Sigmoid()

    def construct(self, x):
        out = self.discriminator(x)
        out = out.reshape(out.shape[0], -1)
        return self.adv_layer(out)

discriminator = Discriminator()

接下来进入模型训练阶段

模型训练

其中分为几个要素:

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

优化器

训练模型:训练判别器和训练生成器。

实现模型训练正向逻辑:

def generator_forward(real_imgs, valid):
    # 将噪声采样为发生器的输入
    z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))

    # 生成一批图像
    gen_imgs = generator(z)

    # 损失衡量发生器绕过判别器的能力
    g_loss = adversarial_loss(discriminator(gen_imgs), valid)

    return g_loss, gen_imgs

def discriminator_forward(real_imgs, gen_imgs, valid, fake):
    # 衡量鉴别器从生成的样本中对真实样本进行分类的能力
    real_loss = adversarial_loss(discriminator(real_imgs), valid)
    fake_loss = adversarial_loss(discriminator(gen_imgs), fake)
    d_loss = (real_loss + fake_loss) / 2
    return d_loss

grad_generator_fn = ms.value_and_grad(generator_forward, None,
                                      optimizer_G.parameters,
                                      has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,
                                          optimizer_D.parameters)

@ms.jit
def train_step(imgs):
    valid = ops.ones((imgs.shape[0], 1), mindspore.float32)
    fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)

    (g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)
    optimizer_G(g_grads)
    d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)
    optimizer_D(d_grads)

    return g_loss, d_loss, gen_imgs

代码训练

结果展示就不多说了看成品

文末附上打卡时间

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1962663.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

程序一调用这个接口就会崩溃, 因为他的静态库添加是放在release文件下,而我用的debug模式

程序一调用这个接口就会崩溃 因为他的静态库添加是放在release文件下 而我用的debug模式 DESTDIR ../x64/ReleaseINCLUDEPATH ./../3rdparty/ZZDecode/include LIBS -lopengl32 \-lglu32 \-luser32 \./../3rdparty/ZZDecode/x64/release/ZZDecodeInterface.lib

华为仓颉语言测试申请

1. 申请网址 HarmonyOS NEXT仓颉语言开发者预览版 Beta招募- 华为开发者联盟 点击立即报名登录华为账号 勾选选项 , 点击同意 按要求填写信息即可 2. 申请通过后官方会通过邮件的方式发送相关下载途径 , 根据文档进行下载即可 package Cangmain(): Int64 {println("你…

Springboot使用Redis实现分布式锁

1、使用场景和实现方案: 使用场景:本地锁如Lock和Syncronized只能锁住本地进程,在分布式应用中,需要使用分布式锁来更好实现特定的业务。 实现方案:有多种,比如使用mysql、zookeeper、redis,各…

nginx续1:

八、虚拟主机配置 基于域名的虚拟主机 [rootserver2 ~]# ps -au|grep nginx //查看进程 修改Nginx服务配置,添加相关虚拟主机配置如下 1. [rootproxy ~]# vim /usr/local/nginx/conf/nginx.conf 2. .. .. 3. server { 4. listen …

基于Material studio拉伸-断裂过程的Perl脚本

在材料科学的研究中,拉伸-断裂过程一直是科学家们探索的焦点。这一过程涉及复杂的力学行为和材料内部微观结构的变化,对于理解材料的性能至关重要。然而,传统的实验方法不仅耗时耗力,而且难以捕捉到微观尺度上的所有细节。 为了满…

jQuery前端网页制作

1、Jquery的概述 1.1JavaScript库 JavaScript 高级程序设计(特别是对浏览器差异的复杂处理),通常很困难也很耗时。 为了应对这些调整,许多的 JavaScript (helper) 库应运而生。 这些 JavaScript 库常被称为 JavaScript 框架。 市面上一些广受欢迎的 JavaScript 框架:…

visual studio性能探测器使用案列

visual studio性能探测器使用案列 在visual studio中,我们可以使用自带的工具对项目进行性能探测,具体如下 1.选择性能探查器 Vs2022/Vs2019中打开方式: Vs2017打开方式: 注意最好将解决方案配置为:Release Debu…

【未来餐饮】 配送设置

一、创建门店 关键信息 1. 门店名字要有辨识度,尽量不和其他客户重名 2. 地址要具体到门牌号 3. 定位要和上面的地址一致 可以复制地址搜索地图,然后选择位置 二、创建配送模板 新建模板 填写模板 命名模板,勾上真省钱,然后保…

Meta再下一城:SAM 2

--->更多内容&#xff0c;请移步“鲁班秘笈”&#xff01;&#xff01;<--- “继用于图像的Meta Segment Anything Model &#xff08;SAM&#xff09;取得成功之后&#xff0c;我们发布了SAM 2&#xff0c;这是一种用于在图像和视频中实时进行对象分割的统一模型&#…

npm创建vue的ts项目

一、进入项目文件夹 使用cmd进入你想要创建项目的文件夹&#xff0c;此处为 E盘的test文件夹 cd E:\testE:二、创建项目 此处项目名为 MyTestProject npm create vitelatest输入上述代码&#xff0c;回车后会出现灰色的虚拟名称&#xff0c;此处输入你自己的名称即可&#…

软件平台化开发项目实践

汉捷咨询有40多位来自多家著名企业&#xff08;华为、中兴、三星等&#xff09;的咨询顾问和讲师&#xff0c;资深顾问/项目经理均有华为、中兴等领先企业高管及咨询实践15年以上经验&#xff0c;本文为汉捷一IPD资深顾问的行业实践总结&#xff0c;与各位同仁分享&#xff01;…

WPF用户登录界面设计-使用SQLite数据库进行存储

一、SQLite数据库介绍 SQLite是一款轻量级的关系型数据库&#xff0c;它小巧高效&#xff0c;无需服务器配置&#xff0c;仅需单一文件即可存储数据。SQLite跨平台支持&#xff0c;易于集成到各种应用程序中&#xff0c;并支持SQL语言进行数据操作。它保证了数据的完整性、一致…

Java数据结构和算法中文版(第2版)详细教程

前言 数据结构是指数据在计算机存储空间中(或磁盘中)的安排方式。算法是指软件程序用来操作这些结构中的数据的过程。几乎所有的计算机程序都使用数据结构和算法&#xff0c;即使最简单的程序也不例外。比如设想一个打印地址标签的程序&#xff0c;这个程序使用一个数组来存储…

整理几个常用的Linux命令(Centos发行版)

如果工作中需要经常整理一些文档&#xff0c;需要汇总一下&#xff0c;现有的服务器资源信息&#xff0c;那么这篇文章适合你&#xff1b; 如果你是一名开发者&#xff0c;需要经常登录服务器&#xff0c;排查应用的出现的一些问题&#xff0c;那么这篇文章适合你&#xff1b;…

使用java判断字符串中是否包含中文汉字

1.导入huool工具的maven依赖 <dependency><groupId>cn.hutool</groupId><artifactId>hutool-all</artifactId><version>5.8.16</version></dependency>2.复制一下代码直接运行 import cn.hutool.core.lang.Validator;public …

面向对象 - 概述、类的创建、 实例化与内存解析

一、学习面向对象的三条主线 Java类及类的成员&#xff1a;&#xff08;重点&#xff09;属性、方法、构造器&#xff1b;&#xff08;熟悉&#xff09;代码块、内部类面向对象的特征&#xff1a;封装、继承、多态、&#xff08;抽象&#xff09;其他关键字的使用&#xff1a;…

机械学习—零基础学习日志(高数16——函数极限性质)

零基础为了学人工智能&#xff0c;真的开始复习高数 这里我们继续学习函数极限的性质。 局部有界性 充分条件与必要条件 极限存在是函数局部有界的充分条件。什么是充分条件&#xff0c;什么是必要条件呢&#xff1f;我这里做了一点小思考&#xff0c;和大家分享&#xff0c…

alibaba cloud linux+JDK+TOMCAT+NGINX+PHP+MYSQL配置实践

CentOs要停止维护了&#xff0c;一直在服务器上用的CentOs7也最迟到2024年6月了&#xff0c;这次给公司新购一台备用服务器&#xff0c;在选择操作系统的时候&#xff0c;考虑了一下&#xff0c;决定试用一下阿里云的alibaba cloud linux。 alibaba cloud linux分为2和3版本&am…

创客项目秀 | 基于xiao的光剑

在《星球大战》宇宙中&#xff0c;光剑不仅仅是武器;它们是持有者与原力的桥梁&#xff0c;制造一把光剑几乎是每个创客的梦想&#xff0c;今天给大家带来的是国外大学生团队制作的可伸缩光剑项目。 材料清单&#xff1a; 电机驱动模块1:90减速电机套装MP3模块、喇叭Xiao RP2…

ingress使用HostNetwork部署

1.三种常用的部署模式 1.1 DeploymentLoadBalancer模式的service 用Deployment部署igress-controller&#xff0c;创建一个type为LoadBalancer的service关联这组pod。大部分公有云&#xff0c;都会为LoadBalancer的service自动创建一个负载均衡器&#xff0c;通常还绑定了公网…