Skip to content

Commit

Permalink
Add num_batch configs
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Sep 25, 2023
1 parent cc8b820 commit 86ad0af
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ def _build_input_queue(
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
del num_batches

not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)

Expand All @@ -149,6 +147,7 @@ def _build_input_queue(
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)
weights = None
while True:
Expand Down
3 changes: 2 additions & 1 deletion algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,13 @@ def _build_input_queue(
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
del cache
del num_batches
ds = input_pipeline.get_criteo1tb_dataset(
split=split,
shuffle_rng=data_rng,
data_dir=data_dir,
num_dense_features=self.num_dense_features,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)

for batch in iter(ds):
Expand Down Expand Up @@ -132,6 +132,7 @@ def _eval_model_on_split(self,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=True)
loss = 0.0
for _ in range(num_batches):
Expand Down

0 comments on commit 86ad0af

Please sign in to comment.