diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 38744716b..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] + if global_batch_size is not None: + assert global_batch_size >= current_batch_size, \ + 'global_batch_size must be larger than or equal to current_batch_size.' + # Always pad to global_batch_size if it is provided. + pad_to_global_batch_size = global_batch_size > current_batch_size + else: + pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -50,8 +57,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0: - x = pad(x, pad_size, 'jax', padding_value=padding_value) + if remainder_size != 0 or pad_to_global_batch_size: + x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -61,21 +68,13 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: spec.Tensor, +def pad(tensor: np.ndarray, pad_size: int, - framework: str, - padding_value: int = 0) -> spec.Tensor: - if len(tensor) > 1: + padding_value: int = 0) -> np.ndarray: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) - if framework == 'pytorch': - padding = torch.full( - pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) - padded_tensor = torch.cat((tensor, padding), dim=0) - elif framework == 'jax': - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) - else: - raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) return padded_tensor