forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
35 lines (29 loc) · 789 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
def shuffle_walks(walks):
seeds = torch.randperm(walks.size()[0])
return walks[seeds]
def sum_up_params(model):
"""Count the model parameters"""
n = []
n.append(model.u_embeddings.weight.cpu().data.numel() * 2)
n.append(model.lookup_table.cpu().numel())
n.append(model.index_emb_posu.cpu().numel() * 2)
n.append(model.grad_u.cpu().numel() * 2)
try:
n.append(model.index_emb_negu.cpu().numel() * 2)
except:
pass
try:
n.append(model.state_sum_u.cpu().numel() * 2)
except:
pass
try:
n.append(model.grad_avg.cpu().numel())
except:
pass
try:
n.append(model.context_weight.cpu().numel())
except:
pass
print("#params " + str(sum(n)))
exit()