太长不看版
在local_mode下,不要使用t.CustModel进行实例化,而是直接使用原本的类进行实例化。
如果你设置了trainer.local_mode(),那么trainer里面的model不可以是t.CustModel()的实例。
给几个example:
正确:
model = SANet()
trainer.set_model(model) # set model
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = BCEDiceLoss()
trainer.train(train_set=ds, optimizer=optimizer, loss=loss)
错误,会报如题的错误:
model = t.nn.Sequential(
t.nn.CustModel(module_name='sanet',class_name='SANet',args=None)
)
trainer.set_model(model) # set model
optimizer = t.optim.Adam(model.parameters(), lr=0.01)
loss = BCEDiceLoss()
trainer.train(train_set=ds, optimizer=optimizer, loss=loss)
debug历程,跟上述内容没什么关系,各位不用看
提示都让我init model了。但是具体在哪里init?而且为什么父类没有自行init model,还要我手动init?看了下文档竟然是copy?这不对劲吧。。又搜索了所有文档,都没有提到。
最后发现是自己实现trainer必须要需要手动实现train,还有一个server_aggregate_procedure
最后在自己的trainer里面加上了下面的示例代码,发现也没有作用。
def server_aggregate_procedure(self, extra_data={}):
aggregator = SecureAggregatorServer(communicate_match_suffix='fedprox')
# the aggregation process is simple: every epoch the server aggregate model and loss once
for i in range(self.epochs):
aggregator.model_aggregation()
merge_loss, _ = aggregator.loss_aggregation()
最后在trainer_base
类里面,发现了这样的文字:
@property
def model(self):
if not hasattr(self, '_model'):
raise AttributeError(
'model variable is not initialized, remember to call'
' super(your_class, self).__init__()')
if self._model is None:
raise AttributeError(
'model is not set, use set_model() function to set training model')
我的情况应该是第一个了,竟然说我没有调用super init?
我的trainer的初始化并没有具体写哪个,让他自动推了,莫非必须要显式写上才行吗:
class SATrainer(TrainerBase):
def __init__(self,epoch,batch_size,workers):
super().__init__()
self.epoch=epoch
self.batch_size=batch_size
self.workers=workers
修改成可是修改了之后仍然无效!
class SATrainer(TrainerBase):
def __init__(self,epoch,batch_size,workers):
super(SATrainer,self).__init__()
再找!那他这个任务self._model是有个setter的,是不是因为set model没有成功?也尝试了,set model是成功的。
实在是没有办法了,在github直接提issue了。https://github.com/FederatedAI/FATE/issues/4843
最后在debug的里面,找到不同的报错原因:
Traceback (most recent call last):
File "<string>", line 1, in <module>
File "/root/anaconda3/envs/fate/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/fate/lib/python3.9/site-packages/torch/nn/modules/container.py", line 204, in forward
input = module(input)
File "/root/anaconda3/envs/fate/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/root/anaconda3/envs/fate/lib/python3.9/site-packages/pipeline/component/nn/backend/torch/cust.py", line 59, in forward
raise ValueError('model not init, call init_model() function')
ValueError: model not init, call init_model() function
看样子是torch cust.py出现了问题!其源码:
class CustModel(FateTorchLayer, nn.Module):
def __init__(self, module_name, class_name, **kwargs):
super(CustModel, self).__init__()
assert isinstance(
module_name, str), 'name must be a str, specify the module in the model_zoo'
assert isinstance(
class_name, str), 'class name must be a str, specify the class in the module'
self.param_dict = {
'module_name': module_name,
'class_name': class_name,
'param': kwargs}
self._model = None
def init_model(self):
if self._model is None:
self._model = self.get_pytorch_model()
def forward(self, x):
if self._model is None:
raise ValueError('model not init, call init_model() function')
return self._model(x)
def get_pytorch_model(self, module_path=None):
if module_path is None:
return get_class(
self.param_dict['module_name'],
self.param_dict['class_name'],
self.param_dict['param'],
MODEL_PATH)
else:
return get_class(
self.param_dict['module_name'],
self.param_dict['class_name'],
self.param_dict['param'],
module_path)
于是我改写了代码,显式调用了init model
cust_model=t.nn.CustModel(module_name='sanet',class_name='SANet',args=None)
cust_model.init_model()
model = t.nn.Sequential(
cust_model
)
报错:
Exception has occurred: ModuleNotFoundError
No module named 'None'
File "/root/Downloads/project_demo/homo_sanet_core.py", line 78, in <module>
cust_model.init_model()
ModuleNotFoundError: No module named 'None'
所以本质出现这个报错的原因是没有成功加载自定义的模型!
然后直接修改源码打印param的dict,发现没有出现问题。
再次翻看CustModel的文档,发现他们在test local model的时候并没有使用t.CustModel进行测试,而是直接实例化的类,没有用hook!莫非就是这个问题?
修改后,成功!原来是这样,2天的工作时间没了!