From 3f6d1ed7d8e4a84a65ba57a97541dcb25cc036f1 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 16 Nov 2023 16:47:18 +0000 Subject: [PATCH] RF sequence_mask --- returnn/frontend/array_.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index e8efa3640f..348caffbb1 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -28,6 +28,7 @@ "pad", "cum_concat_step", "masked_select", + "sequence_mask", "pack_padded", "gather", "slice", @@ -352,6 +353,8 @@ def masked_select( tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None ) -> Tuple[Tensor, Dim]: """ + In TF, this is ``boolean_mask``. + :param tensor: :param mask: :param dims: the order of the dims defines the format. those dims should be exactly the dims of the mask. @@ -364,8 +367,19 @@ def masked_select( return tensor._raw_backend.masked_select(tensor, mask=mask, dims=dims, out_dim=out_dim) +def sequence_mask(dims: Sequence[Dim], *, device: Optional[str] = None) -> Tensor: + """ + :param dims: + :param device: + """ + assert len(dims) > 0 + dyn_dims = [d for d in dims if d.need_masking()] + assert len(dyn_dims) == 1 # not implemented otherwise yet... + return dyn_dims[0].get_mask(dim_order=dims, device=device) + + def pack_padded( - source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = True, out_dim: Optional[Dim] = None + source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = False, out_dim: Optional[Dim] = None ) -> Tuple[Tensor, Dim]: """ Like pack_padded_sequence. Usually the sequences are padded when they have different lengths. @@ -380,10 +394,7 @@ def pack_padded( :return: packed tensor, new packed dim """ assert not enforce_sorted # not implemented yet... - assert len(dims) > 0 - dyn_dims = [d for d in dims if d.need_masking()] - assert len(dyn_dims) == 1 # not implemented otherwise yet... - mask = source.get_sequence_mask_tensor(source.get_axis_from_description(dyn_dims[0])) + mask = rf.sequence_mask(dims, device=source.device) return rf.masked_select(source, mask=mask, dims=dims, out_dim=out_dim)