-
Notifications
You must be signed in to change notification settings - Fork 0
/
vocabulary.py
113 lines (84 loc) · 3.66 KB
/
vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python
import collections
import enum
# ==================================================================================================
# -- special tokens -------------------------------------------------------------------------------
# ==================================================================================================
SpecialTokenTuple = collections.namedtuple('SpecialTokenTuple', ['word', 'index'])
class SpecialToken(enum.Enum):
START = SpecialTokenTuple(word='<start>', index=0)
END = SpecialTokenTuple(word='<end>', index=1)
PAD = SpecialTokenTuple(word='<pad>', index=2)
UNK = SpecialTokenTuple(word='<unk>', index=3)
# ==================================================================================================
# -- vocabulary -----------------------------------------------------------------------------------
# ==================================================================================================
class Vocabulary(object):
def __init__(self, min_freq=1):
self.min_freq = min_freq
self._word2idx = {}
self._idx2word = {}
self._counter = collections.Counter()
# Special tokens are not registered in the internal counter and do not follow the minimum
# frequency rule, so we insert them manually.
for special_token in SpecialToken:
word, idx = special_token.value.word, special_token.value.index
self._word2idx[word] = idx
self._idx2word[idx] = word
def __len__(self):
return len(self._word2idx)
def __contains__(self, word):
return word in self._word2idx
def get_words(self):
"""
Returns the list of vocabulary words sorted by word index.
"""
return [word for (key, word) in sorted(self._idx2word.items(), key=lambda item: item[0])]
def add_word(self, word):
"""
Adds a new word to the vocabulary and updates the internal counter.
:param word: word to be added.
:returns: True if the word is successfully added. Otherwise, False.
"""
# Ensures that the given word is lowercase.
_word = word.lower()
self._counter.update({_word: 1})
if self._counter[_word] >= self.min_freq and not _word in self:
new_idx = len(self._word2idx)
self._word2idx[_word] = new_idx
self._idx2word[new_idx] = _word
return True
return False
def get_word(self, idx):
return self._idx2word.get(idx, SpecialToken.UNK.value.word)
def get_index(self, word):
return self._word2idx.get(word, SpecialToken.UNK.value.index)
def frequency(self, word):
"""
Returns number of occurrences of the given word.
"""
return self._counter[word]
def most_common(self, n=5):
"""
Returns a list of the n most common elements and their counts.
"""
return self._counter.most_common(n)
def least_common(self, n=5):
"""
Returns a list of the n most common elements and their counts.
"""
return self._counter.most_common()[:-n - 1:-1]
def build_flickr8k_vocabulary(ann_file, min_freq=1):
"""
Builds flickr8k vocabulary.
:param ann_file: Annotation file with the tokenized captions.
:param min_freq: Word minimum frequency to be added to the vocabulary.
:returns: vocabulary object.
"""
vocab = Vocabulary(min_freq)
# Processing file with tokenized captions.
with open(ann_file) as f:
for line in f.readlines():
for token in line.split()[1:]:
vocab.add_word(token)
return vocab