Skip to content

Commit

Permalink
Dim dim value precomputed more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Oct 10, 2023
1 parent 4b8fd5d commit 2d29957
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
24 changes: 12 additions & 12 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,18 +1830,6 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
if self._dyn_size_max_value is not None: # fast path, precomputed
assert self._dyn_size_max_value.raw_tensor is not None
return self._dyn_size_max_value
if self.is_batch_dim():
res = None
if self._extra and self._extra.src_data:
res = self._extra.src_data.get_batch_dim()
elif self.batch:
res = self.batch.dim
if isinstance(res, int):
return res
if res is not None:
res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
self._dyn_size_max_value = res
return res
if (
self._extra
and self._extra.src_data is not None
Expand Down Expand Up @@ -1876,6 +1864,18 @@ def get_dim_value_tensor(self) -> Union[int, _t.Tensor]:
assert res.raw_tensor is not None
self._dyn_size_max_value = res
return res
if self.is_batch_dim():
res = None
if self._extra and self._extra.src_data:
res = self._extra.src_data.get_batch_dim()
elif self.batch:
res = self.batch.dim
if isinstance(res, int):
return res
if res is not None:
res = _t.Tensor("batch", dims=(), dtype=rf.get_default_array_index_dtype(), raw_tensor=res)
self._dyn_size_max_value = res
return res
raise Exception("%s: need placeholder, self.dimension or self.dyn_size for dim value" % self)

def axis_split_info(self):
Expand Down
1 change: 1 addition & 0 deletions returnn/tf/frontend_layers/config_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _cleanup_net_dict_value(elem):
# Reset it now. The TF engine should redefine it again.
if elem.dyn_size_ext:
elem.dyn_size_ext.raw_tensor = None
elem._dyn_size_max_value = None
return elem

# Do some cleanup.
Expand Down
4 changes: 2 additions & 2 deletions returnn/tf/frontend_layers/make_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,15 @@ def register_extern_data(data: Tensor[rfl.Layer]):
# Undefined dynamic dim tag. Set default data template.
orig_tag.dyn_size_ext = tag.dyn_size_ext = Tensor(
name=f"{tag.name or (data.name + f'[{i}]')}_default_dyn_size_ext",
dim_tags=[batch_dim],
dims=[batch_dim],
dtype=data.size_dtype,
batch=data.batch,
)
if tag.is_batch_dim() and not tag.dyn_size_ext and tag.dimension is None:
# Undefined batch dim tag. Set default data template.
batch_dim.dyn_size_ext = orig_tag.dyn_size_ext = tag.dyn_size_ext = Tensor(
name=f"batch_dim_default_dyn_size_ext",
dim_tags=[],
dims=[],
dtype=data.size_dtype,
batch=data.batch,
)
Expand Down

0 comments on commit 2d29957

Please sign in to comment.