forked from wxwilcke/mrgcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmkdataset.py
67 lines (51 loc) · 2.33 KB
/
mkdataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#!/usr/bin/python3
import logging
import argparse
from time import time
import toml
from data.io.knowledge_graph import KnowledgeGraph
from data.io.tarball import Tarball
from data.utils import is_readable, is_writable
from embeddings import graph_structure
from tasks.node_classification import build_dataset
from tasks.utils import strip_graph
def run(args, config):
logger.info("Generating data structures")
with KnowledgeGraph(path=config['graph']['file']) as kg:
targets = strip_graph(kg, config)
A = graph_structure.generate(kg, config)
X, Y, X_node_map = build_dataset(kg, targets, config)
return (A, X, Y, X_node_map)
def set_logging(args, timestamp):
log_path = args.log_directory
if not is_writable(log_path):
return
filename = "{}{}.log".format(log_path, timestamp) if log_path.endswith("/") \
else "{}/{}.log".format(log_path, timestamp)
logging.basicConfig(filename=filename,
format='%(asctime)s %(levelname)s [%(module)s/%(funcName)s]: %(message)s',
level=logging.INFO)
if args.verbose:
logging.getLogger().addHandler(logging.StreamHandler())
if __name__ == "__main__":
timestamp = int(time())
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="Configuration file (toml)", required=True, default=None)
parser.add_argument("-o", "--output", help="Output file (tar)", default=None)
parser.add_argument("-v", "--verbose", help="Increase output verbosity", action="store_true")
parser.add_argument("--log_directory", help="Where to save the log file", default="../log/")
args = parser.parse_args()
assert is_readable(args.config)
config = toml.load(args.config)
if args.output is None:
args.output = './' + config['name'] + '{}.tar'.format(timestamp)
assert is_writable(args.output)
set_logging(args, timestamp)
logger = logging.getLogger(__name__)
logger.info("Arguments:\n{}".format(
"\n".join(["\t{}: {}".format(arg, getattr(args, arg)) for arg in vars(args)])))
logger.info("Configuration:\n{}".format(
"\n".join(["\t{}: {}".format(k,v) for k,v in config.items()])))
with Tarball(args.output, 'w') as tb:
tb.store(run(args, config), names=['A', 'X', 'Y', 'X_node_map'])
logging.shutdown()