Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Dec 5, 2023
1 parent b419962 commit 5303092
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def _reset_cuda_mem():

def train_once(
workload: spec.Workload,
workload_name: str,
global_batch_size: int,
global_eval_batch_size: int,
data_dir: str,
Expand Down Expand Up @@ -559,7 +560,8 @@ def score_submission_on_workload(workload: spec.Workload,
with profiler.profile('Train'):
if 'imagenet' not in workload_name:
imagenet_v2_data_dir = None
timing, metrics = train_once(workload, global_batch_size,
timing, metrics = train_once(workload, workload_name,
global_batch_size,
global_eval_batch_size,
data_dir, imagenet_v2_data_dir,
init_optimizer_state,
Expand Down Expand Up @@ -596,7 +598,7 @@ def score_submission_on_workload(workload: spec.Workload,
logger_utils.makedir(log_dir)
with profiler.profile('Train'):
score, _ = train_once(
workload, global_batch_size, global_eval_batch_size,
workload, workload_name, global_batch_size, global_eval_batch_size,
data_dir, imagenet_v2_data_dir,
init_optimizer_state, update_params, data_selection,
None, rng_seed, rng, profiler, max_global_steps, log_dir,
Expand Down

0 comments on commit 5303092

Please sign in to comment.