【交叉熵损失torch.nn.CrossEntropyLoss详解-附代码实现】

news2025/1/20 2:02:48

CrossEntropyLoss

  • 什么是交叉熵
  • softmax
  • 损失计算
  • 验证
    • CrossEntropyLoss 输入输出介绍
    • 验证代码

什么是交叉熵

交叉熵有很多文章介绍,此处不赘述。只需要知道它是可以衡量真实值和预测值之间的差距的,因而用交叉熵来计算损失的时候,损失是越小越好,它用数学公式表示是:

-P(x) log Q(x)

其中P(x)是真实值,Q(x)是预测值
当p(x)和Q(x)是矩阵的时候,就分别对其计算,然后求和即可

在pytorch中的交叉熵损失CrossEntropyLoss 包含了 两部分,softmax和交叉熵计算,下面分别介绍这两部分

softmax

一句话理解,是将预测值转成概率。通常经过神经网络计算出来的预测数据不是一个,举个例子:
比如一个二分类问题,一个输入计算出来的结果总是两个值(a, b)其中 a 表示1分类的得分,b 表示2分类的得分,多分类同样
比如一个翻译模型,每个时间步的输出是词表大小(a, b,…) 其中每个值表示词表中每个词的得分

而我们需要的是概率,不是分数,因此需要一个转换,要保证所有分类的概率和为1 softmax的做法:
在这里插入图片描述

即:exp(某分数)/所有分类的exp后的分数

损失计算

计算完softmax,就可以用文中刚开始的 -P(x) log Q(x) 计算损失了,通常情况下,我们的真实值 p(x),也就是target 通常是one-hot编码的,举个例子:
比如二分类类的时候,target通常是(0,1)(1,0)
比如翻译模型,target通常是(0,…1…0)等
我们计算的时候不难发现target中为0经过乘法都是0了,因此最后只剩下正确类型的这个损失差距 最后公式可以演变成 - log Q(x)

一句话来说,交叉熵的损失值只关注了正确分类的差距

验证

自己实现了一下softmax和cross_loss,验证下上述理论的正确性,那就要介绍下torch.nn.CrossEntropyLoss

CrossEntropyLoss 输入输出介绍

可以翻看官网介绍

CLASStorch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction=‘mean’, label_smoothing=0.0)

reduction是指损失计算方式,默认取平均mean,同时支持none,sum ,分别表示每一个损失不做其他操作、所有损失求求和

计算是target 的shape支持直接输入具体值,或者是索引形式,举个例子:
预测值: [0.8, 0.5, 0.2, 0.5]
target可以是 [1, 0, 0, 0] 或者索引形式 0

多样本也同样:
预测值:
[[0.8, 0.5, 0.2, 0.5],
[0.2, 0.9, 0.3, 0.2],
[0.4, 0.3, 0.7, 0.1],
[0.1, 0.2, 0.4, 0.8]]
target 可以是:

  • 列表形式 torch.tensor([[1, 0, 0, 0],
    [0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 0, 1]], dtype=torch.float)
  • 索引形式: torch.tensor([0,1, 1, 3], dtype=torch.long)

验证代码

def soft_max(x):
    x_exp = torch.exp(x)
    partition = x_exp.sum(1, keepdim=True)
   
    # 广播partition
    return x_exp / partition
def cross_entropy(y, y_hat):
    x = y_hat[range(len(y_hat)), y]
    print("取出对应元素:", x, '真实label:', y)

    return -torch.log(x)

	y = torch.tensor([0, 2])
    y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
    y_hat_softmax = soft_max(y_hat)
    print(y_hat_softmax)
    out = cross_entropy(y, y_hat_softmax)
    print('手动计算的损失', out)

    cr_loss = torch.nn.CrossEntropyLoss(reduction="none")
    out = cr_loss(y_hat, y)
    print('公式计算的损失', out)

输出如下:

手动计算的损失 tensor([1.3533, 0.9398])
公式计算的损失 tensor([1.3533, 0.9398])

结果一致,可以验证无问题

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/997213.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaScript手撕代码】new

目录 手写 手写 /* * param {Function} fn 构造函数 * return {*} **/ function myNew(fn, ...args){if(typeof fn ! function){return new TypeError(fn must be a function)}// 先创建一个对象let obj Object.create(fn.prototype)// 通过apply让this指向obj, 并调用执行构…

SHIB去零计划:创新金融未来,打造稳定数字资产新范式

SHIB去零计划,由星火有限公司发起,以区块链去中心化手段解决信任危机,对抗垄断与不公平问题,破解经济制裁,实现稳定数字资产的快速有效、平等互利交易。星火有限公司,一家跨国运营集团,主营业务…

UIStackView入门使用两个问题

项目中横向一排元素,竖向一排元素,可以使用UIStackView。UIStackView的原理不做介绍,这里主要讲两个初次使用容易出现的两个问题。 首先创建一个stackview -(UIStackView*)titleStackView{if(_titleStackView nil){_titleStackView [UISta…

时序分解 | MATLAB实现北方苍鹰优化算法NGO优化VMD信号分量可视化

时序分解 | MATLAB实现北方苍鹰优化算法NGO优化VMD信号分量可视化 目录 时序分解 | MATLAB实现北方苍鹰优化算法NGO优化VMD信号分量可视化效果一览基本介绍程序设计参考资料 效果一览 基本介绍 北方苍鹰优化算法NGO优化VMD,对其分解层数,惩罚因子数做优化…

绝对的搜索利器

苏生不惑第450 篇原创文章,将本公众号设为星标,第一时间看最新文章。 今天分享几个文件搜索利器,下载地址在公众号苏生不惑后台回复2023909,你的小电影要藏不住了。 首先自然是Everything https://www.voidtools.com/zh-cn/&#…

python DVWAXSSPOC练习

XSS反射性低难度 数据包 GET /dv/vulnerabilities/xss_r/?name%3Cscript%3Ealert%28%27xss%27%29%3C%2Fscript%3E HTTP/1.1Host: 10.9.75.161Upgrade-Insecure-Requests: 1User-Agent: Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Ch…

数据结构与算法-BtreeB+Tree

一:引入 作为一个IT从业者大家对数据库肯定是都知道的,大家应该知道在数据库中有个索引,在一张表中用了索引与不用索引那查找效率简直就是天壤之别,但是大家有没思考过,你经常用的索引是什么样的数据结构呢&#xff1f…

Unity中Shader抓取屏幕并实现扭曲效果

文章目录 前言一、屏幕抓取,在上一篇文章已经写了二、实现抓取后的屏幕扭曲实现思路:1、屏幕扭曲要借助传入 UV 贴图进行扭曲2、传入贴图后在顶点着色器的输入参数处,传入一个 float2 uv : TEXCOORD,用于之后对扭曲贴图进行采样3、…

SAP 创建动态内表

创建动态内表 一、根据表名创建内表 程序代码: "复杂方式 SELECTION-SCREEN BEGIN OF BLOCK b1 WITH FRAME TITLE TEXT-001. PARAMETERS:p_tab TYPE string. SELECTION-SCREEN END OF BLOCK b1.DATA:lr_struct TYPE REF TO data,lr_table TYPE REF TO data. …

【云原生系列】Docker学习

目录 一、Docker常用命令 1 基础命令 2 镜像命令 2.1 docker images 查看本地主机的所有镜像 2.2 docker search 搜索镜像 2.3 docker pull 镜像名[:tag] 下载镜像 2.4 docker rmi 删除镜像 2.5 docker build 构建镜像 3 容器命令 3.1 如拉取一个centos镜像 3.2 运行…

.env文件详解

.env配置文件 vue会根据 process.env.NODE_ENV 的值,自动加载对应的环境配置文件 .env 全局默认配置文件,在所有的环境中被载入;.env.production 生产环境文件 production;.env.development 开发环境文件 development;.env.test/.env.stagi…

从零开始完整实现-循环神经网络RNN

一 简介 使用 pytorch 搭建循环神经网络RNN,循环神经网络(Recurrent Neural Network,RNN)是一类用于 处理序列数据的神经网络架构。与传统神经网络不同,RNN 具有内部循环结构,可以在处理序列数据时保持状态…

MySQL基础篇:掌握数据库基本操作,轻松上手

查看和指定现有的数据库 mysql> show databases; -------------------- | Database | -------------------- | information_schema | | bjpowernode | | eladmin | | mysql | | performance_schema | | sqlalchemy | | s…

makefile之使用函数wildcard和patsubst

Makefile之调用函数 调用makefile机制实现的一些函数 $(function arguments) : function是函数名,arguments是该函数的参数 参数和函数名用空格或Tab分隔,如果有多个参数,之间用逗号隔开. wildcard函数:让通配符在makefile文件中使用有效果 $(wildcard pattern) 输入只有一个参…

Qt串口基本设置与协议收发

前言 1.一直都想要做一个Qt上位机,趁着这个周末有时间,动手写一下 2.comboBox没有点击的信号,所以做了一个触发的功能 3.Qt的数据类型很奇怪,转来转去的我也搞得很迷糊 4.给自己挖个坑,下一期做一个查看波形的上位…

Java输入-a,-b,geek,-c,888,-d,[hello,world]字符之后,如何将[hello,world]这个不分开

Java输入-a,-b,geek,-c,888,-d,[hello,world]字符之后,如何将[hello,world]这个不分开? 你可以使用命令行参数解析库来处理Java输入中的各个参数。在这种情况下,你可以使用Apache Commons CLI库来解析命令行参数。以下是一个示例代码片段&am…

MATLAB遗传算法求解生鲜货损制冷时间窗碳排放多成本车辆路径规划问题

MATLAB遗传算法求解生鲜货损制冷时间窗碳排放多成本车辆路径规划问题实例 1、问题描述 已知配送中心和需求门店的地理位置,并且已经获得各个门店的需求量。关于送货时间的要求,门店都有规定的时间窗,对于超过规定时间窗外的配送时间会产生相应的惩罚成本。为保持生鲜农产品的…

2023.09.10 学习周报

文章目录 摘要文献阅读1-1 题目1-2 创新点1-3 本文工作2-1 题目2-2 什么是图2-3 图神经网络2-4 信息传递3-1 题目3-2 创新点3-3 本文工作 深度学习1.GNN的构建步骤2.构建图的方法3.GNN的简单样例 总结 摘要 本周阅读了三篇文章,第一篇是基于物理信息深度学习和激光…

【LeetCode题目详解】第九章 动态规划part11 ● 123.买卖股票的最佳时机III ● 188.买卖股票的最佳时机IV (day50补)

本文章代码以c为例! 一、力扣第123题:买卖股票的最佳时机 III 题目: 给定一个数组,它的第 i 个元素是一支给定的股票在第 i 天的价格。 设计一个算法来计算你所能获取的最大利润。你最多可以完成 两笔 交易。 注意&#xff1…

基于springboot+vue的在线课程学习网站(前后端分离)

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目介绍…