This repository has been archived by the owner on Aug 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
95 lines (80 loc) · 3.05 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Author: GC
from typing import List
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torchcrf import CRF
class BiLSTM_CRF(nn.Module):
"""
Args:
vocab_size: size of word vocabulary
num_tags: total tags
embed_dim: word embedding dimension
hidden_dim: output dimension of BiLSTM at each step
dropout: dropout rate (apply on embeddings)
Attributes:
vocab_size: size of word vocabulary
num_tags: total tags
"""
def __init__(
self,
vocab_size: int,
num_tags: int,
embed_dim: int,
hidden_dim: int,
dropout: float,
) -> None:
super(BiLSTM_CRF, self).__init__()
self.vocab_size = vocab_size
self.num_tags = num_tags
# Layers
self.dropout = nn.Dropout(dropout)
self.embeds = nn.Embedding(vocab_size, embed_dim)
self.lstm = nn.LSTM(embed_dim, hidden_dim // 2, bidirectional=True)
self.hidden2tag = nn.Linear(hidden_dim, num_tags)
self.crf = CRF(num_tags)
def _get_emissions(
self, seqs: torch.LongTensor, masks: torch.ByteTensor
) -> torch.Tensor:
"""Get emission scores from BiLSTM
Args:
seqs: (seq_len, batch_size), sorted by length in descending order
masks: (seq_len, batch_size), sorted by length in descending order
Returns:
emission scores (seq_len, batch_size, num_tags)
"""
embeds = self.embeds(seqs) # (seq_len, batch_size, embed_dim)
embeds = self.dropout(embeds)
packed = pack_padded_sequence(embeds, masks.sum(0))
lstm_out, _ = self.lstm(packed)
lstm_out, _ = pad_packed_sequence(lstm_out) # (seq_len, batch_size, hidden_dim)
# Space Transform (seq_len, batch_size, num_tags)
emissions = self.hidden2tag(lstm_out)
return emissions
def loss(
self, seqs: torch.LongTensor, tags: torch.LongTensor, masks: torch.ByteTensor
) -> torch.Tensor:
"""Negative log likelihood loss
Args:
seqs: (seq_len, batch_size), sorted by length in descending order
tags: (seq_len, batch_size), sorted by length in descending order
masks: (seq_len, batch_size), sorted by length in descending order
Returns:
loss
"""
emissions = self._get_emissions(seqs, masks)
loss = -self.crf(emissions, tags, mask=masks, reduction="mean")
return loss
def decode(
self, seqs: torch.LongTensor, masks: torch.ByteTensor
) -> List[List[int]]:
"""Viterbi decode
Args:
seqs: (seq_len, batch_size), sorted by length in descending order
masks: (seq_len, batch_size), sorted by length in descending order
Returns:
List of list containing the best tag sequence for each batch
"""
emissions = self._get_emissions(seqs, masks)
best_tags = self.crf.decode(emissions, mask=masks)
return best_tags