From 039018e06a6b2091b585e2a7dcb916959bb8f441 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Fri, 29 Sep 2023 00:12:27 +0000 Subject: [PATCH] disable pad_to_global_batch_size check --- algorithmic_efficiency/data_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 14e3c7c6c..9d7014e6b 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -36,7 +36,8 @@ def shard_and_maybe_pad_np( else: pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0 or pad_to_global_batch_size: + # if remainder_size != 0 or pad_to_global_batch_size: + if remainder_size != 0: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -57,7 +58,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0 or pad_to_global_batch_size: + # if remainder_size != 0 or pad_to_global_batch_size: + if remainder_size != 0: x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to