1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
| def load_treebank(): from nltk.corpus import treebank sents, postags = zip(*(zip(*sent) for sent in treebank.tagged_sents())) vocab = Vocab.build(sents, reserved_tokens=["<pad>"]) tag_vocab = Vocab.build(postags) train_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_to_ids(tags)) for sentence, tags in zip(sents[:3000], postages[:3000])] test_data = [(vocab.convert_tokens_to_ids(sentence), tag_vocab.convert_to_ids(tags)) for sentence, tags in zip(sents[3000:], postages[3000:])] return train_data, test_data, vocab, tag_vocab
def collate_fn(examples): lengths = torch.tensor([len(ex[0])] for ex in examples) inputs = [torch.tensor(ex[0]) for ex in example] targets = [torch.tensor(ex[1]) for ex in examples] inputs = pad_sequence(inputs, batch_first=True, padding_value=vocab["<pad>"]) targets = pad_sequence(targets, batch_first=True, padding_Value=vocab["<pad>"]) return inputs, lengths, targets, inputs != vocab["<pad>"]
class LSTM(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_dim, num_class): super(LSTM, self).__init__() self.embeddings = nn.Embedding(vocan_size, embedding_dim) self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True) self.output = nn.Linear(hidden_dim, num_class)
def forward(self, inputs, lengths): embeddings = self.embeddings(inputs) x_pack = pack_padded_sequence(embeddings, lengths, batch_first=True, enforce_sorted=False) hidden, (hn, cn) = self.lstm(x_pack) hidden, _ = pad_packed_sequence(hidden, batch_first=True) outputs = self.output(hidden) log_probs = F.log_softmax(outputs, dim=-1) return log_probs
|