Pytorch实用教程:nn.CrossEntropyLoss()的用法

news2024/12/26 10:41:20

在 PyTorch 中,nn.CrossEntropyLoss() 是一个非常常用且功能强大的损失函数,特别适合用于多类分类问题。这个损失函数结合了 nn.LogSoftmax()nn.NLLLoss() (Negative Log Likelihood Loss) 两个操作,从而在一个模块中提供完整的交叉熵损失计算功能。这不仅方便使用,也提高了数值稳定性。

功能说明

nn.CrossEntropyLoss() 计算模型输出实际标签之间的交叉熵损失。它自动完成softmax 概率分布的计算和对数似然损失的计算,这意味着你应该直接将网络的原始输出(logits,即未经 softmax 层处理的输出)作为 CrossEntropyLoss 的输入。

上面这句话非常重要,这就是为什么在用交叉熵损失函数的时候,在模型的输出部分见不到softmax的原因。

参数详解

nn.CrossEntropyLoss 主要有以下几个参数:

  • weight (Tensor, optional): 一个手动指定的权重,用于平衡类别间的损失贡献。这在类别不平衡的情况下非常有用。
  • size_average (bool, deprecated): 这个参数已经被弃用,用 reduction 参数代替。
  • ignore_index (int, optional): 指定一个类别索引,对于这个类别的目标(target),损失将不会被计算。这常用于忽略特定的类别。
  • reduce (bool, deprecated): 这个参数也已经被弃用,用 reduction 参数代替。
  • reduction (str, optional): 指定损失的计算模式。可以是 ‘none’(无操作),‘mean’(计算损失的均值,是默认设置)或 ‘sum’(计算损失的总和)。

使用示例

下面是一个使用 nn.CrossEntropyLoss 的简单例子。假设我们有一个分类问题,目标是将输入分类到三个类别中的一个:

import torch
import torch.nn as nn

# 假设我们有3个类别,batch_size为4
data = torch.randn(4, 3)  # 输入,来自某个神经网络的原始输出,形状为(batch_size, num_classes)
targets = torch.tensor([0, 2, 1, 0])  # 实际的标签,形状为(batch_size,)

# 创建交叉熵损失函数实例
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(data, targets)
print(loss) # 输出:tensor(1.6401)

数学原理

对于每个样本 (i),假设 (C) 是类别总数,交叉熵损失定义为:

在这里插入图片描述

这里 (x[class_i]) 是模型输出的第 (i) 个样本对应其真实类别 (class_i) 的 logit。交叉熵损失将这些 logits 转换为正规化的概率分布,然后计算其对数似然。

应用场景

这个损失函数是处理多类分类问题的标准选择之一,特别是当你有一个多类的标签目标时。由于其数学上的稳定性,它在训练深度学习模型时非常受欢迎。使用它可以直接处理 logits,无需单独计算 softmax,从而在实际应用中减少计算量和增加数值稳定性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1607710.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Spring Boot与Vue联手打造前沿智能学习平台

作者介绍:✌️大厂全栈码农|毕设实战开发,专注于大学生项目实战开发、讲解和毕业答疑辅导。 🍅获取源码联系方式请查看文末🍅 推荐订阅精彩专栏 👇🏻 避免错过下次更新 Springboot项目精选实战案例 更多项目…

windows下vscode调试虚拟机linux c++工程的三种方法

vscode去远程调试方法有很多种,不同的插件对应了不同的调试方法,比如: 1.C/C插件进行GDB调试(编写launch.json文件) 2.C/C Runner插件 3.CMake Tools插件(只针对CMake工程,需要搭配C/C插件一起使用,但无…

【系统分析师】系统测试与维护

文章目录 1、测试方法2、测试阶段3、面向对象的测试4、测试自动化5、软件调试6、软件评审7、软件改进过程8、软件开发环境与工具9、系统转换计划10、系统的运行与维护11、系统审计 1、测试方法 例题 2、测试阶段 注意区分:每个阶段都做了什么事情3、面向对象的测试 4、测试自动…

设计模式之模板方法模式详解(下)

3)钩子方法的使用 1.概述 钩子方法的引入使得子类可以控制父类的行为。 2.结构图 3.代码实现 将公共方法和框架代码放在抽象父类中 abstract class DataViewer {//抽象方法:获取数据public abstract void GetData();//具体方法:转换数据…

【传输层】

文章目录 传输层传输服务和协议传输层 vs. 网络层Internet传输层协议多路复用/解复用在发送方主机多路复用在接收方主机多路解复用 多路解复用工作原理无连接(UDP)多路解复用无连接传输:UDPUDP:用户数据报协议UDP校验和 传输层 目…

华为服务Fellow、首席项目管理专家,华为H5M项目管理标准制定主导者孙虎受邀为PMO大会演讲嘉宾

全国PMO专业人士年度盛会 华为服务Fellow、首席项目管理专家,华为H5M项目管理标准制定主导者孙虎先生受邀为PMO评论主办的2024第十三届中国PMO大会演讲嘉宾,演讲议题为“落地项目管理标准,打赢班长的战争”。大会将于5月25-26日在北京举办&am…

excel中vlookup查找值必须在table_array的第一列,有其他办法吗有XLOOKUP

vlookup查找值必须在table_array的第一列,有其他办法吗?有XLOOKUP。 vlookup 查找如下,查找值必须在table_array的第一列 如果下面,编码和名称交换位置,就不能使用vlookup查找了。 XLOOKUP 查找如下

电脑桌面便签软件哪个好?好用的电脑桌面便签

电脑作为我们日常工作的重要工具,承载着大量的任务和项目。当工作任务繁重时,如何在电脑桌面上高效管理这些任务就显得尤为重要。这时,选择一款优秀的桌面便签软件,无疑会给我们带来极大的便利。 一款好的桌面便签软件&#xff0…

注意力机制基本思想(一)

​🌈 个人主页:十二月的猫-CSDN博客🔥 系列专栏: 🏀《深度学习基础知识》 相关专栏: ⚽《机器学习基础知识》 🏐《机器学习项目实战》 🥎《深度学习项目实战…

2024年华中杯B题论文发布+数据预处理问题一代码免费分享

【腾讯文档】2024年华中杯B题资料汇总 https://docs.qq.com/doc/DSExMdnNsamxCVUJt 行车轨迹估计交通信号灯周期问题 摘要 在城市化迅速发展的今天,交通管理和优化已成为关键的城市运营问题之一。本文将基于题目给出的数据,对行车轨迹估计交通信号灯…

记录Python链接mysql的数据库的2种操作方式

一、使用pymysql库方式 import pymysqldb pymysql.connect(hostlocalhost,userroot,password123456) #创建链接,在3.8以后好像已经不支持这个种链接方式了, #db pymysql.connect(localhost,root,123456) cursor db.cursor()#拿到游标这样我们就拿到了…

Rust 语言 GUI 用户界面框架汇总(持续更新)

拜登:“一切非 Rust 项目均为非法”😎 什么是 GUI 图形用户界面(Graphical User Interface,简称 GUI,又称图形用户接口)是指采用图形方式显示的计算机操作用户界面。 现在的应用开发,是既要功…

云从科技AI智能体云月亮相中国铁建GSF项目展示中心

近日,中国铁建大湾区科学论坛永久会址项目综合展示体验中心(以下简称“中国铁建GSF项目展示中心”)迎来了一位特别的客服——云月数智人。云月是云从从容多模态大模型的融合承载体——AI智能体(AI-Agent),她…

设计模式系列:适配器模式

简介 适配器模式(Adapter Pattern)又称为变压器模式,它是一种结构型设计模式。适配器模式的目的是将一个类的接口转换成客户端所期望的另一种接口,从而使原本因接口不匹配而不能一起工作的两个类能够一起工作。 适配器模式有两种…

免费的浏览器翻译插件——easypubmed

支持谷歌和edge浏览器,应用商店直接检索安装就可。 非常方便,无论是打算文字还是查单词,只要选中按D,就可以一键翻译啦。 最重要是免费,而且添加了小牛翻译引擎哦。 当然了,此插件本身是给医学生准备的。Pu…

深度学习--CNN应用--VGG16网络和ResNet18网络

前言 我们在学习这两个网络时,应先了解CNN网络的相关知识 深度学习--CNN卷积神经网络(附图)-CSDN博客 这篇博客能够帮我们更好的理解VGG16和RetNet18 1.VGG16 1.1 VGG简介 VGG论文网址:VGG论文 大家有兴趣的可以去研读一下…

JAVA面向对象(下 )(一、继承和方法重写)

一、继承 1.1 什么是继承 生活中继承是指: 继承财产>延续财产 继承/遗传 父母的长相,基因 > 延续上一代的基因 继承技能、专业、职位 >延续 继承中华民族的传统文化 > 延续 青出于蓝而胜于蓝 或 长江后浪推前浪,前浪被拍在…

es安装中文分词器

下载地址,尽量选择和自己本地es差不多的版本 https://github.com/infinilabs/analysis-ik/releases 下载好,解压,把里面的文件放到es的plugins/ik目录下 把plugin-descriptor.properties文件里的es版本改成自己对应的 再启动es,能…

十、OOP面向对象程序设计(五)

1、什么是接口以及接口的运用 1)接口定义 Java接口(Interface),是一些列方法的声明,是一些方法特征的集合,一个接口只有方法的特征没有方法的实现,因此这些方法可以在不同的地方被不同的类实现,而这些实现可以具有不同的行为(功能。) 2)接口定义的一般形式 修饰符:…

抖音小店怎么选品?这些超级容易爆单的选品方法,很少有人告诉你!

哈喽~我是电商月月 抖音小店的运营过程中,选品是非常重要的,好的商品不用宣传,就有人看 今天我就来给大家分享几个选品技巧,学会后商品一上架就有流量! 利用数据选品 1.“蝉妈妈”的数据排行榜选品 “蝉妈妈”能看…