-
Notifications
You must be signed in to change notification settings - Fork 1
/
generic_singletoken_bigram.py
140 lines (121 loc) · 5.24 KB
/
generic_singletoken_bigram.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import json
import random
from nltk import MLEProbDist, FreqDist
excludes = ["<s>", "</s>", "。", ",", "、", ":", "“", "”", "?"]
class GenericSingleTokenBigramModel():
def __init__(self) -> None:
self.estimator = estimator = lambda fdist, bins: MLEProbDist(fdist)
self.raw_data = self.load_data("宋词三百首.json", "paragraphs")
self.cleaned_tokens = self.clean_token(self.raw_data)
self.unigrams = self.build_unigrams(self.cleaned_tokens)
self.bigrams = self.build_bigrams(self.cleaned_tokens)
def load_data(self, data_file: str, data_content_tag: str) -> list:
with open(data_file, 'r', encoding="utf8") as raw:
data = json.load(raw)
contents = [content[data_content_tag] for content in data]
return contents
def clean_token(self, contents: list) -> list:
tokens = [item for sublist in contents for item in sublist]
cleaned_tokens = []
for token in tokens:
cleaned_token = []
for c in token:
if c not in excludes:
cleaned_token.append(c)
cleaned_token = "".join(cleaned_token)
cleaned_tokens.append(cleaned_token)
return cleaned_tokens
# Note: added normalization for FreqDist values!
def build_bigrams(self, tokens: list) -> FreqDist:
bigrams = []
for token in tokens:
token_bigrams = []
for i in range(len(token)):
if i == 0:
bigram_list = ["<s>", token[i]]
token_bigrams.append("".join(bigram_list))
else:
bigram_list = [token[i - 1], token[i]]
token_bigrams.append("".join(bigram_list))
final_list = [token[i], "</s>"]
token_bigrams.append("".join(final_list))
bigrams += token_bigrams
# normalize step
freq_dist = FreqDist(bigrams)
total = freq_dist.N()
for word in freq_dist:
freq_dist[word] /= float(total)
return freq_dist
def build_unigrams(self, tokens: list) -> FreqDist:
unigrams = []
for token in tokens:
token_unigrams = []
for char in token:
token_unigrams.append(char)
unigrams += token_unigrams
# normalize step
freq_dist = FreqDist(unigrams)
total = freq_dist.N()
for word in freq_dist:
freq_dist[word] /= float(total)
return freq_dist
def generate(self) -> str:
"""
Idea: choose a unigram from seen unigrams, go to bigrams which start with this unigram, and sample a "second char" from the bigrams?
"""
# determine first char
chosen_unigram = None
p = random.random()
# print("random p is: " + str(p))
p_cur = p
for unigram in self.unigrams:
uni_p = self.unigrams.freq(unigram)
# print(uni_p)
p_cur -= uni_p
# print(p_cur)
if p_cur <= 0:
# return unigram # shows that this part works
chosen_unigram = unigram
break
# break
# return chosen_unigram
if chosen_unigram is None:
print("Unigram probability fails to determine a proper first character, randomly sampling...")
chosen_unigram = random.choice(self.unigrams)
# determine second char
p = random.random()
p_cur = p
related_bigrams, sum_of_related_bigram_probs = self.bigrams_starting_with(chosen_unigram)
# TODO: need to normalize for these bigrams as well!
for bigram in related_bigrams:
bi_p = self.bigrams.freq(bigram) / sum_of_related_bigram_probs
p_cur -= bi_p
if p_cur <= 0:
return bigram
print("Bigram probability fails to determine a proper bigram for the leading unigram " + chosen_unigram + ", randomly selecting from related bigrams...")
return random.choice(related_bigrams)
def bigrams_starting_with(self, leading_char: str):
wanted_bigrams = []
sum_of_prob = 0
for bigram in self.bigrams: # .keys()?
if bigram[0] == leading_char:
wanted_bigrams.append(bigram)
sum_of_prob += self.bigrams.freq(bigram)
return wanted_bigrams, sum_of_prob
def generate_multiple(self, number: int) -> list:
generated = []
for i in range(number):
generated.append(
self.generate()
)
return generated
def check_normalization(self):
bigram_sum = sum(self.bigrams.values())
unigram_sum = sum(self.unigrams.values())
print("Bigram sum: " + str(bigram_sum))
print("Unigram sum: " + str(unigram_sum))
if __name__ == "__main__":
bm = GenericSingleTokenBigramModel()
# in memorial to "舷独", the fixed result of early implementation which made me realize that the problem is normalization of probabilities.
generated = bm.generate_multiple(10)
print(generated)