-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
35 lines (26 loc) · 1.06 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#!/usr/bin/env python
import torch
from vocabulary import SpecialToken
class Word2Idx(object):
def __init__(self, vocabulary):
self._vocab = vocabulary
def __call__(self, tokenized_sequence):
result = []
result.append(self._vocab.get_index(SpecialToken.START.value.word))
result.extend([self._vocab.get_index(token.lower()) for token in tokenized_sequence])
result.append(self._vocab.get_index(SpecialToken.END.value.word))
return torch.LongTensor(result)
class IdxToWord(object):
def __init__(self, vocabulary):
self._vocab = vocabulary
def __call__(self, idx_sequence):
result = []
special_tokens = [token.value.word for token in SpecialToken]
for idx in idx_sequence:
# Stop if we arrive to the <END> token.
if idx == SpecialToken.END.value.index:
break
word = self._vocab.get_word(idx)
if word not in special_tokens or word == SpecialToken.UNK.value.word:
result.append(word)
return result