关于nn.CrossEntropyLoss交叉熵损失中weight和ignore_index参数

news2024/11/19 4:31:42

目录

1. 交叉熵损失 CrossEntropyLoss

2. ignore_index 参数

3. weight 参数

4. 例子


1. 交叉熵损失 CrossEntropyLoss

CrossEntropyLoss 交叉熵损失可函数以用于分类或者分割任务中,这里主要介绍分割任务

建立如下的数据,pred是预测样本,label是真实标签

分割中,使用交叉熵损失的话,需要保证label的维度比pred维度少1,也就是没有channel维度。并且,label的类型是int

正常计算损失结果为:

手动计算一下,pred的softmax为

所以,loss = -(ln0.69+ln0.3543+ln0.5987)/3 = -(ln0.1464) / 3 = 0.6406 

后面的是计算产生的误差,这里用数学方法简化计算了

one-hot 编码,只计算label的 ln 预测值

2. ignore_index 参数

在分割任务中,经常有像素点是认为不感兴趣的,所以这里ignore_index可以将那些不感兴趣的像素点排除

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
loss = nn.CrossEntropyLoss(ignore_index=1)
out = loss(pred,label)
print(out)      # tensor(0.4421)

这里将label = 1的像素点排除,手动计算一下

loss = (-ln0.69-ln0.5987) / 2 = 0.4421 

这里将label = 1的忽略了,下面是pred的softmax值

3. weight 参数

当涉及到样本的个数不平衡的时候,可以将样本少的label,w加大点

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1],[0.8, 0.2],[0.7, 0.3]])     # 预测值 size = 3*2, dtype = torch.float32
label = torch.LongTensor([0, 1, 0])                         # 真实值 size = 3 , dtype = torch.int64
w = torch.FloatTensor([1,2])
loss = nn.CrossEntropyLoss(weight=w)
out = loss(pred,label)
print(out)      # tensor(0.7398)

计算方法是:

loss =- ( 1*ln0.69 + 2*ln0.3543+1*ln0.5987) / 4 = (0.3711 + 2.0741+ 0.5130) / 4= 0.7396

可以发现答案是类似的,这里保留了四位小数进行计算,所以有误差

因为,label = 1有一个,label = 0 有两个,所以1的样本较少,这里就对label = 1设置权重大点。可以发现,计算出来的loss确实比不加loss的大,下图为不加w的

如果将w改成[2,1]的话,loss会更低,不利于loss的下降

 

所以,在样本不均衡的情况下,加label少的样本,w加大,可以将loss变大,从而梯度下降的时候可以更好的弥补样本不平衡的问题

注意:w的类型是float

4. 例子

测试代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F


pred = torch.Tensor([[0.9, 0.1,0.2],[0.8, 0.2,0.1],[0.7, 0.3,0.5],[0.1,0.5,0.6]])
label = torch.LongTensor([2, 1, 0,1])

s = F.softmax(pred,dim=1)
print(s)

w = torch.FloatTensor([2,1,2])
loss = nn.CrossEntropyLoss(weight=w,ignore_index=2)
out = loss(pred,label)
print(out)      # tensor(1.0401)

其中,pred的softmax如下:

label 为:2 1 0 1

可以发现,label 是 0 1 2 三类,这里将label = 2的忽略,并且对0 1 2施加的权重为 2 1 2

所以手动计算的公式为,这里精确到六位小数

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

label = 2 的损失 = - ln0.2552 = 1.365708

这里忽略了label = 2,所以还剩:

label = 0 的损失 = - ln0.4018 = 0.911801

label = 1 的损失 = (- ln0.2683 - ln0.3603 ) / 2 = (1.315650 + 1.020818)/2 = 1.168234

并且对0 1 进行加权2 1

所以总的loss = (0.911801 *2 + 1.315650*1+1.020818*1) /(2+1+1) = 4.16007/4=1.0400175

可以发现结果是一样的,这里最后是精度问题

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/419126.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

MongoDB 聚合管道的字段投影($addFields,$set,$unset,$project)

上一篇我们介绍了MongoDB 聚合管道的文档筛选及分组统计: $match:文档过滤 $group:文档分组,并介绍了分组中的常用操作:$addToSet,$avg,$sum,$min,$max等。 如果需要进一…

Ceph集群修复 osd 为 down 的问题

问题描述 由于突然断电了,导致 ceph 服务出现了问题,osd.1 无法起来 ceph osd tree解决方案 尝试重启 systemctl list-units |grep ceph systemctl restart ceph-f0e59898-71d4-11ec-924c-000c290a1a98osd.1.service发现重启无望,可采用…

国内企业使用敏捷开发的多吗?《2022中国企业敏捷实践白皮书》发布(附完整版下载)

通过2021-2022调研数据对比发现,受访者所在企业的敏捷团队占比从2021年的55%提升至2022年的63%,说明越来越多的中国企业正在从传统研发模式转变为敏捷研发模式,并不断扩大敏捷适用范围来促进企业整体敏捷转型; 与2021年的白皮书相…

android jetpack Navigation的使用(java)

简介 Navigation通过图形化的方式管理配置页面的切换。 基本使用 添加依赖 implementation androidx.navigation:navigation-fragment:2.5.3implementation androidx.navigation:navigation-ui:2.5.3创建xml文件(添加导航图)——nav_graph.xml nav_…

Java奠基】Java经典案例讲解

目录 卖飞机票 找质数 开发验证码 数组元素的复制 评委打分 数字加密 数字解密 抢红包 模拟双色球 二维数组 卖飞机票 需求:机票价格按照淡季旺季、头等舱和经济舱收费、输入机票原价、月份和头等舱或经济舱。按照如下规则计算机票价格: 旺季&…

ROS实践14 分布式通信

文章目录运行环境:思路:1.1 设置固定IP2.1 修改hosts文件3.1 检查是否成功通信4.1 修改bashrc5.1 演示运行环境: ubuntu20.04 noetic 宏基暗影骑士笔记本 思路: 主机启动roscore和乌龟速度订阅节点,从机启动乌龟键盘…

大模型时代下做科研的思路

总结zhu老师观点 Efficient 1.这篇论文是真的好orz,总结了目前的视频类模型 修改周边的一些参数,来训练,不改基础的模型(太大了。。。没资源没卡) 引申: prompt 是你想模型干什么你就给提示&#xff08…

python win环境 pip setuptools wheel安装

2023年。 今年的测试小学弟问我python这个安装怎么这么啥b。没有安装pip时 python setup.py install时需要setuptools,安装setuptools需要安装pip。 我看了看他的python是官网下的压缩包解压来的,内部非常干净。python-3.10.11 1. 安装pip 遇到这种情况…

“智慧赟”平台型经济引领行业新标杆

​  2021年,国家高度重视区块链行业发展,各部委发布的区块链相关政策已超60项,区块链不仅被写入“十四五”规划纲要中,各部门更是积极探索区块链发展方向,全方位推动区块链技术赋能各领域发展。在区块链产业具体内容…

【JavaEE】Spring中存储和获取Bean(使用注解)

目录 存储Bean 配置文件中设置扫描路径 使用注解存储Bean 五大类注解存储Bean 五大类注解之间的关系 为什么要有五大类注解 Bean方法注解存储方法返回值 注入Bean 属性注入 Setter方法注入 构造方法注入 Resource注解 存储Bean 上篇文章的存储Bean是在Spring的配置…

16.网络爬虫—字体反爬(实战演示)

网络爬虫—字体反爬一字体反爬原理二字体反爬模块FonttoolsTTF文件三FontCreator 14.0.0.2790FontCreatorPortable下载与安装四实战演示五后记前言: 🏘️🏘️个人简介:以山河作礼。 🎖️🎖️:Python领域新星…

一天吃透MySQL面试八股文

什么是MySQL MySQL是一个关系型数据库,它采用表的形式来存储数据。你可以理解成是Excel表格,既然是表的形式存储数据,就有表结构(行和列)。行代表每一行数据,列代表该行中的每个值。列上的值是有数据类型的…

python调用matlab源码函数

Background 关于在python中调用matlab函数,我之前已经写过两篇文章了,非常详细,且之前的方法可以不用安装matlab程序,只需要按照mcr运行环境就行了。具体可以参考:【java和python调用matlab程序详细记录】【Python 高效…

一文解析为什么进程地址空间中包括操作系统?

今天聊聊进程地址空间这点小事。说到进程的地址空间,大家可能都知道这样一张图: 这张图就是Linux程序运行起来后所谓的进程地址空间,这里包括我们熟悉的代码区、数据区、以及堆区和栈区,今天我们不讲解这些区域,而是重…

elementui的el-message重复点击,提示会一直叠加

1.问题: elementui的el-message连续点击按钮会出现一排提示,注意体验很不友好,而且也不好看 如下: 这种问题如何解决呢 ? 2.参考api elementui的官网有这个api,也就是说通过close这个方法可以解决 3.附上代码&a…

设计模式之美-结构型模式-装饰器模式

装饰器模式主要解决继承关系过于复杂的问题,通过组合来替代继承。指在不改变现有对象结构的情况下,动态地给该对象增加一些职责(即增加其额外功能)的模式,装饰器模式提供了比继承更有弹性的替代方案将功能附加到对象上…

4月20日专家谈:内网突遭攻击,安全人员一招有效处理

随着网络威胁的愈加频繁,企业面临的安全问题也越来越多,传统的安全能力在面对日益增长的安全问题时显得捉襟见肘。 SOAR借助安全编排和自动化技术,将人工操作和技术集成在一起,自动化完成安全处置,帮助企业更快地响应…

JavaScript【九】JavaScript BOM(浏览器对象模型)

文章目录🌟前言🌟 Bom(浏览器对象模型)🌟window对象:🌟属性:🌟 方法:🌟 获取元素:🌟 添加点击事件:🌟 获取表单…

大数据Flink进阶(二十):Flink细粒度资源管理

文章目录 Flink细粒度资源管理 一、细粒度资源管理介绍 二、细粒度资源适用场景

关于合金电阻

合金电阻是一种具有高精度、高稳定性和高温度特性的电阻器件,广泛应用于各种电子设备中。选型合适的合金电阻并进行合理的设计,可以有效地提高电路的性能和可靠性。本文将从合金电阻的基本原理、选型方法及设计要点等方面进行详细介绍。 一、合金电阻的基…