diff --git a/returnn/frontend/array_.py b/returnn/frontend/array_.py index e8efa3640f..348caffbb1 100644 --- a/returnn/frontend/array_.py +++ b/returnn/frontend/array_.py @@ -28,6 +28,7 @@ "pad", "cum_concat_step", "masked_select", + "sequence_mask", "pack_padded", "gather", "slice", @@ -352,6 +353,8 @@ def masked_select( tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None ) -> Tuple[Tensor, Dim]: """ + In TF, this is ``boolean_mask``. + :param tensor: :param mask: :param dims: the order of the dims defines the format. those dims should be exactly the dims of the mask. @@ -364,8 +367,19 @@ def masked_select( return tensor._raw_backend.masked_select(tensor, mask=mask, dims=dims, out_dim=out_dim) +def sequence_mask(dims: Sequence[Dim], *, device: Optional[str] = None) -> Tensor: + """ + :param dims: + :param device: + """ + assert len(dims) > 0 + dyn_dims = [d for d in dims if d.need_masking()] + assert len(dyn_dims) == 1 # not implemented otherwise yet... + return dyn_dims[0].get_mask(dim_order=dims, device=device) + + def pack_padded( - source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = True, out_dim: Optional[Dim] = None + source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = False, out_dim: Optional[Dim] = None ) -> Tuple[Tensor, Dim]: """ Like pack_padded_sequence. Usually the sequences are padded when they have different lengths. @@ -380,10 +394,7 @@ def pack_padded( :return: packed tensor, new packed dim """ assert not enforce_sorted # not implemented yet... - assert len(dims) > 0 - dyn_dims = [d for d in dims if d.need_masking()] - assert len(dyn_dims) == 1 # not implemented otherwise yet... - mask = source.get_sequence_mask_tensor(source.get_axis_from_description(dyn_dims[0])) + mask = rf.sequence_mask(dims, device=source.device) return rf.masked_select(source, mask=mask, dims=dims, out_dim=out_dim)