diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index dcc8754a9..559859515 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -279,7 +279,7 @@ def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - # meta_data.update({'rng_seed': rng_seed}) + meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index 8f55cd882..fb8df198c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -177,6 +177,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -273,7 +274,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng, preemption_count) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -533,7 +534,9 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, rng, + hyperparameters, + rng_seed, + rng, profiler, max_global_steps, tuning_dir_name, @@ -559,7 +562,7 @@ def score_submission_on_workload(workload: spec.Workload, workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - None, rng, profiler, max_global_steps, log_dir, + None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) return score