Skip to content

Commit

Permalink
RF sequence_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 16, 2023
1 parent 48e7244 commit 3f6d1ed
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions returnn/frontend/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"pad",
"cum_concat_step",
"masked_select",
"sequence_mask",
"pack_padded",
"gather",
"slice",
Expand Down Expand Up @@ -352,6 +353,8 @@ def masked_select(
tensor: Tensor, *, mask: Tensor, dims: Sequence[Dim], out_dim: Optional[Dim] = None
) -> Tuple[Tensor, Dim]:
"""
In TF, this is ``boolean_mask``.
:param tensor:
:param mask:
:param dims: the order of the dims defines the format. those dims should be exactly the dims of the mask.
Expand All @@ -364,8 +367,19 @@ def masked_select(
return tensor._raw_backend.masked_select(tensor, mask=mask, dims=dims, out_dim=out_dim)


def sequence_mask(dims: Sequence[Dim], *, device: Optional[str] = None) -> Tensor:
"""
:param dims:
:param device:
"""
assert len(dims) > 0
dyn_dims = [d for d in dims if d.need_masking()]
assert len(dyn_dims) == 1 # not implemented otherwise yet...
return dyn_dims[0].get_mask(dim_order=dims, device=device)


def pack_padded(
source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = True, out_dim: Optional[Dim] = None
source: Tensor, *, dims: Sequence[Dim], enforce_sorted: bool = False, out_dim: Optional[Dim] = None
) -> Tuple[Tensor, Dim]:
"""
Like pack_padded_sequence. Usually the sequences are padded when they have different lengths.
Expand All @@ -380,10 +394,7 @@ def pack_padded(
:return: packed tensor, new packed dim
"""
assert not enforce_sorted # not implemented yet...
assert len(dims) > 0
dyn_dims = [d for d in dims if d.need_masking()]
assert len(dyn_dims) == 1 # not implemented otherwise yet...
mask = source.get_sequence_mask_tensor(source.get_axis_from_description(dyn_dims[0]))
mask = rf.sequence_mask(dims, device=source.device)
return rf.masked_select(source, mask=mask, dims=dims, out_dim=out_dim)


Expand Down

0 comments on commit 3f6d1ed

Please sign in to comment.