Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
pomonam committed Sep 19, 2023
1 parent 1d82836 commit 1d32537
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):

@property
def eval_batch_size(self) -> int:
return 8192
return 8192 * 8

def _per_example_sigmoid_binary_cross_entropy(
self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
Expand Down
12 changes: 6 additions & 6 deletions algorithmic_efficiency/workloads/criteo1tb/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,16 @@ def get_criteo1tb_dataset(split: str,
# print("Amount2:")
# print(len(list(ds)))

if is_training:
ds = ds.repeat()
# if is_training:
# ds = ds.repeat()
ds = ds.prefetch(10)

if num_batches is not None:
ds = ds.take(num_batches)
# if num_batches is not None:
# ds = ds.take(num_batches)

# We do not use ds.cache() because the dataset is so large that it would OOM.
if repeat_final_dataset:
ds = ds.repeat()
# if repeat_final_dataset:
# ds = ds.repeat()

ds = map(
functools.partial(
Expand Down

0 comments on commit 1d32537

Please sign in to comment.