Skip to content

Commit

Permalink
Merge pull request #755 from mlcommons/juhan/ddp_fix2
Browse files Browse the repository at this point in the history
Allow DDP checkpointing
  • Loading branch information
priyakasimbeg authored Apr 9, 2024
2 parents 04ea6e1 + 80a93bf commit 94b8e54
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 18 deletions.
8 changes: 6 additions & 2 deletions algorithmic_efficiency/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def maybe_restore_checkpoint(framework: str,

else:
checkpoint_state = latest_ckpt
if isinstance(model_params, torch.nn.DataParallel):
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params = model_params.module
model_params.load_state_dict(checkpoint_state['model_params'])
checkpoint_state['model_params'] = model_params
Expand Down Expand Up @@ -196,7 +198,9 @@ def save_checkpoint(framework: str,
opt_state = jax.device_get(jax_utils.unreplicate(opt_state))
model_state = jax.device_get(jax_utils.unreplicate(model_state))
else:
if isinstance(model_params, torch.nn.DataParallel):
if isinstance(
model_params,
(torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
model_params = model_params.module
model_params = model_params.state_dict()
optimizer_state_dict = {}
Expand Down
25 changes: 15 additions & 10 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import GPUtil
import pandas as pd
import psutil
import torch.distributed as dist

from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup
Expand Down Expand Up @@ -43,9 +44,6 @@ def get_log_dir(
resume_last_run: bool,
overwrite: bool,
) -> Optional[str]:
if RANK != 0:
return

# Construct path to experiment workload directory.
experiment_dir = os.path.expanduser(experiment_dir)
workload_dir_name = f'{workload}_{framework}'
Expand All @@ -61,18 +59,25 @@ def get_log_dir(
logging.info(
f'Removing existing experiment directory {experiment_path} because '
'--overwrite was set.')
shutil.rmtree(experiment_path)
if RANK == 0:
shutil.rmtree(experiment_path)
elif resume_last_run:
logging.info(
f'Resuming from experiment directory {experiment_path} because '
'--resume_last_run was set.')
else:
resume = input(
'Found existing experiment dir with the same name: {}. Do you wish '
'to resume training from this dir? [y/N]:'.format(experiment_path))
if resume.lower() != 'y':
sys.exit()

if RANK == 0:
resume = input(
'Found existing experiment dir with the same name: {}. Do you wish '
'to resume training from this dir? [y/N]:'.format(experiment_path))
if resume.lower() != 'y':
sys.exit()

if USE_PYTORCH_DDP:
try:
dist.barrier()
except RuntimeError:
sys.exit()
logging.info(f'Creating experiment directory at {experiment_path}.')
makedir(experiment_path)
return experiment_path
Expand Down
14 changes: 8 additions & 6 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,12 @@ def train_once(
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
logging.info(f'Saving flags to {flag_file_name}.')
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
metrics_logger = logger_utils.set_up_loggers(log_dir,
flags.FLAGS,
hyperparameters)
workload.attach_metrics_logger(metrics_logger)
metrics_logger = None
if RANK == 0:
metrics_logger = logger_utils.set_up_loggers(log_dir,
flags.FLAGS,
hyperparameters)
workload.attach_metrics_logger(metrics_logger)

global_start_time = get_time()
train_state['last_step_end_time'] = global_start_time
Expand Down Expand Up @@ -429,7 +431,7 @@ def train_once(

logging_start_time = get_time()

if log_dir is not None:
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
latest_eval_result,
global_step=global_step,
Expand Down Expand Up @@ -467,7 +469,7 @@ def train_once(

metrics = {'eval_results': eval_results, 'global_step': global_step}

if log_dir is not None:
if log_dir is not None and RANK == 0:
metrics_logger.append_scalar_metrics(
{'score': train_state['accumulated_submission_time']},
global_step=global_step,
Expand Down

0 comments on commit 94b8e54

Please sign in to comment.