完整的训练套路(三) train() eval()
1. 什么是tain() eval()
-
在许多代码中我们经常会看到模型开始训练前会先进行一个
model.train()
, 模型的测试之前会有一行model.eval()
-
官方文档
https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module
中给出的大致说明是:如果构建的网络中存在部分特殊层例如Dropout BatchNorm
等,train() eval()
的调用是十分必要的
2. 如何使用
-
model.train()
的位置就是在开启模型的训练之前for i in range(epoch): # start train model.train() for data in train_dataloader: pass
-
model.eval()
的位置就是在开始对测试集进行测试之前for i in range(epoch): # start train model.train() for data in train_dataloader: pass # test model model.eval() with torch.no_grad(): # 不进行梯度调优 for data in test_dataloader: pass