From df542c2ea88bd7d89ef9547596005fa7ff9ae140 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 18 Sep 2023 23:48:28 +0000 Subject: [PATCH 01/10] add rng_seed flag and save seed to metadata --- algorithmic_efficiency/logger_utils.py | 8 ++++++++ submission_runner.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index af2e61581..39c039f18 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,6 +275,14 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data +def save_meta_data(workload: spec.Workload, + rng_seed: int, + preemption_count: int): + meta_data = get_meta_data(workload) + meta_data.update({'rng_seed': rng_seed}) + meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') + write_json(meta_file_name, meta_data) + class MetricLogger(object): """Used to log all measurements during training. diff --git a/submission_runner.py b/submission_runner.py index f4ee32ede..bed3d1e22 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,6 +133,10 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer('rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -267,10 +271,8 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) - meta_data = logger_utils.get_meta_data(workload) - meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.write_json(meta_file_name, meta_data) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -449,7 +451,8 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True): + save_checkpoints: Optional[bool] = True, + rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -496,7 +499,8 @@ def score_submission_on_workload(workload: spec.Workload, all_metrics = [] for hi, hyperparameters in enumerate(tuning_search_space): # Generate a new seed from hardware sources of randomness for each trial. - rng_seed = struct.unpack('I', os.urandom(4))[0] + if not rng_seed: + rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) rng = prng.PRNGKey(rng_seed) # Because we initialize the PRNGKey with only a single 32 bit int, in the @@ -610,7 +614,8 @@ def main(_): tuning_search_space=FLAGS.tuning_search_space, num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints) + save_checkpoints=FLAGS.save_checkpoints, + rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From 5ff2ec2f7affd2091fcddc1b227122f34fae2a78 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 18 Sep 2023 23:56:13 +0000 Subject: [PATCH 02/10] fix --- algorithmic_efficiency/logger_utils.py | 3 +-- submission_runner.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 39c039f18..559859515 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -277,10 +277,9 @@ def get_meta_data(workload: spec.Workload) -> dict: def save_meta_data(workload: spec.Workload, rng_seed: int, - preemption_count: int): + meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) - meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') write_json(meta_file_name, meta_data) class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index bed3d1e22..1f4bcf603 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -271,6 +271,7 @@ def train_once( global_step, preemption_count, checkpoint_dir=log_dir) + meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') From df016238f4d9b7e1a6ff21d90e155104ea1fffd1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:02:52 +0000 Subject: [PATCH 03/10] fix --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 1f4bcf603..8f55cd882 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -273,7 +273,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng_seed, preemption_count) + logger_utils.save_meta_data(workload, rng, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) From 87ecd5b477b4e99484fe53216fff2282651155b6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:05:58 +0000 Subject: [PATCH 04/10] debug --- algorithmic_efficiency/logger_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 559859515..dcc8754a9 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -279,7 +279,7 @@ def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - meta_data.update({'rng_seed': rng_seed}) + # meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) class MetricLogger(object): From 828765cbf400e05641a2641beb548baea6a60939 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 00:12:08 +0000 Subject: [PATCH 05/10] fix --- algorithmic_efficiency/logger_utils.py | 2 +- submission_runner.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index dcc8754a9..559859515 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -279,7 +279,7 @@ def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) - # meta_data.update({'rng_seed': rng_seed}) + meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) class MetricLogger(object): diff --git a/submission_runner.py b/submission_runner.py index 8f55cd882..fb8df198c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -177,6 +177,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -273,7 +274,7 @@ def train_once( checkpoint_dir=log_dir) meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json') logging.info(f'Saving meta data to {meta_file_name}.') - logger_utils.save_meta_data(workload, rng, preemption_count) + logger_utils.save_meta_data(workload, rng_seed, preemption_count) flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json') logging.info(f'Saving flags to {flag_file_name}.') logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict()) @@ -533,7 +534,9 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, rng, + hyperparameters, + rng_seed, + rng, profiler, max_global_steps, tuning_dir_name, @@ -559,7 +562,7 @@ def score_submission_on_workload(workload: spec.Workload, workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - None, rng, profiler, max_global_steps, log_dir, + None, rng_seed, rng, profiler, max_global_steps, log_dir, save_checkpoints=save_checkpoints) return score From f861353ccc7bf7a50b064ebb207163758ba13074 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:43:04 +0000 Subject: [PATCH 06/10] lint fix --- algorithmic_efficiency/logger_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 559859515..18652dcaa 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,8 +275,8 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data -def save_meta_data(workload: spec.Workload, - rng_seed: int, +def save_meta_data(workload: spec.Workload, + rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) From 18a8c20362b71c3a90361240832380348fcb7cfc Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:53:16 +0000 Subject: [PATCH 07/10] pylint --- submission_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index fb8df198c..af3741812 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -134,7 +134,7 @@ True, 'Whether or not to checkpoint the model at every eval.') flags.DEFINE_integer('rng_seed', - None, + None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') FLAGS = flags.FLAGS @@ -177,7 +177,7 @@ def train_once( update_params: spec.UpdateParamsFn, data_selection: spec.DataSelectionFn, hyperparameters: Optional[spec.Hyperparameters], - rng_seed: int, + rng_seed: int, rng: spec.RandomState, profiler: Profiler, max_global_steps: int = None, @@ -534,7 +534,7 @@ def score_submission_on_workload(workload: spec.Workload, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, - hyperparameters, + hyperparameters, rng_seed, rng, profiler, From 25b05b848a52ce4423772ba5843c69d1fca3414d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 19 Sep 2023 21:55:57 +0000 Subject: [PATCH 08/10] formatting --- algorithmic_efficiency/logger_utils.py | 5 ++--- submission_runner.py | 9 +++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index 18652dcaa..2b3cf86f6 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -275,13 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict: return meta_data -def save_meta_data(workload: spec.Workload, - rng_seed: int, - meta_file_name: str): +def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str): meta_data = get_meta_data(workload) meta_data.update({'rng_seed': rng_seed}) write_json(meta_file_name, meta_data) + class MetricLogger(object): """Used to log all measurements during training. diff --git a/submission_runner.py b/submission_runner.py index af3741812..8096eeda3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -133,10 +133,11 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') -flags.DEFINE_integer('rng_seed', - None, - 'Value of rng seed. If None, a random seed will' - 'be generated from hardware.') +flags.DEFINE_integer( + 'rng_seed', + None, + 'Value of rng seed. If None, a random seed will' + 'be generated from hardware.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() From d54b8660c8a05f7ada6a0385cfc212bc2fa0115b Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 20 Sep 2023 22:18:17 +0000 Subject: [PATCH 09/10] pass rng_seed arg for self-tuning submission as well --- submission_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 8096eeda3..2289d39d3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -554,7 +554,8 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: - rng_seed = struct.unpack('q', os.urandom(8))[0] + if not rng_seed: + rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) # If the submission is responsible for tuning itself, we only need to run it # once and return the total time. From a7b60fa9fead9b453245dce35e5487deacef09be Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 21 Sep 2023 00:34:16 +0000 Subject: [PATCH 10/10] pin ml_dytpes version --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 6f53cd51b..a7ce5ebb2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -115,6 +115,7 @@ jax_core_deps = # Todo(kasimbeg): verify if this is necessary after we # upgrade jax. chex==0.1.7 + ml_dtypes==0.2.0 # JAX CPU jax_cpu =