Skip to content

Commit

Permalink
Dim auto_generated flag (#950)
Browse files Browse the repository at this point in the history
Allows for better Dim is_equal
which does not rely on the description.

#634
  • Loading branch information
albertz authored Feb 16, 2022
1 parent 916d70d commit 3650515
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 71 deletions.
2 changes: 1 addition & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _base_get_out_data_from_opts(cls, network, name,
feature_dim_tag = out_dim
else:
dim = out_type.get("dim", None)
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim)
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim, auto_generated=True)
if feature_dim_axis in (NotSpecified, None):
if sources_data.feature_dim_axis is None:
feature_dim_axis = len(dim_tags)
Expand Down
86 changes: 47 additions & 39 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def get_out_data_from_opts(
if out_dim:
assert out_dim.dimension == new_dim
else:
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim)
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True)
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name)


Expand Down Expand Up @@ -1219,7 +1219,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T
out_spatial_dim = Dim(
kind=Dim.Types.Spatial,
description="sliced-time:%s" % name,
dimension=size)
dimension=size, auto_generated=True)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim)
position = InternalLayer(
Expand Down Expand Up @@ -2527,7 +2527,9 @@ def get_out_data_from_opts(cls, name, shape, dtype="float32", **kwargs):
:param str dtype:
:rtype: Data
"""
dim_tags = [d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d) for i, d in enumerate(shape)]
dim_tags = [
d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d, auto_generated=True)
for i, d in enumerate(shape)]
return Data(name="%s_output" % name, dim_tags=dim_tags, dtype=dtype)


Expand Down Expand Up @@ -2613,7 +2615,7 @@ def get_out_data_from_opts(cls, name, network, shape, maxval, minval=0, dtype="i
elif isinstance(d, int):
d = Dim(
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
description="%s:static:%i" % (name, i),
description="%s:static:%i" % (name, i), auto_generated=True,
dimension=d)
else:
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
Expand Down Expand Up @@ -2679,12 +2681,12 @@ def get_out_data_from_opts(cls, name, limit, start=0, delta=1, dtype=None, spars
else:
dtype = "int32"
dim = len(range(start, limit, delta))
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name)
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name, auto_generated=True)
if out_spatial_dim:
tag.declare_same_as(out_spatial_dim)
sparse_dim = None
if sparse:
sparse_dim = SpatialDim("%s:range-indices" % name)
sparse_dim = SpatialDim("%s:range-indices" % name, auto_generated=True)
return Data(name="%s_output" % name, dim_tags=[tag], dtype=dtype, sparse_dim=sparse_dim)


Expand Down Expand Up @@ -2799,7 +2801,7 @@ def get_out_data_from_opts(cls, name, sources, dtype="int32", sparse=False, out_
dim_tag = Dim.get_tag_from_size_tensor(source.placeholder)
if not dim_tag:
dim_tag = Dim(
kind=Dim.Types.Spatial, description="%s_input_len" % name,
kind=Dim.Types.Spatial, description="%s_input_len" % name, auto_generated=True,
batch=source.batch, control_flow_ctx=source.control_flow_ctx,
dyn_size_ext=source)
if source.placeholder is not None:
Expand Down Expand Up @@ -2940,7 +2942,7 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None,
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = FeatureDim("%s:gating" % name, dimension=dim)
out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True)
if n_out is not NotSpecified:
assert n_out == dim
return Data(
Expand Down Expand Up @@ -3079,15 +3081,15 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:spatial" % name,
dimension=dim, derived_from_tag=in_spatial_dim,
dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True,
batch=data.batch, control_flow_ctx=data.control_flow_ctx)
data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
new_dim_axis = axis + 1 # add new axis right after
if window_dim:
assert window_dim.dimension == window_size
else:
window_dim = Dim(
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size)
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True)
return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True)

# noinspection PyMethodOverriding
Expand Down Expand Up @@ -3585,9 +3587,11 @@ def _get_axis_size_splits_num_splits(cls, name, input_data, axis=None,
err_prefix, out_dims, dim, input_data)
if not out_dims:
assert size_splits
out_dims = [Dim(
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
dimension=size_splits[idx]) for idx in range(len(size_splits))]
out_dims = [
Dim(
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
dimension=size_splits[idx], auto_generated=True)
for idx in range(len(size_splits))]
return axis, out_dims

def _make_split_layer(self, idx):
Expand Down Expand Up @@ -3867,6 +3871,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
kind=axis_dim_tag.kind,
description="%s_split_dims%i_rem" % (name, rem_dim_idx),
dimension=resolved_shape_dims[rem_dim_idx],
auto_generated=True,
derived_from_tag=axis_dim_tag,
batch=axis_dim_tag.batch, control_flow_ctx=axis_dim_tag.control_flow_ctx)
if rem_dim.dimension is None and axis_dim_tag.dyn_size_ext is not None:
Expand All @@ -3883,7 +3888,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
Dim(
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
description="%s_split_dims%i" % (name, i),
dimension=shape_dim)
dimension=shape_dim, auto_generated=True)
if rem_dim is None or i != rem_dim_idx else rem_dim
for i, shape_dim in enumerate(resolved_shape_dims))
out_batch = data.batch
Expand Down Expand Up @@ -4158,7 +4163,7 @@ def get_out_data_from_opts(cls, name, sources, num_axes, in_dim="T", out_dims=No
assert not declare_same_sizes_as
else:
out_dims = [
SpatialDim("%s:unflatten-nd:%i" % (name, i))
SpatialDim("%s:unflatten-nd:%i" % (name, i), auto_generated=True)
for i in range(num_axes)]
if declare_same_sizes_as:
for i, other in declare_same_sizes_as.items():
Expand Down Expand Up @@ -4238,7 +4243,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
else:
new_dim = Dim(
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
description="%s_expand_dims" % name,
description="%s_expand_dims" % name, auto_generated=True,
dimension=dim)
data = data.copy_template(name="%s_output" % name)
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
Expand Down Expand Up @@ -4394,7 +4399,7 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None,
if isinstance(repetitions, int):
out_dim = tag * repetitions
else:
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag)
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True)
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)


Expand Down Expand Up @@ -4804,12 +4809,12 @@ def map_axis_name(s):
pass
else:
out.sparse_dim = Dim(
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name)
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name, auto_generated=True)
if increase_sparse_dim:
assert out.sparse
out.sparse_dim = Dim(
kind=Dim.Types.Feature, dimension=out.sparse_dim.dimension + 1,
description="%s:inc-sparse-dim" % name)
description="%s:inc-sparse-dim" % name, auto_generated=True)
if batch_dim_base:
out.batch = batch_dim_base.output.batch
return out
Expand Down Expand Up @@ -5086,7 +5091,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None,
cls._check_defined_in_spatial_dims(len(in_spatial_dims) == 1)
if input_expand_dims:
for i in range(input_expand_dims):
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1)
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1, auto_generated=True)
input_data = input_data.copy_add_dim_by_tag(dim_tag, unbroadcast=True)
in_spatial_dims.append(dim_tag)
if input_split_feature_dim:
Expand Down Expand Up @@ -5263,10 +5268,10 @@ def get_out_data_from_opts(
filter_size=filter_size[i], stride=strides[i], dilation_rate=dilation_rate[i], padding=padding)
dim_tags.append(Dim(
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
derived_from_tag=old_tag, undefined=not old_tag))
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
if not out_dim:
assert n_out
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
dim_tags.append(out_dim)
feature_dim_axis = NotSpecified
# Swap the dims if the input dim order doesn't fit the flag auto_use_channel_first.
Expand Down Expand Up @@ -5767,10 +5772,10 @@ def get_out_data_from_opts(cls, name, sources, network,
padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2
dim_tags.append(Dim(
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
derived_from_tag=old_tag, undefined=not old_tag))
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
if not out_dim:
assert n_out
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
dim_tags.append(out_dim)
return Data(
name="%s_output" % name, dim_tags=dim_tags,
Expand Down Expand Up @@ -5983,7 +5988,8 @@ def get_out_data_from_opts(cls, name, sources, mode="", axes=None, axis=None, ke
out_time_dim_axis = x.time_dim_axis
if keep_dims:
for i in axes:
y_dim_tags[i] = Dim(kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i))
y_dim_tags[i] = Dim(
kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i), auto_generated=True)
else:
if out_batch_dim_axis in axes:
out_batch_dim_axis = None
Expand Down Expand Up @@ -6184,7 +6190,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None,
out = common_source.copy_template(name="%s_output" % name)
if not out_spatial_dim:
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources))
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True)
assert out_spatial_dim.dimension == len(sources)
out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True)
return out
Expand Down Expand Up @@ -6316,7 +6322,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding=None, size=None, ke
dim_tags = list(data.dim_tags)
for i, a in enumerate(axes):
dim_tags[a] = Dim(
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i])
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i],
auto_generated=True)
data = data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True)
else:
assert all([d == 1 for d in res_dims])
Expand Down Expand Up @@ -6467,7 +6474,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base
assert not out_dim
out_dim = size_base.output.get_time_dim_tag()
if not out_dim:
out_dim = (repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name)) + in_dim
out_dim = (
repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim
assert out_dim.dimension == out_dim_int
x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim)
if isinstance(repeat, LayerBase):
Expand Down Expand Up @@ -6630,7 +6638,7 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
data = data.copy_move_axis(old_axis=axis, new_axis=0)
data = data.copy_with_batch_dim_axis(1)
if not out_dim:
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name)
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name, auto_generated=True)
data = data.copy_template_replace_dim_tag(axis=0, new_dim_tag=out_dim)
data.time_dim_axis = 0
return data
Expand Down Expand Up @@ -7126,7 +7134,7 @@ def find_axis(a_axis, b_axis):

if not b_var_dims and add_var2_if_empty:
b_var_dims.append(
SpatialDim("%s:dot:dummy-var2" % name, dimension=1))
SpatialDim("%s:dot:dummy-var2" % name, dimension=1, auto_generated=True))

dim_tags = list(a_rem_dims + a_var_dims + b_var_dims)
return Data(
Expand Down Expand Up @@ -7189,9 +7197,9 @@ def __init__(self, axis, amount, pad=True, adjust_size_info=True, **kwargs):
self.output.size_placeholder[axis_wob] + size_delta, 0, tf.shape(shifted)[axis])
from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="shift_axis",
kind=Dim.Types.Spatial, description="%s_shift_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
src_data=self.output, src_axis=axis)
src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis_wob] = new_size

@classmethod
Expand All @@ -7210,7 +7218,7 @@ def get_out_data_from_opts(cls, name, amount, axis, pad, sources=(), **kwargs):
axis = out.get_axis_from_description(axis)
tag = out.dim_tags[axis]
dim = None if tag.dimension is None else max(0, tag.dimension - abs(amount))
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim)
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=tag)


Expand Down Expand Up @@ -7319,7 +7327,7 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim)
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)


Expand Down Expand Up @@ -7410,7 +7418,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
axis = out.get_axis_from_description(axis, allow_int=False)
in_dim = out.dim_tags[axis]
if not out_dim:
out_dim = Dim(kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim)
out_dim = Dim(
kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)


Expand Down Expand Up @@ -8582,18 +8591,17 @@ def get_out_data_from_opts(cls, name, network,
elif isinstance(d, int):
d = Dim(
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
description="%s:static:%i" % (name, i),
description="%s:static:%i" % (name, i), auto_generated=True,
dimension=d)
else:
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
dim_tags.append(d)
if add_time_axis:
dim_tags.insert(
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1))
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True))
if add_batch_axis:
dim_tags.insert(
0, Dim(
kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
0, Dim(kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
return Data(
name="%s_output" % name, dim_tags=dim_tags, dtype=dtype,
batch=network.get_global_batch_info() if add_batch_axis else None)
Expand Down
Loading

0 comments on commit 3650515

Please sign in to comment.