From 1d325371b7ca447bccfbafbd8fed149eab96dd9c Mon Sep 17 00:00:00 2001 From: Juhan Bae Date: Tue, 19 Sep 2023 14:51:49 -0400 Subject: [PATCH] minor --- .../criteo1tb/criteo1tb_pytorch/workload.py | 2 +- .../workloads/criteo1tb/input_pipeline.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 83a850c71..f24ddb392 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -22,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload): @property def eval_batch_size(self) -> int: - return 8192 + return 8192 * 8 def _per_example_sigmoid_binary_cross_entropy( self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor: diff --git a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py index 00163757a..b9f8a9399 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py +++ b/algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py @@ -127,16 +127,16 @@ def get_criteo1tb_dataset(split: str, # print("Amount2:") # print(len(list(ds))) - if is_training: - ds = ds.repeat() + # if is_training: + # ds = ds.repeat() ds = ds.prefetch(10) - if num_batches is not None: - ds = ds.take(num_batches) + # if num_batches is not None: + # ds = ds.take(num_batches) # We do not use ds.cache() because the dataset is so large that it would OOM. - if repeat_final_dataset: - ds = ds.repeat() + # if repeat_final_dataset: + # ds = ds.repeat() ds = map( functools.partial(