求助:我在测试自己搭建的通用MLP网络时,发现它与等价的参数写死的MLP网络相比效果奇差无比,不知道是哪里出了问题,请大佬们帮忙看下。
我写的通用MLP网络:
class MLP(nn.Module):
def __init__(self, feature_num, class_num, *hidden_nums):
super().__init__()
self.feature_num = feature_num
self.class_num = class_num
self.hidden_nums = hidden_nums
input_num = feature_num
for i, hidden_num in enumerate(hidden_nums):
self.__dict__['fc' + str(i)] = nn.Linear(input_num, hidden_num)
input_num = hidden_num
self.output = nn.Linear(input_num, class_num)
def forward(self, x):
for i in range(len(self.hidden_nums)):
x = F.relu(self.__dict__['fc' + str(i)](x))
x = self.output(x)[..., 0] if self.class_num == 1 else F.sigmoid(self.output(x))
return x
按理说这样实例化时:
model = MLP(57, 2, 30, 10)
它应该与下面这个网络等价:
class MLPclassification(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(57, 30)
self.fc1 = nn.Linear(30, 10)
self.output = nn.Linear(10, 2)
def forward(self, x):
x = F.relu(self.fc0(x))
x = F.relu(self.fc1(x))
x = F.sigmoid(self.output(x))
return x
但当我用model = MLP(57, 2, 30, 10)训练网络时,在二分类问题中,它把所有数据都预测成了类别0:
而用 model = MLPclassification()训练网络时,预测的效果很好:
我检查了半天,不知道是哪里出了问题,有没有大佬懂的,帮忙看下,十分感谢!