Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RNG seed flag to submission runner #514

Merged
merged 10 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict:
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.

Expand Down
24 changes: 17 additions & 7 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@
flags.DEFINE_boolean('save_checkpoints',
True,
'Whether or not to checkpoint the model at every eval.')
flags.DEFINE_integer(
'rng_seed',
None,
'Value of rng seed. If None, a random seed will'
'be generated from hardware.')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down Expand Up @@ -173,6 +178,7 @@ def train_once(
update_params: spec.UpdateParamsFn,
data_selection: spec.DataSelectionFn,
hyperparameters: Optional[spec.Hyperparameters],
rng_seed: int,
rng: spec.RandomState,
profiler: Profiler,
max_global_steps: int = None,
Expand Down Expand Up @@ -267,10 +273,9 @@ def train_once(
global_step,
preemption_count,
checkpoint_dir=log_dir)
meta_data = logger_utils.get_meta_data(workload)
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
logger_utils.write_json(meta_file_name, meta_data)
logger_utils.save_meta_data(workload, rng_seed, preemption_count)
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
logging.info(f'Saving flags to {flag_file_name}.')
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
Expand Down Expand Up @@ -449,7 +454,8 @@ def score_submission_on_workload(workload: spec.Workload,
tuning_search_space: Optional[str] = None,
num_tuning_trials: Optional[int] = None,
log_dir: Optional[str] = None,
save_checkpoints: Optional[bool] = True):
save_checkpoints: Optional[bool] = True,
rng_seed: Optional[int] = None):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
if imagenet_v2_data_dir:
Expand Down Expand Up @@ -496,7 +502,8 @@ def score_submission_on_workload(workload: spec.Workload,
all_metrics = []
for hi, hyperparameters in enumerate(tuning_search_space):
# Generate a new seed from hardware sources of randomness for each trial.
rng_seed = struct.unpack('I', os.urandom(4))[0]
if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]
logging.info('Using RNG seed %d', rng_seed)
rng = prng.PRNGKey(rng_seed)
# Because we initialize the PRNGKey with only a single 32 bit int, in the
Expand Down Expand Up @@ -528,7 +535,9 @@ def score_submission_on_workload(workload: spec.Workload,
data_dir, imagenet_v2_data_dir,
init_optimizer_state,
update_params, data_selection,
hyperparameters, rng,
hyperparameters,
rng_seed,
rng,
profiler,
max_global_steps,
tuning_dir_name,
Expand All @@ -554,7 +563,7 @@ def score_submission_on_workload(workload: spec.Workload,
workload, global_batch_size, global_eval_batch_size,
data_dir, imagenet_v2_data_dir,
init_optimizer_state, update_params, data_selection,
None, rng, profiler, max_global_steps, log_dir,
None, rng_seed, rng, profiler, max_global_steps, log_dir,
save_checkpoints=save_checkpoints)
return score

Expand Down Expand Up @@ -610,7 +619,8 @@ def main(_):
tuning_search_space=FLAGS.tuning_search_space,
num_tuning_trials=FLAGS.num_tuning_trials,
log_dir=logging_dir_path,
save_checkpoints=FLAGS.save_checkpoints)
save_checkpoints=FLAGS.save_checkpoints,
rng_seed=FLAGS.rng_seed)
logging.info(f'Final {FLAGS.workload} score: {score}')

if FLAGS.profile:
Expand Down