pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

news2024/11/28 12:40:44

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np

class CrossEntropyLoss2d(nn.Module):
    def __init__(self, weight=None):
      super(CrossEntropyLoss2d, self).__init__()
       self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')
    def forward(self, preds, targets):
        return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):
    def __init__(self,ignore_index=255):
        super(CE_w_loss, self).__init__()
        self.ignore_index = ignore_index
        # self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)
    def forward(self, outputs, targets):
        class_num = outputs.shape[1]
        # print("class_num :",class_num )
        # # 计算每个类别在整个 batch 中的像素数占比
        class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别
        class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)
        # # 根据类别占比计算权重
        class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重
        # # print("class_weights :",class_weights)
        #
        # 定义交叉熵损失函数,并使用动态计算的类别权重
        criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)

        # 计算损失
        loss = criterion(outputs, targets)
        print(loss.item())  # 打印损失值
        return loss

    np.random.seed(666)
    pred = np.ones((2, 5, 256,256))
    seg = np.ones((2, 5, 256, 256)) # 灰度
    label = np.ones((2, 256, 256))  # 灰度

    pred = torch.from_numpy(pred)
    seg = torch.from_numpy(seg).int()  # 灰度
    label = torch.from_numpy(label).long()
     ce = CE_w_loss()
    loss = ce(pred, label)
    print("loss:",loss.item())

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1802560.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

【香橙派】Orange Pi AIpro体验——国产AI赋能

文章目录 🍔开箱🛸烧录镜像⭐启动系统🎈本机登录🎈远程登陆 🎆AI功能体验🔎总结 🍔开箱 可以看到是很精美的开发组件 这里是香橙派官网 http://www.orangepi.cn/ 我们找到下面图片的内容&#…

uc/OS移植到stm32实现三个任务

文章目录 一、使用CubeMX创建工程二、uc/OS移植三、添加代码四、修改代码五、实践结果六、参考文章七、总结 实践内容 学习嵌入式实时操作系统(RTOS),以uc/OS为例,将其移植到stm32F103上,构建至少3个任务(task&#xf…

python有short类型吗

Python 数字数据类型用于存储数值。 Python 支持三种不同的数值类型:整型(int)、浮点型(float)、复数(complex)。 在其他的编程语言中,比如Java、C这一类的语言中还分有长整型&…

西米支付:刷卡手续费进入高费率时代! 十多家支付机构公布最新收费标准

《非银行支付机构监督管理条例》自5月1日施行以来,越来越多支付机构落实收费透明化。 支付界注意到,日前,拉卡拉、银联商务两家持牌支付公司公布了新的收单业务收费标准。 拉卡拉在其官网公布了最新的“收费项目及收费标准公示”&#xff0…

二叉查找树详解

目录 二叉查找树的定义 二叉查找树的基本操作 查找 插入 建立 删除 二叉树查找树的性质 二叉查找树的定义 二叉查找树是一种特殊的二叉树,又称为排序二叉树、二叉搜索树、二叉排序树。 二叉树的递归定义如下: (1)要么二…

CorelDraw安装时界面显示不全的解决方案

问题原因:安装包权限 解决方案: 1、安装包解压后,找到Setup文件,复制粘贴到当前文件夹并重命名为Getup后,右击Getup文件,选择“以管理员身份运行” 说明:除了命名Gsetup。还可以命名为其他的…

【Java】Java18的新特性

人不走空 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌赋:斯是陋室,惟吾德馨 目录 🌈个人主页:人不走空 💖系列专栏:算法专题 ⏰诗词歌…

史上最有趣嫁妆:晋公盘的传奇

在遥远的春秋时代,晋国的晋文公为他的女儿用心打造了一件独特的嫁妆——晋公盘。 晋公盘由青铜制成,形状独特,工艺精湛。在晋公盘内底中央,一对精美的浮雕龙盘绕成圆形,盘上饰有鸟、龟、鱼、蛙等多种动物,最…

带你学习Mybatis之逆向工程

逆向工程 可以针对单表自动生成MyBatis执行所需要的代码&#xff0c;包括&#xff1a;Mapper.java&#xff0c;Mapper.xml&#xff0c;实体类&#xff0c;这样可以减少重复代码的编写 <dependency> <groupId>org.mybatis.generator</groupId> …

【数据结构】初识数据结构之复杂度与链表

【数据结构】初识数据结构之复杂度与链表 &#x1f525;个人主页&#xff1a;大白的编程日记 &#x1f525;专栏&#xff1a;C语言学习之路 文章目录 【数据结构】初识数据结构之复杂度与链表前言一.数据结构和算法1.1数据结构1.2算法1.3数据结构和算法的重要性 二.时间与空间…

人工智能对话系统源码 手机版+电脑版二合一 全端支持 前后端分离 带完整的安装代码包以及搭建部署教程

系统概述 该系统采用前后端分离的设计模式&#xff0c;前端负责用户界面展示与交互&#xff0c;后端负责数据处理与业务逻辑实现。前后端通过API接口进行通信&#xff0c;实现数据的实时传输与处理。系统支持全端访问&#xff0c;无论是手机还是电脑&#xff0c;都能获得良好的…

Type-C接口,乱成一锅粥了!

前言 小白第一次接触Type-C接口的时候是2017年的夏天&#xff0c;那时候小白买了小米最新发布的小米5x手机&#xff0c;使用的是相当主流的Type-C接口。 在2017年之前&#xff0c;很多手机还是使用Micro-USB作为手机首选的充电接口。 当同宿舍的小伙伴还在为给手机充电需要分辨…

Linux卸载RocketMQ教程【带图文命令巨详细】

巨详细Linux卸载RocketMQ教程 #查询rocketmq进程 ps -ef | grep rocketmq #杀掉相关进程 kill -9 进程id #查找安装目录 find / -name runbroker.sh #删除rocketMQ目录 rm -rf 安装目录框起来的就是进程id&#xff0c;全部杀掉 这里就是我的安装目录&#xff0c;我的删除命令…

基于STM32开发的智能语音控制系统

目录 引言环境准备智能语音控制系统基础代码实现&#xff1a;实现智能语音控制系统 4.1 语音识别模块数据读取4.2 设备控制4.3 实时数据监控与处理4.4 用户界面与反馈显示应用场景&#xff1a;语音控制的家居设备管理问题解决方案与优化收尾与总结 1. 引言 随着人工智能技术…

从零开始实现自己的串口调试助手(10) - 优化 收尾 + 打包

光标位置优化 在接收槽函数中更新光标位置: // 让光标始终在结尾 ui->textEditRev->moveCursor(QTextCursor::End); ui->textEditRev->ensureCursorVisible(); // 让光标可视化 //记得HEX显示槽函数底下也得加上这两行代码 新的接收槽函数如下: void Wid…

10秒钟docker 安装Acunetix

1、拉取镜像&#xff1a; 2、查看镜像&#xff1a; [rootdns-server ~]# docker images REPOSITORY TAG IMAGE ID CREATED SIZE quay.io/hiepnv/acunetix latest f8415551b8f4 2 months ago 1.98GB 3、运行镜像&#xff1a; …

算法训练营day03--203.移除链表元素+707.设计链表+206.反转链表

一、203.移除链表元素 题目链接&#xff1a;https://leetcode.cn/problems/remove-linked-list-elements/ 文章讲解&#xff1a;https://programmercarl.com/0203.%E7%A7%BB%E9%99%A4%E9%93%BE%E8%A1%A8%E5%85%83%E7%B4%A0.html 视频讲解&#xff1a;https://www.bilibili.com…

CentOS7 MySQL5.7.35主从 不停机搭建 以及配置

如需安装MySQL&#xff0c;参照MySQL 5.7.35 安装教程 https://blog.csdn.net/CsethCRM/article/details/119418841一、主&从 环境信息准备 1.1.查看硬盘信息&#xff0c;确保磁盘够用&#xff08;主&从&#xff09; df -h1.2.查看内存信息 &#xff08;主&从&am…

11 IP协议 - IP协议头部

什么是 IP 协议 IP&#xff08;Internet Protocol&#xff09;是一种网络通信协议&#xff0c;它是互联网的核心协议之一&#xff0c;负责在计算机网络中路由数据包&#xff0c;使数据能够在不同设备之间进行有效的传输。IP协议的主要作用包括寻址、分组、路由和转发数据包&am…

Android 应用权限

文章目录 权限声明uses-permissionpermissionpermission-grouppermission-tree其他uses-feature 权限配置 权限声明 Android权限在AndroidManifest.xml中声明&#xff0c;<permission>、 <permission-group> 、<permission-tree> 和<uses-permission>…