模型部署总是会遇到pytorch版本推理与训练不匹配的问题,一般报错:
AttributeError: Can't get attribute '_rebuild_parameter_v2' on <module 'torch._utils' from '/usr/local/python3.9.0/lib/python3.9/site-packages/torch/_utils.py'>
提示pytorch 中_utils.py没有实现这个方法’_rebuild_parameter_v2’ ,这就表明新的pytorch增加了一些方法,而旧的pytorch没有实现。为了避免环境升级各种乱七八糟的事情,那我们就手动实现它,在仅实现这个function的情况下实现pytorch的伪升级。
(1)首先看pytorch的github
发现’_rebuild_parameter_v2’这个方法是新添加的
源码路径实现在这:
(2)那我们就把这段代码抄过来
攒成下面一块:
import torch._utils
try:
torch._utils._rebuild_parameter_v2
except AttributeError:
def _set_obj_state(obj, state):
if isinstance(state, tuple):
if not len(state) == 2:
raise RuntimeError(f"Invalid serialized state: {state}")
dict_state = state[0]
slots_state = state[1]
else:
dict_state = state
slots_state = None
for k, v in dict_state.items():
setattr(obj, k, v)
if slots_state:
for k, v in slots_state.items():
setattr(obj, k, v)
return obj
def _rebuild_parameter_v2(data, requires_grad, backward_hooks, state):
param = torch.nn.Parameter(data, requires_grad)
param._backward_hooks = backward_hooks
param = _set_obj_state(param, state)
return param
torch._utils._rebuild_parameter_v2 = _rebuild_parameter_v2
这样,每次加载模型的时候,把上面那段代码拷贝上去,就能执行,不会报错了