Skip to content

Commit

Permalink
Add wandb logger init to hydra runners (#894)
Browse files Browse the repository at this point in the history
* add wandb logger init to hydra runners

* update to reading dict vars

* update to reading dict vars

* get rid of finally clause

* move logger init

* add deprecation comment

* Revert "add deprecation comment"

This reverts commit f9760e7.
  • Loading branch information
rayg1234 authored Nov 26, 2024
1 parent 21af12f commit 72614bf
Showing 1 changed file with 42 additions and 17 deletions.
59 changes: 42 additions & 17 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING

import hydra
from omegaconf import OmegaConf

if TYPE_CHECKING:
import argparse
Expand All @@ -34,29 +35,48 @@


class Submitit(Checkpointable):
def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None:
def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
self.cli_args = cli_args
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
try:
distutils.setup(map_cli_args_to_dist_config(cli_args))
self.runner: Runner = hydra.utils.instantiate(dict_config.runner)
self.runner.load_state()
self.runner.run()
finally:
distutils.cleanup()

def checkpoint(self, *args, **kwargs):
distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args))
self._init_logger()
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
distutils.cleanup()

def _init_logger(self) -> None:
# optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger
# don't start logger if in debug mode
if (
"logger" in self.config
and distutils.is_master()
and not self.config.cli_args.debug
):
# get a partial function from the config and instantiate wandb with it
logger_initializer = hydra.utils.instantiate(self.config.logger)
simple_config = OmegaConf.to_container(
self.config, resolve=True, throw_on_missing=True
)
logger_initializer(
config=simple_config,
run_id=self.config.cli_args.timestamp_id,
run_name=self.config.cli_args.identifier,
log_dir=self.config.cli_args.logdir,
)

def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
# TODO: this is yet to be tested properly
logging.info("Submitit checkpointing callback is triggered")
new_runner = Submitit()
self.runner.save_state()
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(new_runner, self.config, self.cli_args)


def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict:
def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict:
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
Expand All @@ -78,8 +98,8 @@ def get_hydra_config_from_yaml(
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace):
Submitit()(config, cli_args)
def runner_wrapper(config: DictConfig):
Submitit()(config)


# this is meant as a future replacement for the main entrypoint
Expand All @@ -93,6 +113,11 @@ def main(
cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
# override timestamp id and logdir
args.timestamp_id = timestamp_id
args.logdir = log_dir
os.makedirs(log_dir)
OmegaConf.update(cfg, "cli_args", vars(args), force_add=True)
if args.submit: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
Expand All @@ -107,7 +132,7 @@ def main(
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
)
job = executor.submit(runner_wrapper, cfg, args)
job = executor.submit(runner_wrapper, cfg)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}"
)
Expand All @@ -131,8 +156,8 @@ def main(
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg, args)
elastic_launch(launch_config, runner_wrapper)(cfg)
else:
logger.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg, args)
runner_wrapper(cfg)

0 comments on commit 72614bf

Please sign in to comment.