Skip to content

Commit

Permalink
Crop padding in collate (#343)
Browse files Browse the repository at this point in the history
* update

* update

* return_type

* update
  • Loading branch information
liamhazan authored Feb 11, 2024
1 parent 735f5cc commit 1b50c1d
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions fuse/data/utils/collates.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,41 @@ def _pad_size(value: torch.Tensor, dim: int) -> List[int]:
padded_values.append(padded_value)

return default_collate(padded_values)

@staticmethod
def crop_padding(
input_ids_list: List[torch.Tensor], pad_token_id: int
) -> torch.Tensor:
"""
Crop padding of a batch of input_ids 1D tensors to the minimum length possible.
Args:
input_ids_list (list of torch.Tensor): List of input_ids tensors, where each tensor represents a sequence.
pad_token_id (int): ID of the padding token used in input_ids tensors.
Returns:
torch.Tensor: Batched and cropped input_ids tensor with padding removed to the maximum length.
Example:
>>> input_ids_list = [
... torch.tensor([101, 2054, 2003, 0, 0, 0, 0, 0, 0, 0]),
... torch.tensor([101, 2023, 2003, 1037, 1999, 0, 0, 0, 0, 0]),
... torch.tensor([101, 2002, 0, 0, 0, 0, 0, 0, 0, 0]),
... ]
>>> pad_token_id = 0
>>> cropped_batch = crop_padding_to_max_length(input_ids_list, pad_token_id)
>>> print(cropped_batch)
tensor([[ 101, 2054, 2003, 0, 0],
[ 101, 2023, 2003, 1037, 1999],
[ 101, 2002, 0, 0, 0]])
Note:
This function assumes that the input_ids tensors are already padded, and it crops the sequences
to the minimum length by removing trailing padding tokens.
"""
min_length = min(
len(ids) - (ids == pad_token_id).sum().item() for ids in input_ids_list
)
cropped_sequences = [ids[:min_length] for ids in input_ids_list]
batched_sequences = torch.stack(cropped_sequences, dim=0)
return batched_sequences

0 comments on commit 1b50c1d

Please sign in to comment.