From 3fa6739c0540595f66313c934d637f36ac2d933e Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Tue, 15 Feb 2022 20:34:43 +0100 Subject: [PATCH 1/3] Dim auto_generated flag Allows for better Dim is_equal which does not rely on the description. #634 --- returnn/tf/layers/base.py | 2 +- returnn/tf/layers/basic.py | 86 ++++++++++++++------------ returnn/tf/layers/rec.py | 27 ++++---- returnn/tf/layers/signal_processing.py | 4 +- returnn/tf/util/data.py | 37 +++++++---- 5 files changed, 88 insertions(+), 68 deletions(-) diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 02910d5564..08394c2a4a 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -399,7 +399,7 @@ def _base_get_out_data_from_opts(cls, network, name, feature_dim_tag = out_dim else: dim = out_type.get("dim", None) - feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim) + feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim, auto_generated=True) if feature_dim_axis in (NotSpecified, None): if sources_data.feature_dim_axis is None: feature_dim_axis = len(dim_tags) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index edf78acc1e..0401c206ae 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1039,7 +1039,7 @@ def get_out_data_from_opts( if out_dim: assert out_dim.dimension == new_dim else: - out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim) + out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True) return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name) @@ -1219,7 +1219,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T out_spatial_dim = Dim( kind=Dim.Types.Spatial, description="sliced-time:%s" % name, - dimension=size) + dimension=size, auto_generated=True) gather_positions_data = gather_positions_data.copy_add_dim_by_tag( out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim) position = InternalLayer( @@ -2527,7 +2527,9 @@ def get_out_data_from_opts(cls, name, shape, dtype="float32", **kwargs): :param str dtype: :rtype: Data """ - dim_tags = [d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d) for i, d in enumerate(shape)] + dim_tags = [ + d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d, auto_generated=True) + for i, d in enumerate(shape)] return Data(name="%s_output" % name, dim_tags=dim_tags, dtype=dtype) @@ -2613,7 +2615,7 @@ def get_out_data_from_opts(cls, name, network, shape, maxval, minval=0, dtype="i elif isinstance(d, int): d = Dim( kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature, - description="%s:static:%i" % (name, i), + description="%s:static:%i" % (name, i), auto_generated=True, dimension=d) else: raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape)) @@ -2679,12 +2681,12 @@ def get_out_data_from_opts(cls, name, limit, start=0, delta=1, dtype=None, spars else: dtype = "int32" dim = len(range(start, limit, delta)) - tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name) + tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name, auto_generated=True) if out_spatial_dim: tag.declare_same_as(out_spatial_dim) sparse_dim = None if sparse: - sparse_dim = SpatialDim("%s:range-indices" % name) + sparse_dim = SpatialDim("%s:range-indices" % name, auto_generated=True) return Data(name="%s_output" % name, dim_tags=[tag], dtype=dtype, sparse_dim=sparse_dim) @@ -2799,7 +2801,7 @@ def get_out_data_from_opts(cls, name, sources, dtype="int32", sparse=False, out_ dim_tag = Dim.get_tag_from_size_tensor(source.placeholder) if not dim_tag: dim_tag = Dim( - kind=Dim.Types.Spatial, description="%s_input_len" % name, + kind=Dim.Types.Spatial, description="%s_input_len" % name, auto_generated=True, batch=source.batch, control_flow_ctx=source.control_flow_ctx, dyn_size_ext=source) if source.placeholder is not None: @@ -2940,7 +2942,7 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None, if out_dim: assert out_dim.dimension == dim else: - out_dim = FeatureDim("%s:gating" % name, dimension=dim) + out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True) if n_out is not NotSpecified: assert n_out == dim return Data( @@ -3079,7 +3081,7 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window filter_size=window_size, stride=stride, dilation_rate=1, padding=padding) out_spatial_dim = Dim( kind=Dim.Types.Spatial, description="%s:spatial" % name, - dimension=dim, derived_from_tag=in_spatial_dim, + dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True, batch=data.batch, control_flow_ctx=data.control_flow_ctx) data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim) new_dim_axis = axis + 1 # add new axis right after @@ -3087,7 +3089,7 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window assert window_dim.dimension == window_size else: window_dim = Dim( - kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size) + kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True) return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True) # noinspection PyMethodOverriding @@ -3585,9 +3587,11 @@ def _get_axis_size_splits_num_splits(cls, name, input_data, axis=None, err_prefix, out_dims, dim, input_data) if not out_dims: assert size_splits - out_dims = [Dim( - kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx), - dimension=size_splits[idx]) for idx in range(len(size_splits))] + out_dims = [ + Dim( + kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx), + dimension=size_splits[idx], auto_generated=True) + for idx in range(len(size_splits))] return axis, out_dims def _make_split_layer(self, idx): @@ -3867,6 +3871,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources kind=axis_dim_tag.kind, description="%s_split_dims%i_rem" % (name, rem_dim_idx), dimension=resolved_shape_dims[rem_dim_idx], + auto_generated=True, derived_from_tag=axis_dim_tag, batch=axis_dim_tag.batch, control_flow_ctx=axis_dim_tag.control_flow_ctx) if rem_dim.dimension is None and axis_dim_tag.dyn_size_ext is not None: @@ -3883,7 +3888,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources Dim( kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial, description="%s_split_dims%i" % (name, i), - dimension=shape_dim) + dimension=shape_dim, auto_generated=True) if rem_dim is None or i != rem_dim_idx else rem_dim for i, shape_dim in enumerate(resolved_shape_dims)) out_batch = data.batch @@ -4158,7 +4163,7 @@ def get_out_data_from_opts(cls, name, sources, num_axes, in_dim="T", out_dims=No assert not declare_same_sizes_as else: out_dims = [ - SpatialDim("%s:unflatten-nd:%i" % (name, i)) + SpatialDim("%s:unflatten-nd:%i" % (name, i), auto_generated=True) for i in range(num_axes)] if declare_same_sizes_as: for i, other in declare_same_sizes_as.items(): @@ -4238,7 +4243,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs): else: new_dim = Dim( kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial, - description="%s_expand_dims" % name, + description="%s_expand_dims" % name, auto_generated=True, dimension=dim) data = data.copy_template(name="%s_output" % name) data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis) @@ -4394,7 +4399,7 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None, if isinstance(repetitions, int): out_dim = tag * repetitions else: - out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag) + out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True) return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim) @@ -4804,12 +4809,12 @@ def map_axis_name(s): pass else: out.sparse_dim = Dim( - kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name) + kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name, auto_generated=True) if increase_sparse_dim: assert out.sparse out.sparse_dim = Dim( kind=Dim.Types.Feature, dimension=out.sparse_dim.dimension + 1, - description="%s:inc-sparse-dim" % name) + description="%s:inc-sparse-dim" % name, auto_generated=True) if batch_dim_base: out.batch = batch_dim_base.output.batch return out @@ -5086,7 +5091,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None, cls._check_defined_in_spatial_dims(len(in_spatial_dims) == 1) if input_expand_dims: for i in range(input_expand_dims): - dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1) + dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1, auto_generated=True) input_data = input_data.copy_add_dim_by_tag(dim_tag, unbroadcast=True) in_spatial_dims.append(dim_tag) if input_split_feature_dim: @@ -5263,10 +5268,10 @@ def get_out_data_from_opts( filter_size=filter_size[i], stride=strides[i], dilation_rate=dilation_rate[i], padding=padding) dim_tags.append(Dim( kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim, - derived_from_tag=old_tag, undefined=not old_tag)) + derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True)) if not out_dim: assert n_out - out_dim = FeatureDim("%s:channel" % name, dimension=n_out) + out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True) dim_tags.append(out_dim) feature_dim_axis = NotSpecified # Swap the dims if the input dim order doesn't fit the flag auto_use_channel_first. @@ -5767,10 +5772,10 @@ def get_out_data_from_opts(cls, name, sources, network, padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2 dim_tags.append(Dim( kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim, - derived_from_tag=old_tag, undefined=not old_tag)) + derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True)) if not out_dim: assert n_out - out_dim = FeatureDim("%s:channel" % name, dimension=n_out) + out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True) dim_tags.append(out_dim) return Data( name="%s_output" % name, dim_tags=dim_tags, @@ -5983,7 +5988,8 @@ def get_out_data_from_opts(cls, name, sources, mode="", axes=None, axis=None, ke out_time_dim_axis = x.time_dim_axis if keep_dims: for i in axes: - y_dim_tags[i] = Dim(kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i)) + y_dim_tags[i] = Dim( + kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i), auto_generated=True) else: if out_batch_dim_axis in axes: out_batch_dim_axis = None @@ -6184,7 +6190,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, out = common_source.copy_template(name="%s_output" % name) if not out_spatial_dim: out_spatial_dim = Dim( - kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources)) + kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True) assert out_spatial_dim.dimension == len(sources) out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True) return out @@ -6316,7 +6322,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding=None, size=None, ke dim_tags = list(data.dim_tags) for i, a in enumerate(axes): dim_tags[a] = Dim( - kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i]) + kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i], + auto_generated=True) data = data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True) else: assert all([d == 1 for d in res_dims]) @@ -6467,7 +6474,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base assert not out_dim out_dim = size_base.output.get_time_dim_tag() if not out_dim: - out_dim = (repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name)) + in_dim + out_dim = ( + repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim assert out_dim.dimension == out_dim_int x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim) if isinstance(repeat, LayerBase): @@ -6630,7 +6638,7 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs) data = data.copy_move_axis(old_axis=axis, new_axis=0) data = data.copy_with_batch_dim_axis(1) if not out_dim: - out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name) + out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name, auto_generated=True) data = data.copy_template_replace_dim_tag(axis=0, new_dim_tag=out_dim) data.time_dim_axis = 0 return data @@ -7126,7 +7134,7 @@ def find_axis(a_axis, b_axis): if not b_var_dims and add_var2_if_empty: b_var_dims.append( - SpatialDim("%s:dot:dummy-var2" % name, dimension=1)) + SpatialDim("%s:dot:dummy-var2" % name, dimension=1), auto_generated=True) dim_tags = list(a_rem_dims + a_var_dims + b_var_dims) return Data( @@ -7189,9 +7197,9 @@ def __init__(self, axis, amount, pad=True, adjust_size_info=True, **kwargs): self.output.size_placeholder[axis_wob] + size_delta, 0, tf.shape(shifted)[axis]) from ..util.data import Dim Dim( - kind=Dim.Types.Spatial, description="shift_axis", + kind=Dim.Types.Spatial, description="%s_shift_axis" % self.name, dyn_size=new_size, batch=self.output.batch, - src_data=self.output, src_axis=axis) + src_data=self.output, src_axis=axis, auto_generated=True) self.output.size_placeholder[axis_wob] = new_size @classmethod @@ -7210,7 +7218,7 @@ def get_out_data_from_opts(cls, name, amount, axis, pad, sources=(), **kwargs): axis = out.get_axis_from_description(axis) tag = out.dim_tags[axis] dim = None if tag.dimension is None else max(0, tag.dimension - abs(amount)) - tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim) + tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim, auto_generated=True) return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=tag) @@ -7319,7 +7327,7 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa if out_dim: assert out_dim.dimension == dim else: - out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim) + out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True) return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim) @@ -7410,7 +7418,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs) axis = out.get_axis_from_description(axis, allow_int=False) in_dim = out.dim_tags[axis] if not out_dim: - out_dim = Dim(kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim) + out_dim = Dim( + kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim, auto_generated=True) return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim) @@ -8582,18 +8591,17 @@ def get_out_data_from_opts(cls, name, network, elif isinstance(d, int): d = Dim( kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature, - description="%s:static:%i" % (name, i), + description="%s:static:%i" % (name, i), auto_generated=True, dimension=d) else: raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape)) dim_tags.append(d) if add_time_axis: dim_tags.insert( - 0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1)) + 0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True)) if add_batch_axis: dim_tags.insert( - 0, Dim( - kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info())) + 0, Dim(kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info())) return Data( name="%s_output" % name, dim_tags=dim_tags, dtype=dtype, batch=network.get_global_batch_info() if add_batch_axis else None) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 0e23aa2df1..6f64fbbb25 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -347,7 +347,7 @@ def transform_source_and_axis(cls, network, source_data=None, have_dyn_seq_len_e # However, there are cases such as the RecUnstackLayer which can also define the time dim. # Expect that we have a subnet. assert have_dyn_seq_len_end or (opts and isinstance(opts.get("unit"), dict)) - axis = SpatialDim("dyn-time:%s%s" % (network.get_absolute_name_prefix(), name)) + axis = SpatialDim("dyn-time:%s%s" % (network.get_absolute_name_prefix(), name), auto_generated=True) assert isinstance(axis, Dim) inside_rec_time_dim = network.get_inside_rec_time_dim(inside_loop=True) over_rec_time_dim = network.get_inside_rec_time_dim(inside_loop=False) @@ -477,7 +477,7 @@ def get_out_data_from_opts(cls, name, network, sources, unit, axis=None, out_dim if n_out is NotSpecified or not n_out: assert out_type and "dim" in out_type n_out = out_type["dim"] - out_dim = FeatureDim("%s:feature" % name, dimension=n_out) + out_dim = FeatureDim("%s:feature" % name, dimension=n_out, auto_generated=True) if out.have_feature_axis(): out = out.copy_template_replace_dim_tag(axis=out.feature_dim_axis, new_dim_tag=out_dim) else: @@ -3638,7 +3638,8 @@ def get_loop_acc_layer(name): else: time_dim_tag = Dim( kind=Dim.Types.Spatial, - description="dyn-time:%s/%s" % (self.parent_rec_layer.get_full_ctx_name(), search_choices)) + description="dyn-time:%s/%s" % (self.parent_rec_layer.get_full_ctx_name(), search_choices), + auto_generated=True) elif is_out_time_dim: self.time_dim_tag.declare_same_as(time_dim_tag) output = ( @@ -4424,7 +4425,7 @@ def get_out_data_from_opts(cls, n_out, name, sources=(), **kwargs): :rtype: Data """ sources_data = Data.get_common_data([src.output for src in sources if src], ignore_feature_dim=True) - feat = FeatureDim("%s:rnn_cell_feat" % name, dimension=n_out) + feat = FeatureDim("%s:rnn_cell_feat" % name, dimension=n_out, auto_generated=True) if sources_data and sources_data.have_time_axis(): dim_tags = (sources_data.get_time_dim_tag(), sources_data.get_batch_dim_tag(), feat) batch_dim_axis = 1 @@ -4833,7 +4834,7 @@ def get_out_data_from_opts(cls, name, sources, out_dim=None, n_out=None, **kwarg from returnn.tf.util.data import batch_dim if not out_dim: assert n_out - out_dim = FeatureDim("%s:hidden-out" % name, n_out) + out_dim = FeatureDim("%s:hidden-out" % name, n_out, auto_generated=True) out = Data("%s_output" % name, dim_tags=[batch_dim, out_dim]) out.beam = sources[0].output.beam out.batch = sources[0].output.batch @@ -6093,7 +6094,7 @@ def get_out_data_from_opts(cls, name, network, sources, beam_dim=None, **kwargs) assert beam, "no beam in %r" % data data.beam = None if beam_dim is None: - beam_dim = SpatialDim("beam:%s" % beam.name, beam.beam_size) + beam_dim = SpatialDim("beam:%s" % beam.name, beam.beam_size, auto_generated=True) assert beam_dim.dimension == beam.beam_size data = data.copy_add_dim_by_tag(beam_dim, unbroadcast=True, axis=1) return data @@ -6857,14 +6858,14 @@ def get_out_data_from_opts(cls, n_out, name, sources, **kwargs): import numpy out = sources[0].output.copy_as_batch_major().copy(name="%s_output" % name) batch_dim_tag = out.dim_tags[out.batch_dim_axis] - feat_tag = FeatureDim("%s_self_att_feat" % name, dimension=n_out) + feat_tag = FeatureDim("%s_self_att_feat" % name, dimension=n_out, auto_generated=True) if len(out.shape_dense) >= 2: if all(out.shape_dense[:-1]): time_dim = numpy.prod(out.shape[:-1]) else: time_dim = None time_tag = Dim( - kind=Dim.Types.Spatial, description="%s_self_att_time" % name, dimension=time_dim) + kind=Dim.Types.Spatial, description="%s_self_att_time" % name, dimension=time_dim, auto_generated=True) dim_tags = (batch_dim_tag, time_tag, feat_tag) else: dim_tags = (batch_dim_tag, feat_tag) @@ -7242,7 +7243,7 @@ def get_out_data_from_opts(cls, name, sources, vocab = Vocabulary(vocab_file=vocab_file, unknown_label=vocab_unknown_label) tag = Dim( kind=Dim.Types.Feature, description="%s_ken_lm_vocab" % name, - dimension=vocab.num_labels, vocab=vocab) + dimension=vocab.num_labels, vocab=vocab, auto_generated=True) data = data.copy_add_dim_by_tag(tag, axis=-1, unbroadcast=True) return data @@ -7808,7 +7809,7 @@ def _create_template(cls, name, network, sources, masked_from, unit, if not out_spatial_dim: out_spatial_dim = Dim( kind=Dim.Types.Spatial, description="%s:masked:time" % name, - derived_from_tag=source_data.get_time_dim_tag()) + derived_from_tag=source_data.get_time_dim_tag(), auto_generated=True) source_data = source_data.copy_template_replace_dim_tag( axis=0, new_dim_tag=out_spatial_dim) @@ -9225,7 +9226,7 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs): data = get_concat_sources_data_template(sources, name="%s_output" % name) # The result will be without batch dim. feature_dim_tag = Dim( - kind=Dim.Types.Feature, description="%s_rel_pos_enc_feat" % name, dimension=n_out) + kind=Dim.Types.Feature, description="%s_rel_pos_enc_feat" % name, dimension=n_out, auto_generated=True) if data.have_time_axis(): time_dim_tag = data.get_time_dim_tag() # TODO using same dim tag twice will not be supported at some future point... @@ -9233,9 +9234,9 @@ def get_out_data_from_opts(cls, name, sources, n_out, **kwargs): else: # length will be ``network.get_rec_step_index() + 1``. dummy_dim_tag = Dim( - kind=Dim.Types.Spatial, description="%s_rel_pos_enc_dummy" % name, dimension=1) + kind=Dim.Types.Spatial, description="%s_rel_pos_enc_dummy" % name, dimension=1, auto_generated=True) time_dim_tag = Dim( - kind=Dim.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None) + kind=Dim.Types.Spatial, description="%s_rel_pos_enc_time" % name, dimension=None, auto_generated=True) data = data.copy_template_new_dim_tags((dummy_dim_tag, time_dim_tag, feature_dim_tag)) return data diff --git a/returnn/tf/layers/signal_processing.py b/returnn/tf/layers/signal_processing.py index d76acb3fdf..c6ea7e6282 100644 --- a/returnn/tf/layers/signal_processing.py +++ b/returnn/tf/layers/signal_processing.py @@ -457,9 +457,9 @@ def _compute_size_placeholder(): new_size = nr_of_full_frames + nf_of_paded_frames from ..util.data import Dim Dim( - kind=Dim.Types.Spatial, description="MultiChannelMultiResolutionStft", + kind=Dim.Types.Spatial, description="%s:MultiChannelMultiResolutionStft" % self.name, dyn_size=new_size, batch=self.output.batch, - src_data=self.output, src_axis=self.output.get_batch_axis(0)) + src_data=self.output, src_axis=self.output.get_batch_axis(0), auto_generated=True) size_placeholder_dict[0] = new_size return size_placeholder_dict diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 70c94a8703..436e6a5933 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -58,6 +58,7 @@ def __init__(self, kind=Types.Unspecified, description=None, vocab=None, dyn_size=None, dyn_size_ext=None, undefined=False, generic=False, special=False, + auto_generated=False, match_priority=0, derived_from_tag=None, derived_from_op=None, batch=None, control_flow_ctx=None, @@ -75,6 +76,9 @@ def __init__(self, kind=Types.Unspecified, description=None, :param bool special: Like `generic`, this can not be a dim tag of :class:`Data`. But this dim tag also does not match anything except itself. So it can be used to represent special placeholders with special meanings like ``single_step``. + :param bool auto_generated: This is auto-generated by RETURNN because it was not explicitly specified by the user. + E.g. for ConvLayer and others. This implies certain behavior on equality, such as comparing the description, + to allow for several independent creations of the dim tag during template construction. :param Dim|None derived_from_tag: Whether this new tag is reduced, down/up sampled, padded etc from this given other tag. In situations where dim tags are being matched (Data.get_common_data), @@ -121,6 +125,7 @@ def __init__(self, kind=Types.Unspecified, description=None, self._undefined = undefined self.generic = generic self.special = special + self.auto_generated = auto_generated # We can have different tag variants per batch info (e.g. with beam), or per control flow ctx. # They each have same_as = self. The same_base should have the base (global) batch info. self._same_for_batch_ctx = {} # type: typing.Dict[typing.Tuple[BatchInfo,typing.Optional[ControlFlowContext]],Dim] # nopep8 @@ -375,6 +380,7 @@ def get_for_batch_ctx(self, batch, ctx, allow_none=False): return None dim_tag = Dim( kind=self.kind, description=self.description, dimension=self.dimension, + auto_generated=self.auto_generated, batch=batch, control_flow_ctx=dyn_size_ext.control_flow_ctx if dyn_size_ext else ctx, dyn_size_ext=dyn_size_ext) dim_tag.same_as = same_base @@ -739,7 +745,7 @@ def is_equal(self, other, ignore_feature_dim=False, allow_same_feature_dim=False # We currently use the description because the identity would not be the same # in case of template construction where a dim tag is once created for a template layer, # and then later again for the real layer. - if self.description == other.description: + if self.auto_generated and other.auto_generated and self.description == other.description: return True return False @@ -781,7 +787,9 @@ def __hash__(self): return hash(base) if self.derived_from_op: return hash(self.derived_from_op) - return hash((base.kind, base.dimension, base.description)) + if self.auto_generated: + return hash((base.kind, base.dimension, base.description)) + return hash(id(base)) def get_same_base(self): """ @@ -1168,7 +1176,8 @@ def _make_constant_static_dim(cls, value, kind=None): dimension=value, kind=kind or Dim.Types.Unspecified, description="unnamed_%sdim_%i" % (kind.name + "_" if kind else "", value), - derived_from_op=Dim.Op(kind="constant", inputs=[], attribs={"value": value})) + derived_from_op=Dim.Op(kind="constant", inputs=[], attribs={"value": value}), + auto_generated=True) def _is_constant_static_dim(self): return self.derived_from_op and self.derived_from_op.kind == "constant" @@ -2568,12 +2577,12 @@ def __init__(self, name, if sparse: assert dim is not NotSpecified, "need dim (num classes) if sparse" assert dim is None or isinstance(dim, int) - sparse_dim = Dim(kind=Dim.Types.Feature, dimension=dim, description="%s:sparse-dim" % name) + sparse_dim = Dim(kind=Dim.Types.Feature, dimension=dim, description="%s:sparse-dim" % name, auto_generated=True) else: sparse_dim = None if isinstance(sparse_dim, int): sparse_dim = Dim( - kind=Dim.Types.Feature, dimension=sparse_dim, description="%s:sparse-dim" % name) + kind=Dim.Types.Feature, dimension=sparse_dim, description="%s:sparse-dim" % name, auto_generated=True) if sparse_dim is not None: assert isinstance(sparse_dim, Dim) assert sparse_dim.can_be_used_as_dim() @@ -2735,7 +2744,7 @@ def template_from_constant(cls, x, name, dtype=None, shape=None, with_batch_dim= assert d == d_ d = Dim( kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature, - description="%s:static:%i" % (name, i), + description="%s:static:%i" % (name, i), auto_generated=True, dimension=d) else: raise TypeError("%r shape[%i] invalid type %r in shape %r" % (name, i, type(d), shape)) @@ -3401,7 +3410,8 @@ def copy_add_dim_by_tag(self, dim_tag, unbroadcast=False, axis=None): dim_tag = dim_tag.copy(same_as_self=True, kind=Dim.Types.Spatial) if not unbroadcast and dim_tag.dimension != 1: dim_tag = Dim( - kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1) + kind=dim_tag.kind, description="%s_dummy_dim1" % (dim_tag.description or "unnamed"), dimension=1, + auto_generated=True) data_opts["dim_tags"] = self.dim_tags[:axis] + (dim_tag,) + self.dim_tags[axis:] other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) for k, a in other_special_axes.items(): @@ -3436,11 +3446,11 @@ def copy_split_feature_dim(self, new_feature_dim): new_feature_dim_axis = self.feature_dim_axis + 1 data_opts = self.get_kwargs(include_special_axes=False) dim_tag_split_rem = Dim( - kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem, + kind=Dim.Types.Spatial, description="feature_split_rem_%i" % feature_dim_rem, auto_generated=True, dimension=feature_dim_rem) dim_tag_new = Dim( kind=self.dim_tags[self.feature_dim_axis].kind, - description="feature_split_new_%i" % new_feature_dim, + description="feature_split_new_%i" % new_feature_dim, auto_generated=True, dimension=new_feature_dim) dim_tags = ( self.dim_tags[:self.feature_dim_axis] + @@ -3606,7 +3616,7 @@ def copy_time_flattened(self): data_opts["placeholder"] = self.get_placeholder_time_flattened() dim_tag = self.dim_tags[self.time_dim_axis] dim_tag = Dim( - kind=Dim.Types.Spatial, description="%s_flattened" % (dim_tag.description or "unnamed")) + kind=Dim.Types.Spatial, description="%s_flattened" % (dim_tag.description or "unnamed"), auto_generated=True) data_opts["dim_tags"] = ( (dim_tag,) + tuple(tag for (i, tag) in enumerate(self.dim_tags) if i not in (self.batch_dim_axis, self.time_dim_axis))) @@ -3810,7 +3820,7 @@ def copy_template_adding_time_dim(self, name=None, time_dim_axis=0): assert time_dim_axis >= 0 assert 0 <= time_dim_axis <= self.batch_ndim kwargs = self.get_kwargs(include_special_axes=False) - dim_tag = Dim(kind=Dim.Types.Time, description="unknown_time", dimension=None) + dim_tag = Dim(kind=Dim.Types.Time, description="unknown_time", dimension=None, auto_generated=True) dim_tags = self.dim_tags[:time_dim_axis] + (dim_tag,) + self.dim_tags[time_dim_axis:] kwargs["dim_tags"] = dim_tags other_special_axes = self.get_special_axes_dict(counted_with_batch_dim=True, only_available=True) @@ -3864,6 +3874,7 @@ def copy_template_replace_dim(self, axis, new_dim, new_size=None): return self.copy_template() # nothing to do dim_tag = Dim( kind=dim_tag.kind, description="%s_replaced" % (dim_tag.description or "unnamed"), + auto_generated=True, dimension=new_dim, dyn_size=new_size) return self.copy_template_replace_dim_tag(axis=axis, new_dim_tag=dim_tag) @@ -5564,7 +5575,7 @@ def _infer_dim_tags_tuple_from_shape( if axis == feature_dim_axis and dyn_size is None and axis != time_dim_axis: tag = Dim( kind=Dim.Types.Feature, dimension=dim, description="feature:%s" % name, - undefined=dim is None) + undefined=dim is None, auto_generated=True) else: assert axis in spatial_axes description = "time" if axis == time_dim_axis else "spatial%i" % spatial_axes.index(axis) @@ -5579,7 +5590,7 @@ def _infer_dim_tags_tuple_from_shape( description += ":%s" % name tag = Dim( kind=Dim.Types.Spatial, description=description, dimension=dim, dyn_size=dyn_size, - undefined=dim is None and dyn_size is None) + undefined=dim is None and dyn_size is None, auto_generated=True) dim_tags[axis] = tag assert sorted(dim_tags.keys()) == list(range(len(batch_shape))) return tuple(dim_tags[axis] for axis in range(len(batch_shape))) From acf85fe1d2e110bfa0abb25de6f4247e00580196 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 16 Feb 2022 00:23:36 +0100 Subject: [PATCH 2/3] Dim is_equal small fix --- returnn/tf/util/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 436e6a5933..b1d29c63a2 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -697,15 +697,15 @@ def is_equal(self, other, ignore_feature_dim=False, allow_same_feature_dim=False if self.dimension != other.dimension: return False return self.kind == other.kind - if self.derived_from_op and other.derived_from_op: - if self.derived_from_op == other.derived_from_op: - return True if allow_same_spatial_dim is None: allow_same_spatial_dim = allow_same_feature_dim self_base = self.get_same_derived_base() if derived_matches else self.get_same_base() other_base = other.get_same_derived_base() if derived_matches else other.get_same_base() if self_base is other_base: return True + if self_base.derived_from_op and other_base.derived_from_op: + if self_base.derived_from_op == other_base.derived_from_op: + return True self_kind = self.kind other_kind = other.kind if self_kind == other_kind == self.Types.Feature and ignore_feature_dim: From 5e63218e508f6b68c7a176904c8a4daeb5be8cd5 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Wed, 16 Feb 2022 00:58:51 +0100 Subject: [PATCH 3/3] Dim auto_generated small fix --- returnn/tf/layers/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 0401c206ae..d383b5fe52 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -7134,7 +7134,7 @@ def find_axis(a_axis, b_axis): if not b_var_dims and add_var2_if_empty: b_var_dims.append( - SpatialDim("%s:dot:dummy-var2" % name, dimension=1), auto_generated=True) + SpatialDim("%s:dot:dummy-var2" % name, dimension=1, auto_generated=True)) dim_tags = list(a_rem_dims + a_var_dims + b_var_dims) return Data(