引言
Ignite是Pytorch配套的高级框架,我们可以借其构筑一套标准化的训练流程,规范训练器在每个循环、轮次中的行为。本文将不再赘述Ignite的具体细节或者API,详见官方教程和其他博文。本文将分析Ignite的运行机制、如何将Pytorch训练代码转为Ignite范式,最后给出个人设计的标准化Ignite训练模版。
Ignite简介
Ignite所做的事情就是我们在pytorch里常写的范式用更加机械、更加标注格式展现出来,这也就是为啥其核心被称为–Engine,高效而精密。Pytorh里常用的训练范式如下:
for ep in Epoch:
for batch in train_loader:
model.train()
inputs, targets = batch
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
if it%log_period:
print()
if ep%save_perid:
torch.save()
具体而言,可以拆解为批训练、批完结处理、轮次完结处理三个组成部分,批训练部分是网络训练的基础单元,完成数据当前批次读取、前向传播、反向传播等步骤,批完结处理负责在每个批次结束后输出模型训练的相关信息,轮次完结处理负责在每个epoch结束进行模型的保存、对模型的训练参数进行更新。这三个模型训练的主要组成部分在ignite中得到了完整的封装,围绕批训练构造了一个核心的Engine,将批完结处理和轮次完结处理附加该Engine运行的时间轴中,形成了批训练->批完结处理->轮次完结处理的流水线作业范式,更为详细的时间轴如下1:
以下将从实用性的角度出发给出Ignite的建设框架,最终给出个人设计的Ignite使用模版,后续直接在train.py
文件里直接调用do_train()
函数即可利用Ignite进行模型训练。为讲解需要,中间每个子部分的代码为最终代码中相应部分重新排序得到,最终代码中其顺序会进行调整。
批训练
批训练的代码较为简单,只需将原本的Pytorch版本批处理流程复制粘贴,最后将该过程函数化,并且实例化成Engine即可,代码如下所示,最终启动Engine,即可进行模型的训练。到此为止,实际上已经完成了狭义上的“模型”训练部分。
def create_supervised_trainer(model,optimizer,criterion,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
"""
有监督模型的Engine创建
Args:
model (`torch.nn.Module`):
optimizer (`torch.optim.Optimizer`):
loss_fn (torch.nn loss function):
device (str, optional):
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()
Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item
Returns:
Engine: 有监督任务的engine实例
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
loss=criterion(output,y)
loss.backward()
optimizer.step()
return output_transform(x, y, None, loss)
return Engine(_update)
trainer=create_supervised_trainer(model,optimizer,criterion,device) # 建立ignite的engine
trainer.run(train_loader,max_epochs=cfg['max_epochs'])
批完结处理
批完结处理部分我们常做的操作是输出模型在当前批的损失,Ignite中这一过程通过在Engine上附着于ITERAION_COMPLETEDE时触发的回调函数实现。实际上这只是限定了触发时间,具体进行何种操作,完全依赖于个人的选择。我们只需要知道该函数可以利用engine保留的当前批属性信息进行各种操作即可,具体可以利用哪些属性,见官方API2,本文只利用了常用的几个。
##########################################################################################
########### Events.ITERATION_COMPLETED #############
##########################################################################################
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
"""
隔一定iteration输出模型损失
"""
log_period=int(cfg["log_period"]*len(train_loader)) # 跑了log_period*len输出一次,取值<=1
if engine.state.iteration%log_period==0:
pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.output:.2f}")
pbar.update(log_period)
@trainer.on(Events.ITERATION_COMPLETED)
def scheduler_update(engine):
"""
optional 每个ITER更新学习率
"""
scheduler.step()
轮次完结处理
轮次完结处理和批完结处理相同,也是通过回调函数实现,我们通常在轮次完结处理要进行模型的保存,这里就要做两件事:
- 在val_loader上验证模型效果
- 保留迄今为止效果最好的模型
针对于第一个要求,这里我同样采用了Ignite风格的Engine驱动范式,读者可以自行选择在这里切换为Pytorch范式,构建验证集Engine的代码如下:
def create_supervised_evaluator(model, metric,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred: (y_pred,y)):
"""
构造evaluator
:param model:
:param metric: dict,key为metric名字,value为Metric类
:param device:
:param non_blocking:
:param prepare_batch:
:param output_transform:
:return:
"""
if device:
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad:
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
return output_transform(x, y, output)
engine=Engine(_inference)
# 附着metric
for name, metric in metric.items():
metric.attach(engine,name)
return engine
可以看到和trainer较为不同的点在于去除了opt等等选项,此外,由于保存模型时我们要依据验证集上的metric来判断是否要保存当前模型还是沿用此前的模型,因此额外将一个Metric类附着在了Engine上,它使得模型可以自动收集eval_engine每个轮次的输出,并进行metric的计算,ignite中提供了许多metric选项3,这里笔者给出自己定制mertric的范式如下,主要由reset()
,update()
和commpute()
组成,reset()
完成每个epoch的记录状态重置,update()
则接受某一批次engine的输出值,commpute()
完成最终的metric计算。值得一提的是,在trainer上我们并没有额外附着Loss类,而是直接用engine输出了loss,实际上或许你也可以用相同的方式对eval_engine进行处理。
class CustomMetric(Metric):
def __init__(self):
super(CustomMetric,self).__init__()
def reset(self) -> None:
self._num_correct=0
self._num_examples=0
def update(self, output) -> None:
'''
保存该轮次的输出
:param output: 每个batch engine的输出
:return:
'''
pred,label=output
pred=pred.detach()
label=label.detach()
indices=torch.argmax(pred,dim=1)
correct=torch.eq(indices,label).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_examples += correct.shape[0]
def compute(self):
'''
计算总ACC
:return:
'''
return self._num_correct/self._num_examples
完成第二步的方法Ignite同样进行了封装,即Checkpoint
类4,但笔者也进行了自己的定制化,如下:
class BestCheckPoint():
def __init__(self,save_path,n_saved,model_name):
'''
建立存档点类
:param save_path: 存档点保存路径
:param n_saved: 保留的存档点数目
'''
self.save_path=save_path
self.n_save=n_saved
self.model_name=model_name
self.score=[]
if not os.path.exists(save_path):
os.mkdir(self.save_path)
def update(self,score):
'''
更新最优记录
:param score: 当前模型的metric
:return:
'''
if type(self.score)==torch.Tensor:
score=score.item()
if len(self.score)<self.n_save:
self.score.append(score)
self.score.sort()
self.removed=[]
self._in=[]
return True
else:
value=self.score[0]
if score>value:
self.score.remove(value)
self.score.append(score)
self.score.sort()
return value
else:
return False
def save(self,score,model):
'''
视当前得分判断是否保存当前模型并删除
:param score: 当前模型得分
:param model: 模型
:return:
'''
is_save=self.update(score)
if is_save:
torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
# pop的存档要删除
if not isinstance(is_save,bool):
# 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
try:
os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
except:
print("already removed")
主要就是设置了一个metric池,新的metric进来后判断是否优于池子里最烂的模型,并以此判断是否进行保存。
运行框架
将上述模块封装在一起,我们就可以得到了最终的ignite运行框架,而后只需导入该文件,并运行其中的do_train()
函数即可轻松完成模型训练。
整体代码:
# -*- coding: utf-8 -*-
# ---
# @File: trainer.py
# @Author: sgdy3
# @E-mail: sgdy03@163.com
# @Time: 2023/5/9 19:44
# Describe:
# ---
import os
from tqdm import tqdm
import ignite
import torch
from ignite.engine import Engine
from ignite.utils import convert_tensor
from ignite.engine.engine import Engine, State, Events
from ignite.engine import create_supervised_evaluator
from ignite.metrics import Metric,Accuracy
class BestCheckPoint():
def __init__(self,save_path,n_saved,model_name):
'''
建立存档点类
:param save_path: 存档点保存路径
:param n_saved: 保留的存档点数目
'''
self.save_path=save_path
self.n_save=n_saved
self.model_name=model_name
self.score=[]
if not os.path.exists(save_path):
os.mkdir(self.save_path)
def update(self,score):
'''
更新最优记录
:param score: 当前模型的metric
:return:
'''
if type(self.score)==torch.Tensor:
score=score.item()
if len(self.score)<self.n_save:
self.score.append(score)
self.score.sort()
self.removed=[]
self._in=[]
return True
else:
value=self.score[0]
if score>value:
self.score.remove(value)
self.score.append(score)
self.score.sort()
return value
else:
return False
def save(self,score,model):
'''
视当前得分判断是否保存当前模型并删除
:param score: 当前模型得分
:param model: 模型
:return:
'''
is_save=self.update(score)
if is_save:
torch.save(model.state_dict(), os.path.join(self.save_path, self.model_name + f"_{score:.4f}.pth"))
# pop的存档要删除
if not isinstance(is_save,bool):
# 似乎ignite存在并行机制,单步运行的时候没问题,多步就会发生早就remove的错误,可以通过保存每一次更新后的score验证
try:
os.remove(os.path.join(self.save_path,self.model_name+f"_{is_save:.4f}.pth"))
except:
print("already removed")
class CustomMetric(Metric):
def __init__(self):
super(CustomMetric,self).__init__()
def reset(self) -> None:
self._num_correct=0
self._num_examples=0
def update(self, output) -> None:
'''
保存该轮次的输出
:param output: 每个batch engine的输出
:return:
'''
pred,label=output
pred=pred.detach()
label=label.detach()
indices=torch.argmax(pred,dim=1)
correct=torch.eq(indices,label).view(-1)
self._num_correct += torch.sum(correct).item()
self._num_examples += correct.shape[0]
def compute(self):
'''
计算总ACC
:return:
'''
return self._num_correct/self._num_examples
def do_train(model,optimizer,criterion,scheduler,device,train_loader,val_loader,cfg):
def _prepare_batch(batch, device=None, non_blocking=False):
"""
对dataloader每个batch的输出进行进一步的处理
:param batch: dataloader输出
:param device:
:param non_blocking:
:return:
"""
device = "cuda:" + str(device)
x, y = batch
x = convert_tensor(x,device=device,non_blocking=non_blocking)
y = convert_tensor(y,device=device,non_blocking=non_blocking)
return x,y
def create_supervised_trainer(model,optimizer,criterion,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred, loss: loss.item()):
"""
有监督模型的Engine创建
Args:
model (`torch.nn.Module`):
optimizer (`torch.optim.Optimizer`):
loss_fn (torch.nn loss function):
device (str, optional):
non_blocking (bool, optional): if True and this copy is between CPU and GPU, the copy may occur asynchronously
with respect to the host. For other cases, this argument has no effect.
prepare_batch (callable, optional): 批处理函数,对dataloader的输出进行处理
output_transform (callable, optional): 输出变换函数,设定输出,默认情况下,输入为x,y,y_pred,loss,输出loss.item()
Note: engine在每个batch下的最终输出由transform所指定,默认传回loss.item
Returns:
Engine: 有监督任务的engine实例
"""
if device:
model.to(device)
def _update(engine, batch):
model.train()
optimizer.zero_grad()
x,y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
loss=criterion(output,y)
loss.backward()
optimizer.step()
return output_transform(x, y, None, loss)
return Engine(_update)
def create_supervised_evaluator(model, metric,
device=None, non_blocking=False,
prepare_batch=_prepare_batch,
output_transform=lambda x, y, y_pred: (y_pred,y)):
"""
构造evaluator
:param model:
:param metric: dict,key为metric名字,value为Metric类
:param device:
:param non_blocking:
:param prepare_batch:
:param output_transform:
:return:
"""
if device:
model.to(device)
def _inference(engine, batch):
model.eval()
with torch.no_grad:
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
return output_transform(x, y, output)
engine=Engine(_inference)
for name, metric in metric.items():
metric.attach(engine,name)
return engine
trainer=create_supervised_trainer(model,optimizer,criterion,device) # 建立ignite的engine
evaluator=create_supervised_evaluator(model,{"ACC":CustomMetric()})
CP=BestCheckPoint(cfg['save_path'],cfg['n_saved'],cfg['model_name'])
pbar=tqdm(total=len(train_loader)) # 为训练器迭代器建立进度条
##########################################################################################
########### Events.ITERATION_COMPLETED #############
##########################################################################################
@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
"""
隔一定iteration输出模型损失
"""
log_period=cfg["log_period"]
if engine.state.iteration%log_period==0:
pbar.write(f"Epoch {engine.state.epoch}, iter {engine.state.iteration}: Loss {engine.state.metrics['avg_loss']:.2f}")
pbar.update(log_period)
@trainer.on(Events.ITERATION_COMPLETED)
def scheduler_update(engine):
"""
optional 每个ITER更新学习率
"""
scheduler.step()
##########################################################################################
########### Events.EPOCH_COMPLETED #############
##########################################################################################
@trainer.on(Events.EPOCH_COMPLETED)
def save_model(engine):
'''
保存模型
:param engine:
:return:
'''
if engine.state.epoch % cfg['save_period']==0:
evaluator.run(val_loader)
metrics=evaluator.state.metrics
print(f"Validation Results - Epoch[{trainer.state.epoch}] Avg accuracy: {metrics['Acc']:.2f}")
CP.save(metrics['ACC'],model)
@trainer.on(Events.EPOCH_COMPLETED)
def reset_bar(engine):
'''
重置进度条
:param engine:
:return:
'''
pbar.reset()
##########################################################################################
################# training Start ###################
##########################################################################################
trainer.run(train_loader,max_epochs=cfg['max_epochs'])
pbar.close()
参考
Events | Pytorch-Ignite ↩︎
State | Ignite ↩︎
IGNITE.METRICS ↩︎
CHECKPOINT ↩︎