Pytorch高级训练框架Ignite详细介绍与常用模版

news2025/1/22 12:30:15

引言

Ignite是Pytorch配套的高级框架,我们可以借其构筑一套标准化的训练流程,规范训练器在每个循环、轮次中的行为。本文将不再赘述Ignite的具体细节或者API,详见官方教程和其他博文。本文将分析Ignite的运行机制、如何将Pytorch训练代码转为Ignite范式,最后给出个人设计的标准化Ignite训练模版。

Ignite简介

 Ignite所做的事情就是我们在pytorch里常写的范式用更加机械、更加标注格式展现出来,这也就是为啥其核心被称为–Engine,高效而精密。Pytorh里常用的训练范式如下:

for ep in Epoch:
	for batch in train_loader:
	    model.train()
	    inputs, targets = batch
	    optimizer.zero_grad()
	    outputs = model(inputs)
	    loss = criterion(outputs, targets)
	    loss.backward()
	    optimizer.step()
	    
		if it%log_period:
			print()
	if ep%save_perid:
		torch.save()

具体而言,可以拆解为批训练、批完结处理、轮次完结处理三个组成部分,批训练部分是网络训练的基础单元,完成数据当前批次读取、前向传播、反向传播等步骤,批完结处理负责在每个批次结束后输出模型训练的相关信息,轮次完结处理负责在每个epoch结束进行模型的保存、对模型的训练参数进行更新。这三个模型训练的主要组成部分在ignite中得到了完整的封装,围绕批训练构造了一个核心的Engine,将批完结处理轮次完结处理附加该Engine运行的时间轴中,形成了批训练->批完结处理->轮次完结处理的流水线作业范式,更为详细的时间轴如下1
在这里插入图片描述
以下将从实用性的角度出发给出Ignite的建设框架,最终给出个人设计的Ignite使用模版,后续直接在train.py文件里直接调用do_train()函数即可利用Ignite进行模型训练。为讲解需要,中间每个子部分的代码为最终代码中相应部分重新排序得到,最终代码中其顺序会进行调整。

批训练

 批训练的代码较为简单,只需将原本的Pytorch版本批处理流程复制粘贴,最后将该过程函数化,并且实例化成Engine即可,代码如下所示,最终启动Engine,即可进行模型的训练。到此为止,实际上已经完成了狭义上的“模型”训练部分。

 def create_supervised_trainer(model,optimizer,criterion,
                              device=None, non_blocking=False,
                              prepare_batch=_prepare_batch,
                              output_transform=lambda x, y, y_pred, loss: loss.item()):
      """
      有监督模型的Engine创建

      Args:
          model (`torch.nn.Module`):
          optimizer (`torch.optim.Optimizer`):
          loss_fn (torch.nn loss function):
          device (str, optional):
          non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
              with respect to the host. For other cases, this argument has no effect.
          prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
          output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()

      Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item

      Returns:
          Engine: 有监督任务的engine实例
      """
      if device:
          model.to(device)

      def _update(engine, batch):
          model.train()
          optimizer.zero_grad()
          x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
          output = model(x)
          loss=criterion(output,y)
          loss.backward()
          optimizer.step()
          return output_transform(x, y, None, loss)

      return Engine(_update)

trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的engine
trainer.run(train_loader,max_epochs=cfg['max_epochs'])

批完结处理

 批完结处理部分我们常做的操作是输出模型在当前批的损失,Ignite中这一过程通过在Engine上附着于ITERAION_COMPLETEDE时触发的回调函数实现。实际上这只是限定了触发时间,具体进行何种操作,完全依赖于个人的选择。我们只需要知道该函数可以利用engine保留的当前批属性信息进行各种操作即可,具体可以利用哪些属性,见官方API2,本文只利用了常用的几个。

    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        """
        隔一定iteration输出模型损失
        """
        log_period=int(cfg["log_period"]*len(train_loader))  # 跑了log_period*len输出一次,取值<=1
        if engine.state.iteration%log_period==0:
            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.output:.2f}")
            pbar.update(log_period)


    @trainer.on(Events.ITERATION_COMPLETED)
    def scheduler_update(engine):
        """
        optional 每个ITER更新学习率
        """
        scheduler.step()

轮次完结处理

  轮次完结处理和批完结处理相同,也是通过回调函数实现,我们通常在轮次完结处理要进行模型的保存,这里就要做两件事:

  1. 在val_loader上验证模型效果
  2. 保留迄今为止效果最好的模型

针对于第一个要求,这里我同样采用了Ignite风格的Engine驱动范式,读者可以自行选择在这里切换为Pytorch范式,构建验证集Engine的代码如下:

    def create_supervised_evaluator(model, metric,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred: (y_pred,y)):
        """
        构造evaluator
        :param model:
        :param metric: dict,key为metric名字,value为Metric类
        :param device:
        :param non_blocking:
        :param prepare_batch:
        :param output_transform:
        :return:
        """
        if device:
            model.to(device)

        def _inference(engine, batch):
            model.eval()
            with torch.no_grad:
                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
                output = model(x)
            return output_transform(x, y, output)

        engine=Engine(_inference)

		# 附着metric
        for name, metric in metric.items():
            metric.attach(engine,name)

        return engine

可以看到和trainer较为不同的点在于去除了opt等等选项,此外,由于保存模型时我们要依据验证集上的metric来判断是否要保存当前模型还是沿用此前的模型,因此额外将一个Metric类附着在了Engine上,它使得模型可以自动收集eval_engine每个轮次的输出,并进行metric的计算,ignite中提供了许多metric选项3,这里笔者给出自己定制mertric的范式如下,主要由reset(),update()commpute()组成,reset()完成每个epoch的记录状态重置,update()则接受某一批次engine的输出值,commpute()完成最终的metric计算。值得一提的是,在trainer上我们并没有额外附着Loss类,而是直接用engine输出了loss,实际上或许你也可以用相同的方式对eval_engine进行处理。

class CustomMetric(Metric):
    def __init__(self):
        super(CustomMetric,self).__init__()

    def reset(self) -> None:
        self._num_correct=0
        self._num_examples=0

    def update(self, output) -> None:
        '''
        保存该轮次的输出
        :param output: 每个batch engine的输出
        :return:
        '''
        pred,label=output
        pred=pred.detach()
        label=label.detach()

        indices=torch.argmax(pred,dim=1)
        correct=torch.eq(indices,label).view(-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    def compute(self):
        '''
        计算总ACC
        :return:
        '''
        return self._num_correct/self._num_examples

 完成第二步的方法Ignite同样进行了封装,即Checkpoint4,但笔者也进行了自己的定制化,如下:

class BestCheckPoint():
    def __init__(self,save_path,n_saved,model_name):
        '''
        建立存档点类
        :param save_path: 存档点保存路径
        :param n_saved:  保留的存档点数目
        '''
        self.save_path=save_path
        self.n_save=n_saved
        self.model_name=model_name
        self.score=[]
        if not os.path.exists(save_path):
            os.mkdir(self.save_path)

    def update(self,score):
        '''
        更新最优记录
        :param score: 当前模型的metric
        :return:
        '''
        if type(self.score)==torch.Tensor:
            score=score.item()
        if len(self.score)<self.n_save:
            self.score.append(score)
            self.score.sort()
            self.removed=[]
            self._in=[]
            return True
        else:
            value=self.score[0]
            if score>value:
                self.score.remove(value)
                self.score.append(score)
                self.score.sort()
                return value
            else:
                return False

    def save(self,score,model):
        '''
        视当前得分判断是否保存当前模型并删除
        :param score: 当前模型得分
        :param model: 模型
        :return:
        '''
        is_save=self.update(score)
        if is_save:
            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
            # pop的存档要删除
            if not isinstance(is_save,bool):
                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
                try:
                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
                except:
                    print("already removed")

主要就是设置了一个metric池,新的metric进来后判断是否优于池子里最烂的模型,并以此判断是否进行保存。

运行框架

 将上述模块封装在一起,我们就可以得到了最终的ignite运行框架,而后只需导入该文件,并运行其中的do_train()函数即可轻松完成模型训练。

整体代码:

# -*- coding: utf-8 -*-
# ---
# @File: trainer.py
# @Author: sgdy3
# @E-mail: sgdy03@163.com
# @Time: 2023/5/9 19:44
# Describe: 
# ---
import os

from tqdm import tqdm
import ignite
import torch
from ignite.engine import Engine
from ignite.utils import convert_tensor
from ignite.engine.engine import Engine, State, Events
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Metric,Accuracy


class BestCheckPoint():
    def __init__(self,save_path,n_saved,model_name):
        '''
        建立存档点类
        :param save_path: 存档点保存路径
        :param n_saved:  保留的存档点数目
        '''
        self.save_path=save_path
        self.n_save=n_saved
        self.model_name=model_name
        self.score=[]
        if not os.path.exists(save_path):
            os.mkdir(self.save_path)

    def update(self,score):
        '''
        更新最优记录
        :param score: 当前模型的metric
        :return:
        '''
        if type(self.score)==torch.Tensor:
            score=score.item()
        if len(self.score)<self.n_save:
            self.score.append(score)
            self.score.sort()
            self.removed=[]
            self._in=[]
            return True
        else:
            value=self.score[0]
            if score>value:
                self.score.remove(value)
                self.score.append(score)
                self.score.sort()
                return value
            else:
                return False

    def save(self,score,model):
        '''
        视当前得分判断是否保存当前模型并删除
        :param score: 当前模型得分
        :param model: 模型
        :return:
        '''
        is_save=self.update(score)
        if is_save:
            torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
            # pop的存档要删除
            if not isinstance(is_save,bool):
                # 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
                try:
                    os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
                except:
                    print("already removed")



class CustomMetric(Metric):
    def __init__(self):
        super(CustomMetric,self).__init__()

    def reset(self) -> None:
        self._num_correct=0
        self._num_examples=0

    def update(self, output) -> None:
        '''
        保存该轮次的输出
        :param output: 每个batch engine的输出
        :return:
        '''
        pred,label=output
        pred=pred.detach()
        label=label.detach()

        indices=torch.argmax(pred,dim=1)
        correct=torch.eq(indices,label).view(-1)

        self._num_correct += torch.sum(correct).item()
        self._num_examples += correct.shape[0]

    def compute(self):
        '''
        计算总ACC
        :return:
        '''
        return self._num_correct/self._num_examples


def do_train(model,optimizer,criterion,scheduler,device,train_loader,val_loader,cfg):

    def _prepare_batch(batch, device=None, non_blocking=False):
        """
        对dataloader每个batch的输出进行进一步的处理
        :param batch: dataloader输出
        :param device:
        :param non_blocking:
        :return:
        """
        device = "cuda:" + str(device)
        x, y = batch
        x = convert_tensor(x,device=device,non_blocking=non_blocking)
        y = convert_tensor(y,device=device,non_blocking=non_blocking)
        return x,y


    def create_supervised_trainer(model,optimizer,criterion,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred, loss: loss.item()):
        """
        有监督模型的Engine创建

        Args:
            model (`torch.nn.Module`):
            optimizer (`torch.optim.Optimizer`):
            loss_fn (torch.nn loss function):
            device (str, optional):
            non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
                with respect to the host. For other cases, this argument has no effect.
            prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
            output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()

        Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item

        Returns:
            Engine: 有监督任务的engine实例
        """
        if device:
            model.to(device)

        def _update(engine, batch):
            model.train()
            optimizer.zero_grad()
            x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
            output = model(x)
            loss=criterion(output,y)
            loss.backward()
            optimizer.step()
            return output_transform(x, y, None, loss)

        return Engine(_update)

    def create_supervised_evaluator(model, metric,
                                device=None, non_blocking=False,
                                prepare_batch=_prepare_batch,
                                output_transform=lambda x, y, y_pred: (y_pred,y)):
        """
        构造evaluator
        :param model:
        :param metric: dict,key为metric名字,value为Metric类
        :param device:
        :param non_blocking:
        :param prepare_batch:
        :param output_transform:
        :return:
        """
        if device:
            model.to(device)

        def _inference(engine, batch):
            model.eval()
            with torch.no_grad:
                x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
                output = model(x)
            return output_transform(x, y, output)

        engine=Engine(_inference)

        for name, metric in metric.items():
            metric.attach(engine,name)

        return engine



    trainer=create_supervised_trainer(model,optimizer,criterion,device)  # 建立ignite的engine
    evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})
    CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])
    pbar=tqdm(total=len(train_loader))  # 为训练器迭代器建立进度条




    ##########################################################################################
    ###########                    Events.ITERATION_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        """
        隔一定iteration输出模型损失
        """
        log_period=cfg["log_period"]
        if engine.state.iteration%log_period==0:
            pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.metrics['avg_loss']:.2f}")
            pbar.update(log_period)

    @trainer.on(Events.ITERATION_COMPLETED)
    def scheduler_update(engine):
        """
        optional 每个ITER更新学习率
        """
        scheduler.step()

    ##########################################################################################
    ###########                    Events.EPOCH_COMPLETED                    #############
    ##########################################################################################

    @trainer.on(Events.EPOCH_COMPLETED)
    def save_model(engine):
        '''
        保存模型
        :param engine:
        :return:
        '''
        if engine.state.epoch % cfg['save_period']==0:
            evaluator.run(val_loader)
            metrics=evaluator.state.metrics
            print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")
            CP.save(metrics['ACC'],model)

    @trainer.on(Events.EPOCH_COMPLETED)
    def reset_bar(engine):
        '''
        重置进度条
        :param engine:
        :return:
        '''
        pbar.reset()

    ##########################################################################################
    #################                    training Start                    ###################
    ##########################################################################################
    trainer.run(train_loader,max_epochs=cfg['max_epochs'])
    pbar.close()

参考


  1. Events | Pytorch-Ignite ↩︎

  2. State | Ignite ↩︎

  3. IGNITE.METRICS ↩︎

  4. CHECKPOINT ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.coloradmin.cn/o/510918.html

如若内容造成侵权/违法违规/事实不符,请联系多彩编程网进行投诉反馈,一经查实,立即删除!

相关文章

Kubectl-AI: 一款 OpenAI GPT 自动生成应用 K8s yaml神器

首页: 官网 下载安装 wget https://github.com/sozercan/kubectl-ai/releases/download/v0.0.10/kubectl-ai_linux_amd64.tar.gz tar xvf kubectl-ai_linux_amd64.tar.gz -C /usr/local/bin/kubectl-ai需要OpenAI API密钥或Azure OpenAI服务 API密钥和端点以及有效的Kubernet…

系统移植——linux内核移植——分析内核编译过程

uImage镜像文件 1.进入linux内核源码目录 ubuntuubuntu:~$ cd FSMP1A/linux-stm32mp-5.10.61-stm32mp-r2-r0/linux-5.10.61/ 打开Makefile文件 vi Makefile 搜索include 因为 $(SRCARCH)->arm 所以上述指令为 arch/arm/Makefile 2.进入linux内核源码目录下,arch/arm目录下…

Windows 11 本地部署 Stable Diffusion web UI

Windows 11 本地部署 Stable Diffusion web UI 0. 什么是 Stable Diffusion1. 什么是 Stable Diffusion web UI2. Github 地址3. 安装 CUDA Toolkit 11.84. 安装 cuDNN v8.9.1 for CUDA 11.x5. 配置环境变量6. 安装 Python 3.10.67. 安装 Stable Diffusion web UI8. 启动 Stabl…

吊打面试官的Java项目经验一:物流系统

引言&#xff1a; java面试一般分为两部分&#xff0c;技术面试和项目面试&#xff0c;相信大多数小伙伴们都刷过很多技术性的面试题&#xff0c;连博主本人也刷过很多无聊的面试题&#xff0c;但是对于项目经验的面试&#xff0c;可能很多刚入行小伙伴属于一个空白期&#xff…

【软考|软件设计师】编辑距离算法

目录 编辑距离算法&#xff1a; 步骤&#xff1a; 实例&#xff1a; 题&#xff1a; 完整代码如下&#xff1a; 调试&#xff1a; 代码解析&#xff1a; 具体过程参考&#xff1a; 编辑距离算法&#xff1a; 是一种计算两个自符串之间差异程度的方法&#xff0c;它通过…

现场工程师出马:VMware+LVM卷快速在windows Server上部署Kafka集群

最近遇到的疑难现场问题层出不穷&#xff0c;本次遭遇的挑战是在4台windows Server 服务器上部署Kafka集群。这是一种比较少见的操作&#xff0c;原因是有些依赖的驱动对虚拟化支持不好&#xff0c;只能运行在实体win机上。 原有的上层业务是由B团队开发运维&#xff0c;现在B…

今年的博客数量上两百了

今年的博客数量上两百了 不知不觉在 C S D N CSDN CSDN中写了那么多篇文章。与 C S D N CSDN CSDN相伴的生活中&#xff0c;我过得很充实。

并发编程10:Java对象内存布局和对象头

文章目录 10.1 面试题10.2 Object object new Object()谈谈你对这句话的理解&#xff1f;10.3 对象在堆内存中布局10.3.1 权威定义----周志明老师JVM10.3.2 对象在堆内存中的存储布局 10.4 再说对象头的MarkWord10.5 聊聊Object obj new Object()10.5.1 运行结果展示10.5.2 压…

C++入门(命名空间、缺省参数、函数重载、引用、内联函数)

全文目录 引言C输入与输出命名空间概念使用使用域作用限定符::使用某个成员使用using namespace 引入整个命名空间域使用using引入某个成员 缺省参数概念分类 函数重载定义与调用原理 引用定义需要注意 使用引用作为返回型参数引用作为返回值 引用与指针的区别 内联函数总结 引…

华为OD机试真题 Java 实现【猜字谜】【2023Q2】

一、题目描述 小王设计了一人简单的清字谈游戏&#xff0c;游戏的迷面是一人错误的单词&#xff0c;比如nesw&#xff0c;玩家需要猜出谈底库中正确的单词。猜中的要求如 对于某个谜面和谜底单词&#xff0c;满足下面任一条件都表示猜中&#xff1a; 变换顺序以后一样的&…

np保存数据为txt或者csv格式

目录 1、基础参数 2、参数详解 2.1、fmt 2.2、delimiter 2.3、newline 2.4、header 1、基础参数 numpy.savetxt(fname,arrry,fmt%.18e,delimiter ,newline\n,header,footer,comments# ,encodingNone,) 2、参数详解 fname:要存入的文件、文件名、或生成器。arrry:要存储…

xxl-Job分布式任务调度 入门

1.概述 1.1 什么是任务调度 我们可以先思考一下业务场景的解决方案&#xff1a; 某电商系统需要在每天上午10点&#xff0c;下午3点&#xff0c;晚上8点发放一批优惠券。 某银行系统需要在信用卡到期还款日的前三天进行短信提醒。 某财务系统需要在每天凌晨0:10结算前一天的…

C高级第二天

#include <stdio.h> #include <stdlib.h> #include <string.h> int main(int argc,const char *argv[]) { int n 0, m 0, MAX 0; int arr[n][m]; printf("请输入矩阵行数、列数>>>"); scanf("%d%d", &n…

【动态规划】线性DP

目录 一&#xff1a;思考方式 二&#xff1a;例题 例题1&#xff1a;数字三角形 例题二&#xff1a;最长上升子序列​​​​​​​ 例题三&#xff1a;最长公共子序列 一&#xff1a;思考方式 线性dp就是一条线上的动态规划 二&#xff1a;例题 例题1&#xff1a;数字三…

Python基础(三)

目录 1、Python的输入函数input() 1、input函数介绍 1.1作用&#xff1a; 1.2返回值类型&#xff1a; 1.3值得存储&#xff1a; 2、input函数的基本使用 2、Python中的运算符 2.1算术运算符 2.1.1标准算术运算符 2.1.2取余运算符(%) 2.1.3幂运算符(**) 2.1.4特殊运…

分布式锁的多种实现方式

1、不使用分布式锁 synchronized (this){int stock Integer.parseInt(Objects.requireNonNull(stringRedisTemplate.opsForValue().get("stock")));if (stock > 0) {int realStock stock - 1;// 更新库存stringRedisTemplate.opsForValue().set("stock&qu…

vi编辑器的三种模式及其对应模式下常用指令

vi是Linux系统的第一个全屏幕交互式编辑工具&#xff0c;在嵌入式的 学习中是一个不可或缺的强大的文本编辑工具。 一、三种模式 命令模式 如何进入命令模式&#xff1a;按esc键 复制&#xff1a;yy nyy(n&#xff1a;行数) 删除(剪切): dd ndd 粘贴&#xff1a;p 撤销&…

【Java】java | 将可运行jar打包成exe可执行文件

一、说明 1、javafx桌面程序&#xff0c;但又不想安装jre环境 2、需要将可执行jar打包成exe 3、使用工具exe4j 二、操作步骤 1、下载exe4j https://exe4j.apponic.com/ 2、安装 说明1&#xff1a; 在d盘建个exe4j的文件夹 说明2&#xff1a; 建个output文件jar&#xff0c;存放…

计算机组成原理——计算机系统的组成

一台完整的计算机包括硬件和软件两部分&#xff0c;另外还有一部分固话的软件成为固件(Frimware)&#xff0c;固件兼具软件和硬件的特性&#xff0c;常见的如个人计算机中的BIOS&#xff0c;BIOS&#xff08;Basic Input/Output System&#xff09;是个人计算机上的一个基本输入…

React 路由

React 的路由跳转需要引用第三方的 React Router npm i react-router-dom5.2.0 React Router 分为 BrowserRouter 和 HashRouter 如果我们的应用有服务器响应 web 的请求&#xff0c;建议使用<BrowserRouter>组件; 如果使用静态文件服务器&#xff0c;建议使用<Hash…