Skip to content

Commit

Permalink
Merge pull request #515 from runame/fix-padding
Browse files Browse the repository at this point in the history
Fix `shard_and_maybe_pad_np` function
  • Loading branch information
priyakasimbeg authored Sep 21, 2023
2 parents d3fcbb6 + ad64fd1 commit a6d06df
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np(
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0:
if remainder_size != 0 or pad_to_global_batch_size:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
Expand All @@ -50,8 +57,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0:
x = pad(x, pad_size, 'jax', padding_value=padding_value)
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
Expand All @@ -61,21 +68,13 @@ def _prepare(x):
return jax.tree_map(_prepare, batch)


def pad(tensor: spec.Tensor,
def pad(tensor: np.ndarray,
pad_size: int,
framework: str,
padding_value: int = 0) -> spec.Tensor:
if len(tensor) > 1:
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
if framework == 'pytorch':
padding = torch.full(
pad_size, padding_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat((tensor, padding), dim=0)
elif framework == 'jax':
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
else:
raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.')
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
return padded_tensor


Expand Down

0 comments on commit a6d06df

Please sign in to comment.