-
Notifications
You must be signed in to change notification settings - Fork 0
/
assistance.py
99 lines (74 loc) · 2.62 KB
/
assistance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import pickle
import copy
import numpy as np
CODES = {'<PAD>': 0, '<EOS>': 1, '<UNK>': 2, '<GO>': 3 }
def load_data(path):
"""
Load Dataset from File
"""
input_file = os.path.join(path)
with open(input_file, 'r', encoding='utf-8') as f:
return f.read()
def preprocess_and_save_data(source_path, target_path, text_to_ids):
"""
Preprocess Text Data. Save to to file.
"""
# Preprocess
source_text = load_data(source_path)
target_text = load_data(target_path)
source_text = source_text.lower()
target_text = target_text.lower()
source_vocab_to_int, source_int_to_vocab = create_lookup_tables(source_text)
target_vocab_to_int, target_int_to_vocab = create_lookup_tables(target_text)
source_text, target_text = text_to_ids(source_text, target_text, source_vocab_to_int, target_vocab_to_int)
# Save Data
with open('preprocess.p', 'wb') as out_file:
pickle.dump((
(source_text, target_text),
(source_vocab_to_int, target_vocab_to_int),
(source_int_to_vocab, target_int_to_vocab)), out_file)
def load_preprocess():
"""
Load the Preprocessed Training data and return them in batches of <batch_size> or less
"""
with open('preprocess.p', mode='rb') as in_file:
return pickle.load(in_file)
def create_lookup_tables(text):
"""
Create lookup tables for vocabulary
"""
vocab = set(text.split())
vocab_to_int = copy.copy(CODES)
for v_i, v in enumerate(vocab, len(CODES)):
vocab_to_int[v] = v_i
int_to_vocab = {v_i: v for v, v_i in vocab_to_int.items()}
return vocab_to_int, int_to_vocab
def save_params(params):
"""
Save parameters to file
"""
with open('params.p', 'wb') as out_file:
pickle.dump(params, out_file)
def load_params():
"""
Load parameters from file
"""
with open('params.p', mode='rb') as in_file:
return pickle.load(in_file)
def batch_data(source, target, batch_size):
"""
Batch source and target together
"""
for batch_i in range(0, len(source)//batch_size):
start_i = batch_i * batch_size
source_batch = source[start_i:start_i + batch_size]
target_batch = target[start_i:start_i + batch_size]
yield np.array(pad_sentence_batch(source_batch)), np.array(pad_sentence_batch(target_batch))
def pad_sentence_batch(sentence_batch):
"""
Pad sentence with <PAD> id
"""
max_sentence = max([len(sentence) for sentence in sentence_batch])
return [sentence + [CODES['<PAD>']] * (max_sentence - len(sentence))
for sentence in sentence_batch]