Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dim auto_generated flag #950

Merged
merged 3 commits into from
Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def _base_get_out_data_from_opts(cls, network, name,
feature_dim_tag = out_dim
else:
dim = out_type.get("dim", None)
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim)
feature_dim_tag = FeatureDim("%s:feature-dense" % name, dim, auto_generated=True)
if feature_dim_axis in (NotSpecified, None):
if sources_data.feature_dim_axis is None:
feature_dim_axis = len(dim_tags)
Expand Down
86 changes: 47 additions & 39 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,7 +1039,7 @@ def get_out_data_from_opts(
if out_dim:
assert out_dim.dimension == new_dim
else:
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim)
out_dim = Dim(kind=dim_tag.kind, description="%s:slice" % name, dimension=new_dim, auto_generated=True)
return input_data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim, name="%s_output" % name)


Expand Down Expand Up @@ -1219,7 +1219,7 @@ def get_out_data_from_opts(cls, name, sources=(), start=None, size=None, axis="T
out_spatial_dim = Dim(
kind=Dim.Types.Spatial,
description="sliced-time:%s" % name,
dimension=size)
dimension=size, auto_generated=True)
gather_positions_data = gather_positions_data.copy_add_dim_by_tag(
out_spatial_dim, unbroadcast=True, axis=start_data.batch_ndim)
position = InternalLayer(
Expand Down Expand Up @@ -2527,7 +2527,9 @@ def get_out_data_from_opts(cls, name, shape, dtype="float32", **kwargs):
:param str dtype:
:rtype: Data
"""
dim_tags = [d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d) for i, d in enumerate(shape)]
dim_tags = [
d if isinstance(d, Dim) else SpatialDim("%s:dim%i" % (name, i), d, auto_generated=True)
for i, d in enumerate(shape)]
return Data(name="%s_output" % name, dim_tags=dim_tags, dtype=dtype)


Expand Down Expand Up @@ -2613,7 +2615,7 @@ def get_out_data_from_opts(cls, name, network, shape, maxval, minval=0, dtype="i
elif isinstance(d, int):
d = Dim(
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
description="%s:static:%i" % (name, i),
description="%s:static:%i" % (name, i), auto_generated=True,
dimension=d)
else:
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
Expand Down Expand Up @@ -2679,12 +2681,12 @@ def get_out_data_from_opts(cls, name, limit, start=0, delta=1, dtype=None, spars
else:
dtype = "int32"
dim = len(range(start, limit, delta))
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name)
tag = Dim(kind=Dim.Types.Spatial, dimension=dim, description="%s:range" % name, auto_generated=True)
if out_spatial_dim:
tag.declare_same_as(out_spatial_dim)
sparse_dim = None
if sparse:
sparse_dim = SpatialDim("%s:range-indices" % name)
sparse_dim = SpatialDim("%s:range-indices" % name, auto_generated=True)
return Data(name="%s_output" % name, dim_tags=[tag], dtype=dtype, sparse_dim=sparse_dim)


Expand Down Expand Up @@ -2799,7 +2801,7 @@ def get_out_data_from_opts(cls, name, sources, dtype="int32", sparse=False, out_
dim_tag = Dim.get_tag_from_size_tensor(source.placeholder)
if not dim_tag:
dim_tag = Dim(
kind=Dim.Types.Spatial, description="%s_input_len" % name,
kind=Dim.Types.Spatial, description="%s_input_len" % name, auto_generated=True,
batch=source.batch, control_flow_ctx=source.control_flow_ctx,
dyn_size_ext=source)
if source.placeholder is not None:
Expand Down Expand Up @@ -2940,7 +2942,7 @@ def get_out_data_from_opts(cls, name, sources, n_out=NotSpecified, out_dim=None,
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = FeatureDim("%s:gating" % name, dimension=dim)
out_dim = FeatureDim("%s:gating" % name, dimension=dim, auto_generated=True)
if n_out is not NotSpecified:
assert n_out == dim
return Data(
Expand Down Expand Up @@ -3079,15 +3081,15 @@ def get_out_data_from_opts(cls, name, network, sources, window_size=None, window
filter_size=window_size, stride=stride, dilation_rate=1, padding=padding)
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:spatial" % name,
dimension=dim, derived_from_tag=in_spatial_dim,
dimension=dim, derived_from_tag=in_spatial_dim, auto_generated=True,
batch=data.batch, control_flow_ctx=data.control_flow_ctx)
data = data.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_spatial_dim)
new_dim_axis = axis + 1 # add new axis right after
if window_dim:
assert window_dim.dimension == window_size
else:
window_dim = Dim(
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size)
kind=Dim.Types.Spatial, description="%s:window" % name, dimension=window_size, auto_generated=True)
return data.copy_add_dim_by_tag(axis=new_dim_axis, dim_tag=window_dim, unbroadcast=True)

# noinspection PyMethodOverriding
Expand Down Expand Up @@ -3585,9 +3587,11 @@ def _get_axis_size_splits_num_splits(cls, name, input_data, axis=None,
err_prefix, out_dims, dim, input_data)
if not out_dims:
assert size_splits
out_dims = [Dim(
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
dimension=size_splits[idx]) for idx in range(len(size_splits))]
out_dims = [
Dim(
kind=input_data.dim_tags[axis].kind, description="%s_split%i" % (name, idx),
dimension=size_splits[idx], auto_generated=True)
for idx in range(len(size_splits))]
return axis, out_dims

def _make_split_layer(self, idx):
Expand Down Expand Up @@ -3867,6 +3871,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
kind=axis_dim_tag.kind,
description="%s_split_dims%i_rem" % (name, rem_dim_idx),
dimension=resolved_shape_dims[rem_dim_idx],
auto_generated=True,
derived_from_tag=axis_dim_tag,
batch=axis_dim_tag.batch, control_flow_ctx=axis_dim_tag.control_flow_ctx)
if rem_dim.dimension is None and axis_dim_tag.dyn_size_ext is not None:
Expand All @@ -3883,7 +3888,7 @@ def get_out_data_from_opts(cls, name, axis, dims, pad_to_multiples=None, sources
Dim(
kind=axis_dim_tag.kind if not axis_dim_tag.is_batch_dim() else Dim.Types.Spatial,
description="%s_split_dims%i" % (name, i),
dimension=shape_dim)
dimension=shape_dim, auto_generated=True)
if rem_dim is None or i != rem_dim_idx else rem_dim
for i, shape_dim in enumerate(resolved_shape_dims))
out_batch = data.batch
Expand Down Expand Up @@ -4158,7 +4163,7 @@ def get_out_data_from_opts(cls, name, sources, num_axes, in_dim="T", out_dims=No
assert not declare_same_sizes_as
else:
out_dims = [
SpatialDim("%s:unflatten-nd:%i" % (name, i))
SpatialDim("%s:unflatten-nd:%i" % (name, i), auto_generated=True)
for i in range(num_axes)]
if declare_same_sizes_as:
for i, other in declare_same_sizes_as.items():
Expand Down Expand Up @@ -4238,7 +4243,7 @@ def get_out_data_from_opts(cls, name, axis, dim=1, sources=(), **kwargs):
else:
new_dim = Dim(
kind=Dim.Types.Feature if init_axis.lower() == "f" else Dim.Types.Spatial,
description="%s_expand_dims" % name,
description="%s_expand_dims" % name, auto_generated=True,
dimension=dim)
data = data.copy_template(name="%s_output" % name)
data = data.copy_add_dim_by_tag(new_dim, unbroadcast=True, axis=axis)
Expand Down Expand Up @@ -4394,7 +4399,7 @@ def get_out_data_from_opts(cls, name, sources, axis, repetitions, out_dim=None,
if isinstance(repetitions, int):
out_dim = tag * repetitions
else:
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag)
out_dim = Dim(description="repeated:%s" % name, kind=tag.kind, derived_from_tag=tag, auto_generated=True)
return data.copy_template_replace_dim_tag(axis=data.get_batch_axis(0), new_dim_tag=out_dim)


Expand Down Expand Up @@ -4804,12 +4809,12 @@ def map_axis_name(s):
pass
else:
out.sparse_dim = Dim(
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name)
kind=Dim.Types.Feature, dimension=set_sparse_dim, description="%s:set-sparse-dim" % name, auto_generated=True)
if increase_sparse_dim:
assert out.sparse
out.sparse_dim = Dim(
kind=Dim.Types.Feature, dimension=out.sparse_dim.dimension + 1,
description="%s:inc-sparse-dim" % name)
description="%s:inc-sparse-dim" % name, auto_generated=True)
if batch_dim_base:
out.batch = batch_dim_base.output.batch
return out
Expand Down Expand Up @@ -5086,7 +5091,7 @@ def transform_input(cls, input_data, network, in_dim=None, in_spatial_dims=None,
cls._check_defined_in_spatial_dims(len(in_spatial_dims) == 1)
if input_expand_dims:
for i in range(input_expand_dims):
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1)
dim_tag = SpatialDim("input_expand_dims:%i" % i, dimension=1, auto_generated=True)
input_data = input_data.copy_add_dim_by_tag(dim_tag, unbroadcast=True)
in_spatial_dims.append(dim_tag)
if input_split_feature_dim:
Expand Down Expand Up @@ -5263,10 +5268,10 @@ def get_out_data_from_opts(
filter_size=filter_size[i], stride=strides[i], dilation_rate=dilation_rate[i], padding=padding)
dim_tags.append(Dim(
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
derived_from_tag=old_tag, undefined=not old_tag))
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
if not out_dim:
assert n_out
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
dim_tags.append(out_dim)
feature_dim_axis = NotSpecified
# Swap the dims if the input dim order doesn't fit the flag auto_use_channel_first.
Expand Down Expand Up @@ -5767,10 +5772,10 @@ def get_out_data_from_opts(cls, name, sources, network,
padding=padding, output_padding=output_padding[i]) - remove_padding[i] * 2
dim_tags.append(Dim(
kind=Dim.Types.Spatial, description="%s:conv:s%i" % (name, i), dimension=new_dim,
derived_from_tag=old_tag, undefined=not old_tag))
derived_from_tag=old_tag, undefined=not old_tag, auto_generated=True))
if not out_dim:
assert n_out
out_dim = FeatureDim("%s:channel" % name, dimension=n_out)
out_dim = FeatureDim("%s:channel" % name, dimension=n_out, auto_generated=True)
dim_tags.append(out_dim)
return Data(
name="%s_output" % name, dim_tags=dim_tags,
Expand Down Expand Up @@ -5983,7 +5988,8 @@ def get_out_data_from_opts(cls, name, sources, mode="", axes=None, axis=None, ke
out_time_dim_axis = x.time_dim_axis
if keep_dims:
for i in axes:
y_dim_tags[i] = Dim(kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i))
y_dim_tags[i] = Dim(
kind=y_dim_tags[i].kind, dimension=1, description="%s:keep-dim-%i" % (name, i), auto_generated=True)
else:
if out_batch_dim_axis in axes:
out_batch_dim_axis = None
Expand Down Expand Up @@ -6184,7 +6190,7 @@ def get_out_data_from_opts(cls, name, sources, axis=None, out_spatial_dim=None,
out = common_source.copy_template(name="%s_output" % name)
if not out_spatial_dim:
out_spatial_dim = Dim(
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources))
kind=Dim.Types.Spatial, description="%s:stack" % name, dimension=len(sources), auto_generated=True)
assert out_spatial_dim.dimension == len(sources)
out = out.copy_add_dim_by_tag(axis=axis, dim_tag=out_spatial_dim, unbroadcast=True)
return out
Expand Down Expand Up @@ -6316,7 +6322,8 @@ def get_out_data_from_opts(cls, name, sources, axes, padding=None, size=None, ke
dim_tags = list(data.dim_tags)
for i, a in enumerate(axes):
dim_tags[a] = Dim(
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i])
kind=dim_tags[a].kind, description="%s:weighted-sum:%i" % (name, i), dimension=res_dims[i],
auto_generated=True)
data = data.copy_template_new_dim_tags(dim_tags, keep_special_axes=True)
else:
assert all([d == 1 for d in res_dims])
Expand Down Expand Up @@ -6467,7 +6474,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, size_base
assert not out_dim
out_dim = size_base.output.get_time_dim_tag()
if not out_dim:
out_dim = (repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name)) + in_dim
out_dim = (
repeat if isinstance(repeat, int) else SpatialDim("%s:repeat" % repeat.name, auto_generated=True)) + in_dim
assert out_dim.dimension == out_dim_int
x = x.copy_template_replace_dim_tag(axis=axis_int, new_dim_tag=out_dim)
if isinstance(repeat, LayerBase):
Expand Down Expand Up @@ -6630,7 +6638,7 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
data = data.copy_move_axis(old_axis=axis, new_axis=0)
data = data.copy_with_batch_dim_axis(1)
if not out_dim:
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name)
out_dim = Dim(kind=in_dim.kind, description="%s:chunking" % name, auto_generated=True)
data = data.copy_template_replace_dim_tag(axis=0, new_dim_tag=out_dim)
data.time_dim_axis = 0
return data
Expand Down Expand Up @@ -7126,7 +7134,7 @@ def find_axis(a_axis, b_axis):

if not b_var_dims and add_var2_if_empty:
b_var_dims.append(
SpatialDim("%s:dot:dummy-var2" % name, dimension=1))
SpatialDim("%s:dot:dummy-var2" % name, dimension=1, auto_generated=True))

dim_tags = list(a_rem_dims + a_var_dims + b_var_dims)
return Data(
Expand Down Expand Up @@ -7189,9 +7197,9 @@ def __init__(self, axis, amount, pad=True, adjust_size_info=True, **kwargs):
self.output.size_placeholder[axis_wob] + size_delta, 0, tf.shape(shifted)[axis])
from ..util.data import Dim
Dim(
kind=Dim.Types.Spatial, description="shift_axis",
kind=Dim.Types.Spatial, description="%s_shift_axis" % self.name,
dyn_size=new_size, batch=self.output.batch,
src_data=self.output, src_axis=axis)
src_data=self.output, src_axis=axis, auto_generated=True)
self.output.size_placeholder[axis_wob] = new_size

@classmethod
Expand All @@ -7210,7 +7218,7 @@ def get_out_data_from_opts(cls, name, amount, axis, pad, sources=(), **kwargs):
axis = out.get_axis_from_description(axis)
tag = out.dim_tags[axis]
dim = None if tag.dimension is None else max(0, tag.dimension - abs(amount))
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim)
tag = Dim(kind=tag.kind, description="%s_shift_axis" % name, dimension=dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=tag)


Expand Down Expand Up @@ -7319,7 +7327,7 @@ def get_out_data_from_opts(cls, factor, axis, sources, name, out_dim=None, **kwa
if out_dim:
assert out_dim.dimension == dim
else:
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim)
out_dim = Dim(kind=tag.kind, description="%s_resize" % name, dimension=dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)


Expand Down Expand Up @@ -7410,7 +7418,8 @@ def get_out_data_from_opts(cls, name, sources, axis="T", out_dim=None, **kwargs)
axis = out.get_axis_from_description(axis, allow_int=False)
in_dim = out.dim_tags[axis]
if not out_dim:
out_dim = Dim(kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim)
out_dim = Dim(
kind=in_dim.kind, description="%s_removed_items", dimension=None, derived_from_tag=in_dim, auto_generated=True)
return out.copy_template_replace_dim_tag(axis=axis, new_dim_tag=out_dim)


Expand Down Expand Up @@ -8582,18 +8591,17 @@ def get_out_data_from_opts(cls, name, network,
elif isinstance(d, int):
d = Dim(
kind=Dim.Types.Spatial if i < len(shape) - 1 else Dim.Types.Feature,
description="%s:static:%i" % (name, i),
description="%s:static:%i" % (name, i), auto_generated=True,
dimension=d)
else:
raise TypeError("Layer %r: invalid type %s in shape %r" % (name, type(d), shape))
dim_tags.append(d)
if add_time_axis:
dim_tags.insert(
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1))
0, Dim(kind=Dim.Types.Time, description="%s:dummy-time" % name, dimension=1, auto_generated=True))
if add_batch_axis:
dim_tags.insert(
0, Dim(
kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
0, Dim(kind=Dim.Types.Batch, description="batch", batch=network.get_global_batch_info()))
return Data(
name="%s_output" % name, dim_tags=dim_tags, dtype=dtype,
batch=network.get_global_batch_info() if add_batch_axis else None)
Expand Down
Loading