这篇博客瞄准的是 pytorch 官方教程中 Learn the Basics
章节的 Save and Load the Model
部分。
- 官网链接:https://pytorch.org/tutorials/beginner/basics/saveloadrun_tutorial.html
完整网盘链接: https://pan.baidu.com/s/1L9PVZ-KRDGVER-AJnXOvlQ?pwd=aa2m 提取码: aa2m
Save and Load the Model
这部分主要介绍如何通过保存、加载和运行模型预测来持久化模型状态。
Step1. 导入依赖包
import torch
import torchvision.models as models
Step2. 保存与加载模型参数
- 保存权重:PyTorch 模型将学习到的参数存储在名为
state_dict
的内部状态dict
对象中。这些参数可以通过torch.save
方法保存; - 加载权重:需要先创建 相同的模型 实例,然后使用
load_state_dict()
方法加载参数,通常情况下设置weights_only=True
加载最不容易出错;
- 保存模型参数【这一步之行后会先下载
IMAGENET1K_V1
权重】:
model = models.vgg16(weights="IMAGENET1K_V1")
torch.save(model.state_dict(), "model_weights.pth")
- 加载模型参数:
model = models.vgg16()
model.load_state_dict(torch.load("model_weights.pth", weights_only=True))
model.eval()
Step3. 保存与加载完整模型
上面的模型保存与加载方式只限于对模型 参数 的操作,并不会将整个模型结构保存下来,使用下面的方式可以连同模型结构一起保存。
- 保存模型参数 + 结构:
torch.save(model, "model.pth")
- 加载模型参数 + 结构
model = torch.load("model.pth", weights_only=False)