From 2d29957fa31edbb9a67f3d7e804aecdd9d088675 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 10 Oct 2023 10:15:13 +0200 Subject: [PATCH] Dim dim value precomputed more fixes --- returnn/tensor/_dim_extra.py | 24 +++++++++---------- .../tf/frontend_layers/config_entry_points.py | 1 + returnn/tf/frontend_layers/make_layer.py | 4 ++-- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index bd838a1556..341da27c84 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -1830,18 +1830,6 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]: if self._dyn_size_max_value is not None: # fast path, precomputed assert self._dyn_size_max_value.raw_tensor is not None return self._dyn_size_max_value - if self.is_batch_dim(): - res = None - if self._extra and self._extra.src_data: - res = self._extra.src_data.get_batch_dim() - elif self.batch: - res = self.batch.dim - if isinstance(res, int): - return res - if res is not None: - res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res) - self._dyn_size_max_value = res - return res if ( self._extra and self._extra.src_data is not None @@ -1876,6 +1864,18 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]: assert res.raw_tensor is not None self._dyn_size_max_value = res return res + if self.is_batch_dim(): + res = None + if self._extra and self._extra.src_data: + res = self._extra.src_data.get_batch_dim() + elif self.batch: + res = self.batch.dim + if isinstance(res, int): + return res + if res is not None: + res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res) + self._dyn_size_max_value = res + return res raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self) def axis_split_info(self): diff --git a/returnn/tf/frontend_layers/config_entry_points.py b/returnn/tf/frontend_layers/config_entry_points.py index 97b8111f3b..4b6590780d 100644 --- a/returnn/tf/frontend_layers/config_entry_points.py +++ b/returnn/tf/frontend_layers/config_entry_points.py @@ -111,6 +111,7 @@ def _cleanup_net_dict_value(elem): # Reset it now. The TF engine should redefine it again. if elem.dyn_size_ext: elem.dyn_size_ext.raw_tensor = None + elem._dyn_size_max_value = None return elem # Do some cleanup. diff --git a/returnn/tf/frontend_layers/make_layer.py b/returnn/tf/frontend_layers/make_layer.py index ae17a0c109..3fa84d61cf 100644 --- a/returnn/tf/frontend_layers/make_layer.py +++ b/returnn/tf/frontend_layers/make_layer.py @@ -309,7 +309,7 @@ def register_extern_data(data: Tensor[rfl.Layer]): # Undefined dynamic dim tag. Set default data template. orig_tag.dyn_size_ext = tag.dyn_size_ext = Tensor( name=f"{tag.name or (data.name + f'[{i}]')}_default_dyn_size_ext", - dim_tags=[batch_dim], + dims=[batch_dim], dtype=data.size_dtype, batch=data.batch, ) @@ -317,7 +317,7 @@ def register_extern_data(data: Tensor[rfl.Layer]): # Undefined batch dim tag. Set default data template. batch_dim.dyn_size_ext = orig_tag.dyn_size_ext = tag.dyn_size_ext = Tensor( name=f"batch_dim_default_dyn_size_ext", - dim_tags=[], + dims=[], dtype=data.size_dtype, batch=data.batch, )