forked from nathanwispinski/meta-rl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
64 lines (49 loc) · 1.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""Launch distributed training or testing."""
import sys
from absl import app, flags
from ml_collections.config_flags import config_flags
import numpy as np
import jax
import modules.managers as managers
import modules.workers as workers
import modules.evaluators as evaluators
FLAGS = flags.FLAGS
config_flags.DEFINE_config_file('config') # NOTE: this is prod
def main(_):
"""Launch distributed training or testing.
Launches one manager, some number of workers, and some number of evaluators.
Workers take a copy of the agent and run it through their own environment.
Workers then compute the gradient on their experience and send gradients to the manager.
Manager updates the global agent parameters with each worker's gradients.
Evaluators take a periodic copy of the agent and evaluate agent performance.
"""
jax.config.update('jax_platform_name', 'cpu') # Make sure main() runs on CPU
config = FLAGS.config
phase = config.phase
# Random seeds
random_seed = config.random_seed
np.random.seed(random_seed)
started_manager, manager, config = managers.create_manager(
manager_type=phase,
config=config
)
all_workers = workers.create_workers(
worker_type=phase,
config=config,
manager=manager,
)
all_evaluators = evaluators.create_evaluators(
evaluator_type=phase,
config=config,
manager=manager,
)
all_subprocesses = all_workers + all_evaluators
for subprocess in all_subprocesses:
subprocess.start()
for subprocess in all_subprocesses:
subprocess.join()
# All workers done, save and shutdown manager
started_manager.shutdown()
sys.exit()
if __name__ == "__main__":
app.run(main)