多分类交叉熵理解

news2025/1/16 17:54:47

多分类交叉熵有多种不同的表示形式,如下图所示:

 但是,有时候我们读论文会深陷其中不能自拔。

也有很多读者、观众会纠正其他作者的文章、视频的交叉熵形式。

实际上,上述三种形式都是没有问题的。

这里,我们就要了解交叉熵的本质。交叉熵,可以用一句话来形容:

正确标签的自然对数。

 我们看个例子:

具体情况如下:

——三幅图像,其分别对应着数字3、数字6、数字8

如果将其标识位one-hot形式,则其对应的值为:

也就是对应位置上的值为1,其余位置为0. 

——三幅图像的识别结果为y_pred如下图所示:

计算交叉熵,就是计算 【正确标签的自然对数】

具体来说,计算的是

-t*log(y_pred)

 当然,根据需要后续需要计算和和均值。

所以,我们通常看到的表达式是:

或者:

 

都是可以的,因为代入数据,计算结果是一致的。

下面,我们分别通过三种不同的方式计算交叉熵损失函数

  • 方式0:手动计算
  • 方式1:调用sklearn函数计算
  • 方式2:自定义函数计算 

全部代码如下:

# -*- coding: utf-8 -*-
"""
Created on Wed Nov  9 12:22:39 2022

@author: Administrator
"""
import numpy as np
from sklearn.metrics import log_loss
from sklearn.preprocessing import LabelBinarizer
from math import log


# 三个样本对应的标签
# 这里指三幅图像分别对应着数值3、数字6、数字8
y_true = ['3', '6', '8']  
# 预测值,这里使用softmax处理过了。所有概率和为0
# 例如:第一行对应着第一张图像的识别结果:0.1+0.3+0.6+一堆0=1
# 第2行:0.2+0.5+0.3+一堆0=1
# 第3行:0.3+0.5+0.2+一堆0=1
y_pred = [[0.1, 0, 0.3, 0.6, 0, 0, 0, 0, 0, 0],
          [0, 0, 0.2, 0, 0, 0, 0.5, 0, 0.3, 0],
          [0, 0.3, 0, 0, 0.5, 0, 0, 0, 0.2, 0]
          ]               
# 标签对应的值(识别数字)
labels = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']  


# =============================================================================
# 方法0:手算
# =============================================================================
# 真实值one-hot编码
t = [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
     [0, 0, 0, 0, 0, 0, 0, 0, 1, 0]]
y = np.array(y_pred)               # 样本预测值
# 面临问题:log(0)为负无穷大,导致后续无法计算。
# 解决方案:将log(0)处理为一个log(接近零)
y[y < 1e-15] = 1e-15
Loss = 0
for i in range(3):  # 逐个遍历样本
    for k in range(10):  # 逐个遍历标签
        # 计算对应位置上的真实值*log(预测值)
        Loss -= t[i][k] * log(y[i][k])
Loss /= 3
print("自定义的交叉熵:", Loss)



# =============================================================================
# 不用循环的可以改进计算:
# a = -np.multiply(t, np.log(y))
# print("自定义的交叉熵:",sum(map(sum, a))/3)
# =============================================================================

# =============================================================================
# 逐个处理也可以
#
# for i in range(3):  # 逐个遍历样本
#     for k in range(10):  # 逐个遍历标签
#         delta = 1e-15      # 控制值
#         if y[i][k] < delta:
#             y[i][k] = delta
# =============================================================================


# =============================================================================
# 方法1:使用sklearn计算
# =============================================================================
# 说明:
# 传递给sklearn的y_true = ['3', '6', '8']
# 会根据参数【labels】被识别为one-hot形式。
# [[0 0 0 1 0 0 0 0 0 0]
#  [0 0 0 0 0 0 1 0 0 0]
#  [0 0 0 0 0 0 0 0 1 0]]

sk_log_loss = log_loss(y_true, y_pred, labels=labels)
print("sklearn交叉熵:", sk_log_loss)


# =============================================================================
# 方法2:自定义函数形式
# =============================================================================
def lilizong():
    # 将样本标签处理为one-hot形式
    lb = LabelBinarizer()
    lb.fit(labels)
    transformed_labels = lb.transform(y_true)
    # transformed_labels值为:
    # [[0 0 0 1 0 0 0 0 0 0]
    #  [0 0 0 0 0 0 1 0 0 0]
    #  [0 0 0 0 0 0 0 0 1 0]]
    # 计算样本个数、标签个数
    sn = len(y_true)  # 样本个数
    ln = len(labels)  # 标签个数
    # 初始化值
    # log(0)为无穷大,这样一来,后续无法计算
    # 保护性对策:添加一个极小值δ,防止负无限大的发生
    delta = 1e-15      # 控制值
    Loss = 0         # 损失值初始化
    # 循环遍历
    for i in range(sn):  # 逐个遍历样本
        for k in range(ln):  # 逐个遍历标签
            if y_pred[i][k] < delta:
                y_pred[i][k] = delta
            if y_pred[i][k] > 1-delta:
                y_pred[i][k] = 1-delta
            # 计算对应位置上的真实值*log(预测值)
            Loss -= transformed_labels[i][k]*log(y_pred[i][k])
    Loss /= sn
    return Loss
#  调用自定义函数
print("自定义的交叉熵:", lilizong())


# =============================================================================
# 参考资料1:sklearn官网说明
# =============================================================================
# https://scikit-learn.org/stable/modules/model_evaluation.html#log-loss

# =============================================================================
# 参考资料2:sklearn中log_loss源代码
# =============================================================================
# https://github.com/scikit-learn/scikit-learn/blob/ed5e127b/sklearn/metrics/classification.py#L1576


# def log_loss(y_true, y_pred, eps=1e-15, normalize=True, sample_weight=None,
#              labels=None):
#     """Log loss, aka logistic loss or cross-entropy loss.
#     This is the loss function used in (multinomial) logistic regression
#     and extensions of it such as neural networks, defined as the negative
#     log-likelihood of the true labels given a probabilistic classifier's
#     predictions. The log loss is only defined for two or more labels.
#     For a single sample with true label yt in {0,1} and
#     estimated probability yp that yt = 1, the log loss is
#         -log P(yt|yp) = -(yt log(yp) + (1 - yt) log(1 - yp))
#     Read more in the :ref:`User Guide <log_loss>`.
#     Parameters
#     ----------
#     y_true : array-like or label indicator matrix
#         Ground truth (correct) labels for n_samples samples.
#     y_pred : array-like of float, shape = (n_samples, n_classes) or (n_samples,)
#         Predicted probabilities, as returned by a classifier's
#         predict_proba method. If ``y_pred.shape = (n_samples,)``
#         the probabilities provided are assumed to be that of the
#         positive class. The labels in ``y_pred`` are assumed to be
#         ordered alphabetically, as done by
#         :class:`preprocessing.LabelBinarizer`.
#     eps : float
#         Log loss is undefined for p=0 or p=1, so probabilities are
#         clipped to max(eps, min(1 - eps, p)).
#     normalize : bool, optional (default=True)
#         If true, return the mean loss per sample.
#         Otherwise, return the sum of the per-sample losses.
#     sample_weight : array-like of shape = [n_samples], optional
#         Sample weights.
#     labels : array-like, optional (default=None)
#         If not provided, labels will be inferred from y_true. If ``labels``
#         is ``None`` and ``y_pred`` has shape (n_samples,) the labels are
#         assumed to be binary and are inferred from ``y_true``.
#         .. versionadded:: 0.18
#     Returns
#     -------
#     loss : float
#     Examples
#     --------
#     >>> log_loss(["spam", "ham", "ham", "spam"],  # doctest: +ELLIPSIS
#     ...          [[.1, .9], [.9, .1], [.8, .2], [.35, .65]])
#     0.21616...
#     References
#     ----------
#     C.M. Bishop (2006). Pattern Recognition and Machine Learning. Springer,
#     p. 209.
#     Notes
#     -----
#     The logarithm used is the natural logarithm (base-e).
#     """
#     y_pred = check_array(y_pred, ensure_2d=False)
#     check_consistent_length(y_pred, y_true)

#     lb = LabelBinarizer()

#     if labels is not None:
#         lb.fit(labels)
#     else:
#         lb.fit(y_true)

#     if len(lb.classes_) == 1:
#         if labels is None:
#             raise ValueError('y_true contains only one label ({0}). Please '
#                              'provide the true labels explicitly through the '
#                              'labels argument.'.format(lb.classes_[0]))
#         else:
#             raise ValueError('The labels array needs to contain at least two '
#                              'labels for log_loss, '
#                              'got {0}.'.format(lb.classes_))

#     transformed_labels = lb.transform(y_true)

#     if transformed_labels.shape[1] == 1:
#         transformed_labels = np.append(1 - transformed_labels,
#                                        transformed_labels, axis=1)

#     # Clipping
#     y_pred = np.clip(y_pred, eps, 1 - eps)

#     # If y_pred is of single dimension, assume y_true to be binary
#     # and then check.
#     if y_pred.ndim == 1:
#         y_pred = y_pred[:, np.newaxis]
#     if y_pred.shape[1] == 1:
#         y_pred = np.append(1 - y_pred, y_pred, axis=1)

#     # Check if dimensions are consistent.
#     transformed_labels = check_array(transformed_labels)
#     if len(lb.classes_) != y_pred.shape[1]:
#         if labels is None:
#             raise ValueError("y_true and y_pred contain different number of "
#                              "classes {0}, {1}. Please provide the true "
#                              "labels explicitly through the labels argument. "
#                              "Classes found in "
#                              "y_true: {2}".format(transformed_labels.shape[1],
#                                                   y_pred.shape[1],
#                                                   lb.classes_))
#         else:
#             raise ValueError('The number of classes in labels is different '
#                              'from that in y_pred. Classes found in '
#                              'labels: {0}'.format(lb.classes_))

#     # Renormalize
#     y_pred /= y_pred.sum(axis=1)[:, np.newaxis]
#     loss = -(transformed_labels * np.log(y_pred)).sum(axis=1)

#     return _weighted_sum(loss, sample_weight, normalize)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/12333.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

多媒体内容理解在美图社区的应用实践

导读&#xff1a;移动互联网时代&#xff0c;图像和短视频等多媒体内容爆发&#xff0c;基于计算机视觉的AI算法是多媒体内容分析的基础。在美图社区智能化发展的过程中&#xff0c;视频和图像分类打标、去重以及质量评估的结果&#xff0c;在推荐、搜索以及人工审核等多个场景…

【R语言数据科学】:变量选择(三)主成分回归和偏最小二乘回归

变量选择(三)主成分回归和偏最小二乘回归 🌸个人主页:JOJO数据科学📝个人介绍:统计学top3高校统计学硕士在读💌如果文章对你有帮助,欢迎✌关注、👍点赞、✌收藏、👍订阅专栏✨本文收录于【R语言数据科学】本系列主要介绍R语言在数据科学领域的应用包括: R语言编…

多分类问题的precision和recall以及F1 scores的计算

对于多分类问题&#xff0c;首先&#xff0c;对于每一个类的精准率&#xff08;Precision&#xff09;和召回率&#xff08;Recall&#xff09;&#xff0c;定义和二分类问题一致&#xff0c;但是计算上不再需要TP,FP,FN等量了&#xff1a;&#xff09; 比如对A, B, C三类有如…

SpringBoot中如何集成ThymeLeaf呢?

转自: SpringBoot中如何集成ThymeLeaf呢&#xff1f; 下文笔者将讲述SpringBoot集成ThymeLeaf的方法&#xff0c;如下所示: 实现思路:1.在pom.xml中引入ThymeLeaf的相关依赖2.在Templates文件夹下编写相应的模板文件例: 1.pom.xml 添加ThymeLeaf依赖<!-- ThymeLeaf 依赖…

河南某商务楼BA系统设计

目 录 第一章 概述 3 第二章 设计任务与要求 4 第三章 设计依据和规范 4 第四章 系统设计 5 4.1系统选型 5 4.2 I/O点位设计 7 4.2.1暖通空调系统 11 4.2.2给排水系统 13 4.2.3电气系统 15 4.3线缆选型设计 17 4.4供电接地设计 17 4.5中央控制室设计 18 第五章 设备清单配置 18…

ASEMI代理力特LSIC1MO120E0080碳化硅MOSFET

编辑-Z 力特碳化硅MOS管LSIC1MO120E0080参数&#xff1a; 型号&#xff1a;LSIC1MO120E0080 漏极-源极电压&#xff08;VDS&#xff09;&#xff1a;1200V 连续漏电流&#xff08;ID&#xff09;&#xff1a;25A 功耗&#xff08;PD&#xff09;&#xff1a;214W 工作结温…

mysql数据库日志

1、日志类型 mysql日志在mysql事务章有事务日志相关的记录。初次之外&#xff0c;MySQL有不同类型的日志文件&#xff0c;用来存储不同类型的日志&#xff0c;分为二进制日志 、 错误日志 通用查询日志和 查询日志 &#xff0c;这也是常用的4种。MySQL 8又新增两种支持的日志&…

关于HTTPDNS,你知道多少?

导读&#xff1a; 全网域名劫持率高&#xff0c;域名解析失败、解析超时&#xff0c;IP调度不精准&#xff0c;域名解析变更生效不实时&#xff0c;这些问题是否一直困扰着你&#xff1f;作为网络请求最前置的环节&#xff0c;域名解析的稳定与精准程度直接决定了APP的访问体验…

实战讲解SpringCloud网关接口限流SpringCloudGateway+Redis(图+文)

1 缘起 最近补充微服务网关相关知识&#xff0c;学习了网关相关概念&#xff0c; 了解网关在微服务中存在的意义及其使命&#xff0c;如统一用户认证、接口权限控制、接口限流、接口熔断、黑白名单机制等&#xff0c; 打算通过实践的方式逐步学习网关的相关功能&#xff0c;同…

从零到一落地接口自动化测试

前段时间写了一系列自动化测试的文章&#xff0c;更多是从方法和解决问题思路角度阐述我的观点。 昨天花了几个小时看完了陈磊老师的《接口测试入门课》&#xff0c;有一些新的收获&#xff0c;结合我自己实践自动化测试的一些经验以及个人理解&#xff0c;这篇文章来聊聊新手…

主要控制系统之间的逻辑关系

电力行业 工控安全解决思路保障框架从电力行业对工控安全需求看&#xff0c;电力企业在主要是以合规性建设为主&#xff0c;在 2004 年原电监会 5 号令颁布开始&#xff0c;大部 分的电厂控制系统安全 建设已经按照 5 号令的要求进行了整改&#xff0c;形成“安全分区、网络专…

【2022硬件设计开源盛宴】一年一度的hackaday大赛结束,冠军便携式风力涡轮机,共提交326个电子作品,奖金池15万美元

https://hackaday.com/2022/11/05/ ... -years-competition/ &#xff08;1&#xff09;一年一度的Hackaday大赛结束&#xff0c;今年是第9届了&#xff0c;总奖金池是15万美元&#xff0c;冠军5万美元。前6届&#xff0c;冠军奖金非常高&#xff0c;像第3届冠军是最厉害的&am…

java设计模式之装饰者模式

一&#xff1a;装饰者模式 1.什么是装饰者模式? 装饰模式是一种结构型设计模式&#xff0c; 允许你通过将对象放入包含行为的特殊封装对象中来为原对象绑定新的行为。 装饰者模式的基本介绍 1.装饰者模式&#xff1a;动态的将新功能附加到对象上。在对象功能扩展方面&#xf…

Jasper 中如何将数据拆成多行并跨行累计

【问题】 I have a query that returns some summary records. For instance, loan amount, loan term, interest rate. Then I want to have a second row that builds out the detailed payment schedule. so the report would look like this: Loan Amt Term …

SpringBoot2

文章目录1.简介1.1 SpringBoot优缺点1.2 官方文档结构2. SpringBoot入门2.1 HelloWord2.2 依赖管理2.3 自动配置2.4 容器功能组件添加原生配置文件引入2.5 配置绑定ConfigurationPropertiesEnableConfigurationProperties2.6 自动配置原理底层总结最佳实践2.7 开发小技巧Lombok…

UML类图简单认识

类 类图包括类、接口和关系。类中包含三元素&#xff0c;第一行是类名&#xff0c;如果是虚类则为斜体。第二行包括属性&#xff0c;如果是public则为&#xff0c;如果是private则为-&#xff0c;如果是protected则为#。第三行包括方法&#xff0c;方法前面的符号表示与属性的…

QSS的应用

盒子模型&#xff1a; margin 边距border 边框padding 内边距content 内容常用的一些属性&#xff1a; background背景background-color背景颜色background-image背景图片background-position对齐方式border-&#xff08;top、left、bottom、right&#xff09;边界border-…

单调栈问题---(每日温度,下一个更大元素Ⅰ)

代码随想录day 58 单调栈问题— 每日温度,下一个更大元素Ⅰ 文章目录1.leetcode 739. 每日温度1.1 详细思路及解题步骤1.2Java版代码示例2.leetcode 496. 下一个更大元素 I2.1 详细思路及解题步骤2.2Java版代码示例1.leetcode 739. 每日温度 1.1 详细思路及解题步骤 这题会用到…

Spark RDD编程模型及算子介绍(一)

文章目录RDD编程模型介绍RDD的两种算子及延迟计算常见的Transformation算子RDD编程模型介绍 RDD是Spark 对于分布式数据集的抽象&#xff0c;它用于囊括所有内存中和磁盘中的分布式数据实体。每一个RDD都代表着一种分布式数据形态。在RDD的编程模型中&#xff0c;一共有两种算…

Linux-服务管理

服务介绍 服务本质就是进程&#xff0c;但是是运行在后台的&#xff0c;通常都会监听某个端口&#xff0c;等待其他程序的ing求&#xff0c;比如mysqld&#xff0c;sshd&#xff0c;防火墙灯&#xff0c;因为又称为守护进程 如何管理服务 CentOS7.0前使用service命令 servi…