-
Notifications
You must be signed in to change notification settings - Fork 53
/
Copy pathcopyGraph.py
73 lines (61 loc) · 2.74 KB
/
copyGraph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import print_function
import os
import h5py
import numpy as np
import argparse
import scipy.io as sio
from config import get_data_dir
# python 3 compatibility
try:
import cPickle as pickle
except ImportError:
import pickle
# Note that just like in RCC & RCC-DR, the graph is built on original data.
# Once the features are extracted from the pretrained SDAE,
# they are merged along with the mkNN graph data into a single file using this module.
parser = argparse.ArgumentParser(
description='This module is used to merge graph and extracted features into single file')
parser.add_argument('--data', dest='db', type=str, default='mnist', help='name of the dataset')
parser.add_argument('--graph', dest='g', help='path to the graph file', default=None, type=str)
parser.add_argument('--features', dest='feat', help='path to the feature file', default=None, type=str)
parser.add_argument('--out', dest='out', help='path to the output file', default=None, type=str)
parser.add_argument('--h5', dest='h5', action='store_true', help='to store as h5py file')
def main(args):
datadir = get_data_dir(args.db)
featurefile = os.path.join(datadir, args.feat)
graphfile = os.path.join(datadir, args.g)
outputfile = os.path.join(datadir, args.out)
if os.path.isfile(featurefile) and os.path.isfile(graphfile):
if args.h5:
data0 = h5py.File(featurefile, 'r')
data1 = h5py.File(graphfile, 'r')
data2 = h5py.File(outputfile + '.h5', 'w')
else:
fo = open(featurefile, 'rb')
data0 = pickle.load(fo)
data1 = sio.loadmat(graphfile)
fo.close()
x0 = data0['data'][:].astype(np.float32).reshape((len(data0['labels'][:]), -1))
x1 = data1['X'][:].astype(np.float32).reshape((len(data1['gtlabels'].T), -1))
a, b = np.where(x0 - x1)
assert not a.size
joined_data = {'gtlabels': data0['labels'][:], 'X': data0['data'][:].astype(np.float32),
'Z': data0['Z'][:].astype(np.float32),
'w': data1['w'][:].astype(np.float32)}
if args.h5:
data2.create_dataset('gtlabels', data=data0['labels'][:])
data2.create_dataset('X', data=data0['data'][:].astype(np.float32))
data2.create_dataset('Z', data=data0['Z'][:].astype(np.float32))
data2.create_dataset('w', data=data1['w'][:].astype(np.float32))
data0.close()
data1.close()
data2.close()
else:
sio.savemat(outputfile + '.mat', joined_data)
return joined_data
else:
print('one or both the files not found')
raise FileNotFoundError
if __name__ == '__main__':
args = parser.parse_args()
main(args)