在浏览器中进行深度学习:TensorFlow.js (八)生成对抗网络 (GAN)

news2025/1/16 10:59:23

Generative Adversarial Network 是深度学习中非常有趣的一种方法。GAN 最早源自 Ian Goodfellow 的这篇论文。LeCun 对 GAN 给出了极高的评价:

“There are many interesting recent development in deep learning…The most important one, in my opinion, is adversarial training (also called GAN for Generative Adversarial Networks). This, and the variations that are now being proposed is the most interesting idea in the last 10 years in ML, in my opinion.” – Yann LeCun

那么我们就看看 GAN 究竟是怎么回事吧:

GAN 包含两个互相对抗的网络:G(Generator)和 D(Discriminator)。正如它的名字所暗示的那样,它们的功能分别是:

  • Generator 是一个生成器的网络,它接收一个随机的噪声,通过这个噪声生成图片,记做 G (z)。
  • Discriminator 是一个鉴别器网络,判别一张图片或者一个输入是不是 “真实的”。它的输入 x 是数据或者图片,输出 D(x)代表 x 为真实图片的概率,如果为 1,就代表 100% 是真实的图片,而输出为 0,就代表不可能是真实的图片。

在训练过程中,生成网络 G 的目标就是尽量生成真实的图片去欺骗判别网络 D。而 D 的目标就是尽量把 G 生成的图片和真实的图片分别开来。这样,G 和 D 构成了一个动态的 “博弈过程”。在最理想的状态下,G 可以生成足以 “以假乱真” 的图片 G (z)。对于 D 来说,它难以判定 G 生成的图片究竟是不是真实的,因此 D (G (z)) = 0.5。

最后,我们就可以使用生成器和随机输入来生成不同的数据或者图片了。

上面的描述大家可能都能理解,但是把它变成数学语言,可能你就蒙 B 了。

如上图所示,x 是输入,z 是随机噪声。D (x) 是鉴别器的判定数据为真的概率,D (G (z)) 是判定生成数据为真的概率。生成器希望这个 D (G (z)) 越大越好,这个时候整个表达式的值应该变小。而鉴别器的目的是能够有效区分真实数据和假数据,所以 D (x) 应该趋向于变大,D (G (z)) 趋向于变小,整个表达式就变大。也就是说训练过程,生成器和辨别器互相对抗,一个使上述表达式变小,另一个使其变大,最后训练趋向于平衡,而生成器这时候应该生成真假难辨的数据,这就是我们的最终目的。

上图是 GAN 算法训练的具体过程,这里我们不做过多的解释,直接运行一个例子。

我们用 MINST 数据集来看看如何使用 TensorflowJS 来训练一个 GAN,模拟生成手写数字。

function gen(xs) {
  const l1 = tf.leakyRelu(xs.matMul(G1w).add(G1b));
  const l2 = tf.leakyRelu(l1.matMul(G2w).add(G2b));
  const l3 = tf.tanh(l2.matMul(G3w).add(G3b));
  return l3;
}

function disReal(xs) {
  const l1 = tf.leakyRelu(xs.matMul(D1w).add(D1b));
  const l2 = tf.leakyRelu(l1.matMul(D2w).add(D2b));
  const logits = l2.matMul(D3w).add(D3b);
  const output = tf.sigmoid(logits);
  return [logits, output];
}

function disFake(xs) {
  return disReal(gen(xs));
}

GAN 的两个网络分别用 gen 和 disReal 创建。gen 是生成器网络,disReal 是辨别器的网络。disFake 是把生成数据用辨别器来辨别。这里的网络使用 leakyrelu。使得输出在 - inf 到 + inf,利用 sigmoid 映射到【0,1】,这是辨别器模型输出一个 0-1 之间的概率。

âleaky reluâçå¾çæç´¢ç»æ

通常我们会创建一个比生成器更复杂的鉴别器网络使得鉴别器有足够的分辨能力。但在这个例子里,两个网络的复杂程度类似。

计算损失的函数使用 tf.sigmoidCrossEntropyWithLogits,值得注意的是,在最新的 0.13 版本中,这个交叉熵被移除了,你需要自己实现该方法。

训练过程如下:

async function trainBatch(realBatch, fakeBatch) {
  const dcost = dOptimizer.minimize(
    () => {
      const [logitsReal, outputReal] = disReal(realBatch);
      const [logitsFake, outputFake] = disFake(fakeBatch);

      const lossReal = sigmoidCrossEntropyWithLogits(ONES_PRIME, logitsReal);
      const lossFake = sigmoidCrossEntropyWithLogits(ZEROS, logitsFake);
      return lossReal.add(lossFake).mean();
    },
    true,
    [D1w, D1b, D2w, D2b, D3w, D3b]
  );
  await tf.nextFrame();
  const gcost = gOptimizer.minimize(
    () => {
      const [logitsFake, outputFake] = disFake(fakeBatch);

      const lossFake = sigmoidCrossEntropyWithLogits(ONES, logitsFake);
      return lossFake.mean();
    },
    true,
    [G1w, G1b, G2w, G2b, G3w, G3b]
  );
  await tf.nextFrame();

  return [dcost, gcost];
}

训练使用了两个 optimizer,

  1. 第一步,计算实际数据的辨别结果和 1 的交叉熵,以及生成器生成数据的辨别结果和 0 的交叉熵。也就是说,我们希望辨别器尽可能的判断出生成数据都是假的而实际数据都是真的。使得这两个交叉熵的均值最小。
  2. 第二步开始对抗,要让生成数据尽可能被判别为真。

下图是某个训练过程的损失:

这个是经过 1000 个迭代后的生成图:

大家可以尝试调整学习率,增加网络复杂度,加大迭代次数来获得更好的生成模型。

GAN 的学习其实还是比较复杂的,参数和损失选择都不容易,好在有一些现成的工具可以使用,另外推荐大家去 https://poloclub.github.io/ganlab/,提供了很直观的 GAN 学习的过程。这个也是用 TensorflowJS 来实现的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1331640.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

华为设备文件系统基础

华为网络设备的配置文件和VRP系统文件都保存在物理存储介质中,所以文件系统是VRP正常运行的基础。只有掌握了对文件系统的基本操作,网络工程师才能对设备的配置文件和VRP系统文件进行高效的管理。 基本查询命令 VRP基于文件系统来管理设备上的文件和目录…

NTFS权限与文件系统:深入解析与实践指南

在当今的信息时代,数据安全和管理成为了每个组织和个人的重要议题。NTFS权限作为Windows操作系统中的一个核心功能,为文件和文件夹的安全管理提供了强大的支持。本文将深入解析NTFS权限的基本概念,并通过实际操作指导如何有效地利用这些权限来…

HTML5文档

目录 HTML5文档结构1.HTML5页面结构2.HTML5新增结构元素 HTML5新增页面元素1.hgroup标记2.figure标记与figcaption标记3.mark标记与time标记4.details标记与summary标记5.progress标记与meter标记6.input标记与datalist标记 HTML5文档结构 HTML5文档结构同样是由头部和主体两部…

R语言ggplot2可视化:分组堆叠条形图,展示不同分组的多个处理数据特征,动态交互式条形图

在实验数据可视化过程中,经常需要对多个样本在多个处理条件下多种指标进行比较,使用下面这种分组堆叠条形图能从多个角度同时展示数据特征。 备注:图中横轴以“0”为界左右分为两部分,可以用来表示处理A和处理B,纵轴表…

AOSP源码下载方法,解决repo sync错误:android-13.0.0_r82

篇头 最近写文章,反复多次折腾AOSP代码,因通过网络repo sync aosp代码,能一次顺利下载的概率很低,以前就经常遇到,但从未总结,导致自己也要回头检索方法,所以觉得可以总结一下,涉及…

python+django疾病健康知识科普推荐系统

基于智能推荐的卫生健康系统通过信息化技术,研究健康管理倌息的获取、传输、处理和反馈,实现区域一体化协同医疗健康服务,建立高品质与高效率的健康监测、疾病防治服务体系、健康生活方式与健康风险评价体系,达到改善健康状况、防治常见和慢性疾病的发生和发展、提高生命质量、…

docker笔记1-安装与基础命令

docker的用途: 可以把应用程序代码及运行依赖环境打包成镜像,作为交付介质,在各种环境部署。可以将镜像(image)启动成容器(container),并提供多容器的生命周期进行管理(…

Netty-2-数据编解码

解析编解码支持的原理 以编码为例,要将对象序列化成字节流,你可以使用MessageToByteEncoder或MessageToMessageEncoder类。 这两个类都继承自ChannelOutboundHandlerAdapter适配器类,用于进行数据的转换。 其中,对于MessageToMe…

数据结构-如何巧妙实现一个栈?逐步解析与代码示例

文章目录 引言1.栈的基本概念2.选择数组还是链表?3. 定义栈结构4.初始化栈5.压栈操作6.弹栈操作7.查看栈顶和判断栈空9.销毁栈操作10.测试并且打印栈内容栈的实际应用结论 引言 栈是一种基本但强大的数据结构,它在许多算法和系统功能中扮演着关键角色。…

智能优化算法应用:基于天鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用:基于天鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用:基于天鹰算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.天鹰算法4.实验参数设定5.算法结果6.参考文献7.MA…

oracle恢复分片和非分片备份?

分片备份命令参考:适合大数据库并行备份提高备份速度 对于超大数据库,混合有小文件和大文件表空间,section size 表示分片,大小一般大于32G,可结合通道数量设置最佳值。 run { allocate channel t1 type disk; alloc…

PostGreSQL:货币类型

货币类型:money money类型存储固定小数精度的货币数字,小数的精度由数据库的lc_monetary设置决定。windows系统下,该配置项位于/data/postgresql.conf文件中,默认配置如下, lc_monetary Chinese (Simplified)_Chi…

redis基本用法学习(C#调用CSRedisCore操作redis)

除了NRedisStack包,csredis也是常用的redis操作模块(从EasyCaching提供的常用redis操作包来看,CSRedis、freeredis、StackExchange.Redis应该都属于常用redis操作模块),本文学习使用C#调用CSRedis包操作redis的基本方式…

【Spring Security】打造安全无忧的Web应用--使用篇

🥳🥳Welcome Huihuis Code World ! !🥳🥳 接下来看看由辉辉所写的关于Spring Security的相关操作吧 目录 🥳🥳Welcome Huihuis Code World ! !🥳🥳 一.Spring Security中的授权是…

阿贝云云服务器

最近,我有幸获得了阿贝云提供的免费云服务器,阿贝云_免费云服务器、高防服务器、虚拟主机、免费空间、免费vps主机服务商!并在使用过程中有了一些深刻的体验和感受。在这篇博客中,我将分享我对阿贝云免费云服务器的使用感受和评价。 首先&am…

【iOS】UICollectionView

文章目录 前言一、实现简单九宫格布局二、UICollectionView中的常用方法和属性1.UICollectionViewFlowLayout相关属性2.UICollectionView相关属性 三、协议和代理方法:四、九宫格式的布局进行升级五、实现瀑布流布局实现思路实现原理代码调用顺序实现步骤实现效果 总…

需求分析工程师岗位的职责描述(合集)

需求分析工程师岗位的职责描述1 职责: 1,负责需求调研,对需求进行分析,编写解决方案、需求规格说明书等 2,根据需求制作原型,并负责原型展示以及客户沟通等工作 3,负责向技术团队精确地传达业务…

排序算法——桶排序

把数据放进若干个桶,然后在桶里用其他排序,近乎分治思想。从数值的低位到高位依次排序,有几位就排序几次。例如二位数就排两次,三位数就排三次,依次按照个十百...的顺序来排序。 第一次排序:50 12 …

Unity手机移动设备重力感应

Unity手机移动设备重力感应 一、引入二、介绍三、测试成果X Y轴Z轴横屏的手机,如下图竖屏的手机,如下图 一、引入 大家对重力感应应该都不陌生,之前玩过的王者荣耀的资源更新界面就是使用了重力感应的概念,根据手机的晃动来给实体…

EPROM 作为存储器的 8 位单片机

一、基本概述 TX-P01I83 是以 EPROM 作为存储器的 8 位单片机,专为多 IO 产品的应用而设计,例如遥控器、风扇/灯光控制或是 玩具周边等等。采用 CMOS 制程并同时提供客户低成本、高性能等显着优势。TX-P01I83 核心建立在 RISC 精简指 令集架构可以很容易…