Skip to content

Commit

Permalink
Dim is_static is_dynamic more reasonable (#1438)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz authored Oct 17, 2023
1 parent 21a8923 commit 2dfe3f6
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 36 deletions.
2 changes: 1 addition & 1 deletion returnn/frontend/array_.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def pack_padded(
"""
assert not enforce_sorted # not implemented yet...
assert len(dims) > 0
dyn_dims = [d for d in dims if d.is_dynamic()]
dyn_dims = [d for d in dims if d.need_masking()]
assert len(dyn_dims) == 1 # not implemented otherwise yet...
mask = source.get_sequence_mask_tensor(source.get_axis_from_description(dyn_dims[0]))
return rf.masked_select(source, mask=mask, dims=dims, out_dim=out_dim)
Expand Down
2 changes: 1 addition & 1 deletion returnn/frontend/run_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def _default_dim_order(tensor: Tensor) -> Sequence[Dim]:
if tensor.have_time_axis():
rem_dims.remove(tensor.get_time_dim_tag())
dims.append(tensor.get_time_dim_tag())
dyn_dims = [d for d in rem_dims if d.is_dynamic()]
dyn_dims = [d for d in rem_dims if d.is_dynamic_seq_length()]
if len(dyn_dims) > 1:
raise Exception(
f"Cannot infer order of dims automatically for output {tensor}. Please specify `dims` explicitly."
Expand Down
32 changes: 21 additions & 11 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def get_for_batch_ctx(
if self.batch == batch:
return self
return batch.batch_dim_tag
if not self.is_dynamic():
if self.is_static():
# If static dim, no effect.
assert not self.batch
return self
Expand Down Expand Up @@ -896,7 +896,7 @@ def is_dim_known(self):
"""
if self.is_batch_dim():
return True
if not self.is_dynamic() and self.dimension is not None:
if not self.dyn_size_ext and self.dimension is not None:
return True
if self.dyn_size_ext:
return True
Expand All @@ -915,8 +915,8 @@ def is_dim_known_in_batch_ctx(self: Dim, batch: BatchInfo, ctx: Optional[Control

if self.is_batch_dim():
return True
if not self.is_dynamic():
return self.dimension is not None
if self.is_static():
return True
dim = self.get_for_batch_ctx(batch=batch, ctx=ctx, allow_none=True)
if dim:
return bool(dim.dyn_size_ext)
Expand All @@ -930,17 +930,25 @@ def is_dim_known_in_batch_ctx(self: Dim, batch: BatchInfo, ctx: Optional[Control
return True
return False

def is_dynamic_seq_length(self) -> bool:
"""
:return: whether the dim is not static. usually means that it has seq lengths
"""
return self.dimension is None and (
(self.dyn_size_ext and self.dyn_size_ext.dims) or (not self.dyn_size_ext and not self.is_batch_dim())
)

def is_dynamic(self) -> bool:
"""
:return: whether the dim is not static. usually means that it has seq lengths
"""
return self.dimension is None and not self.is_batch_dim()
return self.dimension is None

def is_static(self) -> bool:
"""
:return: static
"""
return not self.is_dynamic()
return self.dimension is not None

def need_masking(self):
"""
Expand All @@ -952,8 +960,10 @@ def need_masking(self):
return False
if self.capacity is not None:
return True
if not self.dyn_size_ext:
return True # unknown
if not self.dyn_size_ext: # unknown, so we can only guess
if self.is_batch_dim():
return False
return True
return self.dyn_size_ext.batch_ndim > 0

def can_be_used_as_dim(self):
Expand Down Expand Up @@ -1078,7 +1088,7 @@ def complete_dyn_size(self, *, template_only=False, _backend=None):
:param bool template_only:
:param _backend:
"""
if not self.is_dynamic():
if self.is_static():
return
self._validate_in_current_graph()
if self.dyn_size_ext and (self.dyn_size_ext.placeholder is not None or template_only):
Expand Down Expand Up @@ -1719,7 +1729,7 @@ def declare_same_as(self: _d.Dim, other: _d.Dim):
other_same_base._make_extra().derived_from_op = self.derived_from_op
elif other_same_base.derived_from_op and not self.derived_from_op:
self._make_extra().derived_from_op = other_same_base.derived_from_op
if self._extra and not other_same_base.is_dynamic():
if self._extra and other_same_base.is_static():
# Those might be set via get_batch_for_ctx for an undefined dim,
# which now becomes static due to `other`.
self._extra.batch = None
Expand Down Expand Up @@ -2815,7 +2825,7 @@ def _representative_tag(terms: Sequence[Dim]) -> Optional[Dim]:
# Also see _OpLinearTerm.representative_tag().
# First find any dynamic.
for term_ in terms:
if term_.is_dynamic():
if term_.is_dynamic_seq_length():
return term_
# Now find non-unspecified.
for term_ in terms:
Expand Down
2 changes: 1 addition & 1 deletion returnn/tensor/_tensor_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -3131,7 +3131,7 @@ def get_dyn_size_tags(self):
:return: all dim tags with dynamic size
:rtype: list[Dim]
"""
return [dim_tag for dim_tag in self._dims if dim_tag.is_dynamic()]
return [dim_tag for dim_tag in self._dims if dim_tag.is_dynamic_seq_length()]

def get_size_dim_tag(self, number):
"""
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/frontend_layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def make_net_dict_raw(self, net: Net, *, _stack: Optional[_StackInfo] = None) ->
)
dim_tags = list(data_template.dim_tags)
for dim in dim_tags:
if dim.is_batch_dim() or not dim.is_dynamic():
if dim.is_batch_dim() or dim.is_static():
continue
# We need dyn_size_ext to know the implicit dims, to correctly set out_shape.
# If dyn_size_ext is not set yet, try to complete it.
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/frontend_low_level/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def reduce(source: _TT, *, mode: str, axis: Union[Dim, Sequence[Dim]], use_mask:
i, d = [
(i, d) for i, d in enumerate(size_actual.dim_tags) if d not in out_data.dim_tags
][0]
assert not d.is_dynamic() # not implemented
assert not d.need_masking() # not implemented
size_all *= d.get_dim_value()
s = tf.reduce_sum(size_actual.placeholder, axis=i)
size_actual = size_actual.copy_template_excluding_axis(i)
Expand Down
24 changes: 12 additions & 12 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2388,7 +2388,7 @@ def get_out_data_from_opts(
sparse=sparse,
dim=None if sparse else NotSpecified,
)
if not dim.is_dynamic(): # static
if dim.is_static():
return Data(
name="%s_static_dim" % name,
dim_tags=(),
Expand Down Expand Up @@ -2455,15 +2455,15 @@ def __init__(
energy_data = self.input_data
assert energy_data.dtype.startswith("float")
axis = self._get_axis_to_reduce(input_data=energy_data, axis=axis, exception_prefix=self)
if not energy_data.dim_tags[axis].is_dynamic():
if energy_data.dim_tags[axis].is_static():
self.recurrent = False
# tf.nn.softmax operates on the last axis.
energy_data = energy_data.copy_move_axis(axis, -1)
energy = energy_data.placeholder
axis = energy_data.batch_ndim - 1
# if the time-axis is static, we can skip the masking
if use_time_mask is None:
use_time_mask = energy_data.is_axis_dynamic(axis)
use_time_mask = energy_data.dims[axis].need_masking()
if start or window_start is not None or window_size is not None:
assert use_time_mask
if use_time_mask:
Expand Down Expand Up @@ -4725,9 +4725,9 @@ def __init__(self, axis, dims, pad_to_multiples=None, pad_value=0, **kwargs):
if isinstance(axis, int):
data = data.copy_as_batch_major()
axis = data.get_axis_from_description(axis)
old_dim = data.dim_tags[axis]
old_dim: Dim = data.dim_tags[axis]
if pad_to_multiples is None:
pad_to_multiples = data.is_axis_dynamic(axis)
pad_to_multiples = old_dim.is_dynamic()

from returnn.tf.util.basic import get_shape

Expand Down Expand Up @@ -4761,7 +4761,7 @@ def __init__(self, axis, dims, pad_to_multiples=None, pad_value=0, **kwargs):
rem_const_size = None
if len(new_pos_dims) == len(dims) - 1:
rem_const_size = util.prod(new_pos_dims)
assert not data.is_axis_dynamic(axis) or pad_to_multiples or rem_const_size == 1
assert old_dim.is_static() or pad_to_multiples or rem_const_size == 1
if pad_to_multiples and (not isinstance(rem_const_size, int) or rem_const_size != 1):
indices = [i for i, d in enumerate(dims) if isinstance(d, int) and d == -1]
assert len(indices) == 1, "%s: exactly one -1 dim in %r expected" % (self, dims)
Expand Down Expand Up @@ -4866,10 +4866,10 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
if isinstance(axis, int):
data = data.copy_as_batch_major()
axis = data.get_axis_from_description(axis)
axis_dim_tag: Dim = data.dim_tags[axis]
if pad_to_multiples is None:
pad_to_multiples = data.is_axis_dynamic(axis)
pad_to_multiples = axis_dim_tag.is_dynamic()

axis_dim_tag = data.dim_tags[axis]
rem_dim_indices = [
i
for i, d in enumerate(dims)
Expand Down Expand Up @@ -5145,7 +5145,7 @@ def __init__(self, axis="T", batch_major=True, **kwargs):
x = self.input_data
axis = x.get_axis_from_description(axis, allow_int=False)
assert axis != x.batch_dim_axis
if x.is_axis_dynamic(axis):
if x.dims[axis].need_masking():
if batch_major:
self.output.placeholder = tf_util.flatten_with_seq_len_mask(
x.placeholder,
Expand Down Expand Up @@ -7611,7 +7611,7 @@ def reduce(cls, input_data, mode, axes=None, keep_dims=False, enforce_batch_dim_
# We need to remove this.
# https://github.com/rwth-i6/returnn/issues/1242
i, d = [(i, d) for i, d in enumerate(size_actual.dim_tags) if d not in out_data.dim_tags][0]
assert not d.is_dynamic() # not implemented
assert not d.need_masking() # not implemented
size_all *= d.get_dim_value()
s = tf.reduce_sum(size_actual.placeholder, axis=i)
size_actual = size_actual.copy_template_excluding_axis(i)
Expand Down Expand Up @@ -10429,7 +10429,7 @@ def __init__(self, sorted_sequence, values, axis="T", side="left", **kwargs):
transposed_sorted_data = sorted_data.copy_transpose(perm=sorted_batch_axes + [sorted_axis]) # [B,T]
transposed_values_data = values_data.copy_transpose(perm=values_batch_axes + values_non_batch_axes) # [B,F]
x = transposed_sorted_data.placeholder # [B,T]
if transposed_sorted_data.is_axis_dynamic(axis=-1):
if transposed_sorted_data.dims[-1].need_masking():
from returnn.tf.util.basic import where_bc, sequence_mask

seq_mask = transposed_sorted_data.get_sequence_mask_broadcast(axis=-1)
Expand Down Expand Up @@ -13254,7 +13254,7 @@ def get_value(self):
error_signal = (
self.output.placeholder - self.align_layer.output.copy_compatible_to(self.output).placeholder
)
if self.output.is_time_axis_dynamic():
if self.output.get_time_dim_tag().need_masking():
seq_mask_bc = self.output.get_sequence_mask_broadcast()
error_signal = where_bc(seq_mask_bc, error_signal, 0.0)
if self.loss_wrt_to_act_in:
Expand Down
8 changes: 5 additions & 3 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,15 +1878,17 @@ def __call__(lself, name, is_prev_time_frame=False):
# Now we do the same logic more explicitly, to directly take over dim tags,
# and not work with size_placeholder.
earlier_dyn_dims = [
(i, d) for i, d in enumerate(earlier_layer_output.dims) if d.is_dynamic()
(i, d) for i, d in enumerate(earlier_layer_output.dims) if d.is_dynamic_seq_length()
]
new_dyn_dims = [
(i, d) for i, d in enumerate(layer_.output.dims) if d.is_dynamic_seq_length()
]
new_dyn_dims = [(i, d) for i, d in enumerate(layer_.output.dims) if d.is_dynamic()]
if len(earlier_dyn_dims) == len(new_dyn_dims):
out_dims = list(layer_.output.dims)
for (new_axis, new_dim), (old_axis, old_dim) in zip(new_dyn_dims, earlier_dyn_dims):
new_dim: Dim
old_dim: Dim
assert old_dim.is_dynamic() and new_dim.is_dynamic()
assert old_dim.is_dynamic_seq_length() and new_dim.is_dynamic_seq_length()
if new_dim.dyn_size_ext and new_dim.dyn_size_ext.raw_tensor is not None:
continue
if not old_dim.dyn_size_ext:
Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/util/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def copy_compatible_reduce(source, target, reduce_type):
return source.copy_compatible_to(target, check_sparse=False, check_dtype=False)
# extra_dims now contains dims only in source but not in target
for d in extra_dims:
assert not d.is_dynamic(), "%r, %r, cannot reduce dynamic dim %r (just not implemented here...)" % (
assert not d.need_masking(), "%r, %r, cannot reduce dynamic dim %r (just not implemented here...)" % (
source,
target,
d,
Expand Down
6 changes: 2 additions & 4 deletions tools/torch_export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,13 @@ def main():

dynamic_axes = {}
for k, v in list(extern_data.data.items()) + list(model_outputs.data.items()):
dynamic_axes[k] = {i: dim.name for i, dim in enumerate(v.dims) if dim.is_dynamic() or dim.is_batch_dim()}
dynamic_axes[k] = {i: dim.name for i, dim in enumerate(v.dims) if dim.is_dynamic()}
for i, dim in enumerate(v.dims):
if dim.dyn_size_ext and dim.dyn_size_ext.dims == ():
continue
if dim.dyn_size_ext:
dynamic_axes[f"{k}:size{i}"] = {
j: dim_.name
for j, dim_ in enumerate(dim.dyn_size_ext.dims)
if dim_.is_dynamic() or dim_.is_batch_dim()
j: dim_.name for j, dim_ in enumerate(dim.dyn_size_ext.dims) if dim_.is_dynamic()
}

print("*** Input names:", list(extern_data_raw.keys()))
Expand Down

0 comments on commit 2dfe3f6

Please sign in to comment.