From f8584409aaef75e3d0ad53a42d643c06405477b5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 13 Oct 2023 14:34:56 +0200 Subject: [PATCH] Tensor set_dynamic_size fix, make sure it is always set (#1429) --- returnn/tensor/_tensor_extra.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/returnn/tensor/_tensor_extra.py b/returnn/tensor/_tensor_extra.py index f42cbb32c2..4725ad43ec 100644 --- a/returnn/tensor/_tensor_extra.py +++ b/returnn/tensor/_tensor_extra.py @@ -2717,24 +2717,28 @@ def set_dynamic_size(self, axis, sizes): sizes_tag = Dim.get_tag_from_size_tensor(sizes) if sizes_tag: assert sizes_tag.is_same_size_tensor(sizes) - tag = self.dim_tags[axis] + tag = self._dims[axis] assert tag.is_dynamic() if tag.is_same_size_tensor(sizes): - return # nothing to do - if tag.dyn_size is None: + pass # nothing to do + elif tag.dyn_size is None: if sizes_tag: # special rule for older code: overtake previous existing assert sizes_tag.is_same_size_tensor(sizes) - self._dims = self._dims[:axis] + (sizes_tag,) + self._dims[axis + 1 :] + tag = sizes_tag else: # Assign now. This should also set the dim tag on sizes. - new_tag = tag.set_tag_on_size_tensor(sizes, batch=self.batch) - if new_tag is not tag: - self._dims = self._dims[:axis] + (new_tag,) + self._dims[axis + 1 :] + tag = tag.set_tag_on_size_tensor(sizes, batch=self.batch) else: # Reset to some new size. # Use new dim tag, or previous existing attached to size. assert sizes_tag, "%s: assign dyn sizes %s without defined dim tag" % (self, sizes) - self._dims = self._dims[:axis] + (sizes_tag,) + self._dims[axis + 1 :] + tag = sizes_tag + if self.batch: + tag = tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx) + if tag is not self._dims[axis]: + self._dims = self._dims[:axis] + (tag,) + self._dims[axis + 1 :] + if tag.dyn_size is None: + tag.dyn_size = sizes def get_dynamic_axes(self): """