Skip to content

Commit

Permalink
get_common_data, add name parameter
Browse files Browse the repository at this point in the history
Before, the output name was just the first input's name.
This was very confusing when get_common_data errored because of implicit broadcasting:
Then you would think the error was in the first input,
while it was in the layer that combined the dims instead.
  • Loading branch information
Zettelkasten committed Dec 16, 2021
1 parent 31dfe3b commit ddfb7e3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
2 changes: 1 addition & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def _base_get_out_data_from_opts(cls, network, name,
allow_broadcast_all_sources = True
sources_data = Data.get_common_data(
sources_data_list, ignore_feature_dim=True,
allow_broadcast_all_sources=allow_broadcast_all_sources) if sources_data_list else None
allow_broadcast_all_sources=allow_broadcast_all_sources, name="%s_sources" % name) if sources_data_list else None
if sources_data and not sources_data.sparse and not out_type.get("sparse", False):
out_type.setdefault("dtype", sources_data.dtype)
# You are supposed to set self.output.{batch_dim_axis,time_dim_axis} explicitly,
Expand Down
19 changes: 10 additions & 9 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def __init__(self, start, size, min_size=None, out_spatial_dim=None, **kwargs):
data_objs = [start_data]
data_objs += [size.output] if isinstance(size, LayerBase) else []
data_objs += [seq_lens_data] if isinstance(seq_lens_data, Data) else []
common_data = Data.get_common_data(data_objs)
common_data = Data.get_common_data(data_objs, name="%s_inputs")
start_data = start_data.copy_compatible_to(common_data, check_sparse=False)
start_t = start_data.placeholder
if size is None:
Expand Down Expand Up @@ -3049,7 +3049,7 @@ def _set_output_sizes(self, merge_axes):
if not out_size:
out_size = in_size
else:
new_data = Data.get_common_data([out_size, in_size])
new_data = Data.get_common_data([out_size, in_size], name="%s_output" % self.name)
new_data.placeholder = (
out_size.copy_compatible_to(new_data).placeholder
* in_size.copy_compatible_to(new_data).placeholder)
Expand Down Expand Up @@ -5674,12 +5674,13 @@ def __init__(self, axis=None, out_spatial_dim=None, **kwargs):
self.output.placeholder = tf.stack([src.placeholder for src in sources_], axis=axis)

@classmethod
def _get_axis_and_common(cls, sources):
def _get_axis_and_common(cls, sources, name):
"""
:param list[LayerBase] sources:
:param str name:
:rtype: (int,Data)
"""
common_source = Data.get_common_data([src.output for src in sources]).copy_template()
common_source = Data.get_common_data([src.output for src in sources], name=name)
dummy_tag = Dim(kind=Dim.Types.Spatial, dimension=1)
return common_source.get_default_new_axis_for_dim_tag(dummy_tag), common_source

Expand All @@ -5692,7 +5693,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None,
:param Dim|None out_spatial_dim:
:rtype: Data
"""
axis_, common_source = cls._get_axis_and_common(sources)
axis_, common_source = cls._get_axis_and_common(sources, name="%s_sources" % name)
if axis is None:
axis = axis_
out = common_source.copy_template(name="%s_output" % name)
Expand Down Expand Up @@ -6945,10 +6946,10 @@ def get_out_data_from_opts(cls, eval_locals=None, n_out=NotSpecified, sources=()
allow_broadcast_all_sources = True
out_type_.update(
Data.get_common_data(
[s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs())
[s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs(),
name="%s_output" % kwargs["name"])
if n_out is not NotSpecified:
out_type_["dim"] = n_out
out_type_["name"] = "%s_output" % kwargs["name"]
if out_type:
if isinstance(out_type, dict):
if "shape" in out_type:
Expand Down Expand Up @@ -7196,14 +7197,14 @@ def get_out_data_from_opts(cls, n_out=NotSpecified, out_type=None, out_shape=Non
allow_broadcast_all_sources = True
out_type_.update(
Data.get_common_data(
[s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs())
[s.output for s in sources], allow_broadcast_all_sources=allow_broadcast_all_sources).get_kwargs(),
name="%s_output" % kwargs["name"])
if n_out is not NotSpecified:
out_type_["dim"] = n_out
elif out_type_.get("sparse", False):
out_type_["dim"] = 2
out_type_["dtype"] = "bool"
out_type_["vocab"] = None
out_type_["name"] = "%s_output" % kwargs["name"]
if out_type:
if isinstance(out_type, dict):
out_type_.update(out_type)
Expand Down
5 changes: 3 additions & 2 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5039,11 +5039,12 @@ def get_batch_shape_dim_tags(self):
return self.dim_tags

@classmethod
def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_sources=NotSpecified):
def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_sources=NotSpecified, name=None):
"""
:param list[Data] sources:
:param bool ignore_feature_dim: when set, the feature dim does not have to match in the sources
:param bool|NotSpecified allow_broadcast_all_sources:
:param str name:
:return: some generic data where the sources should be compatible to (with copy_compatible_to),
i.e. it contains the union of all axes from all sources (least common multiple).
This is always a template, and a new copy.
Expand All @@ -5058,7 +5059,7 @@ def get_common_data(cls, sources, ignore_feature_dim=False, allow_broadcast_all_
common_batch = BatchInfo.get_common_batch_info([src.batch for src in sources if src.batch])
# Try with the (first) largest.
common = [s for s in sources if s.batch_ndim == max_ndim][0]
common = common.copy_template()
common = common.copy_template(name=name)
common.beam = None # this will be reset
if common_batch:
common.batch = common_batch.copy_set_beam(None) # the beam will be reset
Expand Down

0 comments on commit ddfb7e3

Please sign in to comment.