pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?

news2025/3/29 19:50:05

pytorch小记(十五):pytorch中 交叉熵损失详解:为什么logits比targets多一个维度?

  • PyTorch交叉熵损失详解:为什么logits比targets多一个维度?
    • 一、前言:新手常见困惑
    • 二、核心概念:从考试得分到概率分布
      • 1. logits:原始得分矩阵
      • 2. targets:正确答案索引
    • 三、维度差异的本质原因
      • 1. 分类任务的数学需求
      • 2. 维度对照表
      • 3. 错误用法解析
    • 四、手把手计算交叉熵损失
      • 1. 输入数据
      • 2. 计算步骤
        • 步骤1:Softmax归一化
        • 步骤2:提取正确类别的概率
        • 步骤3:计算交叉熵
    • 五、设计哲学深度解析
      • 1. 为何不直接使用概率?
      • 2. 多任务场景对照表
    • 六、常见问题解答
      • Q1:二分类能否用形状[N]的logits?
      • Q2:如何处理多标签分类?
      • Q3:为什么我的loss计算很慢?
    • 七、总结


PyTorch交叉熵损失详解:为什么logits比targets多一个维度?

关键词:PyTorch交叉熵损失、logits维度、分类任务原理、深度学习基础


一、前言:新手常见困惑

许多初学PyTorch的朋友在使用交叉熵损失函数时,都会对logitstargets的维度关系感到困惑。典型的报错场景如下:

# 正确用法
logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]])  # 形状 [2, 2]
targets = torch.tensor([0, 1])                     # 形状 [2]

# 错误用法(触发维度错误)
logits_error = torch.tensor([0.5, 1.2])            # 形状 [2]
targets_error = torch.tensor([0, 1])               # 形状 [2]
loss = F.cross_entropy(logits_error, targets_error)  # 报错!

本文将用生活实例+手把手计算的方式,带你彻底理解交叉熵损失的维度设计逻辑。


二、核心概念:从考试得分到概率分布

1. logits:原始得分矩阵

想象你正在参加一场有2道选择题的考试,每道题有A、B两个选项。模型对每个选项给出原始得分:

logits = torch.tensor([
    [-1.0, 1.0],   # 第1题:A得-1分,B得1分
    [-0.5, 1.5],   # 第2题:A得-0.5分,B得1.5分
    [-0.5, 1.5]    # 第3题(新增):同上
])
  • 形状[3, 2]:3个样本(题目),每个样本2个类别(选项)
  • 物理意义:未经归一化的"信心分数",数值越大表示模型越倾向该选项

2. targets:正确答案索引

targets = torch.tensor([0, 1, 1]) 
# 含义:第1题正确答案是A(索引0),第2、3题是B(索引1)
  • 形状[3]:3个样本各对应一个正确答案位置

三、维度差异的本质原因

1. 分类任务的数学需求

  • 模型需要为每个可能的类别提供判断依据
  • 即使正确答案只有一个,也必须比较所有选项的"证据强度"

2. 维度对照表

张量形状物理意义
logits[N, C]N个样本,每个样本C个类别的得分
targets[N]N个样本的正确类别索引(n在0~c-1之间)

3. 错误用法解析

logitstargets同维度:

logits_error = torch.tensor([0.2, 0.7, 0.5])  # 形状[3]
targets = torch.tensor([0, 1, 1])              # 形状[3]

此时模型无法判断:

  • 每个数值对应哪个类别?
  • 如何进行多类别比较?

四、手把手计算交叉熵损失

以具体例子演示计算全过程:

1. 输入数据

logits = torch.tensor([
    [-1.0, 1.0], 
    [-0.5, 1.5],
    [-0.5, 1.5]
])  # 形状[3,2]
targets = torch.tensor([0, 1, 1])  # 形状[3]

2. 计算步骤

步骤1:Softmax归一化

将原始得分转换为概率分布(每行和为1):

第1个样本([-1.0, 1.0]):

exp(-1.0) = 0.3679  
exp(1.0) = 2.7183
总合 = 0.3679 + 2.7183 = 3.0862
概率 = [0.3679/3.08630.1192, 2.7183/3.08630.8808]

第2个样本([-0.5, 1.5]):

exp(-0.5)0.6065  
exp(1.5)4.4817
总合 = 0.6065 + 4.48175.0882
概率 = [0.6065/5.08820.1192, 4.4817/5.08820.8808]
步骤2:提取正确类别的概率

根据targets索引:

样本1:取索引00.1192  
样本2:取索引10.8808  
样本3:取索引10.8808
步骤3:计算交叉熵

公式:loss = -平均(ln(正确概率))

loss = -(ln(0.1192) + ln(0.8808) + ln(0.8808)) / 3
     = -[(-2.127) + (-0.127) + (-0.127)] / 30.7937

验证PyTorch计算结果:

print(loss.item())  # 输出 0.7937

五、设计哲学深度解析

1. 为何不直接使用概率?

  • 数值稳定性:直接处理指数运算易导致溢出
  • 梯度优化:logits的线性特性更利于反向传播

2. 多任务场景对照表

任务类型logits形状targets形状损失函数
二分类(2个选项)[N,2][N]CrossEntropyLoss
多标签分类[N,C][N,C]BCEWithLogitsLoss
回归任务[N][N]MSELoss

六、常见问题解答

Q1:二分类能否用形状[N]的logits?

可以,但需配合sigmoid

# 二分类特例
logits = torch.tensor([0.8, -0.3])  # 形状[2]
prob = torch.sigmoid(logits)        # 转换为概率
loss = F.binary_cross_entropy(prob, targets)

Q2:如何处理多标签分类?

当每个样本可能有多个正确标签时:

logits = torch.tensor([[1.2, -0.5], [0.3, 2.1]])  # 形状[2,2]
targets = torch.tensor([[1, 0], [0, 1]])          # 形状[2,2] (one-hot)
loss = F.binary_cross_entropy_with_logits(logits, targets)

Q3:为什么我的loss计算很慢?

  • 检查是否误用了for循环逐个样本计算
  • 正确的向量化计算可加速百倍以上

七、总结

理解logits与targets的维度差异,关键在于把握分类任务的本质需求:

  1. logits提供全类别的判断依据 → 需要二维结构
  2. targets只需指出正确位置 → 一维索引足矣

掌握这一设计哲学后,你就能:
✅ 正确构建分类模型的输出层
✅ 快速调试维度相关的错误
✅ 深入理解损失函数的工作原理

练习建议:在Jupyter Notebook中复现本文的计算示例,尝试修改logits值观察loss变化。


相关阅读

  • PyTorch官方文档:CrossEntropyLoss

如有疑问欢迎留言讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2322171.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

利用zabbix自带key获取数据

获取数据的三种方法 1、链接模版 服务器系统自身的监控 CPU CPU使用率、CPU负载 内存 内存剩余量 硬盘 关键性硬盘的剩余量、IO 网卡 流量/IO(流入流量、流出流量、总流量、错误数据包流量) 进程数 用户数 2、利用zabbix自带的键值key 1)监…

无人机数据处理系统设计要点与难点!

一、系统设计要点 无人机数据处理系统需要高效、可靠、低延迟地处理多源异构数据(如影像、传感器数据、位置信息等),同时支持实时分析和长期存储。以下是核心设计要点: 1.数据采集与预处理 多传感器融合:集成摄像头…

最大异或对 The XOR Largest Pair

题目来自洛谷网站: 思路: 两个循环时间复杂度太高了,会超时。 我们可以先将读入的数字,插入到字典树中,从高位到低位。对每个数查询的时候,题目要求是最大的异或对,所以我们选择相反的路径&am…

基于SpringBoot + Vue 的汽车租赁管理系统

技术介绍: ①:架构: B/S、MVC ②:系统环境:Windows/Mac ③:开发环境:IDEA、JDK1.8、Maven、Mysql ④:技术栈:Java、Mysql、SpringBoot、Mybatis、Vue 项目功能: 角色&am…

基于DrissionPage的TB商品信息采集与可视化分析

一、项目背景 随着电子商务的快速发展,淘宝作为中国最大的电商平台之一,拥有海量的商品信息。这些数据对于市场分析、用户行为研究以及竞争情报收集具有重要意义。然而,由于淘宝的反爬虫机制和复杂的页面结构,直接获取商品信息并不容易。尤其是在电商行业高速发展的今天,商…

电气、电子信息与通信工程的探索与应用

从传统定义来看,电气工程是现代科技领域的核心学科和关键学科。它涵盖了创造产生电气与电子系统的有关学科的总和。然而,随着科学技术的飞速发展,电气工程的概念已经远超出这一范畴。 电子信息工程则是将电子技术、通信技术、计算机技术等应…

Python备赛笔记2

1.区间求和 题目描述 给定a1……an一共N个整数,有M次查询,每次需要查询区间【L,R】的和。 输入描述: 第一行包含两个数:N,M 第二行输入N个整数 接下来的M行,每行有两个整数,L R,中间用空格隔开&…

Unity2022发布Webgl2微信小游戏部分真机黑屏

复现规律: Unity PlayerSetting中取消勾选ShowSplashScreen 分析: 在Unity中,Splash Screen(启动画面) 不仅是视觉上的加载动画,还承担了关键的引擎初始化、资源预加载和渲染环境准备等底层逻辑。禁用后导…

记一次线上SQL死锁事故

一、 引言 SQL死锁是一个常见且复杂的并发控制问题。当多个事务在数据库中互相等待对方释放锁时,就会形成死锁,从而导致事务无法继续执行,影响系统的性能和可用性。死锁不仅会导致数据库操作的阻塞,增加延迟,还可能对…

Axure项目实战:智慧城市APP(六)市民互动(动态面板、显示与隐藏)

亲爱的小伙伴,在您浏览之前,烦请关注一下,在此深表感谢! 课程主题:市民互动 主要内容:动态面板、显示与隐藏交互应用 应用场景:AI产品交互、互动类应用 案例展示: 案例视频&am…

为何服务器监听异常?

报错: 执行./RCF后出现监听异常--在切换网络后,由于前面没有退出./RCF执行状态;重新连接后,会出现服务器监听异常 原因如下: 由于刚开始登录内网,切换之后再重新登录内网,并且切换网络的过程中…

1.认识Excel

一 Excel 可以用来做什么 二 提升技巧 1.数据太多 2.计算太累 3.提升数据的价值和意义 4.团队协作 三 学习目标 学习目标不是为了掌握所有的技能,追逐新功能。而是学知识来解决需求,如果之前的技能和新出的技能都可以解决问题,那不学新技能也…

光谱范围与颜色感知的关系

光谱范围与颜色感知是光学、生理学及技术应用交叉的核心课题,两者通过波长分布、人眼响应及技术处理共同决定人类对色彩的认知。以下是其关系的系统解析: ‌1.基础原理:光谱范围与可见光‌ ‌光谱范围定义‌: 电磁波谱中能被特定…

网络地址转换技术(2)

NAT的配置方法: (一)静态NAT的配置方法 进入接口视图配置NAT转换规则 Nat static global 公网地址 inside 私网地址 内网终端PC2(192.168.20.2/24)与公网路由器AR1的G0/0/1(11.22.33.1/24)做…

Python正则表达式(一)

目录 一、正则表达式的基本概念 1、基本概念 2、正则表达式的特殊字符 二、范围符号和量词 1、范围符号 2、匹配汉字 3、量词 三、正则表达式函数 1、使用正则表达式: 2、re.match()函数 3、re.search()函数 4、findall()函数 5、re.finditer()函数 6…

【TI MSPM0】PWM学习

一、样例展示 #include "ti_msp_dl_config.h"int main(void) {SYSCFG_DL_init();DL_TimerG_startCounter(PWM_0_INST);while (1) {__WFI();} } TimerG0输出一对边缘对齐的PWM信号 TimerG0会输出一对62.5Hz的边缘对齐的PWM信号在PA12和PA13引脚上,PA12被…

MySQL: 创建两个关联的表,用联表sql创建一个新表

MySQL: 创建两个关联的表 建表思路 USERS 表:包含用户的基本信息,像 ID、NAME、EMAIL 等。v_card 表:存有虚拟卡的相关信息,如 type 和 amount。关联字段:USERS 表的 V_CARD 字段和 v_card 表的 v_card 字段用于建立…

更改 vscode ! + table 默认生成的 html 初始化模板

vscode ! 快速成的 html 代码默认为&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>D…

使用LVS的 NAT 模式实现 3 台RS的轮询访问

节点规划 1、配置RS RS的网络配置为NAT模式&#xff0c;三台RS的网关配置为192.168.10.8 1.1配置RS1 1.1.1修改主机名和IP地址 [rootlocalhost ~]# hostnamectl hostname rs1 [rootlocalhost ~]# nmcli c modify ens160 ipv4.method manual ipv4.addresses 192.168.10.7/24…

MySQL实战(尚硅谷)

要求 代码 # 准备数据 CREATE DATABASE IF NOT EXISTS company;USE company;CREATE TABLE IF NOT EXISTS employees(employee_id INT PRIMARY KEY,first_name VARCHAR(50),last_name VARCHAR(50),department_id INT );DESC employees;CREATE TABLE IF NOT EXISTS departments…