From 2f76cb9e324258fe2e974379faa2f5ed4702255a Mon Sep 17 00:00:00 2001 From: runame Date: Wed, 13 Sep 2023 15:56:25 +0200 Subject: [PATCH 1/3] Simplify pad function --- algorithmic_efficiency/data_utils.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 38744716b..96fc699c0 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -51,7 +51,7 @@ def _prepare(x): # Pad if remainder_size != 0 (should only be possible during evaluation). if remainder_size != 0: - x = pad(x, pad_size, 'jax', padding_value=padding_value) + x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -61,21 +61,13 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: spec.Tensor, +def pad(tensor: np.ndarray, pad_size: int, - framework: str, - padding_value: int = 0) -> spec.Tensor: + padding_value: int = 0) -> np.ndarray: if len(tensor) > 1: pad_size = (pad_size, *tensor.shape[1:]) - if framework == 'pytorch': - padding = torch.full( - pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) - padded_tensor = torch.cat((tensor, padding), dim=0) - elif framework == 'jax': - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) - else: - raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) return padded_tensor From 35c873615d65cc1c435bf2b2bd3dce51dcd17de4 Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 14 Sep 2023 18:27:20 +0200 Subject: [PATCH 2/3] Always pad to global_batch_size when it is provided --- algorithmic_efficiency/data_utils.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 96fc699c0..245d3768e 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] + if global_batch_size is not None: + assert global_batch_size >= current_batch_size, \ + 'global_batch_size must be larger than or equal to current_batch_size.' + # Always pad to global_batch_size if it is provided. + pad_to_global_batch_size = global_batch_size > current_batch_size + else: + pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -50,7 +57,7 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to From ad64fd18c9e3907f32347e49953a7067293fe36b Mon Sep 17 00:00:00 2001 From: runame Date: Thu, 14 Sep 2023 18:28:07 +0200 Subject: [PATCH 3/3] Fix pad_size in pad function --- algorithmic_efficiency/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 245d3768e..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -71,7 +71,7 @@ def _prepare(x): def pad(tensor: np.ndarray, pad_size: int, padding_value: int = 0) -> np.ndarray: - if len(tensor) > 1: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) padding = np.full(pad_size, padding_value, dtype=tensor.dtype) padded_tensor = np.concatenate((tensor, padding), axis=0)