PyTorch中的交叉熵函数 CrossEntropyLoss的计算过程

news2025/1/11 20:43:28

CrossEntropyLoss() 函数联合调用了 nn.LogSoftmax() 和 nn.NLLLoss()。

关于交叉熵函数的公式详见:
交叉熵损失函数原理详解
CrossEntropyLoss() 函数的计算过程可以拆解为如下四个步骤:
1、对输出的结果进行softmax操作,因为softmax操作可以将所有输入值都归为[0,1]之间,且所有值之和为1,符合概率分布的特性。

2、对softmax结果进行log运算,求出都是小于0的值
3、对真实概率值进行one-hot编码
4、利用下面的公式求出最终的loss值
C r o s s E n t r o p y L o s s ( x ) = − ∑ i = 1 n O n e H o t ( t a r g e t i ) ∗ l o g s o f t m a x ( i n p u t ) i CrossEntropyLoss(x) = - \sum_{i=1}^{n} OneHot(target_i) * log^{softmax(input)_i} CrossEntropyLoss(x)=i=1nOneHot(targeti)logsoftmax(input)i
不难看出NLLloss+log+softmax就是CrossEntropyLoss(softmax版的交叉熵损失函数),而其中的NLLloss就是在做交叉熵损失函数的最后一步:预测结果的取负求和。

一段代码带你上高速:

原生手写CrossEntropyLoss()函数与PyTorch里的CrossEntropyLoss():

import torch
import torch.nn as nn
import torch.nn.functional as F


def softmax(x):
    r"""
    这里的公式是:

    .. math::
        softmax(x_i) = \frac{\exp^{x_i}}{\sum_{i=0}^N (\exp^{x_i})}
    :param x:
    :return:
    """

    # return torch.exp(x) / torch.unsqueeze(torch.sum(torch.exp(x), dim=1), dim=1)
    t = torch.zeros_like(x)
    sum_array = []
    for idx in range(len(x)):
        t[idx] = torch.exp(x[idx]) / torch.sum(torch.exp(x[idx]))
        sum_array.append(torch.sum(torch.exp(x[idx])))
    print(sum_array)
    print(torch.unsqueeze(torch.sum(torch.exp(x), dim=1), dim=1))
    return t

def manual_cross_entropy_loss(inputs, target):
    one_hot_target = F.one_hot(target, 10)
    print(one_hot_target)
    # print(softmax_)
    # print(F.softmax(inputs, dim=-1))
    # cross_entropy_loss = -torch.sum(one_hot_target * torch.log(F.softmax(inputs, dim=-1))) / len(inputs)
    cross_entropy_loss = -torch.sum(one_hot_target * torch.log(softmax(inputs))) / len(inputs)
    print(cross_entropy_loss)


def call_cross_entropy_loss(inputs, target):
    """
    调用真实的CrossEntropyLoss()
    :return:
    """
    loss = nn.CrossEntropyLoss()
    loss = loss(inputs, target)
    print(loss)


if __name__ == '__main__':
    torch.random.manual_seed(0)
    inputs = torch.randn(3, 10)
    # 获取inputs的最大值的索引
    target = torch.argmax(inputs, dim=-1)
    print(target)
    manual_cross_entropy_loss(inputs=inputs, target=target)
    call_cross_entropy_loss(inputs=inputs, target=target)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/475497.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【Java基础教程】初识Java

作者简介: 辭七七,目前大一,正在学习C/C,Java,Python等 作者主页: 七七的个人主页 **文章收录专栏:Java.SE,本专栏主要讲解运算符,程序逻辑控制,方法的使用&a…

Java实现数据压缩所有方式性能测试

目录 1 BZip方式1.1 引入依赖1.2 BZip工具类代码1.3 BZip2工具类代码 2 Deflater方式3 Gzip方式4 Lz4方式4.1 简介4.2 算法思想4.3 算法实现4.3.1 lz4数据格式2、lz4压缩过程3、lz4解压过程 4.4 Lz4-Java4.4.1 简介4.4.2 类库 5 SevenZ方式5.1 引入依赖5.2 工具类代码 6 Zip方式…

C++(继承和组合)

继承:public继承是一种 is-a 的关系,也就是每一个派生类对象都有一个基类对象 这些关系都适合用继承来表达 ----> 继承了之后父类的成员就变成了子类的一部分,子类对象可以直接用 组合: 是一种has -a(有一个&…

GraphSAGE聚合流程计算实例

本篇中我们只讨论聚合流程,不考虑GraphSAGE的小批量训练等内容。 我们先来看一下GraphSAGE的聚合流程伪代码,之后会给出两个具体的计算例子进行说明: 11行中, N ( k ) ( u ) N^{(k)}(u) N(k)(u)表示节点u的邻居节点采样函数&…

力扣杯2023春·个人赛

文章目录 力扣杯2023春-个人赛[LCP 72. 补给马车](https://leetcode.cn/problems/hqCnmP/)模拟 [LCP 73. 探险营地](https://leetcode.cn/problems/0Zeoeg/)模拟 哈希 [LCP 74. 最强祝福力场](https://leetcode.cn/problems/xepqZ5/)二维差分 离散化扫描线 [LCP 75. 传送卷轴…

CANOE入门到精通——CANOE系列教程记录1 第一个仿真工程

本系列以初学者角度记录学习CANOE,以《CANoe开发从入门到精通》参考学习,CANoe16 demo版就可以进行学习 概念 CANoe是一种用于开发、测试和分析汽车电子系统的软件工具。它通过在不同层次上模拟汽车电子系统中的不同部件,如ECU、总线和传感…

自动化运维工具Ansible之playbook剧本

目录 一、playbook 1、playbook简述 2、playbook剧本格式 3、playbook组成部分 4、playbook启动及检测 5、playbook模块实战实例1 6、vars模块实战实例2 7、when模块实战实例3 8、with_items循环模块实战实例4 9、template模块实战实例5 10、tags模块实战实例6 一、…

VM中kali虚拟机创建docker部署WebGoat

这里选择在docker中配置(因为方便) 首先下载docker sudo apt-get install docker.io 然后从Docker Hub下载WebGoat 8.0 的docker镜像 使用命令 docker pull webgoat/webgoat-8.0 完成后查看现在kali虚拟机中的docker镜像列表 输入命令 docker images …

0704一阶线性微分方程-微分方程

文章目录 1 线性方程1.1 定义1.2 解法(常数变易法)1.3 例题 2伯努利方程3 简单变量替换解方程结语 1 线性方程 1.1 定义 一阶微分方程:形式上能化成 d y d x P ( x ) y Q ( x ) \frac{dy}{dx}P(x)yQ(x) dxdy​P(x)yQ(x)的方程,…

树莓派CSI摄像头使用python调用opencv库函数进行运动检测识别

目录 一、完成摄像头的调用 二、利用python调用opencv库函数对图像进行处理 2.1 图像处理大体流程 2.2 opencv调用函数的参数以及含义 2.2.1 ret, img cap.read() 读取帧图像 2.2.2 cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 灰度图像 2.2.3 gray_diff_img cv2.absdiff(g…

详解子网划分练习题(32道)

目录 1 子网划分概念: 2 划分方法: 子网划分方法:段,块,数的计算三步。 段就是确定ip地址段中既有网络地址,又有主机地址的那一段是四段中的那一段? 块就确定上一步中确定的那一段中的主机…

【Linux】网络配置详细步骤及其相关基础知识介绍

一、Linux网络配置步骤 1、登录root账户 进行网络配置需要使用root权限,因此需要先登录root用户 2、输入ip addr查看网络信息 只有一个本机地址127.0.0.1,因为Linux操作系统的网卡开关还没有打开。 3、输入cd /etc/sysconfig/network-scripts/进入目录…

R语言 | 列表

目录 一、建立列表 1.1 建立列表对象——对象元素不含名称 1.2 建立列表对象——对象元素含名称 1.3 处理列表内对象的元素名称 1.4 获得列表的对象元素个数 二、获取列表内对象的元素内容 2.1 使用"$"符号取得列表对象的元素内容 2.2 使用"[[ ]]"符…

关于GeoServer发布服务时数据源设置的避坑指南

题外话 时光任然,一年一度的五一劳动节已然来到。作为疫情之后迎来的第一个五一,不知道各位小伙伴们怎么度过这个劳动节呢?是决定去另一个城市,观察体验一下不一样的风景,或者去旅游,给自己放假。昨天被123…

three.js进阶之动画系统

我曾在three.js进阶之骨骼绑定文章中提到了AnimationMixer、AnimationAction等内容,其实这些应该属于Three.js的动画系统,本文就系统的介绍一下动画系统(Animation System)。 前言 一般情况下,我们很少会使用three.j…

【学习视频】阅读开源工业软件和工业智能实战上线B站

图片来源:https://metrology.news/a-i-for-smarter-factories-the-world-of-industrial-artificial-intelligence/ 为了帮助大家做好工业软件以及用人工智能解决工业领域现实问题,我在B站上开了两个视频系列,一个是“一起来读开源工业软件”…

STM32 基础知识入门 (C语言基础巩固)

1、在不改变其他位的值的状况下,对某几个位进行设值 这个场景在单片机开发中经常使用,方法就是先对需要设置的位用&操作符进行清零操作, 然后用|操作符设值。 比如我要改变 GPIOA 的 CRL 寄存器 bit6(第 6 位)的…

MiNiGPT4安装记录

装conda wget https://repo.anaconda.com/archive/Anaconda3-5.3.0-Linux-x86_64.sh chmod x Anaconda3-5.3.0-Linux-x86_64.sh ./Anaconda3-5.3.0-Linux-x86_64.sh export PATH~/anaconda3/bin:$PATH # 或者写到环境保护变量 # 不会弄看这吧 https://blog.csdn.net/wyf2017/a…

fork()创建进程原理

目录 一、写时复制技术写时复制的优点:vfork()和fork() 二、fork()原理初步再理解下页表与多进程在内存中的图像创建进程和创建线程的区别 三、fork()的具体过程 一、写时复制技术 fork()生成子进程时,只是把虚拟地址拷贝给子进程,也就是父进…

( 字符串) 205. 同构字符串 ——【Leetcode每日一题】

❓205. 同构字符串 难度:简单 给定两个字符串 s 和 t ,判断它们是否是同构的。 如果 s 中的字符可以按某种映射关系替换得到 t ,那么这两个字符串是同构的。 每个出现的字符都应当映射到另一个字符,同时不改变字符的顺序。不同…