一, 炼丹之痛
无论是学术研究还是工业落地,pytorch几乎都是目前炼丹的首选框架。
pytorch的胜出不仅在于其简洁一致的api设计,更在于其生态中丰富和强大的模型库。
但是我们会发现不同的pytorch模型库提供的训练和验证代码非常不一样。
torchvision官方提供的范例代码主要是一个关联了非常多依赖函数的train_one_epoch和evaluate函数,针对检测和分割各有一套。
yolo系列的主要是支持ddp模式的各种风格迥异的Trainer,每个不同的yolo版本都会改动很多导致不同yolo版本之间都难以通用。
抱抱脸的transformer库在借鉴了pytorch_lightning的基础上也搞了一个自己的Trainer,但与pytorch_lightning并不兼容。
非常有名的facebook的目标检测库detectron2, 也是搞了一个它自己的Trainer,配合一个全局的cfg参数设置对象来训练模型。
还有我用的比较多的语义分割的segmentation_models.pytorch这个库,设计了一个TrainEpoch和一个ValidEpoch来做训练和验证。
在学习和使用这些不同的pytorch模型库时,尝试阅读理解和改动这些训练和验证相关的代码让我受到了一万点伤害。
有些设计非常糟糕,嵌套了十几层,有些实现非常dirty,各种带下划线的私有变量满天飞。
让你每次想要改动一下加入一些自己想要的功能时就感到望而却步。
我不就想finetune一下模型嘛,何必拿这么多垃圾代码搞我?
一, 梦中情炉
这一切的苦不由得让我怀念起tensorflow中keras的美好了。
还记得keras那compile, fit, evalute三连击吗?一切都像行云流水般自然,真正的for humans。
而且你看任何用keras实现的模型库,训练和验证都几乎可以用这一套相同的接口,没有那么多莫名奇妙的野生Trainer。
我能否基于pytorch打造一个接口和keras一样简洁易用,功能强大,但是实现代码非常简短易懂,便于修改的模型训练工具呢?
从2020年7月左右发布1.0版本到最新发布的3.85版本,我陆陆续续在工作中一边使用一边雕琢打磨一个工具,总共提交修改了70多次。
最近,我细心雕琢的这个作品终于长成了我心目中接近完美的样子。
她有一个美丽的名字:torchkeras.
是的,她兼具torch的灵动,也有keras的优雅,并且她的美丽,无与伦比~
她,就是我的梦中情炉~
二, 设计理念
torchkeras 是一个通用的pytorch模型训练模版工具,按照如下目标进行设计和实现:
好看 (代码优雅,日志美丽,自带可视化)
好用 (使用方便,支持 进度条、评估指标、early-stopping等常用功能,支持tensorboard,wandb回调函数等扩展功能,支持多GPU的ddp模式训练)
好改 (修改简单,核心代码模块化,仅约200行,并提供丰富的修改使用案例)
三, 使用方法
安装torchkeras
pip install torchkeras
通过使用torchkeras,你不需要写自己的pytorch模型训练循环。你只要做这样两步就可以了。
(1) 创建你的模型结构net,然后把它和损失函数传入torchkeras.KerasModel构建一个model。
(2) 使用model的fit方法在你的训练数据和验证数据上进行训练,训练数据和验证数据需要封装成两个DataLoader.
核心使用代码就像下面这样:
import torch
import torchkeras
import torchmetrics
model = torchkeras.KerasModel(net,
loss_fn = nn.BCEWithLogitsLoss(),
optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
)
dfhistory=model.fit(train_data=dl_train,
val_data=dl_val,
epochs=20,
patience=3,
ckpt_path='checkpoint.pt',
monitor="val_acc",
mode="max",
plot=True
)
在jupyter notebook中执行训练代码,你将看到类似下面的训练可视化图像和训练日志进度条。
四, 主要特性
torchkeras 支持以下这些功能特性,稳定支持这些功能的起始版本以及这些功能借鉴或者依赖的库的来源见下表。
五, 基本范例
以下范例是torchkeras的基础范例,演示了torchkeras的主要功能。
包括基础训练,使用wandb可视化,使用wandb调参,使用tensorboard可视化,使用多GPU的ddp模式训练等。
六, 进阶范例
下面的范例为torchkeras的进阶使用范例,由于输入数据结构的差异,这些范例有些需要对torchkeras的核心模块StepRunner进行修改。
这种修改实际上是非常简单的,保持每个模块的输出与原始实现格式一致就行,中间处理逻辑根据需要灵活调整。
如果你想掌握一个东西,那么就去使用它,如果你想真正理解一个东西,那么尝试去改变它。———— 爱因斯坦
除了上面这些范例,其它的一些常见的模型库,例如开头提到的torchvision中的检测和分割模型,huggingface中的transformers库,ultralytics库中的各类模型等,都可以用这套模版进行训练和finetune.
相信我,她会让你的的炼丹生活变得快乐许多~
公众号算法美食屋后台回复关键词:梦中情炉,获取她的github仓库链接~
如果有任何使用上的疑问或者改进意见,可在公众号后台回复关键词:加群,添加我的微信进行交流。