From 5303092a62cd2c4ae5e3ca28a563d114b4c03e6a Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 5 Dec 2023 06:28:26 +0000 Subject: [PATCH] fix --- submission_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 152da65fd..1bae622a0 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -184,6 +184,7 @@ def _reset_cuda_mem(): def train_once( workload: spec.Workload, + workload_name: str, global_batch_size: int, global_eval_batch_size: int, data_dir: str, @@ -559,7 +560,8 @@ def score_submission_on_workload(workload: spec.Workload, with profiler.profile('Train'): if 'imagenet' not in workload_name: imagenet_v2_data_dir = None - timing, metrics = train_once(workload, global_batch_size, + timing, metrics = train_once(workload, workload_name, + global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, @@ -596,7 +598,7 @@ def score_submission_on_workload(workload: spec.Workload, logger_utils.makedir(log_dir) with profiler.profile('Train'): score, _ = train_once( - workload, global_batch_size, global_eval_batch_size, + workload, workload_name, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, init_optimizer_state, update_params, data_selection, None, rng_seed, rng, profiler, max_global_steps, log_dir,