Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix shard_and_maybe_pad_np function #515

Merged
merged 3 commits into from
Sep 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions algorithmic_efficiency/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np(
inputs = batch['inputs']
current_batch_size = inputs[0].shape[0] if isinstance(
inputs, tuple) else inputs.shape[0]
if global_batch_size is not None:
assert global_batch_size >= current_batch_size, \
'global_batch_size must be larger than or equal to current_batch_size.'
# Always pad to global_batch_size if it is provided.
pad_to_global_batch_size = global_batch_size > current_batch_size
else:
pad_to_global_batch_size = False
remainder_size = current_batch_size % local_device_count
if remainder_size != 0:
if remainder_size != 0 or pad_to_global_batch_size:
if global_batch_size is not None:
pad_size = global_batch_size - current_batch_size
else:
Expand All @@ -50,8 +57,8 @@ def _prepare(x):
x = x._numpy() # pylint: disable=protected-access

# Pad if remainder_size != 0 (should only be possible during evaluation).
if remainder_size != 0:
x = pad(x, pad_size, 'jax', padding_value=padding_value)
if remainder_size != 0 or pad_to_global_batch_size:
x = pad(x, pad_size, padding_value=padding_value)

# Reshape (global_batch_size, ...) to
# (local_device_count, per_device_batch_size, ...).
Expand All @@ -61,21 +68,13 @@ def _prepare(x):
return jax.tree_map(_prepare, batch)


def pad(tensor: spec.Tensor,
def pad(tensor: np.ndarray,
pad_size: int,
framework: str,
padding_value: int = 0) -> spec.Tensor:
if len(tensor) > 1:
padding_value: int = 0) -> np.ndarray:
if tensor.ndim > 1:
pad_size = (pad_size, *tensor.shape[1:])
if framework == 'pytorch':
padding = torch.full(
pad_size, padding_value, dtype=tensor.dtype, device=tensor.device)
padded_tensor = torch.cat((tensor, padding), dim=0)
elif framework == 'jax':
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
else:
raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.')
padding = np.full(pad_size, padding_value, dtype=tensor.dtype)
padded_tensor = np.concatenate((tensor, padding), axis=0)
return padded_tensor


Expand Down