-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataloader.py
175 lines (136 loc) · 5.52 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
import torch
import numpy as np
import pandas as pd
import scipy.sparse as sp
from torch.utils.data import Dataset
class GowallaDataset(Dataset):
def __init__(self, train, path='dataset'):
print('init ' + ('train' if train else 'test') + ' dataset')
self.n_users_ = int(open(f'{path}/user_list.txt').readlines()[-1][:-1].split(' ')[1]) + 1
self.m_items_ = int(open(f'{path}/item_list.txt').readlines()[-1][:-1].split(' ')[1]) + 1
def get_all_users(self):
raise NotImplemented
def get_user_positives(self, user):
raise NotImplemented
def get_user_negatives(self, user, k):
raise NotImplemented
@property
def n_users(self):
return self.n_users_
@property
def m_items(self):
return self.m_items_
def __len__(self):
return self.n_users_
def __getitem__(self, idx):
raise NotImplemented
class GowallaTopNDataset(GowallaDataset):
def __init__(self, path, train=True):
super().__init__(train)
self.df = pd.read_csv(path, names=['userId', 'timestamp', 'long', 'lat', 'loc_id'])
self.unique_users = self.df['userId'].unique()
self.user_positive_items = self.df.groupby('userId')['loc_id'].apply(list).to_dict()
def get_all_users(self):
return self.unique_users
def get_user_positives(self, user):
if user not in self.user_positive_items:
return []
return self.user_positive_items[user]
def get_user_negatives(self, user, k=10):
neg = []
positives = set(self.get_user_positives(user))
while len(neg) < k:
candidate = np.random.randint(1, self.m_items)
if candidate not in positives:
neg.append(candidate)
return neg
class GowallaLightGCNDataset(GowallaDataset):
def __init__(self, path, train=True, n_negatives: int = 10):
super().__init__(train)
self.n_negatives = n_negatives
# dataset = pd.read_csv(path, names=['userId', 'timestamp', 'long', 'lat', 'loc_id'])
dataset = pd.read_csv(path, names=['userId', 'loc_id'])
dataset['feed'] = 1
users = dataset['userId']
items = dataset['loc_id']
feed = dataset['feed']
self.unique_users = users.unique()
self.user_positive_items = dataset.groupby('userId')['loc_id'].apply(list).to_dict()
del dataset
n_nodes = self.n_users + self.m_items
# build scipy sparse matrix
user_np = np.array(users.values, dtype=np.int32)
item_np = np.array(items.values, dtype=np.int32)
ratings = np.array(feed.values, dtype=np.int32)
tmp_adj = sp.csr_matrix((ratings, (user_np, item_np + self.n_users)),
shape=(n_nodes, n_nodes))
adj_mat = tmp_adj + tmp_adj.T
# normalize matrix
rowsum = np.array(adj_mat.sum(1))
d_inv = np.power(rowsum, -0.5).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
# normalize by user counts
norm_adj_tmp = d_mat_inv.dot(adj_mat)
# normalize by item counts
normalized_adj_matrix = norm_adj_tmp.dot(d_mat_inv)
# convert to torch sparse matrix
adj_mat_coo = normalized_adj_matrix.tocoo()
values = adj_mat_coo.data
indices = np.vstack((adj_mat_coo.row, adj_mat_coo.col))
i = torch.LongTensor(indices)
v = torch.FloatTensor(values)
shape = adj_mat_coo.shape
self.adj_matrix = torch.sparse_coo_tensor(i, v, torch.Size(shape))
def get_all_users(self):
return self.unique_users
def get_user_positives(self, user):
if user not in self.user_positive_items:
return []
return self.user_positive_items[user]
def get_user_negatives(self, user, k=10):
neg = []
positives = set(self.get_user_positives(user))
while len(neg) < k:
candidate = np.random.randint(1, self.m_items)
if candidate not in positives:
neg.append(candidate)
return neg
def get_sparse_graph(self):
"""
Returns a grapth in torch.sparse_coo_tensor.
A = |0, R|
|R^T, 0|
"""
return self.adj_matrix
def __len__(self):
return len(self.unique_users)
def __getitem__(self, idx):
"""
returns user, pos_items, neg_items
:param idx: index of user from unique_users
:return:
"""
user = self.unique_users[idx]
pos = np.random.choice(self.get_user_positives(user), self.n_negatives)
neg = self.get_user_negatives(user, self.n_negatives)
return user, pos, neg
class GowallaALSDataset(GowallaDataset):
def __init__(self, path, train=True):
super().__init__(train)
self.path = path
self.train = train
self.df = pd.read_csv(path, names=['userId', 'timestamp', 'long', ' lat', 'loc_id'])
def get_dataset(self, n_users=None, m_items=None):
if self.train:
users = self.df['userId'].values
items = self.df['loc_id'].values
ratings = np.ones(len(users))
n_users = self.n_users if n_users is None else n_users
m_items = self.m_items if m_items is None else m_items
user_item_data = sp.csr_matrix((ratings, (users, items)),
shape=(n_users, m_items))
item_user_data = user_item_data.T.tocsr()
return self.df, user_item_data, item_user_data
else:
return self.df