-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathmain.py
33 lines (23 loc) · 1.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import argparse
def main(work_type_args):
if work_type_args.type == 'classification_TU':
from parsers.classification_TU import Parser
from trainers.trainer_classification_TU import Trainer
elif work_type_args.type == 'classification_OGB':
from parsers.classification_OGB import Parser
from trainers.trainer_classification_OGB import Trainer
elif work_type_args.type == 'reconstruction_ZINC':
from parsers.reconstruction_ZINC import Parser
from trainers.trainer_reconstruction_ZINC import Trainer
elif work_type_args.type == 'reconstruction_synthetic':
from parsers.reconstruction_synthetic import Parser
from trainers.trainer_reconstruction_synthetic import Trainer
else:
raise ValueError("Work Type Name <{}> is Unknown".format(work_type_args.type))
args = Parser().parse()
trainer = Trainer(args)
trainer.train()
if __name__ == '__main__':
work_type_parser = argparse.ArgumentParser()
work_type_parser.add_argument('--type', type=str, required=True)
main(work_type_parser.parse_known_args()[0])