-
Notifications
You must be signed in to change notification settings - Fork 76
/
utils.py
138 lines (109 loc) · 3.77 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""General utilities for training.
Author:
Shrey Desai
"""
import os
import json
import gzip
import pickle
import torch
from tqdm import tqdm
def cuda(args, tensor):
"""
Places tensor on CUDA device (by default, uses cuda:0).
Args:
tensor: PyTorch tensor.
Returns:
Tensor on CUDA device.
"""
if args.use_gpu and torch.cuda.is_available():
return tensor.cuda()
return tensor
def unpack(tensor):
"""
Unpacks tensor into Python list.
Args:
tensor: PyTorch tensor.
Returns:
Python list with tensor contents.
"""
if tensor.requires_grad:
tensor = tensor.detach()
return tensor.cpu().numpy().tolist()
def load_dataset(path):
"""
Loads MRQA-formatted dataset from path.
Args:
path: Dataset path, e.g. "datasets/squad_train.jsonl.gz"
Returns:
Dataset metadata and samples.
"""
with gzip.open(path, 'rb') as f:
elems = [
json.loads(l.rstrip())
for l in tqdm(f, desc=f'loading \'{path}\'', leave=False)
]
meta, samples = elems[0], elems[1:]
return (meta, samples)
def load_cached_embeddings(path):
"""
Loads embedding from pickle cache, if it exists, otherwise embeddings
are loaded into memory and cached for future accesses.
Args:
path: Embedding path, e.g. "glove/glove.6B.300d.txt".
Returns:
Dictionary mapping words (strings) to vectors (list of floats).
"""
bare_path = os.path.splitext(path)[0]
cached_path = f'{bare_path}.pkl'
if os.path.exists(cached_path):
return pickle.load(open(cached_path, 'rb'))
embedding_map = load_embeddings(path)
pickle.dump(embedding_map, open(cached_path, 'wb'))
return embedding_map
def load_embeddings(path):
"""
Loads GloVe-style embeddings into memory. This is *extremely slow* if used
standalone -- `load_cached_embeddings` is almost always preferable.
Args:
path: Embedding path, e.g. "glove/glove.6B.300d.txt".
Returns:
Dictionary mapping words (strings) to vectors (list of floats).
"""
embedding_map = {}
with open(path) as f:
next(f) # Skip header.
for line in f:
try:
pieces = line.rstrip().split()
embedding_map[pieces[0]] = [float(weight) for weight in pieces[1:]]
except:
pass
return embedding_map
def search_span_endpoints(start_probs, end_probs, window=15):
"""
Finds an optimal answer span given start and end probabilities.
Specifically, this algorithm finds the optimal start probability p_s, then
searches for the end probability p_e such that p_s * p_e (joint probability
of the answer span) is maximized. Finally, the search is locally constrained
to tokens lying `window` away from the optimal starting point.
Args:
start_probs: Distribution over start positions.
end_probs: Distribution over end positions.
window: Specifies a context sizefrom which the optimal span endpoint
is chosen from. This hyperparameter follows directly from the
DrQA paper (https://arxiv.org/abs/1704.00051).
Returns:
Optimal starting and ending indices for the answer span. Note that the
chosen end index is *inclusive*.
"""
max_start_index = start_probs.index(max(start_probs))
max_end_index = -1
max_joint_prob = 0.
for end_index in range(len(end_probs)):
if max_start_index <= end_index <= max_start_index + window:
joint_prob = start_probs[max_start_index] * end_probs[end_index]
if joint_prob > max_joint_prob:
max_joint_prob = joint_prob
max_end_index = end_index
return (max_start_index, max_end_index)