diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 2b4126149d..20cd15369e 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -367,7 +367,7 @@ def _base_get_out_data_from_opts(cls, network, name, allow_broadcast_all_sources = True sources_data = Data.get_common_data( sources_data_list, ignore_feature_dim=True, - allow_broadcast_all_sources=allow_broadcast_all_sources) if sources_data_list else None + allow_broadcast_all_sources=allow_broadcast_all_sources, name="%s_sources" % name) if sources_data_list else None if sources_data and not sources_data.sparse and not out_type.get("sparse", False): out_type.setdefault("dtype", sources_data.dtype) # You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly, diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 7839a4730f..16a4d4611f 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1051,7 +1051,7 @@ def __init__(self, start, size, min_size=None, out_spatial_dim=None, **kwargs): data_objs = [start_data] data_objs += [size.output] if isinstance(size, LayerBase) else [] data_objs += [seq_lens_data] if isinstance(seq_lens_data, Data) else [] - common_data = Data.get_common_data(data_objs) + common_data = Data.get_common_data(data_objs, name="%s_inputs") start_data = start_data.copy_compatible_to(common_data, check_sparse=False) start_t = start_data.placeholder if size is None: @@ -3049,7 +3049,7 @@ def _set_output_sizes(self, merge_axes): if not out_size: out_size = in_size else: - new_data = Data.get_common_data([out_size, in_size]) + new_data = Data.get_common_data([out_size, in_size], name="%s_output" % self.name) new_data.placeholder = ( out_size.copy_compatible_to(new_data).placeholder * in_size.copy_compatible_to(new_data).placeholder) @@ -5674,12 +5674,13 @@ def __init__(self, axis=None, out_spatial_dim=None, **kwargs): self.output.placeholder = tf.stack([src.placeholder for src in sources_], axis=axis) @classmethod - def _get_axis_and_common(cls, sources): + def _get_axis_and_common(cls, sources, name): """ :param list[LayerBase] sources: + :param str name: :rtype: (int,Data) """ - common_source = Data.get_common_data([src.output for src in sources]).copy_template() + common_source = Data.get_common_data([src.output for src in sources], name=name) dummy_tag = Dim(kind=Dim.Types.Spatial, dimension=1) return common_source.get_default_new_axis_for_dim_tag(dummy_tag), common_source @@ -5692,7 +5693,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None, :param Dim|None out_spatial_dim: :rtype: Data """ - axis_, common_source = cls._get_axis_and_common(sources) + axis_, common_source = cls._get_axis_and_common(sources, name="%s_sources" % name) if axis is None: axis = axis_ out = common_source.copy_template(name="%s_output" % name) @@ -6945,10 +6946,10 @@ def get_out_data_from_opts(cls, eval_locals=None, n_out=NotSpecified, sources=() allow_broadcast_all_sources = True out_type_.update( Data.get_common_data( - [s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs()) + [s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs(), + name="%s_output" % kwargs["name"]) if n_out is not NotSpecified: out_type_["dim"] = n_out - out_type_["name"] = "%s_output" % kwargs["name"] if out_type: if isinstance(out_type, dict): if "shape" in out_type: @@ -7196,14 +7197,14 @@ def get_out_data_from_opts(cls, n_out=NotSpecified, out_type=None, out_shape=Non allow_broadcast_all_sources = True out_type_.update( Data.get_common_data( - [s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs()) + [s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs(), + name="%s_output" % kwargs["name"]) if n_out is not NotSpecified: out_type_["dim"] = n_out elif out_type_.get("sparse", False): out_type_["dim"] = 2 out_type_["dtype"] = "bool" out_type_["vocab"] = None - out_type_["name"] = "%s_output" % kwargs["name"] if out_type: if isinstance(out_type, dict): out_type_.update(out_type) diff --git a/returnn/tf/util/data.py b/returnn/tf/util/data.py index 869093ba9b..365ac007c1 100644 --- a/returnn/tf/util/data.py +++ b/returnn/tf/util/data.py @@ -5039,11 +5039,12 @@ def get_batch_shape_dim_tags(self): return self.dim_tags @classmethod - def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_sources=NotSpecified): + def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_sources=NotSpecified, name=None): """ :param list[Data] sources: :param bool ignore_feature_dim: when set, the feature dim does not have to match in the sources :param bool|NotSpecified allow_broadcast_all_sources: + :param str name: :return: some generic data where the sources should be compatible to (with copy_compatible_to), i.e. it contains the union of all axes from all sources (least common multiple). This is always a template, and a new copy. @@ -5058,7 +5059,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_ common_batch = BatchInfo.get_common_batch_info([src.batch for src in sources if src.batch]) # Try with the (first) largest. common = [s for s in sources if s.batch_ndim == max_ndim][0] - common = common.copy_template() + common = common.copy_template(name=name) common.beam = None # this will be reset if common_batch: common.batch = common_batch.copy_set_beam(None) # the beam will be reset