分类任务实现模型集成代码模版

news2024/11/23 13:16:34

分类任务实现模型(投票式)集成代码模版

简介

本实验使用上一博客的深度学习分类模型训练代码模板-CSDN博客,自定义投票式集成,手动实现模型集成(投票法)的代码。最后通过tensorboard进行可视化,对每个基学习器的性能进行对比,直观的看出模型集成的作用。

代码

# -*- coding:utf-8 -*-
import os
import torch
import torchvision
import torchmetrics
import torch.nn as nn
import my_utils as utils
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchensemble.utils import set_module
from torchensemble.voting import VotingClassifier

classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


def get_args_parser(add_help=True):
    import argparse

    parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help)

    parser.add_argument("--data-path", default=r"E:\Pytorch-Tutorial-2nd\data\datasets\cifar10-office", type=str,
                        help="dataset path")
    parser.add_argument("--model", default="resnet8", type=str, help="model name")
    parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
    parser.add_argument(
        "-b", "--batch-size", default=128, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
    )
    parser.add_argument("--epochs", default=200, type=int, metavar="N", help="number of total epochs to run")
    parser.add_argument(
        "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 16)"
    )
    parser.add_argument("--opt", default="SGD", type=str, help="optimizer")
    parser.add_argument("--random-seed", default=42, type=int, help="random seed")
    parser.add_argument("--lr", default=0.1, type=float, help="initial learning rate")
    parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
    parser.add_argument(
        "--wd",
        "--weight-decay",
        default=1e-4,
        type=float,
        metavar="W",
        help="weight decay (default: 1e-4)",
        dest="weight_decay",
    )
    parser.add_argument("--lr-step-size", default=80, type=int, help="decrease lr every step-size epochs")
    parser.add_argument("--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma")
    parser.add_argument("--print-freq", default=80, type=int, help="print frequency")
    parser.add_argument("--output-dir", default="./Result", type=str, help="path to save outputs")
    parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")

    return parser


def main():
    args = get_args_parser().parse_args()
    utils.setup_seed(args.random_seed)
    args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    device = args.device
    data_dir = args.data_path
    result_dir = args.output_dir
    # ------------------------------------  log ------------------------------------
    logger, log_dir = utils.make_logger(result_dir)
    writer = SummaryWriter(log_dir=log_dir)

    # ------------------------------------ step1: dataset ------------------------------------

    normMean = [0.4948052, 0.48568845, 0.44682974]
    normStd = [0.24580306, 0.24236229, 0.2603115]
    normTransform = transforms.Normalize(normMean, normStd)
    train_transform = transforms.Compose([
        transforms.Resize(32),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor(),
        normTransform
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        normTransform
    ])

    # root变量下需要存放cifar-10-python.tar.gz 文件
    # cifar-10-python.tar.gz可从 "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" 下载
    train_set = torchvision.datasets.CIFAR10(root=data_dir, train=True, transform=train_transform, download=True)
    test_set = torchvision.datasets.CIFAR10(root=data_dir, train=False, transform=valid_transform, download=True)

    # 构建DataLoder
    train_loader = DataLoader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
    valid_loader = DataLoader(dataset=test_set, batch_size=args.batch_size, num_workers=args.workers)

    # ------------------------------------ tep2: model ------------------------------------
    model_base = utils.resnet20()
    # model_base = utils.LeNet5()
    model = MyEnsemble(estimator=model_base, n_estimators=3, logger=logger, device=device, args=args,
                       classes=classes, writer=writer, save_dir=log_dir)
    model.set_optimizer(args.opt, lr=args.lr, weight_decay=args.weight_decay)
    model.fit(train_loader, test_loader=valid_loader, epochs=args.epochs)


class MyEnsemble(VotingClassifier):
    def __init__(self, **kwargs):
        # logger, device, args, classes, writer
        super(VotingClassifier, self).__init__(kwargs["estimator"], kwargs["n_estimators"])
        self.logger = kwargs["logger"]
        self.writer = kwargs["writer"]
        self.device = kwargs["device"]
        self.args = kwargs["args"]
        self.classes = kwargs["classes"]
        self.save_dir = kwargs["save_dir"]

    @staticmethod
    def save(model, save_dir, logger):
        """Implement model serialization to the specified directory."""
        if save_dir is None:
            save_dir = "./"

        if not os.path.isdir(save_dir):
            os.mkdir(save_dir)

        # Decide the base estimator name
        if isinstance(model.base_estimator_, type):
            base_estimator_name = model.base_estimator_.__name__
        else:
            base_estimator_name = model.base_estimator_.__class__.__name__

        # {Ensemble_Model_Name}_{Base_Estimator_Name}_{n_estimators}
        filename = "{}_{}_{}_ckpt.pth".format(
            type(model).__name__,
            base_estimator_name,
            model.n_estimators,
        )

        # The real number of base estimators in some ensembles is not same as
        # `n_estimators`.
        state = {
            "n_estimators": len(model.estimators_),
            "model": model.state_dict(),
            "_criterion": model._criterion,
        }
        save_dir = os.path.join(save_dir, filename)

        logger.info("Saving the model to `{}`".format(save_dir))

        # Save
        torch.save(state, save_dir)

        return

    def fit(self, train_loader, epochs=100, log_interval=100, test_loader=None, save_model=True, save_dir=None, ):

        # 模型、优化器、学习率调整器、评估器 列表创建
        estimators = []
        for _ in range(self.n_estimators):
            estimators.append(self._make_estimator())

        optimizers = []
        schedulers = []
        for i in range(self.n_estimators):
            optimizers.append(set_module.set_optimizer(estimators[i],
                                                       self.optimizer_name, **self.optimizer_args))
            scheduler_ = torch.optim.lr_scheduler.MultiStepLR(optimizers[i], milestones=[100, 150],
                                                              gamma=self.args.lr_gamma)  # 设置学习率下降策略
            # scheduler_ = torch.optim.lr_scheduler.StepLR(optimizers[i], step_size=self.args.lr_step_size,
            #                                             gamma=self.args.lr_gamma)  # 设置学习率下降策略
            schedulers.append(scheduler_)

        acc_metrics = []
        for i in range(self.n_estimators):
            # task类型与任务一致
            # num_classes与分类任务的类别数一致
            acc_metrics.append(torchmetrics.Accuracy(task="multiclass", num_classes=len(self.classes)))

        self._criterion = nn.CrossEntropyLoss()

        # epoch循环迭代
        best_acc = 0.
        for epoch in range(epochs):

            # training
            for model_idx, (estimator, optimizer, scheduler) in enumerate(zip(estimators, optimizers, schedulers)):
                loss_m_train, acc_m_train, mat_train = \
                    utils.ModelTrainerEnsemble.train_one_epoch(
                        train_loader, estimator, self._criterion, optimizer, scheduler, epoch,
                        self.device, self.args, self.logger, self.classes)
                # 学习率更新
                scheduler.step()

                # 记录
                self.writer.add_scalars('Loss_group', {'train_loss_{}'.format(model_idx):
                                                           loss_m_train.avg}, epoch)
                self.writer.add_scalars('Accuracy_group', {'train_acc_{}'.format(model_idx):
                                                               acc_m_train.avg}, epoch)
                self.writer.add_scalar('learning rate', scheduler.get_last_lr()[0], epoch)
                # 训练混淆矩阵图
                conf_mat_figure_train = utils.show_conf_mat(mat_train, classes, "train", save_dir, epoch=epoch,
                                                            verbose=epoch == epochs - 1, save=False)
                self.writer.add_figure('confusion_matrix_train', conf_mat_figure_train, global_step=epoch)

            # validate
            loss_valid_meter, acc_valid, top1_group, mat_valid = \
                utils.ModelTrainerEnsemble.evaluate(test_loader, estimators, self._criterion, self.device, self.classes)

            # 日志
            self.writer.add_scalars('Loss_group', {'valid_loss':
                                                       loss_valid_meter.avg}, epoch)
            self.writer.add_scalars('Accuracy_group', {'valid_acc':
                                                           acc_valid * 100}, epoch)
            # 验证混淆矩阵图
            conf_mat_figure_valid = utils.show_conf_mat(mat_valid, classes, "valid", save_dir, epoch=epoch,
                                                        verbose=epoch == epochs - 1, save=False)
            self.writer.add_figure('confusion_matrix_valid', conf_mat_figure_valid, global_step=epoch)

            self.logger.info(
                'Epoch: [{:0>3}/{:0>3}]  '
                'Train Loss avg: {loss_train:>6.4f}  '
                'Valid Loss avg: {loss_valid:>6.4f}  '
                'Train Acc@1 avg:  {top1_train:>7.2f}%   '
                'Valid Acc@1 avg: {top1_valid:>7.2%}    '
                'LR: {lr}'.format(
                    epoch, self.args.epochs, loss_train=loss_m_train.avg, loss_valid=loss_valid_meter.avg,
                    top1_train=acc_m_train.avg, top1_valid=acc_valid, lr=schedulers[0].get_last_lr()[0]))

            for model_idx, top1_meter in enumerate(top1_group):
                self.writer.add_scalars('Accuracy_group',
                                        {'valid_acc_{}'.format(model_idx): top1_meter.compute() * 100}, epoch)

            if acc_valid > best_acc:
                best_acc = acc_valid
                self.estimators_ = nn.ModuleList()
                self.estimators_.extend(estimators)
                if save_model:
                    self.save(self, self.save_dir, self.logger)


if __name__ == "__main__":
    main()

效果图

本实验采用3个学习器进行投票式集成,因此绘制了7条曲线,其中各学习器在训练和验证各有2条曲线,集成模型的结果通过 valid_acc输出(蓝色),通过下图可发现,集成模型与三个基学习器相比,分类准确率都能提高3-4百分点左右,是非常高的提升了。

image-20240830103703565

image-20240830154555390

image-20240830154619630

参考

7.7 TorchEnsemble 模型集成库 · PyTorch实用教程(第二版) (tingsongyu.github.io)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/2091565.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Java之初始泛型

1 包装类 在Java中,由于基本类型不是继承自Object,为了在泛型代码中可以支持基本类型,Java给每个基本类型都对应了一个包装类型。 1.1 基本数据类型和对应的包装类 基本数据类型包装类byteByteshortShortintIntegerlongLongfloatFloatdoub…

RAG最佳实践:用 ElasticSearch 打造AI搜索系统与RAG 应用全流程详解!

前面一篇文章《RAG 向量数据库:掌握 Elasticsearch 作为向量数据库的终极指南》中,介绍了使用ElasticSerach作为向量数据的安装和使用指南。 今天这篇文章将介绍如何使用 Elasticsearch 搭建AI搜索系统和RAG应用系统。 Elasticsearch 搭建 AI 搜索系统 在 Elasticsearch 中…

游泳耳机哪个牌子的好?四大口碑精品游泳耳机专业推荐!

在追求健康生活的同时,游泳成为了许多人选择的锻炼方式。它不仅能够帮助人们塑造身材,还能有效缓解压力。而在游泳过程中,音乐的陪伴无疑能让人更加享受这段时光。因此,一款适合游泳时使用的耳机,成为了游泳爱好者们不…

java程序CUP持续飙高

1.top 2.定位进程中使用CPU最高的线程 top -Hp 70688 3.将线程ID转为十六进制 printf "0x%x\n" 28760 4.jstack工具跟踪堆栈定位代码 jstack 70688 | grep 0x7058 -A 10

尺度和位置敏感的红外小目标检测

Infrared Small Target Detection with Scale and Location Sensitivity 在本文中,着重于以更有效的损失和更简单的模型结构来提升检测性能。 问题一 红外小目标检测(IRSTD)一直由基于深度学习的方法主导。然而,这些方法主要集中…

python-春游

[题目描述] 老师带领同学们春游。已知班上有 N 位同学,每位同学有从 0 到 N−1 的唯一编号。到了集合时间,老师确认是否所有同学都到达了集合地点,就让同学们报出自己的编号。到达的同学都会报出自己的编号,不会报出别人的编号&am…

单链表应用

基于单链表实现通讯录项目 //Contact.c #define _CRT_SECURE_NO_WARNINGS 1 #include"contact.h" #include"list.h"//初始化通讯录 void InitContact(contact** con) {con NULL;} //添加通讯录数据 void AddContact(contact** con) {PeoInfo info;printf…

无主灯设计:吊顶之问与光影艺术的探索

在现代家居设计中,照明不仅仅是为了满足基本的照明需求,更是一种艺术和情感的表达。随着无主灯设计越来越受到人们的青睐,许多业主开始考虑一个问题:进行无主灯设计时,是否一定需要吊顶呢?本文将深入探讨这…

2017年系统架构师案例分析试题五

目录 案例 【题目】 【问题 1】(5 分) 【问题 2】(16 分) 【问题 3】(4 分) 【答案】 【问题 1】解析 【问题 2】解析 【问题 3】答案 相关推荐 案例 阅读以下关于 Web 系统架构设计的叙述,在答题纸上回答问题 1 至问题 3。 【题目】 某电子商务企业因发…

小波神经网络的时间序列的短时交通流量预测

小波神经网络的时间序列的短时交通流量预测 通过小波分析进行负荷序列分 解, 获得不同频率负荷分量规律 ; 由粒子群算法进行粒子群适应度排序 , 提升算法收敛速度和收敛能力 ; 为避免算法陷入局部 收敛性, 引入混沌理论来增强全局搜索能力 。 预测结果

linux 系统如何进行nfs(第五节)

网上的截图: 自己的操作: 首先是 在虚拟机中的操作。 然后是在开发板上的操作。 已经是没有问题了。

AI绘画【Stable Diffusion】抽卡必备!时间管理大师Agent Scheduler插件,一键设置任务,让你的休息时间充分利用起来!

大家好,我是灵魂画师向阳 相信大家在玩 Stable Diffusion 的时候一直有一个痛点,每次出图抽卡时都只能等待上一次抽卡结束,才能继续下一次抽卡; 特别是当我们想抽大量的卡来测试不同的模型,不同的参数的效果时&#…

大学生社团管理系统

一、项目概述 Hi,大家好,今天分享的项目是《大学生社团管理系统》。 随着校园文化的不断丰富,大学里各种社团越来越多,社团活动也越来越频繁,社团管理就显得繁琐,传统的人工管理方式比较麻烦,…

Client客户端模块

一.Client模块介绍 二.Client具体实现 1.消费者/订阅者模块 2.信道管理模块 3.异步线程模块 4.连接管理模块 这个模块同样是针对muduo库客户端连接的二次封装,向用户提供创建channel信道的接口,创建信道后,可以通过信道来获取指定服务。 三…

游泳耳机哪个牌子好?四大硬核爆款游泳耳机推荐种草!

随着人们对健康生活方式的不断追求,游泳作为一项全身性的运动受到了越来越多人的喜爱。与此同时,为了在水下也能享受音乐的乐趣,游泳耳机应运而生,并迅速成为泳池和海滩上不可或缺的装备之一。面对市面上琳琅满目的游泳耳机产品&a…

线性表之静态链表

1. 静态链表的设计 1.1 定义静态链表 链表是由多个相同类型的节点组成的线性表,它的每个节点都包含一个数据项和一个指向下一个节点的指针,链表中各个节点的地址是不连续的。 下面是一个用于存储整形数据的链表节点结构: struct Node {int…

深度学习与大模型第1课环境搭建

深度学习与大模型第1课 环境搭建 1. 安装 Anaconda 首先,您需要安装 Anaconda,这是一个开源的 Python 发行版,能够简化包管理和环境管理。以下是下载链接及提取码: 链接:https://pan.baidu.com/s/1Na2xOFpBXQMgzXA…

Text Control 控件教程:智能文档处理 (IDP)

TX Text Control 是一款功能类似于 MS Word 的文字处理控件,包括文档创建、编辑、打印、邮件合并、格式转换、拆分合并、导入导出、批量生成等功能。广泛应用于企业文档管理,网站内容发布,电子病历中病案模板创建、病历书写、修改历史、连续打…

【云计算】什么是云计算服务|为什么出现了云计算|云计算的服务模式

文章目录 什么是云计算服务本地部署VS云计算SaaS PaaS IaaS公有云、私有云、混合云为什么优先发展云计算服务的厂商是亚马逊、阿里巴巴等公司 什么是云计算服务 根据不同的目标用户,云计算服务(Cloud Computing Services)分为两种&#xff1…