Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RF cum_concat_step simplify and other RF things #1665

Merged
merged 13 commits into from
Dec 13, 2024
Merged
24 changes: 9 additions & 15 deletions returnn/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,21 +496,6 @@ def pad(
"""
raise NotImplementedError

@staticmethod
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
"""
Concatenates all previous frames over a time-axis.
See RETURNN :class:`CumConcatLayer` for details.

:param source: same dims as prev_accum except for the accum axis
:param prev_accum: previous accumulated tensor, shape {..., axis}
:param axis: the axis to accumulate over
:param out_spatial_dim: the spatial dim of the output will be this dim. like axis+1.
:return: accumulated. accumulated shape {..., out_spatial_dim},
same shape as prev_accum with axis replaced by out_spatial_dim.
"""
raise NotImplementedError

@staticmethod
def stack(sources: Sequence[Tensor], *, out_dim: Dim) -> Tensor:
"""
Expand Down Expand Up @@ -1095,6 +1080,15 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Dim) -> Tensor:
out.raw_tensor = source.raw_tensor
return out

@staticmethod
def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor:
"""set sparse dim"""
# This default implementation works fine as long as the backend
# does not have special treatments of Tensor and dim tags itself (like TF net dict backend).
out = source.copy()
out.sparse_dim = sparse_dim
return out

_AllowedReduceModes = {"sum", "max", "min", "mean", "logsumexp", "any", "all", "argmin", "argmax"}

@staticmethod
Expand Down
18 changes: 14 additions & 4 deletions returnn/frontend/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def concat(
*sources: Tuple[Tensor, Dim],
allow_broadcast: bool = False,
out_dim: Optional[Dim] = None,
handle_dynamic_dims: Optional[bool] = None,
) -> Tuple[Tensor, Dim]:
"""
Concatenates multiple sources in the specified dimension.
Expand All @@ -376,6 +377,7 @@ def concat(
:param sources: list of (tensor, dim) pairs. dim is the axis to concatenate on.
:param allow_broadcast: if True, the sources can have different dims, and the result will be broadcasted.
:param out_dim: reuse existing dim for the resulting concatenated dim, if given
:param handle_dynamic_dims:
:return: concatenated tensor, out_dim
"""
assert sources
Expand All @@ -385,6 +387,9 @@ def concat(
assert src.dims_set - {dim} == dims, f"concat {sources}, need allow_broadcast=True"
if not out_dim:
out_dim = sum(d for _, d in sources)
if handle_dynamic_dims is None or handle_dynamic_dims:
for src, dim in sources[:-1]:
assert dim.is_static(), f"concat {sources}, dim {dim} is not static, not yet implemented..."
# noinspection PyProtectedMember
return sources[0][0]._raw_backend.concat(*sources, allow_broadcast=allow_broadcast, out_dim=out_dim), out_dim

Expand Down Expand Up @@ -507,13 +512,18 @@ def cum_concat_step(
:return: (accumulated, out_spatial_dim). accumulated shape {..., out_spatial_dim},
same shape as prev_accum with axis replaced by out_spatial_dim.
"""
# Note: Before, we had a backend function just for this.
# In case of TF-layers, this was using CumConcatLayer.
# This would allow for automatic optimization when inside a RecLayer.
# However, we don't really need this for eager frameworks,
# and we want to simplify this for now,
# using pure RF code.
if not out_spatial_dim:
out_spatial_dim = axis + 1
# noinspection PyProtectedMember
return (
source._raw_backend.cum_concat_step(source, prev_accum=prev_accum, axis=axis, out_spatial_dim=out_spatial_dim),
out_spatial_dim,
out, (out_spatial_dim,) = rf.pad(
prev_accum, axes=[axis], padding=[(0, 1)], out_dims=[out_spatial_dim], value=source, handle_dynamic_dims=True
)
return out, out_spatial_dim
albertz marked this conversation as resolved.
Show resolved Hide resolved


def stack(sources: Sequence[Tensor], *, out_dim: Optional[Dim] = None) -> Tuple[Tensor, Dim]:
Expand Down
1 change: 1 addition & 0 deletions returnn/frontend/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,7 @@ def _make_indices(
indices, out_spatial_dim = rf.concat(
(q_pos_vec - query_spatial_dim_m1.get_dim_value_tensor(), query_spatial_dim_m1),
(kv_pos_vec, key_value_spatial_dim),
handle_dynamic_dims=False,
)
if query_offset is not None:
indices = indices - query_offset
Expand Down
11 changes: 11 additions & 0 deletions returnn/frontend/dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"range_over_dim_strided",
"range_over_merged_dims",
"replace_dim",
"set_sparse_dim",
"dim_match_priority_when_needed",
"num_elements_of_shape",
"masked_fraction_of_shape",
Expand Down Expand Up @@ -94,6 +95,16 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Optional[Dim] = None) -
return source._raw_backend.replace_dim(source, in_dim=in_dim, out_dim=out_dim), out_dim


def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor:
"""
:param source:
:param sparse_dim:
:return: source with sparse_dim set
"""
# noinspection PyProtectedMember
return source._raw_backend.set_sparse_dim(source, sparse_dim)


def dim_match_priority_when_needed(dim: Dim, *other_dims: Dim) -> Dim:
"""
:return: maybe copy of dim with higher match_priority if needed to distinguish from other_dims
Expand Down
22 changes: 8 additions & 14 deletions returnn/tf/frontend_layers/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,20 +375,6 @@ def pad(
name="pad",
)

@staticmethod
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
"""cum_concat_step"""
return rfl.make_layer(
{
"class": "cum_concat",
"from": source,
"state": {"state": prev_accum},
"out_spatial_dim": out_spatial_dim,
"axis": axis,
},
name="cum_concat",
)

@staticmethod
def activation(tensor: Tensor, func: str) -> Tensor:
"""activation"""
Expand Down Expand Up @@ -774,6 +760,14 @@ def replace_dim(source: Tensor, *, in_dim: Dim, out_dim: Dim) -> Tensor:
{"class": "reinterpret_data", "set_dim_tags": {in_dim: out_dim}, "from": source}, name="new_dim"
)

@staticmethod
def set_sparse_dim(source: Tensor, sparse_dim: Dim) -> Tensor:
"""set sparse dim"""
return rfl.make_layer(
{"class": "reinterpret_data", "set_sparse": True, "set_sparse_dim": sparse_dim, "from": source},
name="set_sparse_dim",
)

@staticmethod
def reduce(source: Tensor, *, mode: str, axis: Union[Dim, Sequence[Dim]], use_mask: bool = True) -> Tensor:
"""Reduce"""
Expand Down
41 changes: 20 additions & 21 deletions returnn/torch/frontend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,6 @@ def pad(
) -> Tensor:
"""pad"""
assert len(out_dims) == len(axes) == len(padding)
out = source.copy_template_new_dim_tags(
[out_dims[axes.index(dim)] if dim in axes else dim for dim in source.dim_tags], keep_special_axes=True
)
remaining_dims = set(axes)
raw_pad = []
for dim in reversed(source.dims):
Expand All @@ -469,10 +466,24 @@ def pad(
]
if not remaining_dims:
break
if isinstance(value, Tensor):
assert value.dims == (), f"value {value} must be a scalar"
value = value.raw_tensor
out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value)
# Use torch.nn.functional.pad if possible.
if (isinstance(value, Tensor) and value.dims == ()) or (not isinstance(value, Tensor)):
if isinstance(value, Tensor):
assert value.dims == ()
value = value.raw_tensor
out = source.copy_template_new_dim_tags(
[out_dims[axes.index(dim)] if dim in axes else dim for dim in source.dim_tags], keep_special_axes=True
)
out.raw_tensor = torch.nn.functional.pad(source.raw_tensor, pad=raw_pad, mode=mode, value=value)
else: # Fallback to concat.
assert isinstance(value, Tensor)
assert all(dim in source.dims and dim not in axes for dim in value.dims)
assert len(axes) == 1 # not implemented otherwise currently...
ext_dim = Dim(1, name="ext")
value_ext = rf.expand_dim(value, ext_dim)
out = TorchBackend.concat(
(source, axes[0]), (value_ext, ext_dim), allow_broadcast=True, out_dim=out_dims[0]
)
if any(dim.need_masking() for dim in out_dims) and handle_dynamic_dims:
if all(right == 0 for right in raw_pad[1::2]) and mode != "circular":
pass # no masking needed
Expand All @@ -490,24 +501,12 @@ def pad(
rf.copy_to_device((left + middle).dyn_size_ext, out.device),
)
out.raw_tensor = torch.where(
mask.copy_compatible_to(out, check_dtype=False, check_sparse=False).raw_tensor,
mask.copy_compatible_to_dims_raw(out.dims),
out.raw_tensor,
value,
value.copy_compatible_to_dims_raw(out.dims) if isinstance(value, Tensor) else value,
)
return out

@staticmethod
def cum_concat_step(source: Tensor, *, prev_accum: Tensor, axis: Dim, out_spatial_dim: Dim) -> Tensor:
"""cum concat step"""
out = prev_accum.copy_template_replace_dim_tag(
axis=prev_accum.get_axis_from_description(axis),
new_dim_tag=out_spatial_dim,
name=f"{source.name}/cum_concat_step",
)
source_raw = source.copy_compatible_to_dims_raw(prev_accum.dims)
out.raw_tensor = torch.cat((prev_accum.raw_tensor, source_raw), dim=prev_accum.get_axis_from_description(axis))
return out

@staticmethod
def stack(sources: Sequence[Tensor], *, out_dim: Dim) -> Tensor:
"""stack"""
Expand Down
40 changes: 40 additions & 0 deletions tests/test_rf_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,46 @@ def _forward_step(*, model: _Net, extern_data: TensorDict):
assert all(out_.raw_tensor[b, seq_len] == 1.0)


def test_pad_time_right_non_scalar():
time_dim = Dim(Tensor("time", [batch_dim], dtype="int32"))
in_dim = Dim(7, name="in")
extern_data = TensorDict(
{
"data": Tensor("data", [batch_dim, time_dim, in_dim], dtype="float32"),
"value": Tensor("value", [batch_dim], dtype="float32"),
}
)

# noinspection PyShadowingNames
def _forward_step(*, extern_data: TensorDict, **_kwargs):
data, value = extern_data["data"], extern_data["value"]
data.mark_as_output("data", shape=(batch_dim, time_dim, in_dim))
value.mark_as_output("value", shape=(batch_dim,))
out, (new_time,) = rf.pad(data, axes=[time_dim], padding=[(0, 1)], value=value)
out.mark_as_default_output(shape=(batch_dim, new_time, in_dim))

# TF-layers currently does not support this.
res = run_model(extern_data, lambda **_kwargs: rf.Module(), _forward_step, test_tensorflow=False)
data_: Tensor = res["data"]
value_: Tensor = res["value"]
out_: Tensor = res["output"]
assert data_.dims == (batch_dim, time_dim, in_dim)
new_time_dim = out_.dims[1]
assert out_.dims == (batch_dim, new_time_dim, in_dim) and new_time_dim != time_dim
assert time_dim.dyn_size_ext.dims == new_time_dim.dyn_size_ext.dims == (batch_dim,)
batch_size = batch_dim.get_dim_value()
assert batch_size > 1
assert len(set(time_dim.dyn_size_ext.raw_tensor)) > 1 # not all the same
for b in range(batch_size):
seq_len = time_dim.dyn_size_ext.raw_tensor[b]
new_seq_len = new_time_dim.dyn_size_ext.raw_tensor[b]
print(f"batch {b}, seq_len {seq_len}, new_seq_len {new_seq_len}")
assert new_seq_len == seq_len + 1
np.testing.assert_allclose(data_.raw_tensor[b, :seq_len], out_.raw_tensor[b, :seq_len])
print(out_.raw_tensor[b])
assert all(out_.raw_tensor[b, seq_len] == value_.raw_tensor[b])


def test_stack():
batch_dim_ = Dim(3, name="batch")
time_dim = Dim(5, name="time")
Expand Down
Loading