forked from FAIR-Chem/fairchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
140 lines (117 loc) · 4.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import copy
import os
import time
from pathlib import Path
import submitit
from ocpmodels.common import distutils
from ocpmodels.common.flags import flags
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
build_config,
create_grid,
save_experiment_log,
setup_imports,
)
from ocpmodels.trainers import ForcesTrainer
class Runner(submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
self.chkpt_path = None
def __call__(self, config):
self.config = copy.deepcopy(config)
if args.distributed:
distutils.setup(config)
try:
setup_imports()
trainer = registry.get_trainer_class(
config.get("trainer", "simple")
)(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
optimizer=config["optim"],
identifier=config["identifier"],
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
is_vis=config.get("is_vis", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "tensorboard"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
)
if config["checkpoint"] is not None:
trainer.load_pretrained(config["checkpoint"], config["nonddp"])
# save checkpoint path to runner state for slurm resubmissions
self.chkpt_path = os.path.join(
trainer.config["cmd"]["checkpoint_dir"], "checkpoint.pt"
)
start_time = time.time()
if config["mode"] == "train":
trainer.train()
elif config["mode"] == "predict":
assert (
trainer.test_loader is not None
), "Test dataset is required for making predictions"
assert config["checkpoint"]
results_file = "predictions"
trainer.predict(
trainer.test_loader,
results_file=results_file,
disable_tqdm=False,
)
elif config["mode"] == "run-relaxations":
assert isinstance(
trainer, ForcesTrainer
), "Relaxations are only possible for ForcesTrainer"
assert (
trainer.relax_dataset is not None
), "Relax dataset is required for making predictions"
assert config["checkpoint"]
trainer.run_relaxations()
distutils.synchronize()
if distutils.is_master():
print("Total time taken = ", time.time() - start_time)
finally:
if args.distributed:
distutils.cleanup()
def checkpoint(self, *args, **kwargs):
new_runner = Runner()
if os.path.isfile(self.chkpt_path):
self.config["checkpoint"] = self.chkpt_path
return submitit.helpers.DelayedSubmission(new_runner, self.config)
if __name__ == "__main__":
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
if args.submit: # Run on cluster
if args.sweep_yml: # Run grid search
configs = create_grid(config, args.sweep_yml)
else:
configs = [config]
print(f"Submitting {len(configs)} jobs")
executor = submitit.AutoExecutor(
folder=args.logdir / "%j", slurm_max_num_timeout=3
)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(args.num_workers + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
nodes=args.num_nodes,
)
jobs = executor.map_array(Runner(), configs)
print("Submitted jobs:", ", ".join([job.job_id for job in jobs]))
log_file = save_experiment_log(args, jobs, configs)
print(f"Experiment log saved to: {log_file}")
else: # Run locally
Runner()(config)