From 1b50c1d062fefa6be4d15c9adf76426ae1d3d95f Mon Sep 17 00:00:00 2001 From: Liam Hazan <68565394+liamhazan@users.noreply.github.com> Date: Sun, 11 Feb 2024 14:33:58 +0200 Subject: [PATCH] Crop padding in collate (#343) * update * update * return_type * update --- fuse/data/utils/collates.py | 38 +++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/fuse/data/utils/collates.py b/fuse/data/utils/collates.py index ea5a0049..a07b6f6b 100644 --- a/fuse/data/utils/collates.py +++ b/fuse/data/utils/collates.py @@ -190,3 +190,41 @@ def _pad_size(value: torch.Tensor, dim: int) -> List[int]: padded_values.append(padded_value) return default_collate(padded_values) + + @staticmethod + def crop_padding( + input_ids_list: List[torch.Tensor], pad_token_id: int + ) -> torch.Tensor: + """ + Crop padding of a batch of input_ids 1D tensors to the minimum length possible. + + Args: + input_ids_list (list of torch.Tensor): List of input_ids tensors, where each tensor represents a sequence. + pad_token_id (int): ID of the padding token used in input_ids tensors. + + Returns: + torch.Tensor: Batched and cropped input_ids tensor with padding removed to the maximum length. + + Example: + >>> input_ids_list = [ + ... torch.tensor([101, 2054, 2003, 0, 0, 0, 0, 0, 0, 0]), + ... torch.tensor([101, 2023, 2003, 1037, 1999, 0, 0, 0, 0, 0]), + ... torch.tensor([101, 2002, 0, 0, 0, 0, 0, 0, 0, 0]), + ... ] + >>> pad_token_id = 0 + >>> cropped_batch = crop_padding_to_max_length(input_ids_list, pad_token_id) + >>> print(cropped_batch) + tensor([[ 101, 2054, 2003, 0, 0], + [ 101, 2023, 2003, 1037, 1999], + [ 101, 2002, 0, 0, 0]]) + + Note: + This function assumes that the input_ids tensors are already padded, and it crops the sequences + to the minimum length by removing trailing padding tokens. + """ + min_length = min( + len(ids) - (ids == pad_token_id).sum().item() for ids in input_ids_list + ) + cropped_sequences = [ids[:min_length] for ids in input_ids_list] + batched_sequences = torch.stack(cropped_sequences, dim=0) + return batched_sequences