Skip to content

Commit

Permalink
Tensor set_dynamic_size fix, make sure it is always set (#1429)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Oct 13, 2023
1 parent e208b9a commit f858440
Showing 1 changed file with 12 additions and 8 deletions.
20 changes: 12 additions & 8 deletions returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,24 +2717,28 @@ def set_dynamic_size(self, axis, sizes):
sizes_tag = Dim.get_tag_from_size_tensor(sizes)
if sizes_tag:
assert sizes_tag.is_same_size_tensor(sizes)
tag = self.dim_tags[axis]
tag = self._dims[axis]
assert tag.is_dynamic()
if tag.is_same_size_tensor(sizes):
return # nothing to do
if tag.dyn_size is None:
pass # nothing to do
elif tag.dyn_size is None:
if sizes_tag: # special rule for older code: overtake previous existing
assert sizes_tag.is_same_size_tensor(sizes)
self._dims = self._dims[:axis] + (sizes_tag,) + self._dims[axis + 1 :]
tag = sizes_tag
else:
# Assign now. This should also set the dim tag on sizes.
new_tag = tag.set_tag_on_size_tensor(sizes, batch=self.batch)
if new_tag is not tag:
self._dims = self._dims[:axis] + (new_tag,) + self._dims[axis + 1 :]
tag = tag.set_tag_on_size_tensor(sizes, batch=self.batch)
else:
# Reset to some new size.
# Use new dim tag, or previous existing attached to size.
assert sizes_tag, "%s: assign dyn sizes %s without defined dim tag" % (self, sizes)
self._dims = self._dims[:axis] + (sizes_tag,) + self._dims[axis + 1 :]
tag = sizes_tag
if self.batch:
tag = tag.get_for_batch_ctx(batch=self.batch, ctx=self.control_flow_ctx)
if tag is not self._dims[axis]:
self._dims = self._dims[:axis] + (tag,) + self._dims[axis + 1 :]
if tag.dyn_size is None:
tag.dyn_size = sizes

def get_dynamic_axes(self):
"""
Expand Down

0 comments on commit f858440

Please sign in to comment.