相较于PyTorch默认的模型存储格式pth而言,onnx具有多端通用,方便部署的优点(据称能加快推理速度,但是未验证),本文将介绍如何使用onnx并将原有的pth权重转换为onnx。
一、配置环境
在控制台中使用如下指令
pip install onnx
pip install onnxruntime
随后在项目中引入环境
import onnx
二、将pth转换为onnx
使用onnx自带的export函数即可,代码如下:
def Convert2Onnx(pth_Path,Onnx_Path,model):
model_loader(model,pth_Path,torch.device('cpu'))
input = torch.rand(1,3,224,224) #需要调整为你的模型输入尺寸,包含batch项
torch.onnx.export(model,input,Onnx_Path,input_names=['Inp'],output_names=['Outs'])
其中,model_loader是一个封装好的自适应pth加载器,可以在部分权重不匹配的情况下加载pth文件,代码如下:
def model_loader(model,model_path,device):
print(' 开始从本地加载权重文件')
model_dict = model.state_dict()
pretrained_dict = torch.load(model_path, map_location = device)
load_key, no_load_key, temp_dict = [], [], {}
for k, v in pretrained_dict.items():
if k in model_dict.keys() and np.shape(model_dict[k]) == np.shape(v):
temp_dict[k] = v
load_key.append(k)
else:
no_load_key.append(k)
model_dict.update(temp_dict)
model.load_state_dict(model_dict)
三、加载onnx模型
同样适用onnx自带的load代码即可
model = onnx.load('onnx_model.onnx’)
四、onnx模型可视化
访问网址:https://netron.app/,并在其中选择自己的onnx模型即可。
五、其他相关操作
1.模型相关信息的获取
# 检查模型是否完整
onnx.checker.check_model(model)
# 获取输出层信息
output = self.model.graph.output
print(output)
2.模型的层编辑
import onnx
from onnx import helper
# 加载模型
model = onnx.load('converted_vig.onnx’)
# 创建中间节点:层名、数据类型、维度信息
prob_info = helper.make_tensor_value_info('layer1',onnx.TensorProto.FLOAT, [1, 3, 320, 280])
# 将构造好的中间节点插入到模型中
model.graph.output.insert(0, prob_info)
#保存新模型
onnx.save(model, 'onnx_model_new.onnx’)
#删除的节点item
model.graph.output.remove(item)