From beb1ff7e09cdabb10948966afa9ff74f0c8feb34 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 29 Sep 2023 02:11:41 +0000 Subject: [PATCH] padding --- algorithmic_efficiency/data_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 9efbd6e21..40e2d7542 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -37,8 +37,8 @@ def shard_and_maybe_pad_np( else: pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - # if pad_to_global_batch_size: - if remainder_size != 0: + if pad_to_global_batch_size: + # if remainder_size != 0: logging.info("PADDDINGGGGGGG") logging.info(f"current batch size {current_batch_size}") if global_batch_size is not None: @@ -63,8 +63,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - # if pad_to_global_batch_size: - if remainder_size != 0: + if pad_to_global_batch_size: + # if remainder_size != 0: logging.info("PADDDINGGGGGG in _prepare") logging.info(f"current batch size {current_batch_size}") x = pad(x, pad_size, padding_value=padding_value) @@ -73,7 +73,7 @@ def _prepare(x): # (local_device_count, per_device_batch_size, ...). # Assumes that `global_batch_size % local_device_count == 0`. return x.reshape((local_device_count, -1, *x.shape[1:])) - if remainder_size != 0: + if pad_to_global_batch_size != 0: print(batch.keys()) return jax.tree_map(_prepare, batch)