文章目录
- 保存最优模型
- 一、两种保存方法
- 1. 保存模型参数
- 2. 保存完整模型
- 二、迭代模型
- 总结
保存最优模型
我们在迭代模型训练时,随着次数初始的增多,模型的准确率会逐渐的上升,但是同时也随着迭代次数越来越多,由于模型会开始学习到训练数据中的噪声或非共性特征,发生过拟合现象,使得模型的准确率会上下震荡甚至于下降。
本篇就是介绍我们如何在进行那么多次迭代之中,找到训练最好效果时,模型的参数或完整模型。也方便以后使用模型时直接使用。
一、两种保存方法
我们知道,一个模型到底好不好,主要体现在对测试集数据结果上的表现,所以我们的方法主要从测试集入手,计算每次迭代测试集数据的准确率,取到准确率最大时对应的模型和参数。
那么,我们该如何保存模型和参数呢?介绍一个小东西:
- 文件拓展名:pt\pth,t7,使用pt\pth或t7作为模型文件扩展名,保存模型的整个状态(包括模型架构和参数)或仅保存模型的参数(即状态字典,state_dict)。
1. 保存模型参数
方法:
torch.save(model.state_dict(),path)
# model.state_dict()是一个从参数名称映射到参数张量的字典对象,它包含了模型的所有权重和偏置项
# path为创建的保存模型的文件
通过比较每一次迭代准确率的大小,取准确率最大时模型的参数:
best_acc = 0
"""-----测试集-----"""
def test(dataloader,model,loss_fn):
global best_acc
size = len(dataloader.dataset) # 总数据大小
num_batches = len(dataloader) # 划分的小批次数量
model.eval()
test_loss,correct = 0,0
with torch.no_grad():
for x,y in dataloader:
x,y = x.to(device),y.to(device)
pred = model.forward(x)
test_loss += loss_fn(pred,y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 预测正确的个数
test_loss /= num_batches
correct /= size
correct = round(correct, 4)
print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")
# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
if correct > best_acc:
best_acc = correct
# 1. 保存模型参数方法:torch.save(model.state_dict(),path) (w,b)
print(model.state_dict().keys()) # 输出模型参数名称cnn
torch.save(model.state_dict(),"best.pth")
2. 保存完整模型
方法:
torch.save(model,path)
# 直接得到整个模型
依旧是通过比较每一次迭代准确率的大小,但是取准确率最大时的整个模型:
def test(dataloader,model,loss_fn):
global best_acc
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss,correct = 0,0
with torch.no_grad():
for x,y in dataloader:
x,y = x.to(device),y.to(device)
pred = model.forward(x)
test_loss += loss_fn(pred,y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
a = (pred.argmax(1) == y)
b = (pred.argmax(1) == y).type(torch.float)
test_loss /= num_batches
correct /= size
correct = round(correct, 4)
print(f"Test result: \n Accuracy:{(100*correct)}%,Avg loss:{test_loss}")
# 保存最优模型的方法(文件扩展名一般:pt\pth,t7)
if correct > best_acc:
best_acc = correct
# 2. 保存完整模型(w,b,模型cnn)
torch.save(model,"best1.pt")
二、迭代模型
接下来就要迭代模型,得到最优的模型:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=0.001,weight_decay=0.0001)
epochs = 150
# training_data、test_data:数据预处理好的数据
train_dataloader = DataLoader(training_data,batch_size=64,shuffle=True)
test_dataloader = DataLoader(test_data,batch_size=64,shuffle=True)
for t in range(epochs):
print(f"Epoch {t+1} \n-------------------------")
train(train_dataloader,model,loss_fn,optimizer)
test(test_dataloader,model,loss_fn)
print("Done!")
在每轮数据迭代后,project工程栏中的best1.pt与best.pth文件中模型会随着迭代及时更新,迭代结束后,文件中保存的就是最优模型以及最优的模型参数。
总结
本篇介绍了:
- 为什么随着迭代次数越来越多,模型的准确率会上下震荡甚至于下降。—> 过拟合
- pt\pth,t7三个扩展名,用于保存完整模型或者模型参数。
- 模型的好坏,通过体现在测试集的结果上。
- 保存最优模型的两种方法:保存模型参数和保存完整模型。