在学习《动手学深度学习》时,实现下面代码时,报出raise NotImplementedError错误。
import collections
import torch
from d2l import torch as d2l
import math
from torch import nn
class Seq2SeqEncoder(d2l.Encoder):
def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
super(Seq2SeqEncoder,self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size,embed_size)
self.rnn = nn.GRU(embed_size,num_hiddens,num_layers,dropout=dropout)
def forward(self, X, *args):
X = self.embedding(X)
X = X.permute(1,0,2)
output,state = self.rnn(X)
return output,state
encoder = Seq2SeqEncoder(10,8,16,2)
encoder.eval()
X = torch.zeros((4,7),dtype=torch.long)
output,state = encoder(X)
print(output.shape)
class Seq2SeqDecoder(d2l.Decoder):
def __init__(self,vocab_size,embed_size,num_hiddens,num_layers,dropout=0,**kwargs):
super(Seq2SeqDecoder,self).__init__(**kwargs)
self.embedding = nn.Embedding(vocab_size,embed_size)
self.rnn = nn.GRU(embed_size+num_hiddens,num_hiddens,num_layers,dropout=dropout)
self.dense = nn.Linear(num_hiddens,vocab_size)
def init_state(self, enc_outputs, *args):
return enc_outputs[1]
def farward(self,X,state):
X = self.embedding(X).permute(1,0,2)
context = state[-1].repeat(X.shape[0],1,1)
X_and_context = torch.cat((X,context),2)
output,state = self.rnn(X_and_context,state)
output = self.dense(output).permute(1,0,2)
return output,state
decoder = Seq2SeqDecoder(10,8,16,2)
print(decoder.eval())
state = decoder.init_state(encoder(X))
output,state = decoder(X,state)
print(output.shape)
原因是类Seq2SeqDecoder
在继承d2l.Decoder
类时,需要重写父类的方法,而我把forward写成了farward。因此,出现了报错。
在深度学习中,子类继承父类时,需要重写父类的方法。