Skip to content

Commit

Permalink
padding
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 29, 2023
1 parent 65054cf commit beb1ff7
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def shard_and_maybe_pad_np(
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
# if pad_to_global_batch_size:
if remainder_size != 0:
if pad_to_global_batch_size:
# if remainder_size != 0:
logging.info("PADDDINGGGGGGG")
logging.info(f"current batch size {current_batch_size}")
if global_batch_size is not None:
Expand All @@ -63,8 +63,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
# if pad_to_global_batch_size:
if remainder_size != 0:
if pad_to_global_batch_size:
# if remainder_size != 0:
logging.info("PADDDINGGGGGG in _prepare")
logging.info(f"current batch size {current_batch_size}")
x = pad(x, pad_size, padding_value=padding_value)
Expand All @@ -73,7 +73,7 @@ def _prepare(x):
# (local_device_count, per_device_batch_size, ...).
# Assumes that `global_batch_size % local_device_count == 0`.
return x.reshape((local_device_count, -1, *x.shape[1:]))
if remainder_size != 0:
if pad_to_global_batch_size != 0:
print(batch.keys())
return jax.tree_map(_prepare, batch)

Expand Down

0 comments on commit beb1ff7

Please sign in to comment.