Skip to content

Commit

Permalink
reverse padding fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
priyakasimbeg committed Sep 28, 2023
1 parent fc07904 commit 1ea08b0
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,8 @@ def shard_and_maybe_pad_np(
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0 or pad_to_global_batch_size:
if remainder_size != 0:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
Expand All @@ -57,8 +50,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)
if remainder_size != 0:
x = pad(x, pad_size, 'jax', padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
Expand All @@ -68,13 +61,21 @@ def _prepare(x):
return jax.tree_map(_prepare, batch)


def pad(tensor: np.ndarray,
def pad(tensor: spec.Tensor,
pad_size: int,
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
framework: str,
padding_value: int = 0) -> spec.Tensor:
if len(tensor) > 1:
pad_size = (pad_size, *tensor.shape[1:])
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
if framework == 'pytorch':
padding = torch.full(
pad_size, padding_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat((tensor, padding), dim=0)
elif framework == 'jax':
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
else:
raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.')
return padded_tensor


Expand Down

0 comments on commit 1ea08b0

Please sign in to comment.