pytorch如何使用Focal Loss

news2025/1/12 7:53:54

Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种损失函数。是对标准的Cross Entropy Loss 的一种改进。 FL对于简单样本(p比较大)回应较小的loss。 如论文中的图1, 在p=0.6时, 标准的CE然后又较大的loss, 但是对于FL就有相对较小的loss回应。这样就是对简单样本的一种decay。其中alpha 是对每个类别在训练数据中的频率有关, 但是下面的实现我们是基于alpha=1进行实验的。

在这里插入图片描述

PyTorch中使用Focal Loss,你可以按照以下步骤进行操作

方法一:

1、创建FocalLoss.py文件,添加一下代码

在这里插入图片描述

代码修改处:

  • classnum 处改为你分类的数量
  • P = F.softmax(inputs) 改为 P = F.softmax(inputs,dim=1)
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    r"""
        This criterion is a implemenation of Focal Loss, which is proposed in 
        Focal Loss for Dense Object Detection.

            Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])

        The losses are averaged across observations for each minibatch.

        Args:
            alpha(1D Tensor, Variable) : the scalar factor for this criterion
            gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5), 
                                   putting more focus on hard, misclassified examples
            size_average(bool): By default, the losses are averaged over observations for each minibatch.
                                However, if the field size_average is set to False, the losses are
                                instead summed for each minibatch.


    """
    def __init__(self, class_num=5, alpha=None, gamma=2, size_average=True):
        super(FocalLoss, self).__init__()
        if alpha is None:
            self.alpha = Variable(torch.ones(class_num, 1))
        else:
            if isinstance(alpha, Variable):
                self.alpha = alpha
            else:
                self.alpha = Variable(alpha)
        self.gamma = gamma
        self.class_num = class_num
        self.size_average = size_average

    def forward(self, inputs, targets):
        N = inputs.size(0)
        C = inputs.size(1)
        P = F.softmax(inputs)

        class_mask = inputs.data.new(N, C).fill_(0)
        class_mask = Variable(class_mask)
        ids = targets.view(-1, 1)
        class_mask.scatter_(1, ids.data, 1.)
        #print(class_mask)


        if inputs.is_cuda and not self.alpha.is_cuda:
            self.alpha = self.alpha.cuda()
        alpha = self.alpha[ids.data.view(-1)]

        probs = (P*class_mask).sum(1).view(-1,1)

        log_p = probs.log()
        #print('probs size= {}'.format(probs.size()))
        #print(probs)

        batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p 
        #print('-----bacth_loss------')
        #print(batch_loss)


        if self.size_average:
            loss = batch_loss.mean()
        else:
            loss = batch_loss.sum()
        return loss

2、在你的训练函数里加入模块

from FocalLoss import FocalLoss

loss = FocalLoss()

方法二:

首先,确保你已经导入了torchtorch.nn模块,其中torch.nn提供了各种常见的损失函数。

import torch
import torch.nn as nn

然后,定义一个自定义的Focal Loss类,继承自torch.nn.Module。在类的构造函数中,可以指定Focal Loss所需的参数,例如γ(调节因子)和权重。

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight

    def forward(self, inputs, targets):
        ce_loss = nn.CrossEntropyLoss(weight=self.weight)(inputs, targets)  # 使用交叉熵损失函数计算基础损失
        pt = torch.exp(-ce_loss)  # 计算预测的概率
        focal_loss = (1 - pt) ** self.gamma * ce_loss  # 根据Focal Loss公式计算Focal Loss
        return focal_loss

接下来,在模型训练时,使用自定义的Focal Loss替代交叉熵损失函数即可。

# 定义模型
model = YourModel()

# 定义损失函数(使用自定义的Focal Loss)
criterion = FocalLoss(gamma=2, weight=None)

# 初始化优化器等

# 开始训练循环
for epoch in range(num_epochs):
    # 前向传播、计算损失
    outputs = model(inputs)
    loss = criterion(outputs, targets)

    # 反向传播、更新模型参数
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # 其他操作(如打印训练日志等)

通过以上步骤,就可以在PyTorch中将损失函数由交叉熵损失函数换为Focal Loss。请注意,上述代码示例中的一些细节(例如模型、输入、优化器等)可能需要根据你的实际情况进行修改和补充。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/961592.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

openGauss学习笔记-57 openGauss 高级特性-并行查询

文章目录 openGauss学习笔记-57 openGauss 高级特性-并行查询57.1 适用场景与限制57.2 资源对SMP性能的影响57.3 其他因素对SMP性能的影响57.4 配置步骤 openGauss学习笔记-57 openGauss 高级特性-并行查询 openGauss的SMP并行技术是一种利用计算机多核CPU架构来实现多线程并行…

使用Vue3和Vite升级你的Vue2+Webpack项目

🌷🍁 博主猫头虎(🐅🐾)带您 Go to New World✨🍁 🦄 博客首页——🐅🐾猫头虎的博客🎐 🐳 《面试题大全专栏》 🦕 文章图文…

十一、做高并发内存池项目过程中遇到的bug以及调试bug的方法和心得

十一、做高并发内存池项目过程中遇到的bug以及调试bug的方法和心得 第一个bug是内存问题,程序直接崩溃,问题出现在:GetOneSpan函数中的切分span的时候结尾的span1的next没有置空。 第二个bug是还小内存块给span的时候找不到小内存所属的spa…

深入理解 JVM 之——Java 内存区域与溢出异常

更好的阅读体验 \huge{\color{red}{更好的阅读体验}} 更好的阅读体验 本篇为深入理解 Java 虚拟机第二章内容,推荐在学习前先掌握基础的 Linux 操作、编译原理、计算机组成原理等计算机基础以及扎实的 C/C 功底。 该系列的 GitHub 仓库:https://github…

LVS 实现四层负载均衡项目实战--DR模式(直接路由Direct Routing)

一、原理 负载均衡器和RS都使用同一个IP对外服务。但只有DB对ARP请求进行响应,所有RS对本身这个IP的ARP请求保持静默。也就是说,网关会把对这个服务IP的请求全部定向给DB,而DB收到数据包后根据调度算法,找出对应的RS,把目的MAC地址改为RS的MAC&#xff08…

MacOS 为指定应用添加指定权限(浏览器无法使用摄像头、麦克风终极解决方案)

起因:需要浏览器在线做一些测评,但我的 Chrome 没有摄像头/麦克风权限,并且在设置中是没有手动添加按钮的。 我尝试了重装软件,更新系统(上面的 13.5 就是这么来的,我本来都半年懒得更新系统了&#xff09…

Tomcat安装及环境配置

一、首先确认自己是否已经安装JDK WinR打开运行,输入cmd回车,在DOS窗口中输入java 出现这些说明已经安装了,然后就是查看自己的版本 然后输入java -version 可以看到版本是1.8的 如果没有配置java环境变量,打开系统设置&#…

Linux安装JenkinsCLI

项目简介安装目录 mkdir -p /opt/jenkinscli && cd /opt/jenkinscli JenkinsCLI下载 wget http://<your-jenkins-server>/jnlpJars/jenkins-cli.jar # <your-jenkins-server> 替换为你的 Jenkins 服务器地址 JenkinsCLI授权 Dashboard-->Configure Glob…

渣土车识别监测 渣土车未盖篷布识别抓拍算法

渣土车识别监测 渣土车未盖篷布识别抓拍算法通过yolov7深度学习训练模型框架&#xff0c;渣土车识别监测 渣土车未盖篷布识别抓拍算法在指定区域内实时监测渣土车的进出状况以及对渣土车未盖篷布违规的抓拍和预警。YOLOv7 的策略是使用组卷积来扩展计算块的通道和基数。研究者将…

交叉导轨具有那些方面的优势?

交叉导轨和直线导轨这两种导轨经常被拿出来对比&#xff0c;从结构上来看&#xff0c;交叉导轨是分体结构&#xff0c;组装起来相对于直线导轨比较繁琐&#xff1b;从功能上来看&#xff0c;交叉导轨适合于短行程&#xff0c;高频率、高精度的场合。直线导轨可以做到6000mm单支…

Paimon+StarRocks 湖仓一体数据分析方案

摘要&#xff1a;本文整理自阿里云高级开发工程师曾庆栋&#xff08;曦乐&#xff09;在 Streaming Lakehouse Meetup 的分享。内容主要分为四个部分&#xff1a; 传统数据仓库分析实现方案简介PaimonStarRocks 构建湖仓一体数据分析实现方案StarRocks 与 Paimon 结合的使用方式…

从零开始,掌握C语言中的数据类型

数据类型 1. 前言2. 预备知识2.1 打印整数2.2 计算机中的单位 3. C语言有哪些数据类型呢&#xff1f;3.1 内置类型和自定义类型 4. 每种类型的大小是多少&#xff1f;5. 为什么有这么多数据类型呢&#xff1f;6. 这么多类型应该如何使用呢&#xff1f;6.1 一个小知识 1. 前言 …

Redis功能实战篇之Session共享

1.使用redis共享session来实现用户登录以及token刷新 当用户请求我们的nginx服务器&#xff0c;nginx基于七层模型走的事HTTP协议&#xff0c;可以实现基于Lua直接绕开tomcat访问redis&#xff0c;也可以作为静态资源服务器&#xff0c;轻松扛下上万并发&#xff0c; 负载均衡…

【VR】Network Manager HUD

&#x1f4a6;本专栏是我关于VR开发的笔记 &#x1f236;本篇是——Network Manager HUD Network Manager HUD组件 简介基础知识 简介 网络管理器 HUD是一种快速启动工具&#xff0c;可帮助您立即开始构建多人游戏&#xff0c;而无需首先构建用于游戏创建/连接/加入的用户界面…

云原生Kubernetes:二进制部署K8S单Master架构(一)

目录 一、理论 1.K8S单Master架构 2. etcd 集群 3.flannel网络 4.K8S单Master架构环境部署 5.部署 etcd 集群 6.部署 docker 引擎 7.flannel网络配置 二、实验 1.二进制部署K8S单Master架构 2. 环境部署 3.部署 etcd 集群 4.部署 docker 引擎 5.flannel网络配置…

Ramp 有点意思的题目

粗一看都不知道这个要干什么&#xff0c;这 B 装得不错。 IyEvdXNyL2Jpbi9lbnYgcHl0aG9uMwoKJycnCktlZXAgdXMgb3V0IG9mIGdvb2dsZSBzZWFyY2ggcmVzdWx0cy4uCgokIG9kIC1kIC9kZXYvdXJhbmRvbSB8IGhlYWQKMDAwMDAwMCAgICAgNjAyMTUgICAyODc3OCAgIDI5MjI3ICAgMjg1NDggICA2MjY4NiAgIDQ1MT…

浏览器端vscode docker搭建(附带python环境)

dockerfile from centos:7 #安装python环境 run yum -y install wget openssl-devel bzip2-devel expat-devel gdbm-devel readline-devel zlib-devel libffi-devel gcc make run wget https://www.python.org/ftp/python/3.9.0/Python-3.9.0.tgz run tar -xvf Python-3.9.…

六、编辑器vim编辑器的使用

1、编辑器 (1)编辑器就是一款软件。 (2)作用就是用来编辑文件&#xff0c;譬如编辑文字、编写代码。 (3)Windows中常用的编辑器&#xff0c;有自带的有记事本(notepad)&#xff0c;比较好用的notepad、VSCode等。 (4)Linux中常用的编辑器&#xff0c;自带的最古老的vi&…

中文完形填空

本文通过ChnSentiCorp数据集介绍了完型填空任务过程&#xff0c;主要使用预训练语言模型bert-base-chinese直接在测试集上进行测试&#xff0c;也简要介绍了模型训练流程&#xff0c;不过最后没有保存训练好的模型。 一.完形填空 完形填空应该大家都比较熟悉&#xff0c;就是把…

文件上传漏洞-upload靶场5-12关

文件上传漏洞-upload靶场5-12关通关笔记&#xff08;windows环境漏洞&#xff09; 简介 ​ 在前两篇文章中&#xff0c;已经说了分析上传漏的思路&#xff0c;在本篇文章中&#xff0c;将带领大家熟悉winodws系统存在的一些上传漏洞。 upload 第五关 &#xff08;大小写绕过…