文章目录
- 一、.pth文件简介
- 二、如何保存.pth文件
- 三、如何加载.pth文件
- 跨硬件加载
- 加载后操作
- 四、.pth文件的结构与内容
- 解析.pth文件示例
- 五、.pth文件的优缺点
- 优点
- 缺点
- 六、常见应用场景
- 七、模型文件体积优化技巧
- 问题背景
- 解决方案
- 效果对比
- 八、总结
- 九、参考
一、.pth文件简介
.pth
文件是PyTorch中用于保存和加载模型参数的标准文件格式。它不仅可以存储模型的权重(如各层的参数、偏置等),还能保存优化器状态、训练进度(如当前epoch、损失值)等元数据。通过.pth
文件,开发者能够快速保存训练好的模型,并在后续任务中复用或恢复训练,避免重复计算资源消耗。
在深度学习中,模型训练通常耗时较长。使用.pth
文件保存中间状态或最终模型,可显著提升开发效率,尤其适用于迁移学习、模型部署和协作共享等场景。
二、如何保存.pth文件
PyTorch提供了torch.save()
函数来保存模型,支持三种主要方式:
-
保存完整模型(包含结构和参数):
import torch import torch.nn as nn class SimpleModel(nn.Module): def __init__(self): super().__init__() self.fc = nn.Linear(10, 2) def forward(self, x): return self.fc(x) model = SimpleModel() torch.save(model, 'model.pth') # 保存整个模型
-
仅保存模型参数(推荐方式,灵活且轻量):
torch.save(model.state_dict(), 'model_params.pth') # 仅保存权重
-
保存优化器状态和训练进度:
checkpoint = { 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'epoch': 10, 'loss': 0.02 } torch.save(checkpoint, 'checkpoint.pth')
三、如何加载.pth文件
加载文件时需注意模型结构的匹配,具体方法如下:
-
加载完整模型:
loaded_model = torch.load('model.pth') loaded_model.eval() # 切换到推理模式
-
加载模型参数(需预先定义相同结构的模型):
model = SimpleModel() model.load_state_dict(torch.load('model_params.pth'))
-
恢复训练状态(加载模型、优化器和元数据):
checkpoint = torch.load('checkpoint.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss']
跨硬件加载
PyTorch支持在不同硬件设备间灵活加载模型,需通过map_location
参数指定目标设备:
-
从GPU加载到CPU(适用于无GPU环境):
model = torch.load('model.pth', map_location=torch.device('cpu'))
-
从CPU加载到GPU(需确保当前环境有可用GPU):
# 方式1:自动选择当前GPU(默认cuda:0) model = torch.load('model.pth', map_location=torch.device('cuda')) # 方式2:指定具体GPU(如cuda:1) model = torch.load('model.pth', map_location=torch.device('cuda:1'))
-
多GPU间的加载(如将原GPU 0的模型加载到GPU 1):
model = torch.load('model.pth', map_location={'cuda:0': 'cuda:1'})
-
通用加载方法(动态适配当前设备):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = torch.load('model.pth', map_location=device)
加载后操作
- 切换设备:若加载到GPU后需手动将模型参数移至设备:
model = model.to(device) # device为'torch.device('cuda')'或'torch.device('cpu')'
- 模式设置:根据任务切换模型模式:
model.train() # 训练模式(启用Dropout/BatchNorm) model.eval() # 推理模式(关闭Dropout/BatchNorm)
四、.pth文件的结构与内容
.pth
文件本质是一个序列化的Python有序字典(collections.OrderedDict
),可能包含以下内容:
- 模型参数:各层的权重和偏置(通过
state_dict()
获取)。 - 优化器状态:如Adam优化器的动量、学习率等。
- 训练元数据:当前epoch、损失值、学习率调度状态等。
解析.pth文件示例
import torch
pthfile = 'model.pth'
model = torch.load(pthfile, map_location='cpu') # 强制加载到CPU
print("字典类型:", type(model)) # 输出: <class 'collections.OrderedDict'>
print("字典长度:", len(model)) # 输出层数或键值对数量
print("键值列表:", model.keys()) # 输出各层的名称(如conv1.weight, fc.bias)
五、.pth文件的优缺点
优点
- 灵活性:支持保存模型参数、优化器状态及自定义元数据。
- 高效性:与PyTorch无缝集成,加载速度快。
- 可扩展性:可自由添加额外信息(如训练超参数)。
缺点
- 依赖模型结构:仅保存参数,加载时需确保模型定义一致。
- 文件体积较大:若保存优化器、学习率调度器等额外信息,文件会显著增大(详见第七节优化技巧)。
六、常见应用场景
- 模型持久化:保存训练好的模型用于后续推理或部署。
- 训练恢复:从上次中断的位置继续训练,避免资源浪费。
- 迁移学习:复用预训练模型的权重加速新任务训练。
- 协作共享:通过分享
.pth
文件,他人可直接加载模型,无需重新训练。
七、模型文件体积优化技巧
问题背景
训练保存的.pth文件体积较大,通常是因为额外保存了优化器、学习率调度器、epoch等信息。例如:
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': 200,
'loss': 0.01
}
解决方案
若只需部署模型或共享权重,可仅保存模型参数:
# 加载完整检查点后提取权重
checkpoint = torch.load('large_checkpoint.pth')
model_weights = checkpoint['model']
# 重新保存为精简文件
torch.save({'model': model_weights}, 'small_model.pth')
效果对比
- 原始文件(含优化器、调度器等):500MB
- 精简后(仅模型权重):150MB
八、总结
.pth
文件是PyTorch生态中管理模型的核心工具,其灵活性和高效性使其成为模型保存、恢复和共享的首选格式。使用时需注意以下几点:
- 结构一致性:若仅保存参数,加载前需确保模型结构与保存时一致。
- 硬件兼容性:跨设备加载时需指定
map_location
参数。 - 文件体积控制:通过仅保存模型权重,可显著减小文件大小。
掌握这些技巧,能显著提升深度学习工作流的效率,助力模型快速迭代与部署。
扩展阅读:
- PyTorch官方文档:SAVING AND LOADING MODELS
- 实战案例:使用.pth文件部署模型到生产环境
九、参考
【Pytorch】一文详细介绍 pth格式 文件
.pth文件的解析和用法
pytorch解析.pth模型文件
深度学习中为什么保存的训练pth模型权重那么大(附解决代码)