From 2a231ef91a2e38e7a6ae4758601d27a758a3451c Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 24 Oct 2024 22:36:08 +0000 Subject: [PATCH 1/7] add wandb logger init to hydra runners --- src/fairchem/core/_cli_hydra.py | 51 ++++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 79992616d..b630ad405 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING import hydra +from omegaconf import OmegaConf if TYPE_CHECKING: import argparse @@ -32,14 +33,32 @@ class Submitit(Checkpointable): - def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None: + def __call__(self, dict_config: DictConfig) -> None: self.config = dict_config - self.cli_args = cli_args # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() try: - distutils.setup(map_cli_args_to_dist_config(cli_args)) + distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger + # don't start logger if in debug mode + if ( + "logger" in dict_config + and distutils.is_master() + and not dict_config.cli_args.debug + ): + # get a partial function from the config and instantiate wandb with it + logger_initializer = hydra.utils.instantiate(dict_config.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.cli_args.timestamp_id, + run_name=self.config.cli_args.identifier, + log_dir=self.config.cli_args.logdir, + ) + runner: Runner = hydra.utils.instantiate(dict_config.runner) runner.load_state() runner.run() @@ -47,6 +66,7 @@ def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> Non distutils.cleanup() def checkpoint(self, *args, **kwargs): + # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") new_runner = Runner() new_runner.save_state() @@ -54,13 +74,13 @@ def checkpoint(self, *args, **kwargs): return DelayedSubmission(new_runner, self.config) -def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict: +def map_cli_args_to_dist_config(cli_args: dict) -> dict: return { - "world_size": cli_args.num_nodes * cli_args.num_gpus, - "distributed_backend": "gloo" if cli_args.cpu else "nccl", - "submit": cli_args.submit, + "world_size": cli_args["num_nodes"] * cli_args["num_gpus"], + "distributed_backend": "gloo" if cli_args["cpu"] else "nccl", + "submit": cli_args["submit"], "summit": None, - "cpu": cli_args.cpu, + "cpu": cli_args["cpu"], "use_cuda_visibile_devices": True, } @@ -76,8 +96,8 @@ def get_hydra_config_from_yaml( return hydra.compose(config_name=config_name, overrides=overrides_args) -def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace): - Submitit()(config, cli_args) +def runner_wrapper(config: DictConfig): + Submitit()(config) # this is meant as a future replacement for the main entrypoint @@ -91,6 +111,11 @@ def main( cfg = get_hydra_config_from_yaml(args.config_yml, override_args) timestamp_id = get_timestamp_uid() log_dir = os.path.join(args.run_dir, timestamp_id, "logs") + # override timestamp id and logdir + args.timestamp_id = timestamp_id + args.logdir = log_dir + os.makedirs(log_dir) + OmegaConf.update(cfg, "cli_args", vars(args), force_add=True) if args.submit: # Run on cluster executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3) executor.update_parameters( @@ -105,7 +130,7 @@ def main( slurm_qos=args.slurm_qos, slurm_account=args.slurm_account, ) - job = executor.submit(runner_wrapper, cfg, args) + job = executor.submit(runner_wrapper, cfg) logger.info( f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}" ) @@ -119,8 +144,8 @@ def main( rdzv_backend="c10d", max_restarts=0, ) - elastic_launch(launch_config, runner_wrapper)(cfg, args) + elastic_launch(launch_config, runner_wrapper)(cfg) else: logger.info("Running in local mode without elastic launch") distutils.setup_env_local() - runner_wrapper(cfg, args) + runner_wrapper(cfg) From 568a8e29525b418045972212f285e8c8ece55a6e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 24 Oct 2024 22:40:52 +0000 Subject: [PATCH 2/7] update to reading dict vars --- src/fairchem/core/_cli_hydra.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index b630ad405..1e64b2a6e 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -38,6 +38,9 @@ def __call__(self, dict_config: DictConfig) -> None: # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() + import pdb + + pdb.set_trace() try: distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger @@ -74,13 +77,13 @@ def checkpoint(self, *args, **kwargs): return DelayedSubmission(new_runner, self.config) -def map_cli_args_to_dist_config(cli_args: dict) -> dict: +def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict: return { - "world_size": cli_args["num_nodes"] * cli_args["num_gpus"], - "distributed_backend": "gloo" if cli_args["cpu"] else "nccl", - "submit": cli_args["submit"], + "world_size": cli_args.num_nodes * cli_args.num_gpus, + "distributed_backend": "gloo" if cli_args.cpu else "nccl", + "submit": cli_args.submit, "summit": None, - "cpu": cli_args["cpu"], + "cpu": cli_args.cpu, "use_cuda_visibile_devices": True, } From 6c8ecd09183ed2e5c2ded36dd2a1e8648532c572 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 24 Oct 2024 22:41:05 +0000 Subject: [PATCH 3/7] update to reading dict vars --- src/fairchem/core/_cli_hydra.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 1e64b2a6e..f78fa0f06 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -38,9 +38,6 @@ def __call__(self, dict_config: DictConfig) -> None: # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() - import pdb - - pdb.set_trace() try: distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger From ffdea382db02d2ba3abb8fe0b3adf930befbc24e Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 24 Oct 2024 22:52:53 +0000 Subject: [PATCH 4/7] get rid of finally clause --- src/fairchem/core/_cli_hydra.py | 50 ++++++++++++++++----------------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index f78fa0f06..17d27564f 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -38,32 +38,30 @@ def __call__(self, dict_config: DictConfig) -> None: # TODO: setup_imports is not needed if we stop instantiating models with Registry. setup_imports() setup_env_vars() - try: - distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) - # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger - # don't start logger if in debug mode - if ( - "logger" in dict_config - and distutils.is_master() - and not dict_config.cli_args.debug - ): - # get a partial function from the config and instantiate wandb with it - logger_initializer = hydra.utils.instantiate(dict_config.logger) - simple_config = OmegaConf.to_container( - self.config, resolve=True, throw_on_missing=True - ) - logger_initializer( - config=simple_config, - run_id=self.config.cli_args.timestamp_id, - run_name=self.config.cli_args.identifier, - log_dir=self.config.cli_args.logdir, - ) - - runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.load_state() - runner.run() - finally: - distutils.cleanup() + distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger + # don't start logger if in debug mode + if ( + "logger" in dict_config + and distutils.is_master() + and not dict_config.cli_args.debug + ): + # get a partial function from the config and instantiate wandb with it + logger_initializer = hydra.utils.instantiate(dict_config.logger) + simple_config = OmegaConf.to_container( + self.config, resolve=True, throw_on_missing=True + ) + logger_initializer( + config=simple_config, + run_id=self.config.cli_args.timestamp_id, + run_name=self.config.cli_args.identifier, + log_dir=self.config.cli_args.logdir, + ) + + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + distutils.cleanup() def checkpoint(self, *args, **kwargs): # TODO: this is yet to be tested properly From 0c3fb348872b9d526df90baa37d0676405772ec6 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Thu, 24 Oct 2024 22:55:50 +0000 Subject: [PATCH 5/7] move logger init --- src/fairchem/core/_cli_hydra.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/fairchem/core/_cli_hydra.py b/src/fairchem/core/_cli_hydra.py index 17d27564f..766bcadb2 100644 --- a/src/fairchem/core/_cli_hydra.py +++ b/src/fairchem/core/_cli_hydra.py @@ -39,15 +39,22 @@ def __call__(self, dict_config: DictConfig) -> None: setup_imports() setup_env_vars() distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args)) + self._init_logger() + runner: Runner = hydra.utils.instantiate(dict_config.runner) + runner.load_state() + runner.run() + distutils.cleanup() + + def _init_logger(self) -> None: # optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger # don't start logger if in debug mode if ( - "logger" in dict_config + "logger" in self.config and distutils.is_master() - and not dict_config.cli_args.debug + and not self.config.cli_args.debug ): # get a partial function from the config and instantiate wandb with it - logger_initializer = hydra.utils.instantiate(dict_config.logger) + logger_initializer = hydra.utils.instantiate(self.config.logger) simple_config = OmegaConf.to_container( self.config, resolve=True, throw_on_missing=True ) @@ -58,12 +65,7 @@ def __call__(self, dict_config: DictConfig) -> None: log_dir=self.config.cli_args.logdir, ) - runner: Runner = hydra.utils.instantiate(dict_config.runner) - runner.load_state() - runner.run() - distutils.cleanup() - - def checkpoint(self, *args, **kwargs): + def checkpoint(self, *args, **kwargs) -> DelayedSubmission: # TODO: this is yet to be tested properly logging.info("Submitit checkpointing callback is triggered") new_runner = Runner() From f9760e70e05b082e0d2b90344f196f889f9b1424 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 25 Nov 2024 21:17:15 +0000 Subject: [PATCH 6/7] add deprecation comment --- src/fairchem/core/trainers/ocp_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 9a13faed6..3b54a069d 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -79,9 +79,7 @@ def __init__( loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, - # TODO: dealing with local rank is dangerous - # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use - local_rank: int, + local_rank: int, # DEPRECATED, DO NOT USE timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, From 42833c98b9123af8c8355888ee0c4f9517e31506 Mon Sep 17 00:00:00 2001 From: Ray Gao Date: Mon, 25 Nov 2024 21:17:51 +0000 Subject: [PATCH 7/7] Revert "add deprecation comment" This reverts commit f9760e70e05b082e0d2b90344f196f889f9b1424. --- src/fairchem/core/trainers/ocp_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 3b54a069d..9a13faed6 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -79,7 +79,9 @@ def __init__( loss_functions: dict[str, str | float], evaluation_metrics: dict[str, str], identifier: str, - local_rank: int, # DEPRECATED, DO NOT USE + # TODO: dealing with local rank is dangerous + # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use + local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False,