From 6ae497b6e553476d1c22d6b5b3c24dc3216c01ed Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 6 Oct 2023 17:12:01 +0200 Subject: [PATCH] RF backend make_output_tensor, TF-layers disallow transpose Fix #1410 --- returnn/frontend/_backend.py | 13 +++++++++++++ returnn/frontend/run_ctx.py | 4 ++-- returnn/tf/frontend_layers/_backend.py | 8 ++++++-- 3 files changed, 21 insertions(+), 4 deletions(-) diff --git a/returnn/frontend/_backend.py b/returnn/frontend/_backend.py index 1c6df783cf..46949b9a26 100644 --- a/returnn/frontend/_backend.py +++ b/returnn/frontend/_backend.py @@ -269,6 +269,19 @@ def transpose(tensor: Tensor, perm: Sequence[Union[Dim, int]], *, allow_int: boo out.raw_tensor = backend.transpose_raw(tensor.raw_tensor, perm_) return out + @staticmethod + def make_output_tensor(tensor: Tensor, dims: Sequence[Dim], *, name: str) -> Tensor: + """ + :param tensor: + :param dims: + :param name: + :return: tensor with dims order like in dims + """ + # noinspection PyProtectedMember + tensor = tensor._raw_backend.transpose(tensor, dims, allow_int=False) + tensor = tensor.copy(name=name) + return tensor + @staticmethod def expand_dims_raw(raw_tensor: T, axis: int) -> T: """ diff --git a/returnn/frontend/run_ctx.py b/returnn/frontend/run_ctx.py index 0933d25ad0..164a5c0702 100644 --- a/returnn/frontend/run_ctx.py +++ b/returnn/frontend/run_ctx.py @@ -237,8 +237,8 @@ def mark_as_output(self, tensor: Union[Tensor, Any], name: str, *, dims: Optiona # We try some reasonable defaults, specifically: BTF or BF. dims = _default_dim_order(tensor) assert set(dims) == set(tensor.dims), f"mark_as_output: tensor {tensor} does not have the dims {dims}" - tensor = tensor.copy_transpose(dims, allow_int=False) - tensor = tensor.copy(name=name) + # noinspection PyProtectedMember + tensor = tensor._raw_backend.make_output_tensor(tensor, dims, name=name) assert name not in self.outputs.data self.outputs.data[name] = tensor diff --git a/returnn/tf/frontend_layers/_backend.py b/returnn/tf/frontend_layers/_backend.py index a90639a45c..170f3506c3 100644 --- a/returnn/tf/frontend_layers/_backend.py +++ b/returnn/tf/frontend_layers/_backend.py @@ -180,8 +180,12 @@ def reshape_raw(raw_tensor: Layer, shape: Union[Sequence[Union[int, Layer]], Lay @staticmethod def transpose(tensor: Tensor, perm: Sequence[Union[Dim, int]], *, allow_int: bool = False) -> Tensor: """transpose""" - assert all(isinstance(d, Dim) for d in perm) # axis as int not supported - return rfl.make_layer({"class": "transpose", "from": tensor, "perm": perm}, name="transpose") + raise Exception("TF-layers backend: order of dims is irrelevant") + + @staticmethod + def make_output_tensor(tensor: Tensor, dims: Sequence[Dim], *, name: str) -> Tensor: + """only func where we have explicitly defined dim order in the output""" + return rfl.make_layer({"class": "transpose", "from": tensor, "perm": dims}, name=name) @staticmethod def copy(tensor: Tensor) -> Tensor: