基础!!!吴恩达deeplearning.ai:神经网络中使用softmax

news2025/1/22 16:02:21

以下内容有任何不理解可以翻看我之前的博客哦:吴恩达deeplearning.ai

文章目录

  • softmax作为输出层的神经网络
  • Tensorflow的实现
  • softmax的改进实现
    • 数值舍入误差(Numerical Roundoff Errors)
    • sigmoid修改
    • 修改softmax


在上一篇博客中我们了解了有关softmax的原理相关内容,今天我们主要聚焦于如何修改之前的神经网络,从而搭建能够实现多分类问题的神经网络。

softmax作为输出层的神经网络

在这里插入图片描述
相比之前的二分类逻辑回归神经网络,我们主要的改变是将输出层替换为了具有十个神经元的,激活函数为softmax的输出层。整个神经网络的运行流程是接收特征输入X,并且传入隐藏层,两个隐藏层的激活函数均采用的是relu函数;再传入最终输出层,最终的输出 a [ 3 ] a^{[3]} a[3]是一个包含十个概率值的矩阵。
我们再回顾下softmax的公式(这里仅以a1为例):
z 1 = w 1 ⃗ ⋅ x ⃗ + b 1 a 1 = e z 1 e z 1 + e z 2 + e z 3 + e z 4 z_1=\vec{w_1}\cdot\vec{x}+b_1\\ a_1=\frac{e^{z_1}}{e^{z_1}+e^{z_2}+e^{z_3}+e^{z_4}} z1=w1 x +b1a1=ez1+ez2+ez3+ez4ez1
此外提一个定义,softmax层有时也被叫做softmax函数。与其它的激活函数相比不同的是,softmax中 a 1 a_1 a1仅仅和 z 1 z_1 z1有关, a 2 a_2 a2仅仅和 z 2 z_2 z2有关,而不像其它的激活函数最终的某个输出a和多个z有关。
让我们看看如何用代码实现这个神经网络

Tensorflow的实现

第一步,构建神经网络的结构框架:

import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense
model = Sequential([
  Dense(units=25, activation='relu')
  Dense(units=15, activation='relu')
  Dense(units=10, activation='softmax')
  ])

第二步,定义损失函数和价值函数

from tensorflow.keras.losses import
SparseCategoricalCrossentropy
model.compile(loss=SparseCategoricalCrossentropy())

这里出现了一个新的函数SparseCategoricalCrossentropy(),翻译成中文叫做稀疏分类交叉熵,名字超长,甚至超过了当初的二元交叉熵。稀疏(Sparse)的意思是值只能取1~10中的一个;分类(Categorical)指的是你仍然将y分类。
第三步,训练模型,预测代码和以前一样:

model.fit(X, Y, epochs=100)

以上代码是可以起作用的,符合我们之前的认知,但是还不够优化,在tensorflow中有更好的代码版本。下面我们看看如何优化softmax代码。

softmax的改进实现

数值舍入误差(Numerical Roundoff Errors)

让我先展示下在计算机设置数值的两种不同方法:
第一种,简单粗暴法:
x = 2 10 , 000 x=\frac{2}{10,000} x=10,0002
第二种,加加减减法:
x = ( 1 + 1 10 , 000 ) − ( 1 − 1 10 , 000 ) x=(1+\frac{1}{10,000})-(1-\frac{1}{10,000}) x=(1+10,0001)(110,0001)
虽然看上去相同,但是精确度是由差别的:
在这里插入图片描述
我们对softmax的改进,也主要聚焦在精确度上面,让我介绍一种更加精确的方法。

sigmoid修改

在逻辑回归中,我们的公式是这样的:
a = g ( z ) = 1 1 + e − z l o s s = − y l o g a − ( 1 − y ) l o g ( 1 − a ) a=g(z)=\frac{1}{1+e^-z}\\ loss=-yloga-(1-y)log(1-a) a=g(z)=1+ez1loss=yloga(1y)log(1a)
它的代码是:

model = Sequential([
  Dense(units=25, activation='relu')
  Dense(units=15, activation='relu')
  Dense(units=10, activation='sigmoid')
  ])
model.compile(loss=BinaryCrossEntropy())

如果我们要求tensorflow按照这个步骤,一步步算出a,然后带入到loss之中,那么结果就会如同上面的第二种方法一样产生误差,因为其进行了两步运算。但是tensorflow提供了另一种方法,大致意思就是我们先使用线性激活函数(也可以理解为没使用激活函数),最后在计算损失的时候再指定激活函数为sigmoid。如果我们使用了这个命令,这会为tensorflow提供更高的灵活性,从而可以减少误差,就如同上面的方法一;代码如下:

model = Sequential([
  Dense(units=25, activation='relu')
  Dense(units=15, activation='relu')
  Dense(units=10, activation='linear')
  ])
model.compile(loss=BinaryCrossEntropy(from_logits=True))

通俗点说from_ligits=True告诉了激活函数inaryCrossEntropy我没有用激活函数哦,所以你计算损失时内部记得调用下sigmoid哈。这里的logits可以理解为没有经过激活函数的z。

修改softmax

同样地,我们再看看稀疏分类交叉熵的损失函数,我就写出其中的一项:
L o s s = − l o g a   i f y = 1 Loss=-loga\:ify=1 Loss=logaify=1

由于在多分类问题之中,分类的选项很多,而各个选项的概率和是一定的为1,因此很多情况下正确的那个选项的概率依然很小,由于使用了log函数,在x接近于0的时候这个值会非常大,那么产生的误差也就会很大,而二分类问题由于选项仅有两个,因此这个问题不是很明显,便没在讲二分类的时候也进行这种优化。
一样地,我们代码也可以修改为:

model = Sequential([
  Dense(units=25, activation='relu')
  Dense(units=15, activation='relu')
  Dense(units=10, activation='linear')
  ])
  from tensorflow.keras.losses import
SparseCategoricalCrossentropy
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))

另外需要修改的地方是,我们在预测时,model(x)不再是概率a了,而是没经过激活函数的z,因此代码在最后需要添加:

model.fit(X, Y, epochs=100)
logits = model(X)
f_x = tf.nn.softmax(logits)

从而再加入了softmax,出来的才是0~1之间的概率a。
为了给读者你造成不必要的麻烦,博主的所有视频都没开仅粉丝可见,如果想要阅读我的其他博客,可以点个小小的关注哦。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1474883.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

基于华为atlas的分类模型实战

分类模型选用基于imagenet训练的MobileNetV3模型,分类类别为1000类。 pytorch模型导出为onnx: 修改mobilenetv3.py中网络结构,模型选用MobileNetV3_Small模型,网络输出节点增加softmax层,将原始的return self.linear4…

HarmonyOS—使用数据模型和连接器

Serverless低代码开发平台是一个可视化的平台, 打通了HarmonyOS云侧与端侧能力,能够轻松实现HMS Core、AGC Serverless能力调用。其中,数据模型和连接器是两大主要元素。开发者在使用DevEco Studio的低代码功能进行开发时,可以使用…

Unity开发一个FPS游戏

在之前的文章Unity 3D Input System的使用-CSDN博客中,我介绍了如何用Input System来实现一个FPS游戏的移动控制,这里将进一步完善这个游戏。 以下是游戏的演示效果: fps_demo 添加武器模型 首先是增加主角玩家的武器,我们可以在网上搜索到很多免费的3D资源,例如在以下网…

Vue概念详解【目录】

本专栏简介: 这个专栏是关于 Vue2 和 Vue3 各种概念的大集合!它深入挖掘原理,分析各种优势和劣势,适配各种应用场景,部分内容还列出了代码示例,以清晰地讲述原理。在这里,你将全面了解 Vue2 和…

在Web UI上提交Flink作业

1)任务打包完成后,我们打开Flink的WEB UI页面,在右侧导航栏点击“Submit New Job”,然后点击按钮“ Add New”,选择要上传运行的JAR包 JAR包上传完成,如下图所示 (2)点击该JAR包&…

10 Redis之SB整合Redis

7. SB整合Redis Spring Boot 中可以直接使用 Jedis 实现对 Redis 的操作,但一般不这样用,而是使用 Redis操作模板 RedisTemplate 类的实例来操作 Redis。 RedisTemplate 类是一个对 Redis 进行操作的模板类。该模板类中具有很多方法,这些方…

图论基础(一)

一、图论 图论是数学的一个分支,它以图为研究对象。图论中的图是若干给定的点(顶点)以及连接两点的线(边)构成的图像,这种图形通常用来描述某些事物之间的某种特定关系,用点代表事物&#xff0c…

MFC web文件 CHttpFile的使用初探

MFC CHttpFile的使用 两种方式,第一种OpenURL,第二种SendRequest,以前捣鼓过,今天再次整结果发现各种踩坑,好记性不如烂笔头,记录下来。 OpenURL 这种方式简单粗暴,用着舒服。 try {//OpenU…

JavaScript最新实现城市级联操作,json格式的数据

前置知识&#xff1a; <button onclick"doSelect()">操作下拉列表</button><hr>学历&#xff1a;<select id"degree"><option value"0">--请选择学历--</option><option value"1">专科<…

【玩转pandas系列】pandas数据结构—DataFrame

文章目录 前言一、DataFrame创建1.1 字典创建1.2 NumPy二维数组创建 二、DataFrame切片2.1 行切片2.2 列切片2.3 行列切片 三、DataFrame运算3.1 DataFrame和标量的运算3.2 DataFrame之间的运算3.3 Series和DataFrame之间的运算 四、DataFrame多层次索引4.1 多层次索引构造1.隐…

安全防御-第六次

内容安全 攻击可能只是一个点&#xff0c;防御需要全方面进行 DFI和DPI技术--- 深度检测技术 DPI --- 深度包检测技术--- 主要针对完整的数据包&#xff08;数据包分片&#xff0c;分段需要重组&#xff09;&#xff0c;之后对数据包的内容进行识别。&#xff08;应用层&…

音频混音算法的实现

最近项目有用到混音算法&#xff0c;这里用比较常见的一种&#xff0c;就是简单的加和之后做一下归一化。 是参考这个博主实现的&#xff1a; 音频混音的算法实现 下面直接贴代码&#xff1a; #include <stdio.h> #include <stdlib.h> #include <math.h&…

Mac 配置Clion Qt 调试显示变量值

背景 使用Clion开发Qt程序&#xff0c;在进行调试时&#xff0c;会看不到Qt类的变量值&#xff0c;只有指针形式&#xff0c;对于调试很不方便。 环境&#xff1a; Macbook ProCPU&#xff1a;M3Qt 5.15.13CLion 2023.3.4 解决方案 为了让Clion能显示Qt类的值&#xff0c;…

新加坡服务器托管:开启全球化发展之门

新加坡作为一个小国家&#xff0c;却在全球范围内享有极高的声誉。新加坡作为亚洲的科技中心&#xff0c;拥有先进的通信基础设施和成熟的机房托管市场。除了其独特的地理位置和发达的经济体系外&#xff0c;新加坡还以其开放的商业环境和便利的托管服务吸引着越来越多的国际公…

react 路由的基本原理及实现

1. react 路由原理 不同路径渲染不同的组件 有两种实现方式 ● HasRouter 利用hash实现路由切换 ● BrowserRouter 实现h5 API实现路由切换 1. 1 HasRouter 利用hash 实现路由切换 1.2 BrowserRouter 利用h5 Api实现路由的切换 1.2.1 history HTML5规范给我们提供了一个…

【Go 快速入门】协程 | 通道 | select 多路复用 | sync 包

文章目录 前言协程goroutine 调度使用 goroutine 通道无缓冲通道有缓冲通道单向通道 select 多路复用syncsync.WaitGroupsync.Mutexsync.RWMutexsync.Oncesync.Map 项目代码地址&#xff1a;05-GoroutineChannelSync 前言 Go 1.22 版本于不久前推出&#xff0c;更新的新特性可…

机器视觉运动控制一体机在光伏汇流焊机器人系统的解决方案

一、市场应用背景 汇流焊是光伏太阳能电池板中段加工工艺&#xff0c;其前道工序为串焊&#xff0c;在此环节流程中&#xff0c;需要在多个太阳能电池片表面以平行方式串焊多条焊带&#xff0c;形成电池串。串焊好的多组电池串被有序排列输送到汇流焊接工作台&#xff0c;通过…

MWC 2024丨美格智能推出5G RedCap系列FWA解决方案,开启5G轻量化新天地

2月27日&#xff0c;在MWC 2024世界移动通信大会上&#xff0c;美格智能正式推出5G RedCap系列FWA解决方案。此系列解决方案具有低功耗、低成本等优势&#xff0c;可以显著降低5G应用复杂度&#xff0c;快速实现5G网络接入&#xff0c;提升FWA部署的经济效益。 RedCap技术带来了…

Data Leakage and Evaluation Issues inMicro-Expression Analysis 阅读笔记

IEEE Transactions on Affective Computing上的一篇文章&#xff0c;做微表情识别&#xff0c;阅读完做个笔记。本文讨论了Data Leakage对模型准确度评估的影响&#xff0c;及如何融合多个微表情数据集&#xff0c;从而提升模型的准确度。工作量非常饱满&#xff0c;很认真&…

git忽略某些文件(夹)更改方法

概述 在项目中&#xff0c;常有需要忽略的文件、文件夹提交到代码仓库中&#xff0c;在此做个笔录。 一、在项目根目录内新建文本文件&#xff0c;并重命名为.gitignore&#xff0c;该文件语法如下 # 以#开始的行&#xff0c;被视为注释. # 忽略掉所有文件名是 a.txt的文件.…