分类网络-类别不均衡问题之FocalLoss

news2025/1/13 2:53:24

在这里插入图片描述
有训练和测代码如下:(完整代码来自CNN从搭建到部署实战)
train.py

import torch
import torchvision
import time
import argparse
import importlib
from loss import FocalLoss


def parse_args():
    parser = argparse.ArgumentParser('training')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size in training')
    parser.add_argument('--num_epochs', default=5, type=int, help='number of epoch in training')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    batch_size = args.batch_size
    num_epochs = args.num_epochs
    model = importlib.import_module('models.'+args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)

    loss = torch.nn.CrossEntropyLoss()
    
    if args.model == 'mlp':
        optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
    else:
        optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
          
    train_path = r'./Datasets/mnist_png/training'
    test_path = r'./Datasets/mnist_png/testing'
    transform_list = [torchvision.transforms.Grayscale(num_output_channels=1), torchvision.transforms.ToTensor()]
    if args.model == 'alexnet' or args.model == 'vgg':
        transform_list.append(torchvision.transforms.Resize(size=224))
    if args.model == 'googlenet' or args.model == 'resnet':
        transform_list.append(torchvision.transforms.Resize(size=96))
    transform = torchvision.transforms.Compose(transform_list)

    train_dataset = torchvision.datasets.ImageFolder(train_path, transform=transform)
    test_dataset = torchvision.datasets.ImageFolder(test_path, transform=transform)

    train_iter = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    test_iter = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

    for epoch in range(num_epochs):
        train_l, train_acc, test_acc, m, n, batch_count, start = 0.0, 0.0, 0.0, 0, 0, 0, time.time()
        for X, y in train_iter:
            X, y = X.to(device), y.to(device)
            y_hat = net(X)
            l = loss(y_hat, y)
            optimizer.zero_grad()
            l.backward()
            optimizer.step()
            train_l += l.cpu().item()
            train_acc += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            m += y.shape[0]
            batch_count += 1
        with torch.no_grad():
            for X, y in test_iter:
                net.eval() # 评估模式
                test_acc += (net(X.to(device)).argmax(dim=1) == y.to(device)).float().sum().cpu().item()
                net.train() # 改回训练模式
                n += y.shape[0]
        print('epoch %d, loss %.6f, train acc %.3f, test acc %.3f, time %.1fs'% (epoch, train_l / batch_count, train_acc / m, test_acc / n, time.time() - start))
        torch.save(net, args.model+".pth")

test.py

import cv2
import torch
import argparse
import importlib
from pathlib import Path
import torchvision.transforms.functional


def parse_args():
    parser = argparse.ArgumentParser('testing')
    parser.add_argument('--model',  default='lenet', help='model name [default: mlp]')
    return parser.parse_args()


if __name__ == '__main__':
    args = parse_args()
    model = importlib.import_module('models.' + args.model) 
        
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    net = model.net.to(device)
    net = torch.load(args.model+'.pth')
    net.eval()

    with torch.no_grad():
        imgs_path = Path(r"./Datasets/mnist_png/testing/6/").glob("*")
        acc = 0
        count = 0
        for img_path in imgs_path:
            img = cv2.imread(str(img_path), 0)
            if args.model == 'alexnet' or args.model == 'vgg':  
                img = cv2.resize(img, (224,224))
            if args.model == 'googlenet' or args.model == 'resnet':
                img = cv2.resize(img, (96,96))
            img_tensor = torchvision.transforms.functional.to_tensor(img)
            img_tensor = torch.unsqueeze(img_tensor, 0)
            #print(net(img_tensor.to(device)).argmax(dim=1).item())
            if(net(img_tensor.to(device)).argmax(dim=1).item()==6):
                acc += 1
            count+=1
    print(acc/count)

数据集为mnist手写数字识别,其中训练集中数字0~9的数量分别为:0(5923张),1(6472张),2(5985张),3(6131张),4(5842张),5(5421张),6(5918张),7(6265张),8(5851张),9(5949张), 测试集中数字0~9的数量分别为:0(980张),1(1135张),2(1032张),3(1010张),4(982张),5(892张),6(958张),7(1028张),8(974张),9(1009张)。可见各个类别的数量基本上平衡。测试代码仅测试数字6的准确率,因为后面我们要改变训练集中数字6的数量来进行对比。为了节省时间,仅训练5个epoch。
训练结果:

epoch 0, loss 1.443379, train acc 0.529, test acc 0.877, time 23.4s
epoch 1, loss 0.314123, train acc 0.913, test acc 0.939, time 22.1s
epoch 2, loss 0.174050, train acc 0.949, test acc 0.960, time 21.9s
epoch 3, loss 0.122714, train acc 0.963, test acc 0.971, time 21.8s
epoch 4, loss 0.096798, train acc 0.971, test acc 0.975, time 21.8s

测试结果:

0.9780793319415448

现在将训练集中数字6的数量减少到59张(原来的1/100),来模拟某个类别的数据不平衡的情况。
训练结果:

epoch 0, loss 2.200247, train acc 0.131, test acc 0.373, time 20.8s
epoch 1, loss 0.579792, train acc 0.840, test acc 0.855, time 20.5s
epoch 2, loss 0.177890, train acc 0.950, test acc 0.872, time 20.3s
epoch 3, loss 0.128251, train acc 0.963, test acc 0.880, time 20.5s
epoch 4, loss 0.103937, train acc 0.969, test acc 0.888, time 20.7s

测试结果:

0.04801670146137787

可以看到,训练的准确率下降9%,而测试集直接下降了93%惨不忍睹。

引入FocalLoss模块:(参考https://github.com/QunBB/DeepLearning/blob/main/trick/unbalance/loss_pt.py)

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union


class FocalLoss(nn.Module):
    def __init__(self, alpha: Union[List[float], float], gamma: Optional[int] = 2, with_logits: Optional[bool] = True):
        """
        :param alpha: 每个类别的权重
        :param gamma:
        :param with_logits: 是否经过softmax或者sigmoid
        """
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = torch.FloatTensor([alpha]) if isinstance(alpha, float) else torch.FloatTensor(alpha)
        self.smooth = 1e-8
        self.with_logits = with_logits

    def _binary_class(self, input, target):
        prob = torch.sigmoid(input) if self.with_logits else input
        prob += self.smooth
        alpha = self.alpha.to(target.device)
        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * torch.log(prob)
        return loss

    def _multiple_class(self, input, target):
        prob = F.softmax(input, dim=1) if self.with_logits else input

        alpha = self.alpha.to(target.device)
        alpha = alpha.gather(0, target)

        target = target.view(-1, 1)

        prob = prob.gather(1, target).view(-1) + self.smooth  # avoid nan
        logpt = torch.log(prob)

        loss = -alpha * torch.pow(torch.sub(1.0, prob), self.gamma) * logpt
        return loss

    def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        :param input: 维度为[bs, num_classes]
        :param target: 维度为[bs]
        :return:
        """
        if len(input.shape) > 1 and input.shape[-1] != 1:
            loss = self._multiple_class(input, target)
        else:
            loss = self._binary_class(input, target)

        return loss.mean()

并将train.py的第26行修改成

    loss = FocalLoss([1, 1, 1, 1, 1, 1, 100, 1, 1, 1])

其中列表的数字代表10个类别的权重值。
训练结果:

epoch 0, loss 2.045273, train acc 0.137, test acc 0.467, time 20.7s
epoch 1, loss 0.510476, train acc 0.810, test acc 0.907, time 21.3s
epoch 2, loss 0.148246, train acc 0.922, test acc 0.941, time 21.1s
epoch 3, loss 0.099026, train acc 0.944, test acc 0.953, time 21.2s
epoch 4, loss 0.075481, train acc 0.954, test acc 0.959, time 21.3s

测试结果:

0.9196242171189979

对比看出,FocalLoss可以有效缓解类别不均衡问题(当然并不能完全消除,有足够平衡的高质量数据集肯定更好啦~)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/1115559.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Monocle 3 | 太牛了!单细胞必学R包!~(一)(预处理与降维聚类)

1写在前面 忙碌的一周结束了,终于迎来周末了。🫠 这周的手术真的是做到崩溃,2天的手术都过点了。🫠 真的希望有时间静下来思考一下。🫠 最近的教程可能会陆续写一下Monocle 3,炙手可热啊,欢迎大…

【七:docken+jenkens部署】

一:腾讯云轻量服务器docker部署Jenkins https://blog.csdn.net/qq_35402057/article/details/123589493 步骤1:查询jenkins版本:docker search jenkins步骤2:拉取jenkins镜像 docker pull jenkins/jenkins:lts步骤3:…

设计模式:外观模式(C#、JAVA、JavaScript、C++、Python、Go、PHP)

大家好!本节主要介绍设计模式中的外观模式。 简介: 外观模式,它是一种设计模式,它为子系统中的一组接口提供一个统一的、简单的接口。这种模式主张按照描述和判断资料来评价课程,关键活动是在课程实施的全过程中进行…

【C语言必知必会 | 第八篇】一文带你精通循环结构

引言 C语言是一门面向过程的、抽象化的通用程序设计语言,广泛应用于底层开发。它在编程语言中具有举足轻重的地位。 此文为【C语言必知必会】系列第八篇,进行C语言循环结构的专项练习,结合专题优质题目,带领读者从0开始&#xff0…

丰富功能带来新奇体验,乐划锁屏持续引领行业潮流

在日常使用手机的过程中,为了省电或在口袋里误触,不少用户都会对手机进行锁屏,有这么一款锁屏软件,不仅能够解决以上功能性问题,还有不少其他新奇的功能,那就是乐划锁屏。 现在“手机不离手”、每天翻看手机上百遍的年轻人,最喜欢关注什么样的内容呢?乐划锁屏以图片和视频为主…

RabbitMQ官方案例学习记录

官方文档:RabbitMQ教程 — RabbitMQ (rabbitmq.com) 一、安装RabbitMQ服务 直接使用docker在服务器上安装 docker run -it -d --name rabbitmq -p 5672:5672 -p 15672:15672 rabbitmq:3.12-management 安装完成后,访问15672端口,默认用户…

实时数据更新与Apollo:探索GraphQL订阅

前言 「作者主页」:雪碧有白泡泡 「个人网站」:雪碧的个人网站 「推荐专栏」: ★java一站式服务 ★ ★ React从入门到精通★ ★前端炫酷代码分享 ★ ★ 从0到英雄,vue成神之路★ ★ uniapp-从构建到提升★ ★ 从0到英雄&#xff…

FIFO设计16*8,verilog,源码和视频

名称:FIFO设计16*8,数据显示在数码管 软件:Quartus 语言:Verilog 代码功能: 使用verilog语言设计一个16*8的FIFO,深度16,宽度为8。可对FIFO进行写和读,并将FIFO读出的数据显示到…

【微服务保护】初识 Sentinel —— 探索微服务雪崩问题的解决方案,Sentinel 的安装部署以及将 Sentinel 集成到微服务项目

文章目录 前言一、雪崩问题及其解决方案1.1 什么是雪崩问题1.2 雪崩问题的原因1.3 解决雪崩问题的方法1.4 总结 二、初识 Sentinel 框架2.1 什么是 Sentinel2.2 Sentinel 和 Hystrix 的对比 三、Sentinel 的安装部署四、集成 Sentinel 到微服务 前言 微服务架构在现代软件开发…

Rust通用编程概念

文章目录 变量与可变性数据类型标量类型整数类型浮点类型布尔类型字符类型 复合类型元组数组 函数注释控制流if表达式循环 变量与可变性 变量与可变性 在Rust中,声明变量使用let关键字,并且默认情况下,声明的变量是不可变的,要使变…

C语言实现用递归法将一个整数 n 转换成字符串。例如,输入 483,应输出字符串“483“。n 的位数不确定,可以是任意位数的整数

完整代码&#xff1a; /*用递归法将一个整数 n 转换成字符串。例如&#xff0c;输入 483&#xff0c;应输出字符串"483"。n 的 位数不确定&#xff0c;可以是任意位数的整数。*/ #include<stdio.h> //long long int让n的取值范围更广 void func(long long int…

手搭手zabbix5.0监控redis7

Centos7安装配置Redis7 安装redis #安装gcc yum -y install gcc gcc-c #安装net-tools yum -y install net-tools #官网https://redis.io/ cd /opt/ wget http://download.redis.io/releases/redis-7.0.4.tar.gz 解压至/opt/目录下 tar -zxvf redis-7.0.4.tar.gz -C /opt/ #…

SpringCloud链路追踪——Spring Cloud Sleuth 和 Zipkin 介绍 Windows 下使用初步

前言 在微服务中&#xff0c;随着服务越来越多&#xff0c;对调用链的分析越来越复杂。如何能够分析调用链&#xff0c;定位微服务中的调用瓶颈&#xff0c;并对其进行解决。 本篇博客介绍springCloud中用到的链路追踪的组件&#xff0c;Spring Cloud Sleuth和Zipkin&#xf…

扩散模型学习

第一章 1.1 的原理 给定一批训练数据X&#xff0c;假设其服从某种复杂的真实 分布p(x)&#xff0c;则给定的训练数据可视为从该分布中采样的观测样本x。 生成模型就是估计训练数据的真实分布&#xff0c;使得估计的分布q(x)和真实分布p(x)差距尽可能能的小。 使得所有训练…

【五:Httprunner的介绍使用】

接口自动化框架封装思想的建立。httprunner&#xff08;热加载&#xff1a;动态参数&#xff09;&#xff0c;去应用 意义不大。 day1 一、什么是Httprunner? 1.httprunner是一个面向http协议的通用测试框架&#xff0c;目前最新的版本3.X。以前比较流行的 2.X的版本。2.它的…

FPGA的音乐彩灯VHDL流水灯LED花样,源码和视频

名称&#xff1a;FPGA的音乐彩灯VHDL流水灯LED 软件&#xff1a;Quartus 语言&#xff1a;VHDL 代码功能&#xff1a; &#xff08;1&#xff09;设计一彩灯控制电路&#xff0c;按要求控制8路&#xff08;彩灯由发光 二极管代替&#xff0c;受实验箱限制&#xff0c;多路同…

CUDA 学习记录

1.关于volatile&#xff1a; 对于文章中这个函数&#xff0c; __global__ void reduceUnrollWarps8 (int *g_idata, int *g_odata, unsigned int n) {// set thread IDunsigned int tid threadIdx.x;unsigned int idx blockIdx.x * blockDim.x * 8 threadIdx.x;// convert…

李m圆申论

听话出活 3小时 /处理7500字 /一共5题 /写出2200字 字写得好看点&#xff0c;符号也算字数&#xff0c;占一个格 基本思路&#xff1a;考什么范围答什么 。。。落后&#xff1b;资源闲置、缺乏 申论&#xff1a; 作文题&#xff1a;举例子 处理材料 摘抄&#xff1a; 有人出…

《数据结构、算法与应用C++语言描述》-队列的应用-电路布线问题

《数据结构、算法与应用C语言描述》-队列的应用-电路布线问题 问题描述 在 迷宫老鼠问题中&#xff0c;可以寻找从迷宫入口到迷宫出口的一条最短路径。这种在网格中寻找最短路径的算法有许多应用。例如&#xff0c;在电路布线问题的求解中&#xff0c;一个常用的方法就是在布…

Linux进程(三)--进程切换命令行参数

继上回书Linux进程概念&#xff08;二&#xff09;--进程状态&进程优先级&#xff0c;我们在了解了Linux进程状态和优先级的概念&#xff0c;初步掌握了进程状态的相关知识&#xff0c;最终&#xff0c;我们以Linux进程的优先级&#xff0c;引出了一些其他的概念&#xff1…