-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathload_data.py
70 lines (64 loc) · 2.32 KB
/
load_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import scipy.io
import numpy as np
import scipy.sparse
import torch
import csv
import json
from os import path
DATAPATH = path.dirname(path.abspath(__file__)) + '/data/'
def load_fb100(filename):
# e.g. filename = Rutgers89 or Cornell5 or Wisconsin87 or Amherst41
# columns are: student/faculty, gender, major,
# second major/minor, dorm/house, year/ high school
# 0 denotes missing entry
mat = scipy.io.loadmat(DATAPATH + 'facebook100/' + filename + '.mat')
A = mat['A']
metadata = mat['local_info']
return A, metadata
def load_twitch(lang):
assert lang in ('DE', 'ENGB', 'ES', 'FR', 'PTBR', 'RU', 'TW'), 'Invalid dataset'
filepath = f"data/twitch/{lang}"
label = []
node_ids = []
src = []
targ = []
uniq_ids = set()
with open(f"{filepath}/musae_{lang}_target.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
node_id = int(row[5])
# handle FR case of non-unique rows
if node_id not in uniq_ids:
uniq_ids.add(node_id)
label.append(int(row[2]=="True"))
node_ids.append(int(row[5]))
node_ids = np.array(node_ids, dtype=np.int)
with open(f"{filepath}/musae_{lang}_edges.csv", 'r') as f:
reader = csv.reader(f)
next(reader)
for row in reader:
src.append(int(row[0]))
targ.append(int(row[1]))
with open(f"{filepath}/musae_{lang}_features.json", 'r') as f:
j = json.load(f)
src = np.array(src)
targ = np.array(targ)
label = np.array(label)
inv_node_ids = {node_id:idx for (idx, node_id) in enumerate(node_ids)}
reorder_node_ids = np.zeros_like(node_ids)
for i in range(label.shape[0]):
reorder_node_ids[i] = inv_node_ids[i]
n = label.shape[0]
A = scipy.sparse.csr_matrix((np.ones(len(src)),
(np.array(src), np.array(targ))),
shape=(n,n))
features = np.zeros((n,3170))
for node, feats in j.items():
if int(node) >= n:
continue
features[int(node), np.array(feats, dtype=int)] = 1
features = features[:, np.sum(features, axis=0) != 0] # remove zero cols
new_label = label[reorder_node_ids]
label = new_label
return A, label, features