Skip to content

Commit

Permalink
Merge branch 'dev' into singularity
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 21, 2023
2 parents e7b854c + d3fcbb6 commit 5c48b2b
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 17 deletions.
6 changes: 6 additions & 0 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,12 @@ def get_meta_data(workload: spec.Workload) -> dict:
return meta_data


def save_meta_data(workload: spec.Workload, rng_seed: int, meta_file_name: str):
meta_data = get_meta_data(workload)
meta_data.update({'rng_seed': rng_seed})
write_json(meta_file_name, meta_data)


class MetricLogger(object):
"""Used to log all measurements during training.
Expand Down
2 changes: 1 addition & 1 deletion datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ python3 datasets/dataset_setup.py \
--imagenet \
--temp_dir $DATA_DIR/tmp \
--imagenet_train_url <imagenet_train_url> \
--imagenet_val_url <imagenet_val_url\
--imagenet_val_url <imagenet_val_url> \
--framework jax

```
Expand Down
14 changes: 6 additions & 8 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def download_criteo1tb(data_dir,
logging.info(f'Running Criteo 1TB unzip command:\n{unzip_cmd}')
p = subprocess.Popen(unzip_cmd, shell=True)
p.communicate()
_maybe_prompt_for_deletion(all_days_zip_filepath, interactive_deletion)
_maybe_prompt_for_deletion([all_days_zip_filepath], interactive_deletion)

# Unzip the individual days.
processes = []
Expand All @@ -316,9 +316,9 @@ def download_criteo1tb(data_dir,
_maybe_prompt_for_deletion(gz_paths, interactive_deletion)

# Split into files with 5M lines each: day_1.csv -> day_1_[0-39].csv.
unzipped_paths = []
for batch in range(6):
batch_processes = []
unzipped_paths = []
for day_offset in range(4):
day = batch * 4 + day_offset
unzipped_path = os.path.join(criteo_dir, f'day_{day}.csv')
Expand All @@ -330,7 +330,7 @@ def download_criteo1tb(data_dir,
batch_processes.append(subprocess.Popen(split_cmd, shell=True))
for p in batch_processes:
p.communicate()
_maybe_prompt_for_deletion(unzipped_paths, interactive_deletion)
_maybe_prompt_for_deletion(unzipped_paths, interactive_deletion)


def download_cifar(data_dir, framework):
Expand Down Expand Up @@ -567,14 +567,12 @@ def download_librispeech(dataset_dir, tmp_dir):
# After extraction the result is a folder named Librispeech containing audio
# files in .flac format along with transcripts containing name of audio file
# and corresponding transcription.
# tmp_librispeech_dir = os.path.join(dataset_dir, 'librispeech')
# extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech')
# final_data_dir = os.path.join(dataset_dir, 'librispeech_processed')
tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech_raw')
extracted_data_dir = os.path.join(tmp_dir, 'librispeech_extracted')
tmp_librispeech_dir = os.path.join(tmp_dir, 'librispeech')
extracted_data_dir = os.path.join(tmp_librispeech_dir, 'LibriSpeech')
final_data_dir = os.path.join(dataset_dir, 'librispeech')

_maybe_mkdir(tmp_librispeech_dir)
_maybe_mkdir(final_data_dir)

for split in ['dev', 'test']:
for version in ['clean', 'other']:
Expand Down
1 change: 1 addition & 0 deletions datasets/librispeech_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
'train-clean-360': 104014,
'train-other-500': 148688,
'test-clean': 2620,
'test-other': 2939,
'dev-clean': 2703,
'dev-other': 2864,
}
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ jax_core_deps =
# Todo(kasimbeg): verify if this is necessary after we
# upgrade jax.
chex==0.1.7
ml_dtypes==0.2.0

# JAX CPU
jax_cpu =
Expand Down
27 changes: 19 additions & 8 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@
flags.DEFINE_boolean('save_checkpoints',
True,
'Whether or not to checkpoint the model at every eval.')
flags.DEFINE_integer(
'rng_seed',
None,
'Value of rng seed. If None, a random seed will'
'be generated from hardware.')
FLAGS = flags.FLAGS
USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()

Expand Down Expand Up @@ -173,6 +178,7 @@ def train_once(
update_params: spec.UpdateParamsFn,
data_selection: spec.DataSelectionFn,
hyperparameters: Optional[spec.Hyperparameters],
rng_seed: int,
rng: spec.RandomState,
profiler: Profiler,
max_global_steps: int = None,
Expand Down Expand Up @@ -267,10 +273,9 @@ def train_once(
global_step,
preemption_count,
checkpoint_dir=log_dir)
meta_data = logger_utils.get_meta_data(workload)
meta_file_name = os.path.join(log_dir, f'meta_data_{preemption_count}.json')
logging.info(f'Saving meta data to {meta_file_name}.')
logger_utils.write_json(meta_file_name, meta_data)
logger_utils.save_meta_data(workload, rng_seed, preemption_count)
flag_file_name = os.path.join(log_dir, f'flags_{preemption_count}.json')
logging.info(f'Saving flags to {flag_file_name}.')
logger_utils.write_json(flag_file_name, flags.FLAGS.flag_values_dict())
Expand Down Expand Up @@ -449,7 +454,8 @@ def score_submission_on_workload(workload: spec.Workload,
tuning_search_space: Optional[str] = None,
num_tuning_trials: Optional[int] = None,
log_dir: Optional[str] = None,
save_checkpoints: Optional[bool] = True):
save_checkpoints: Optional[bool] = True,
rng_seed: Optional[int] = None):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
if imagenet_v2_data_dir:
Expand Down Expand Up @@ -496,7 +502,8 @@ def score_submission_on_workload(workload: spec.Workload,
all_metrics = []
for hi, hyperparameters in enumerate(tuning_search_space):
# Generate a new seed from hardware sources of randomness for each trial.
rng_seed = struct.unpack('I', os.urandom(4))[0]
if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]
logging.info('Using RNG seed %d', rng_seed)
rng = prng.PRNGKey(rng_seed)
# Because we initialize the PRNGKey with only a single 32 bit int, in the
Expand Down Expand Up @@ -528,7 +535,9 @@ def score_submission_on_workload(workload: spec.Workload,
data_dir, imagenet_v2_data_dir,
init_optimizer_state,
update_params, data_selection,
hyperparameters, rng,
hyperparameters,
rng_seed,
rng,
profiler,
max_global_steps,
tuning_dir_name,
Expand All @@ -545,7 +554,8 @@ def score_submission_on_workload(workload: spec.Workload,
logging.info(f'Total number of evals: {num_evals}')
logging.info('=' * 20)
else:
rng_seed = struct.unpack('q', os.urandom(8))[0]
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
# If the submission is responsible for tuning itself, we only need to run it
# once and return the total time.
Expand All @@ -554,7 +564,7 @@ def score_submission_on_workload(workload: spec.Workload,
workload, global_batch_size, global_eval_batch_size,
data_dir, imagenet_v2_data_dir,
init_optimizer_state, update_params, data_selection,
None, rng, profiler, max_global_steps, log_dir,
None, rng_seed, rng, profiler, max_global_steps, log_dir,
save_checkpoints=save_checkpoints)
return score

Expand Down Expand Up @@ -610,7 +620,8 @@ def main(_):
tuning_search_space=FLAGS.tuning_search_space,
num_tuning_trials=FLAGS.num_tuning_trials,
log_dir=logging_dir_path,
save_checkpoints=FLAGS.save_checkpoints)
save_checkpoints=FLAGS.save_checkpoints,
rng_seed=FLAGS.rng_seed)
logging.info(f'Final {FLAGS.workload} score: {score}')

if FLAGS.profile:
Expand Down

0 comments on commit 5c48b2b

Please sign in to comment.