损失函数-交叉熵 梯度下降

news2024/10/7 20:27:01

文章目录

  • 1、交叉熵的简单例子
    • 1.2、Classification Error(分类错误率)
    • 1.3、Mean Squared Error (均方误差)
    • 1.4、交叉熵损失函数
    • 1.5、二分类
  • 2、什么是梯度下降法?
    • 2.2、梯度下降法的运行过程
    • 2.3、二元函数的梯度下降

1、交叉熵的简单例子

参考文章例子

我们希望通过图像轮廓、颜色等特征,来预测动物的类别,有三种可能类别(猫、狗、猪)假设我们现在有两个模型,都是通过sigmoid/softmax的方式得到的对每个类别预测的概率 。

模型1:
预测 真实 是否正确
0.3 0.3 0.4 0 0 1 (猪) 正确
0.3 0.4 0.3 0 1 0 (狗) 正确
0.1 0.2 0.7 1 0 0 (猫) 错误
模型1对于样本1和样本2以非常微弱的优势判断正确,对于样本3的判断则彻底错误。

模型2:
预测 真实 是否正确
0.1 0.2 0.7 0 0 1 (猪) 正确
0.1 0.7 0.2 0 1 0 (狗) 正确
0.3 0.4 0.3 1 0 0 (猫) 错误
模型2对于样本1和样本2判断非常准确,对于样本3判断错误,但是相对来说没有错得太离谱。

有了模型之后,我们可以定义损失函数来对判断模型在当前样本表现

1.2、Classification Error(分类错误率)

在这里插入图片描述
从结果可知,模型1和模型2虽然都是预测错了1个,但是相对来说模型2表现得更好,损失函数值照理来说应该更小,但是,很遗憾的是,classification\ error 并不能判断出来,所以这种损失函数虽然好理解,但表现不太好。

1.3、Mean Squared Error (均方误差)

在这里插入图片描述
我们发现,MSE能够判断出来模型2优于模型1,那为什么不采样这种损失函数呢?主要原因是在分类问题中,使用sigmoid/softmx得到概率,配合MSE损失函数时,采用梯度下降法进行学习时,会出现模型一开始训练时,学习速率非常慢的情况。

1.4、交叉熵损失函数

交叉熵(Cross Entropy)用于衡量一个概率分布与另一个概率分布之间的距离。
交叉熵是机器学习和深度学习中常常用到的一个概念。在分类问题中,我们通常有一个真实的概率分布(通常是专家或训练数据的分布),以及一个模型生成的概率分布,交叉熵可以衡量这两个分布之间的距离。
模型训练时,通过最小化交叉熵损失函数,我们可以使模型预测值的概率分布逐步接近真实的概率分布。

信息熵
熵 是关于不确定性的数学描述。
信息的大小跟随机事件的概率有关。越小概率的事情发生了产生的信息量越大。
所以 信息的量度应该依赖于概率分布 p ( x ) p(x)p(x)
单调性:发生概率越高的事件,其携带的信息量越低;
非负性:信息熵可以看作为一种广度量,非负性是一种合理的必然;
累加性:即多随机事件同时发生存在的总不确定性的量度是可以表示为各事件不确定性的量度的和,这也是广度量的一种体现

在这里插入图片描述
在这里插入图片描述

使用python sklearn库实现

from sklearn.metrics import log_loss 
y_true = [[0, 0, 1], [0, 1, 0], [1, 0, 0]] 
y_pred_1 = [[0.3, 0.3, 0.4], [0.3, 0.4, 0.3], [0.1, 0.2, 0.7]] 
y_pred_2 = [[0.1, 0.2, 0.7], [0.1, 0.7, 0.2], [0.3, 0.4, 0.3]] 
print(log_loss(y_true, y_pred_1)) 
print(log_loss(y_true, y_pred_2)) 
____________ 
1.3783888522474517 
0.6391075640678003 

1.5、二分类

在这里插入图片描述
在这里插入图片描述

回归问题通常用均方差损失函数,可以保证损失函数是个凸函数,即可以得到最优解。而分类问题如果用均方差的话,损失函数的表现不是凸函数,就很难得到最优解。而交叉熵函数可以保证区间内单调。

2、什么是梯度下降法?

参考

梯度下降法在机器学习中常常用来优化损失函数,是一个非常重要的工具。说白了,就是在高中学习过的「极值」的概念,那么什么是极值呢?用非常形象的方式来说极值点,梯度下降法的作用就是寻找一个 「极小值点」,从而让函数的值尽可能地小。相信你也发现了,这么多个极值点,那么梯度下降法找到的是哪一个点呢?关于这个问题就要看运气了,算法的最开始会 「随机」 寻找一个位置然后开始搜索「局部」的最优解,如果运气好的话能够寻找到一个最小值的极值点,运气不好或许找到的就不是最小值的那个极小值点了。

比如说下边儿的这个函数图像 , 可以发现在当前这个区间范围内这个函数有两个极小值点,如果我们想寻找当前函数在这个区间内的最小值点,那么当然是第二个极小值点更合适一些,可是并不一定能够如我们所愿顺利地找到第二个极小值点,这时候只能够通过多次尝试。
在这里插入图片描述

2.2、梯度下降法的运行过程

接下来用一个简单的一元二次函数来阐述算法的运行过程。一元二次函数天生地就只有一个极小值点,方便用来说明问题。函数的表达式是 f(x) = (x-1)^2+1,通过表达式可以知道它的最小值点的坐标是(1,1)
在这里插入图片描述
「梯度的概念」:梯度就是函数对它的各个自变量求偏导后,由偏导数组成的一个向量。
「梯度下降法」:每个自变量方向上的偏导数构成的向量,函数f(x) 的梯度就是f(x)’ = 2x-2

图上有一个小红点,它的坐标是 (6, f(6)), 也就是当x=6的时候二次函数曲线上边儿的点。那么也可以得到f(6)'=10,也就是这个红点的位置的导数是等于10的。现在用导数值的 「正负来表示方向」,如果导数的值是正数,那么就代表x轴的正方向。如果导数的值是负数,那么就代表是x轴的负方向。那么你会发现,知道了这个方向之后也就知道了应该让f(x)往哪个方向变化x的值才会增大。如果想要让f(x)的值减小,那么就让x朝着导数告诉我们的方向的反方向变化就好啦。接下来我用「黄色」箭头代表导数告诉我们的方向,用「黑色」箭头代表导数指向的方向的反方向。
在这里插入图片描述

「梯度下降法」 的目标是搜索出来一个能够让函数的值尽可能小的位置,所以应该让x朝着黑色箭头的方向走,那么怎么完成这个操作呢?
在代码中有一个eta变量,它的专业名词叫 「学习率」。使用数学表达式来表示更新
在这里插入图片描述

表达式的意思就是让x减去eta乘以函数的导数。其中的eta是为了控制x更新的幅度,将eta设置的小一点,那么每一次更新的幅度就会小一点。

初始化

# 变量 x 表示当前所在的位置
# 就随机初始在 6 这个位置好了
x = 6
eta = 0.05

第一次更新x

# 变量 df 存储当前位置的导数值
df = 2*x-2
x = x - eta*df
# 更新后 x 由 6 变成 
# 6-0.05*10 = 5.5

第二次更新x

# 因为 x 的位置发生了变化
# 要重新计算当前位置的导数
df = 2*x-2
x = x - eta*df
# 更新后 x 由 5.5 变成
# 5.5-0.05*9 = 5.05

接下来只要重复刚才的过程就可以了。那么,把刚才的两个步骤画出来。其中红色箭头指向的点是初始时候的点,「黄色」 箭头指向的点是第 1 次更新后的位置,「蓝色」 箭头指向的点是第 2 次更新后的位置。
在这里插入图片描述

只要不停地重复刚才的这个过程,那么最终就会收敛到一个 「局部」 的极值点。那么可以想一下这是为什么?因为随着每一次的更新,曲线都会越来越平缓,相应的导数值也会越来越小,当我们接近极值点的时候导数的值会无限地靠近 「0」。由于导数的绝对值越来越小,那么随后更新的幅度也会越来越小,最终就会停留在极值点的位置了。接下来我将用 Python 来实现这个过程,并让刚才的步骤迭代 1000 次。

import matplotlib.pyplot as plt
import numpy as np

# 初始算法开始之前的坐标 
# cur_x 和 cur_y 
cur_x = 6
cur_y = (cur_x-1)**2 + 1
# 设置学习率 eta 为 0.05
eta = 0.05
# 变量 iter 用于存储迭代次数
# 这次我们迭代 1000 次
# 所以给它赋值 1000
iter = 1000
# 变量 cur_df 用于存储
# 当前位置的导数
# 一开始我给它赋值为 None
# 每一轮循环的时候为它更新值
cur_df = None

# all_x 用于存储
# 算法进行时所有点的横坐标
all_x = []
# all_y 用于存储
# 算法进行时所有点的纵坐标
all_y = []

# 把最一开始的坐标存储到
# all_x 和 all_y 中
all_x.append(cur_x)
all_y.append(cur_y)

# 循环结束也就意味着算法的结束
for i in range(iter):
    # 每一次迭代之前先计算
    # 当前位置的梯度 cur_df
    # cur 是英文单词 current 
    cur_df = 2*cur_x - 2
    # 更新 cur_x 到下一个位置
    cur_x = cur_x - eta*cur_df
    # 更新下一个 cur_x 对应的 cur_y
    cur_y = (cur_x-1)**2 + 1

    # 其实 cur_y 并没有起到实际的计算作用
    # 在这里计算 cur_y 只是为了将每一次的
    # 点的坐标存储到 all_x 和 all_y 中
    # all_x 存储了二维平面上所有点的横坐标
    # all_y 存储了二维平面上所欲点的纵坐标
    # 使用 list 的 append 方法添加元素
    all_x.append(cur_x)
    all_y.append(cur_y)

# 这里的 x, y 值为了绘制二次函数
# 的那根曲线用的,和算法没有关系
# linspace 将会从区间 [-5, 7] 中
# 等距离分割出 100 个点并返回一个
# np.array 类型的对象给 x
x = np.linspace(-5, 7, 100)
# 计算出 x 中每一个横坐标对应的纵坐标
y = (x-1)**2 + 1
# plot 函数会把传入的 x, y
# 组成的每一个点依次连接成一个平滑的曲线
# 这样就是我们看到的二次函数的曲线了
plt.plot(x, y)
# axis 函数用来指定坐标系的横轴纵轴的范围
# 这样就表示了 
# 横轴为 [-7, 9]
# 纵轴为 [0, 50]
plt.axis([-7, 9, 0, 50])
# scatter 函数是用来绘制散点图的
# scatter 和 plot 函数不同
# scatter 并不会将每个点依次连接
# 而是直接将它们以点的形式绘制出来
plt.scatter(np.array(all_x), np.array(all_y), color='red')
plt.show()

在这里插入图片描述

通过最终的效果图可以发现,「梯度下降」在一步一步地收敛到二次函数的极小值点,同时也是二次函数的最小值点。仔细看一下,可以发现,就像刚才提到的,随着算法的运行,红色的点和点之间的距离越来越小了,也就是每一次更新的幅度越来越小了。这可不关「学习率」 的事儿,因为在这里用到的学习率是一个固定的值,这只是因为随着接近极值点的过程中「导数的绝对值」越来越小。
同时,我们还可以把 all_x 和 all_y 中最后一个位置存储的值打印出来看一下,如果没错的话,all_x 的最后一个位置的值是接近于 1 的,同时对应的函数值 all_y 最后一个位置的值也应该是接近于 1 的。

2.3、二元函数的梯度下降

在多元函数中拥有多个自变量,在这里就用二元的函数来举例子好了,二元函数的图像是在三维坐标系中的。下边儿就是一个多元函数的图像例子。在这里纵坐标我使用 「y」 表示,底面的两个坐标分别使用 「x1」 和 「x2」 来表示。不用太纠结底面到底哪个轴是 x1 哪个轴是 x2,这里只是为了说明问题。

在这里插入图片描述
通过上面的二元函数凸显可以发现想要让 「y」 的值尽可能地小就要寻找极值点(同样地这个极值点也不一定是最小值点)。梯度是一个由各个自变量的偏导数所组成的一个「向量」。用数学表达式看上去就是下边儿这样的。
在这里插入图片描述
同样地,一开始随机初始一个位置,随后让当前位置的坐标值减去学习率乘以当前位置的偏导数。更新自变量 x1 就让 x1 减去学习率乘以 y 对 x1 的偏导数,更新自变量 x2 就让 x2 减去学习率乘以 y 对 x2 的偏导数。

发挥想象力再思考一下,既然梯度是一个向量,那么它就代表了一个方向。x1 的偏导数的反方向告诉我们在 x1 这个轴上朝向哪个方向变化可以使得函数值 y 减小,x2 的偏导数的反方向告诉我们在 x2 这个轴上朝向哪个方向变化可以使得函数值 y 减小。那么,由偏导数组成的向量就告诉了我们在底面朝向哪个方向走就可以使得函数 y 减小。
在下边儿的图中,我使用红色箭头表示我们当前所在的位置,随后我使用 「黑色」 箭头代表其中一个轴上的坐标朝向哪个方向变化可以使得函数值 「y」 减小,并使用 「黄色」 箭头代表另外一个轴上的坐标朝向哪个方向变化可以使得函数值 「y」 减小。根据平行色变形法则有了 「黑色」 向量和 「黄色」 向量,就可以知道这两个向量最终达到的效果就是 「蓝色」 向量所达到的效果。
在这里插入图片描述
最终,对于二元函数梯度下降的理解,就可以理解为求出的梯度的反方向在底面给了我们一个方向,只要朝着这个方向变化底面的坐标就可以使得函数值 「y」 变小。

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1583067.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

动力与智能的碰撞:高效控制下的Profinet与EtherCAT逆变器融合

在实施工业自动化解决方案时,特别是当涉及到西门子S7-1200/1500系列PLC的集成时,一个常见的问题就是是确保不同通信协议之间的兼容性。在这种情况下,我们面临的是将这些PLC与支持EtherCAT通信功能的逆变器设备相连接的需求。西门子PLC通常利用…

浏览器密码框明文密文兼容edge的问题

在网页中注册会员的时候,经常需要输入用户名(账号)和密码,在输入密码的时候,为了防止用户输错密码,经常会给密码框加一个小功能,就是点击密码框右侧闭着的小眼睛,可以让密文变成明文…

蓝桥杯真题 买不到的数目 结论题 数论

👨‍🏫 题目地址 👨‍🏫 数论:pxpy 不能表示的最大数为pq-p-q的证明 最大能表示的数为: p q − p − q ( p − 1 ) ( q − 1 ) pq-p-q(p-1)(q-1) pq−p−q(p−1)(q−1) 则最大不能表示的数为 ( p − …

Tubi 十岁啦!

Tubi 今年十岁了,这十年不可思议,充满奇迹! 从硅谷一个名不见经传的创业小作坊,转变成为四分之一美国电视家庭提供免费流媒体服务的北美领先的平台; 从费尽心力终于签下第一笔内容合作协议,到现在与 450 …

实验:基于Red Hat Enterprise Linux系统建立RAID磁盘阵列

目录 一. 实验目的 二. 实验内容 三. 实验设计描述及实验结果 什么是磁盘阵列(RAID) 1. 为虚拟机添加4块大小为20G的硬盘nvme0n【2-5】,将nvme0n【2、3、4】三块硬盘 建立为raid5并永久挂载,将RAID盘全部空间制作逻辑卷&#…

软件开发自媒体获客避坑:啥都能干=啥都抓不着=啥都干不了。

我就结合我的经验谈一下粗浅的看法,权当抛砖引玉了。 一、自媒体流量本质是一种泛流量 自媒体流量通常指的是通过自媒体平台(如微信公众号、微博、知乎等)获取的泛流量。泛流量是指广泛的、来自不同渠道的流量,包括通过搜索引擎…

面试官为什么喜欢考察Vue底层原理

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

C语言面试题之检查二叉树平衡性

检查二叉树平衡性 实例要求 1、实现一个函数,检查二叉树是否平衡;2、在这个问题中,平衡树的定义如下:任意一个节点,其两棵子树的高度差不超过 1; 示例 1: 给定二叉树 [3,9,20,null,null,15,7]3/ \9 20/…

OSPF数据报文格式

OSPF协议是跨层封装的协议,跨四层封装,直接将应用层的数据封装在网络层协议后面,IP协议包中协议号字段对应的数值为——89 OSPF的头部信息: ——所有数据包公有的信息 版本:OSPF版本 在IPV4中一般使用OSPFV2&#xf…

【3GPP】【核心网】核心网/蜂窝网络重点知识面试题二(超详细)

1. 欢迎大家订阅和关注,3GPP通信协议精讲(2G/3G/4G/5G/IMS)知识点,专栏会持续更新中.....敬请期待! 目录 1. 对于主要的LTE核心网接口,给出运行在该接口上数据的协议栈,并给出协议特征 2. 通常…

ShardingSphere-ShardingSphere读写分离和数据脱敏

文章目录 一、读写分离1.1 读写分离1.2 读写分离应用方案1.3 分表+读写分离1.4 分库分表+读写分离二、ShardingSphere-JDBC读写分离2.1 创建SpringBoot并添加依赖2.2 创建实体类2.3 创建mapper2.4 配置读写分离2.5 测试测试插入数据测试读测试事务一致性测试负载均衡一、读写分…

BP实战之猫狗分类数据集

目录 补充知识 python类里面的魔法方法 transforms.Resize() python里面的OS库 BP实战之猫狗分类数据集 猫狗数据集 注意事项 使用类创建自己的猫狗分类数据集 代码 实例化对象尝试 代码 结果 利用DataLoader加载数据集 BP神经网络的搭建以及对象的使用 运行结果…

【PyQt5篇】使用QtDesigner添加控件和槽

文章目录 &#x1f354;使用QtDesigner进行设计&#x1f6f8;在代码中添加信号和槽 &#x1f354;使用QtDesigner进行设计 我们首先使用QtDesigner设计界面 得到代码login.ui <?xml version"1.0" encoding"UTF-8"?> <ui version"4.0&q…

【Qt】:窗口

窗口 一.概述二.菜单栏1.一个简单的菜单2.添加快捷键3.嵌套子菜单4.添加下划线5.添加图标 三.工具栏1.创建一个简单的工具栏2.设置工具栏的停靠位置 四.状态栏五.浮动窗口 一.概述 Qt窗口是通过QMainWindow类来实现的。 QMainWindow是一个为用户提供主窗口程序的类&#xff0c…

经典本地影音播放器纯净无广告版

MPC-BE&#xff08;Media Player Classic Black Edition&#xff09;是来自 MPC-HC&#xff08;Media Player Classic Home Cinema&#xff09;的俄罗斯开发者重新编译优化后的一款经免费的经典全能影音播放器&#xff0c;纯净无广告&#xff0c;启动速度快&#xff0c;占用消耗…

ES7-10:async和await、异步迭代..

1-ES7新特性 indexof如果没有就返回-1&#xff0c;有就返回索引 如果仅仅查找数据是否在数组中,建议使用includes,如果是查找数据的索引位置,建议使用indexOf更好一些 2-ES8-async和await 所有的需要异步处理的Promise对象都写在async中await等待结果 async、await 使异步操…

HEC-HMS水文模型

HEC-HMS是美国陆军工程兵团水文工程中心开发的一款水文模型。HMS能够模拟各种类型的降雨事件对流域水文&#xff0c;河道水动力以及水利设施的影响&#xff0c;在世界范围内得到了广泛的应用。它有着完善的前后处理软件&#xff0c;能有效减轻建模的负担&#xff1b;能够与HEC开…

数据挖掘实战-基于机器学习的垃圾邮件检测模型(文末送书)

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

30.WEB渗透测试-数据传输与加解密(4)

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a; 易锦网校会员专享课 上一个内容&#xff1a;29.WEB渗透测试-数据传输与加解密&#xff08;3&#xff09;-CSDN博客 加解密需要用到的源…

【linux】sudo 与 su/su -之间的区别

一、区别 二、其他 大概是因为使用 su 命令或直接以 root 用户身份登录有风险&#xff0c;所以&#xff0c;一些 Linux 发行版&#xff08;如 Ubuntu&#xff09;默认禁用 root 用户帐户。鼓励用户在需要 root 权限时使用 sudo 命令。 然而&#xff0c;您还是可以成功执行 su…