-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_batcher.py
220 lines (171 loc) · 9.01 KB
/
data_batcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
# Copyright 2018 Stanford University
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file contains code to read tokenized data from file,
truncate, pad and process it into batches ready for training"""
from __future__ import absolute_import
from __future__ import division
import random
import time
import re
import numpy as np
from six.moves import xrange
from vocab import PAD_ID, UNK_ID
class Batch(object):
"""A class to hold the information needed for a training batch"""
def __init__(self, context_ids, context_mask, context_tokens, qn_ids, qn_mask, qn_tokens, ans_span, ans_tokens, uuids=None):
"""
Inputs:
{context/qn}_ids: Numpy arrays.
Shape (batch_size, {context_len/question_len}). Contains padding.
{context/qn}_mask: Numpy arrays, same shape as _ids.
Contains 1s where there is real data, 0s where there is padding.
{context/qn/ans}_tokens: Lists length batch_size, containing lists (unpadded) of tokens (strings)
ans_span: numpy array, shape (batch_size, 2)
uuid: a list (length batch_size) of strings.
Not needed for training. Used by official_eval mode.
"""
self.context_ids = context_ids
self.context_mask = context_mask
self.context_tokens = context_tokens
self.qn_ids = qn_ids
self.qn_mask = qn_mask
self.qn_tokens = qn_tokens
self.ans_span = ans_span
self.ans_tokens = ans_tokens
self.uuids = uuids
self.batch_size = len(self.context_tokens)
def split_by_whitespace(sentence):
words = []
for space_separated_fragment in sentence.strip().split():
words.extend(re.split(" ", space_separated_fragment))
return [w for w in words if w]
def intstr_to_intlist(string):
"""Given a string e.g. '311 9 1334 635 6192 56 639', returns as a list of integers"""
return [int(s) for s in string.split()]
def sentence_to_token_ids(sentence, word2id):
"""Turns an already-tokenized sentence string into word indices
e.g. "i do n't know" -> [9, 32, 16, 96]
Note any token that isn't in the word2id mapping gets mapped to the id for UNK
"""
tokens = split_by_whitespace(sentence) # list of strings
ids = [word2id.get(w, UNK_ID) for w in tokens]
return tokens, ids
def padded(token_batch, batch_pad=0):
"""
Inputs:
token_batch: List (length batch size) of lists of ints.
batch_pad: Int. Length to pad to. If 0, pad to maximum length sequence in token_batch.
Returns:
List (length batch_size) of padded of lists of ints.
All are same length - batch_pad if batch_pad!=0, otherwise the maximum length in token_batch
"""
maxlen = max(map(lambda x: len(x), token_batch)) if batch_pad == 0 else batch_pad
return map(lambda token_list: token_list + [PAD_ID] * (maxlen - len(token_list)), token_batch)
def refill_batches(batches, word2id, context_file, qn_file, ans_file, batch_size, context_len, question_len, discard_long):
"""
Adds more batches into the "batches" list.
Inputs:
batches: list to add batches to
word2id: dictionary mapping word (string) to word id (int)
context_file, qn_file, ans_file: paths to {train/dev}.{context/question/answer} data files
batch_size: int. how big to make the batches
context_len, question_len: max length of context and question respectively
discard_long: If True, discard any examples that are longer than context_len or question_len.
If False, truncate those exmaples instead.
"""
print "Refilling batches..."
tic = time.time()
examples = [] # list of (qn_ids, context_ids, ans_span, ans_tokens) triples
context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline() # read the next line from each
while context_line and qn_line and ans_line: # while you haven't reached the end
# Convert tokens to word ids
context_tokens, context_ids = sentence_to_token_ids(context_line, word2id)
qn_tokens, qn_ids = sentence_to_token_ids(qn_line, word2id)
ans_span = intstr_to_intlist(ans_line)
# read the next line from each file
context_line, qn_line, ans_line = context_file.readline(), qn_file.readline(), ans_file.readline()
# get ans_tokens from ans_span
assert len(ans_span) == 2
if ans_span[1] < ans_span[0]:
print "Found an ill-formed gold span: start=%i end=%i" % (ans_span[0], ans_span[1])
continue
ans_tokens = context_tokens[ans_span[0] : ans_span[1]+1] # list of strings
# discard or truncate too-long questions
if len(qn_ids) > question_len:
if discard_long:
continue
else: # truncate
qn_ids = qn_ids[:question_len]
# discard or truncate too-long contexts
if len(context_ids) > context_len:
if discard_long:
continue
else: # truncate
context_ids = context_ids[:context_len]
# add to examples
examples.append((context_ids, context_tokens, qn_ids, qn_tokens, ans_span, ans_tokens))
# stop refilling if you have 160 batches
if len(examples) == batch_size * 160:
break
# Once you've either got 160 batches or you've reached end of file:
# Sort by question length
# Note: if you sort by context length, then you'll have batches which contain the same context many times (because each context appears several times, with different questions)
examples = sorted(examples, key=lambda e: len(e[2]))
# Make into batches and append to the list batches
for batch_start in xrange(0, len(examples), batch_size):
# Note: each of these is a list length batch_size of lists of ints (except on last iter when it might be less than batch_size)
context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch = zip(*examples[batch_start:batch_start+batch_size])
batches.append((context_ids_batch, context_tokens_batch, qn_ids_batch, qn_tokens_batch, ans_span_batch, ans_tokens_batch))
# shuffle the batches
random.shuffle(batches)
toc = time.time()
print "Refilling batches took %.2f seconds" % (toc-tic)
return
def get_batch_generator(word2id, context_path, qn_path, ans_path, batch_size, context_len, question_len, discard_long):
"""
This function returns a generator object that yields batches.
The last batch in the dataset will be a partial batch.
Read this to understand generators and the yield keyword in Python: https://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
Inputs:
word2id: dictionary mapping word (string) to word id (int)
context_file, qn_file, ans_file: paths to {train/dev}.{context/question/answer} data files
batch_size: int. how big to make the batches
context_len, question_len: max length of context and question respectively
discard_long: If True, discard any examples that are longer than context_len or question_len.
If False, truncate those exmaples instead.
"""
context_file, qn_file, ans_file = open(context_path), open(qn_path), open(ans_path)
batches = []
while True:
if len(batches) == 0: # add more batches
refill_batches(batches, word2id, context_file, qn_file, ans_file, batch_size, context_len, question_len, discard_long)
if len(batches) == 0:
break
# Get next batch. These are all lists length batch_size
(context_ids, context_tokens, qn_ids, qn_tokens, ans_span, ans_tokens) = batches.pop(0)
# Pad context_ids and qn_ids
qn_ids = padded(qn_ids, question_len) # pad questions to length question_len
context_ids = padded(context_ids, context_len) # pad contexts to length context_len
# Make qn_ids into a np array and create qn_mask
qn_ids = np.array(qn_ids) # shape (batch_size, question_len)
qn_mask = (qn_ids != PAD_ID).astype(np.int32) # shape (batch_size, question_len)
# Make context_ids into a np array and create context_mask
context_ids = np.array(context_ids) # shape (batch_size, context_len)
context_mask = (context_ids != PAD_ID).astype(np.int32) # shape (batch_size, context_len)
# Make ans_span into a np array
ans_span = np.array(ans_span) # shape (batch_size, 2)
# Make into a Batch object
batch = Batch(context_ids, context_mask, context_tokens, qn_ids, qn_mask, qn_tokens, ans_span, ans_tokens)
yield batch
return