forked from dmlc/dgl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
label_utils.py
88 lines (77 loc) · 3.32 KB
/
label_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from collections import defaultdict
import numpy as np
import torch
def remove_unseen_classes_from_training(train_mask, labels, removed_class):
"""Remove the unseen classes (the first three classes by default) to get the zero-shot (i.e., completely imbalanced) label setting
Input: train_mask, labels, removed_class
Output: train_mask_zs: the bool list only containing seen classes
"""
train_mask_zs = train_mask.clone()
for i in range(train_mask_zs.numel()):
if train_mask_zs[i] == 1 and (labels[i].item() in removed_class):
train_mask_zs[i] = 0
return train_mask_zs
def get_class_set(labels):
"""Get the class set.
Input: labels [l, [c1, c2, ..]]
Output:the labeled class set dict_keys([k1, k2, ..])
"""
mydict = {}
for y in labels:
for label in y:
mydict[int(label)] = 1
return mydict.keys()
def get_label_attributes(train_mask_zs, nodeids, labellist, features):
"""Get the class-center (semanic knowledge) of each seen class.
Suppose a node i is labeled as c, then attribute[c] += node_i_attribute, finally mean(attribute[c])
Input: train_mask_zs, nodeids, labellist, features
Output: label_attribute{}: label -> average_labeled_node_features (class centers)
"""
_, feat_num = features.shape
labels = get_class_set(labellist)
label_attribute_nodes = defaultdict(list)
for nodeid, labels in zip(nodeids, labellist):
for label in labels:
label_attribute_nodes[int(label)].append(int(nodeid))
label_attribute = {}
for label in label_attribute_nodes.keys():
nodes = label_attribute_nodes[int(label)]
selected_features = features[nodes, :]
label_attribute[int(label)] = np.mean(selected_features, axis=0)
return label_attribute
def get_labeled_nodes_label_attribute(train_mask_zs, labels, features, cuda):
"""Replace the original labels by their class-centers.
For each label c in the training set, the following operations will be performed:
Get label_attribute{} through function get_label_attributes, then res[i, :] = label_attribute[c]
Input: train_mask_zs, labels, features
Output: Y_{semantic} [l, ft]: tensor
"""
X = torch.LongTensor(range(features.shape[0]))
nodeids = []
labellist = []
for i in X[train_mask_zs].numpy().tolist():
nodeids.append(str(i))
for i in labels[train_mask_zs].cpu().numpy().tolist():
labellist.append([str(i)])
# 1. get the semantic knowledge (class centers) of all seen classes
label_attribute = get_label_attributes(
train_mask_zs=train_mask_zs,
nodeids=nodeids,
labellist=labellist,
features=features.cpu().numpy(),
)
# 2. replace original labels by their class centers (semantic knowledge)
res = np.zeros([len(nodeids), features.shape[1]])
for i, labels in enumerate(labellist):
# support mutiple labels
c = len(labels)
temp = np.zeros([c, features.shape[1]])
for ii, label in enumerate(labels):
temp[ii, :] = label_attribute[int(label)]
temp = np.mean(temp, axis=0)
res[i, :] = temp
if cuda:
res = torch.FloatTensor(res).cuda()
else:
res = torch.FloatTensor(res)
return res