diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index fb7449b99..29c1a821e 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -119,7 +119,9 @@ def maybe_restore_checkpoint(framework: str, else: checkpoint_state = latest_ckpt - if isinstance(model_params, torch.nn.DataParallel): + if isinstance( + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): model_params = model_params.module model_params.load_state_dict(checkpoint_state['model_params']) checkpoint_state['model_params'] = model_params @@ -196,7 +198,9 @@ def save_checkpoint(framework: str, opt_state = jax.device_get(jax_utils.unreplicate(opt_state)) model_state = jax.device_get(jax_utils.unreplicate(model_state)) else: - if isinstance(model_params, torch.nn.DataParallel): + if isinstance( + model_params, + (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): model_params = model_params.module model_params = model_params.state_dict() optimizer_state_dict = {} diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index b7bde226a..609d996e6 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -16,6 +16,7 @@ import GPUtil import pandas as pd import psutil +import torch.distributed as dist from algorithmic_efficiency import spec from algorithmic_efficiency.pytorch_utils import pytorch_setup @@ -43,9 +44,6 @@ def get_log_dir( resume_last_run: bool, overwrite: bool, ) -> Optional[str]: - if RANK != 0: - return - # Construct path to experiment workload directory. experiment_dir = os.path.expanduser(experiment_dir) workload_dir_name = f'{workload}_{framework}' @@ -61,18 +59,25 @@ def get_log_dir( logging.info( f'Removing existing experiment directory {experiment_path} because ' '--overwrite was set.') - shutil.rmtree(experiment_path) + if RANK == 0: + shutil.rmtree(experiment_path) elif resume_last_run: logging.info( f'Resuming from experiment directory {experiment_path} because ' '--resume_last_run was set.') else: - resume = input( - 'Found existing experiment dir with the same name: {}. Do you wish ' - 'to resume training from this dir? [y/N]:'.format(experiment_path)) - if resume.lower() != 'y': - sys.exit() - + if RANK == 0: + resume = input( + 'Found existing experiment dir with the same name: {}. Do you wish ' + 'to resume training from this dir? [y/N]:'.format(experiment_path)) + if resume.lower() != 'y': + sys.exit() + + if USE_PYTORCH_DDP: + try: + dist.barrier() + except RuntimeError: + sys.exit() logging.info(f'Creating experiment directory at {experiment_path}.') makedir(experiment_path) return experiment_path diff --git a/submission_runner.py b/submission_runner.py index a6f8c05a3..551173bf5 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -316,10 +316,12 @@ def train_once( flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) - metrics_logger = logger_utils.set_up_loggers(log_dir, - flags.FLAGS, - hyperparameters) - workload.attach_metrics_logger(metrics_logger) + metrics_logger = None + if RANK == 0: + metrics_logger = logger_utils.set_up_loggers(log_dir, + flags.FLAGS, + hyperparameters) + workload.attach_metrics_logger(metrics_logger) global_start_time = get_time() train_state['last_step_end_time'] = global_start_time @@ -429,7 +431,7 @@ def train_once( logging_start_time = get_time() - if log_dir is not None: + if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( latest_eval_result, global_step=global_step, @@ -467,7 +469,7 @@ def train_once( metrics = {'eval_results': eval_results, 'global_step': global_step} - if log_dir is not None: + if log_dir is not None and RANK == 0: metrics_logger.append_scalar_metrics( {'score': train_state['accumulated_submission_time']}, global_step=global_step,