Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wandb logger init to hydra runners #894

Merged
merged 10 commits into from
Nov 26, 2024
59 changes: 42 additions & 17 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import TYPE_CHECKING

import hydra
from omegaconf import OmegaConf
lbluque marked this conversation as resolved.
Show resolved Hide resolved

if TYPE_CHECKING:
import argparse
Expand All @@ -34,29 +35,48 @@


class Submitit(Checkpointable):
lbluque marked this conversation as resolved.
Show resolved Hide resolved
def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None:
def __call__(self, dict_config: DictConfig) -> None:
self.config = dict_config
self.cli_args = cli_args
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_env_vars()
try:
distutils.setup(map_cli_args_to_dist_config(cli_args))
self.runner: Runner = hydra.utils.instantiate(dict_config.runner)
self.runner.load_state()
self.runner.run()
finally:
distutils.cleanup()

def checkpoint(self, *args, **kwargs):
distutils.setup(map_cli_args_to_dist_config(dict_config.cli_args))
self._init_logger()
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
distutils.cleanup()

def _init_logger(self) -> None:
# optionally instantiate a singleton wandb logger, intentionally only supporting the new wandb logger
# don't start logger if in debug mode
if (
"logger" in self.config
and distutils.is_master()
and not self.config.cli_args.debug
):
# get a partial function from the config and instantiate wandb with it
logger_initializer = hydra.utils.instantiate(self.config.logger)
simple_config = OmegaConf.to_container(
self.config, resolve=True, throw_on_missing=True
)
logger_initializer(
config=simple_config,
run_id=self.config.cli_args.timestamp_id,
run_name=self.config.cli_args.identifier,
log_dir=self.config.cli_args.logdir,
)

def checkpoint(self, *args, **kwargs) -> DelayedSubmission:
# TODO: this is yet to be tested properly
logging.info("Submitit checkpointing callback is triggered")
new_runner = Submitit()
self.runner.save_state()
logging.info("Submitit checkpointing callback is completed")
return DelayedSubmission(new_runner, self.config, self.cli_args)


def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict:
def map_cli_args_to_dist_config(cli_args: DictConfig) -> dict:
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
Expand All @@ -78,8 +98,8 @@ def get_hydra_config_from_yaml(
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace):
Submitit()(config, cli_args)
def runner_wrapper(config: DictConfig):
Submitit()(config)


# this is meant as a future replacement for the main entrypoint
Expand All @@ -93,6 +113,11 @@ def main(
cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
# override timestamp id and logdir
args.timestamp_id = timestamp_id
args.logdir = log_dir
os.makedirs(log_dir)
OmegaConf.update(cfg, "cli_args", vars(args), force_add=True)
if args.submit: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
Expand All @@ -107,7 +132,7 @@ def main(
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
)
job = executor.submit(runner_wrapper, cfg, args)
job = executor.submit(runner_wrapper, cfg)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job.job_id}, logs: {log_dir}"
)
Expand All @@ -131,8 +156,8 @@ def main(
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg, args)
elastic_launch(launch_config, runner_wrapper)(cfg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just out of curiosity is the runner_wrapper just to make elastic_launch happy?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it can only deal with a top level function

else:
logger.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg, args)
runner_wrapper(cfg)