【PyTorch】第九节:Softmax 函数与交叉熵函数

news2025/1/10 16:51:48

作者🕵️‍♂️:让机器理解语言か

专栏🎇:PyTorch

描述🎨:PyTorch 是一个基于 Torch 的 Python 开源机器学习库。

寄语💓:🐾没有白走的路,每一步都算数!🐾 

介绍💬

        本实验主要讲解了分类问题中的二分类问题和多分类问题之间的区别,以及每种问题下的交叉熵损失的定义方法。由于多分类问题的输出为属于每个类别的概率,要求概率和为 1 。因此,我们还介绍了如何利用 Softmax 函数,处理神经网络的输出,使其满足损失函数的格式要求。

本文参考蓝桥云课:PyTorch 基础入门实战_机器学习 - 蓝桥云课 (lanqiao.cn) 

知识点🍀

  • 👉二分类和多分类

  • 👉交叉熵损失

  • 👉PyTorch 中的 Softmax 和交叉熵


二分类问题和多分类问题

        二分类问题:表示分类任务有两个类别。比如我们想要识别一副图是否是猫,我们一般会训练出一个分类器,输入一副图片(用向量 x 表示),输出该图片是猫的概率 p。我们可以使用 max 函数判断 p 和 0.5 的大小。如果 p 大,则输出 1(表示该图像为猫)。如果 0.5 大,则输出 0 (表示该图像不为猫)。这就是二分类问题,即输出只为 0 或 1 的分类问题。

        多分类问题:表示分类任务有多个类别。比如我们需要建立一个分类器,用于分辨一堆水果图片中,哪些是橘子、哪些是苹果还有哪些是香蕉。

        在二分类问题中,我们可以使用 max(代码描述为 if a > b return a; else b) 来判断结果,就是非黑即白。但是在多分类问题中,我们就不能这样做。我们需要引入 Softmax 的概念

🌐Softmax

        在机器学习尤其是深度学习中,Softmax 是个非常常用的函数,尤其在多分类的场景中使用广泛。Softmax把输入映射为 0-1 之间的实数,并且通过归一化保证和为 1

        在多分类问题中,我们需要分类器输出每种分类的概率,且为了能够比较概率之间的大小,我们还希望概率之和能够为 1。因此,我们就需要使用 Softmax 函数。 特别是在利用神经网络解决多分类问题时,我们一般都会将输出的最后一层,加上 Softmax 函数,用于规则化输出

假设一个数组为 V,v_i表示 V 中的第 i 个元素,那么这个元素经历了 Softmax 函数后的输出为:

S_i = \frac{e^{v_i}}{\sum_{i=0}^Ne^{v_i}}

        这个定义其实很简单,也就是对输入的值进行了指数化,我认为这里进行指数化的目的是为了扩大任意两个输入之间的差距。将指数化后的值除以总的值,目的是将所有的值缩放到 0-1 之间,并保证所有的输出值相加的和为 1 。

我们可以先利用 NumPy 对其进行实现:

import numpy as np


def softmax(x):
    return np.exp(x) / np.sum(np.exp(x), axis=0)    # 按行相加


x = np.array([2.0, 1.0, 0.1])
outputs = softmax(x)

print('numpy 版 softmax 的输入 :', x)
print('numpy 版 softmax 的输出 :', outputs)
print('numpy 版 softmax 的输出之和:', outputs.sum())

        当然,我们也可以使用 PyTorch 中自带的 Softmax 函数,完成数据的处理:

import torch
import torch.nn as nn
x = torch.tensor([2.0, 1.0, 0.1])
outputs = torch.softmax(x, dim=0)  # dim=0,表示处理的是第 1 维的数据
print('torch 版 softmax 的输入 :', x)
print('torch 版 softmax 的输出 :', outputs)
print('torch 版 softmax 的输出之和:', outputs.sum())

🌐损失函数

        损失函数反映的是预测结果和实际结果之间的差距,即从预测结果到实际结果需要走的距离,即所需要消耗的成本,故称之为损失函数。

        这里让我们介绍一种常用的损失函数:交叉熵损失。那么为什么我们需要使用交叉熵损失作为我们的损失函数呢?为什么我们不直接用错误率来作为损失函数,用梯度下降算法得到错误率最低时的模型呢?为了能够更好的阐述上面这个问题,让我们来举个例子。

葡萄酒的种类预测

        我们希望通过葡萄酒的酒精浓度 、苹果酸浓度 、灰分浓度等独立特征,来预测该葡萄酒源产地。假设数据集中有三种源产地:英国、法国和美国。

        这里我们建立了两个模型用以预测葡萄酒的种类。每个模型都会输出三个值,即输入的葡萄酒来源于英国、法国和美国的概率。

        这里我们对两个模型输入了三条相同的数据,得到的结果如下:

模型 1 :

        从结果可以看出,模型 1 对于样本 1 和样本 2 以非常微弱的优势(概率只比其他结果高 0.1)判断正确,对于样本 3 的判断则彻底错误。

模型 2 :

        模型 2 对于样本 1 和样本 2 的判断非常准确(概率比其他结果高很多)。模型 2 对于样本 3 的判断错误,但是相对来说没有错得太离谱(概率只比其他结果高 0.1)。好了,有了模型之后,我们需要通过定义损失函数来判断模型在样本上的表现,那么我们可以定义哪些损失函数呢?

        如果使用简单的分类错误率作为损失函数,那么两个模型的分类错误率为:

模型 1:

classification\_error = \frac{1}{3}

模型 2:

classification\_error = \frac{1}{3}

        从结果可以看出,如果使用分类错误率来衡量两个模型的好坏,那么这两个模型的好坏程度相同。我们从上面的结果可以看出,虽然模型 1 和模型 2 都预测错了 1 个,但是相对来说,模型 2 的预测效果更好,损失函数照理来说应该更小。因此,我们使用分类错误率不能很好的描述模型的优劣。

        为此,我们引入了交叉熵损失函数用以描述模型的优劣。

🌐交叉熵损失函数

交叉熵损失函数有两种形式:二分类形式多分类形式

  • 🍒二分类的任务 

在 二分类的任务 中,模型最后需要预测的结果只有两种情况,对于每个类别,我们的预测得到的概率为 p 和 1−p 。此时的交叉熵损失(又叫二进制交叉熵)为:

L = \frac{1}{N}\sum_iL_i=\frac{1}{N}\sum_i-[y_i\cdot log(p_i)+(1-y_i)\cdot log(1-p_i)]

其中:

  • y_i:表示样本 i 的真实标签值,正类为 1,负类为 0。
  • p_i:表示样本 i 的预测为正的概率。

        当然,我们不必手动实现上面的损失函数。我们可以利用 nn.BCELoss() 定义二值交叉熵损失。如下:

loss = nn.BCELoss()
# 假设两个模型最后的预测结果相同,但是概率不同
model_one_pred = torch.tensor([0.9, 0.6])
model_two_pred = torch.tensor([0.6, 0.9])
# 真实结果
target = torch.FloatTensor([1, 0])

# 计算两种模型的损失
l1 = loss(model_one_pred, target)
l2 = loss(model_two_pred, target)
l1, l2
# (tensor(0.5108), tensor(1.4067))

        从上面代码中可以看出,其实模型 1 和模型 2 都预测对了一条数据,预测错了一条数据。但是,由于模型 1 是在概率差距很大的情况下,预测正确的。因此,模型 1 的预测效果比模型 2 好,即模型 1 的损失应当比模型 2 的损失小。

  • 🍒 多分类任务 

        在 多分类任务 中,交叉熵损失的函数形式会发生一定的改变(其实就是二进制交叉熵损失的扩展):

L=\frac{1}{M}\sum_iL_i=-\frac{1}{M}\sum_{c=1}^My_{ic}log(p_{ic})

其中:

  • M: 类别的数量。
  • y_{ic}​: 标签的 one-hot 编码,如果该类别和样本 i 的类别相同,就是 1,否则为 0。
  • p_{ic}​: 对于观察样本 i 属于类别 c 的预测概率。

        利用 PyTorch 中的 nn.CrossEntropyLoss() 定义多分类任务的交叉熵损失函数。但是,实际上 nn.CrossEntropyLoss() 是包含了 nn.LogSoftmax()  nn.NLLLoss() 。 因为这里已经是概率值了,所以我们使用 nn.NLLLoss() 来计算交叉熵损失。

        接下来让我们利用交叉熵损失来评估一下上面建立的两个葡萄酒预测模型的好坏:

loss = nn.NLLLoss()
# 三条数据的真实结果:法国、美国、英国
Y = torch.tensor([2, 1, 0])

# 模型一对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
model_one_pred = torch.tensor(
    [[0.3, 0.3, 0.4],  # predict class 2
     [0.3, 0.4, 0.3],  # predict class 1
     [0.1, 0.2, 0.7]])  # predict class 0

# 模型 2 对每条数据的预测,每条数据对应三个概率,表示该条数据属于第 i 类的概率值
model_two_pred = torch.tensor(
    [[0.1, 0.2, 0.7],  # predict class 2
     [0.1, 0.7, 0.2],  # predict class 1
     [0.4, 0.3, 0.3]])  # predict class 0
l1 = loss(torch.log(model_one_pred), Y)
l2 = loss(torch.log(model_two_pred), Y)
l1, l2
# (tensor(1.3784), tensor(0.5432))

        从上面的损失大小可以看出,通过交叉熵损失对模型进行评估的话,模型 2 的效果要优于模型 1 的效果,这是符合我们的客观想法的。

        综上,这就是我们为什么使用交叉熵损失的原因。交叉熵损失函数除了考虑模型的准确率之外,还将模型的鲁棒性等因素考虑了进去,能够更好的评价模型的好坏,使训练出来的模型具有更加稳定的预测准确率。

        从上面的输入可以看出,交叉熵损失需要的的输入每条数据所属种类的概率,且这些概率之和为 1。说到这里,我想你已经联想到了本实验开始学到的 Softmax 函数了吧。

        实际上,我们一般无法严格的将神经网络的输出控制在 0 - 1 之间,更无法使这些值之和等于 1。因此我们一般会在神经网络的最后一层,加上 softmax 函数,得到每种种类的概率值。然后将概率值放入交叉熵损失函数之中,得到预测结果和真实结果之间的距离

实验总结📌

        本实验以葡萄酒的类别为预测的例,以一种简单的方式,理解了交叉熵损失与其他损失的不同,以及引入交叉熵损失的原因。当然,从模型的训练角度来讲,引入交叉熵损失函数也是为了能够加快模型的收敛速度。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/424274.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

低延迟流式语音识别技术在人机语音交互场景中的实践

美团语音交互部针对交互场景下的低延迟语音识别需求,提出了一种全新的低出字延迟流式语音识别方案。本方法将降低延迟问题转换成一个知识蒸馏过程,极大地简化了延迟优化的难度,仅通过一个正则项损失函数就使得模型在训练过程中自动降低出字延…

靶机精讲之Holynix

找不到ip 就设置两个网络适配器 再添加一个NAT 主机发现 nmap扫描 端口扫描 UDP扫描 服务扫描 脚本扫描 拒绝服务攻击 sql注入 枚举 web渗透 sql注入 证明有注入 sql注入语句 语句 ‘ or 11 --(空格) 目录结构像有文件包含 有报错但无法利用 调用系统…

从零开始学架构-计算高性能

一、概述 高性能是每个程序员的追求,无论做一个系统、还是写一组代码,都希望能够达到高性能的效果。而高性能又是最复杂的一环,磁盘、操作系统、CPU、内存、缓存、网络、编程语言、数据库、架构等,每个都可能影响系统的高性能&…

ChatGPT API接口使用+fine tune微调+prompt介绍

目录1 接口调用1.1 生成key1.2 接口功能1.2.1 图片生成 (image generation)1.2.2 对话(chat)1.2.3 中文纠错 (Chinese Spelling Correct)1.2.4 关键词提取 (keyword extract)1.2.5 抽取文本向量 (Embedding)1.2.6 微调 (fine tune)2 如何写好prompt2.1分类任务2.2 归…

工业智能网关应用场景:高层楼宇智慧消防解决方案

随着城市化建设的飞速发展,人员聚集与土地资源稀缺的矛盾越来越明显。为了让有限的空间满足更多人的居住需求,高层楼宇越来越多,对于安全消防形成更大的挑战。 基于物联网和云计算平台的智慧消防在消防管理、火灾报警和实时监管方面发挥越来…

java内部类入门(接口)

我有一个玩具狗,有一个接口用于启动它,按照传统方法就是写一个类并实现该接口,且该类只使用一次(在启动时使用,后面再不使用) 但是如果我有一堆玩具,我每个玩具都要去写一个类来实现start这个接…

GPT-3.5还没研究明白,GPT-4又来了,chatGPT会进化成什么样?

基于GPT-3.5的chatGPT热度才稍稍减退没多久,GPT-4又来了,文新一言的发布会也槽点满满,差距似乎越来越大了。 chatGPT到底厉害在哪?为什么突然就爆火了呢? 它的爆火,一方面,和它的出现形态有关…

代码随想录第18天 | 530.二叉搜索树的最小绝对差 501.二叉搜索树中的众数 236. 二叉树的最近公共祖先

530.二叉搜索树的最小绝对差 var getMinimumDifference function (root) {//中序遍历法:左中右let res []if (!root) return res;const st [root] //栈,pop(),push()while (st.length) {let x st.pop()if (!x) {res.push(st.pop().val)continue}if (…

Linux环境下搭建composer私服及memory_limit问题

Composer是 PHP项目中用来管理依赖(dependency)关系的工具,允许声明项目所依赖的代码库 ,然后在项目的某个目录中(默认是vendor目录) 中安装相关的依赖包。 在介绍如何安装私服之前,我们先熟悉下 composer 相关 compo…

对话框与子窗口控件(写给大忙人看的快速复习掌握)

对话框与子窗口控件(写给大忙人看的快速复习掌握)1、对话框的概念2、控件的概念我更喜欢称控件为预定义的窗口类3、我们一步一步写代码熟悉常用的预定义的窗口类3.1 什么叫模板呢?3.2 什么是资源文件4、消息处理函数(有这么几个消…

护眼灯哪些牌子好?2023护眼灯品牌推荐

护眼灯就是保护眼睛的,很多人长时间工作和学习,主要还是光的刺激和错误的坐姿,会引起眼睛的近视,导致视觉疲劳的主要原因就是灯光的频闪,而护眼灯就能很好减少频闪。 特别是青少年们的视力发育为成熟,视力…

使用Sentieon加速甲基化WGBS数据分析

全基因组甲基化测序(WGBS)是一种研究DNA甲基化的方法,以全面了解在基因组水平上的表观遗传变化。在进行WGBS数据分析时,通常需要使用专门的比对工具,因为这些工具需要能够处理亚硫酸盐转化后的数据。 以下是四个不同的WGBS比对分析流程&…

ADIDAS阿里纳斯励志广告语

系列文章目录 精选优美英文短文1——Dear Basketball(亲爱的篮球)精选优美英文短文2——Here’s to the Crazy Ones(致疯狂的人)“我祝你不幸并痛苦”——约翰罗伯茨毕业致辞“亲爱的波特兰——CJ麦科勒姆告别信” Hi, I’m Gilb…

七、Django进阶:第三方库Django-extensions的开发使用技巧详解(附源码)

Django-extensions是 Django 的扩展应用,给django开发者提供了许多便捷的扩展工具(extensions),它提供了许多有用的工具和命令行工具,帮助 Django 开发者更高效地进行开发和调试。它的作用包括: - 提供了更多的Django命令&#x…

循环依赖详解及解决方案

介绍 上图就是循环依赖的三种情况,虽然方式不同,但是循环依赖的本质是一样的,就A的完整创建要依赖与B,B的完整创建要依赖于A,相互依赖导致没办法完整创建造成失败. 循环依赖代码演示 public class Demo {public static void main(String[] args) {new Demo1();} }class Demo1…

电子信息工程有哪些SCI期刊推荐? - 易智编译EaseEditing

以下是电子信息工程领域的一些SCI期刊推荐: IEEE Transactions on Information Theory: 该期刊由IEEE出版,专注于信息理论领域的研究,包括编码理论、信道编码、信息传输、信息论应用等方面的研究。 IEEE Transactions on Signal…

Apache网页与安全优化

系列文章目录 文章目录系列文章目录一、1.构建虚拟web主机2.一、基于域名的虚拟主机二、Apache 日志分割1.三、Apache的网页优化总结一、 1.构建虚拟web主机 虚拟Web主机指的是在同一台服务器中运行多个Web站点,其中每一个站点实际上并不独立占用整个服务器&#…

天选姬 - 桌面宠物

天选姬 - 桌面宠物前言下载使用更新设置右键菜单人机交互系统状态闹钟壁纸前言 桌面宠物顾名思义指在电脑桌面的宠物,可以是各种动物或Q版人物。可以进行交互并拥有各种各样的功能,本文介绍一款适用于各种电脑的桌面宠物,天选姬,…

【Redis-面试题及持久化方案】Redis相关面试题(缓存穿透、缓存击穿、缓存血崩) Redis两种持久化方案详情对比(RDB、AOF)

【Redis-面试题及持久化方案】Redis相关面试题(缓存穿透、缓存击穿、缓存血崩) & Redis两种持久化方案详情对比(RDB、AOF)1)Redis 面试题1.1.高频面试题:缓存穿透、缓存击穿、缓存血崩1.2.低频面试题&a…

电脑0X000000D1蓝屏错误U盘重新安装系统教学

电脑0X000000D1蓝屏错误U盘重新安装系统教学。最近有用户遇到了电脑桌面变成了0X000000D1错误代码的蓝屏界面了,无法继续操作使用。那么这个问题怎么去进行系统U盘重装呢?来看看以下的详细解决方法吧。 准备工作: 1、U盘一个(尽量…