forked from rujunhan/ConditionalEmbeddings
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatastream.py
98 lines (68 loc) · 2.5 KB
/
datastream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import torch
from torch.autograd import Variable
class load_data:
def __init__(self, args):
with open(args.source_file, "r") as f:
self.data_size = len(f.readlines())
self.batch_size = args.batch
self.label_map = args.label_map
self.vocab = np.load(str(args.vocab), allow_pickle=True).item()
self.source = open(args.source_file, "r")
self.end_of_data = False
self.skips = args.skips
self.labels = args.label_map.keys()
self.count = 0
def reset(self):
self.source.seek(0)
self.count = 0
raise StopIteration
def __iter__(self):
return self
def __next__(self):
data = []
count = 0
while True:
line = self.source.readline()
if line == "":
print("end of file!")
self.reset()
break
self.count += 1
line = line.strip().split("\t")
label = line[0]
if len(line) < 2:
continue
else:
text = line[1]
text_list = text.split(" ")
if len(text_list) < self.skips * 2 + 1:
continue
elif label in self.labels:
count += 1
for i in range(self.skips, len(text_list) - self.skips):
out_text = (
text_list[i - self.skips : i]
+ text_list[i + 1 : i + self.skips + 1]
)
in_text = text_list[i]
data.append((label, in_text, out_text))
if count >= self.batch_size:
break
# Temporary kludge to avoid memory issues
data = data[:10_000]
in_idxs, out_idxs, covars = self.create_batch(data, self.vocab)
return in_idxs, out_idxs, covars
def create_batch(self, raw_batch, vocab):
all_txt = list(zip(*raw_batch))
idxs = list(map(lambda w: vocab[w], all_txt[1]))
in_idxs = Variable(
torch.LongTensor(idxs).view(len(raw_batch), 1), requires_grad=False
)
idxs = list(map(lambda output: [vocab[w] for w in output], all_txt[2]))
out_idxs = Variable(
torch.LongTensor(idxs).view(len(raw_batch), -1), requires_grad=False
)
cvrs = list(map(lambda c: self.label_map[c], all_txt[0]))
cvrs = Variable(torch.LongTensor(cvrs).view(len(raw_batch), -1))
return in_idxs, out_idxs, cvrs