Skip to content

Commit

Permalink
Data same_dim_tags_as fix auto_create_placeholders (#1159)
Browse files Browse the repository at this point in the history
Don't first create a new size placeholder and then later call declare_same_as.

Esp this is required when declare_same_as becomes stricter (#1143).

Fix wrong batch info:

The dim tag could have an old invalid batch info.
E.g. the global batch_dim when it comes from an old run.

If we really need this, we should validate the dim tag first.

But probably it's better to remove it and clean it up.

Engine reset global batch dim.
  • Loading branch information
albertz authored Oct 19, 2022
1 parent a6236c5 commit 0189da5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 26 deletions.
2 changes: 2 additions & 0 deletions returnn/tf/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,8 @@ def _init_network(self, net_desc, epoch=None):
use_dataset_pipeline = False
if self.config.is_true("dataset_pipeline"):
use_dataset_pipeline = True
from returnn.tf.util.data import batch_dim
batch_dim.batch = None # make sure it is reset
extern_data = ExternData()
extern_data.init_from_config(config=self.config, auto_create_placeholders=not use_dataset_pipeline)
if use_dataset_pipeline:
Expand Down
22 changes: 8 additions & 14 deletions returnn/tf/util/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2823,11 +2823,6 @@ def __init__(self, name,
if time_dim_axis is NotSpecified:
time_dim_axis = _default_time_dim_axis_dim_tags(dim_tags)
dim_tags = tuple(dim_tags)
if auto_create_placeholders:
_auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=dim_tags)
if batch_dim_axis_ is not None:
if dim_tags[batch_dim_axis_].batch and not self._batch:
self._batch = dim_tags[batch_dim_axis_].batch
del shape_
del batch_dim_axis_
else:
Expand All @@ -2846,7 +2841,7 @@ def __init__(self, name,
dim_tags = _infer_dim_tags_tuple_from_shape(
shape, batch_dim_axis=batch_dim_axis, time_dim_axis=time_dim_axis, feature_dim_axis=feature_dim_axis,
size_placeholder=size_placeholder, name=name,
auto_create_placeholders=auto_create_placeholders,
extern_data=auto_create_placeholders,
dim_tags=dim_tags, sparse=sparse)
del batch_dim_axis
del shape
Expand Down Expand Up @@ -2893,8 +2888,10 @@ def __init__(self, name,
base_tag = self._dim_tags[_axis]
if base_tag != _dim_tag:
base_tag.declare_same_as(_dim_tag)
if _dim_tag.dyn_size is not None:
self.set_dynamic_size(_axis, _dim_tag.dyn_size)
self._dim_tags = self._dim_tags[:_axis] + (_dim_tag,) + self._dim_tags[_axis + 1:]
if auto_create_placeholders:
# Do that after same_dim_tags_as handling.
_auto_create_size_placeholders_on_dim_tags(name=name, dim_tags=self._dim_tags)
self._adapt_batch_consistent_dim_tags()
self.sanity_check(assume_complete=False)

Expand Down Expand Up @@ -5780,7 +5777,7 @@ def _infer_dim_tags_tuple_from_shape(
size_placeholder,
dim_tags,
name,
auto_create_placeholders
extern_data
):
"""
:param tuple[int|None]|list[int|None] shape: this is without batch-dim-axis
Expand All @@ -5790,7 +5787,7 @@ def _infer_dim_tags_tuple_from_shape(
:param bool sparse:
:param dict[int,tf.Tensor]|None size_placeholder: key is axis without batch-dim
:param dict[int,Dim]|None dim_tags: some existing explicitly specified dim tags. key is axis with batch-dim
:param bool auto_create_placeholders:
:param bool extern_data:
:param str name:
:return: dim tags tuple
:rtype: tuple[Dim]
Expand All @@ -5808,8 +5805,6 @@ def _infer_dim_tags_tuple_from_shape(
dim_tags = dim_tags.copy() if dim_tags else {}
if batch_dim_axis is not None and batch_dim_axis not in dim_tags:
dim_tags[batch_dim_axis] = Dim(kind=Dim.Types.Batch, description="batch:%s" % name)
# noinspection PyShadowingNames
batch_dim = dim_tags[batch_dim_axis] if batch_dim_axis is not None else None
# Note: Consistent to Data.get_dim_tag,
# prefer interpretation as spatial axis if there is a dynamic size or this is marked as time axis.
if size_placeholder:
Expand All @@ -5833,7 +5828,7 @@ def _infer_dim_tags_tuple_from_shape(
axis_wo_b = _get_axis_wo_b(axis, batch_dim_axis=batch_dim_axis)
dyn_size = size_placeholder.get(axis_wo_b) if (size_placeholder and axis_wo_b is not None) else None
dim = batch_shape[axis]
if auto_create_placeholders and dim is None and dyn_size is None and axis != batch_dim_axis:
if extern_data and dim is None and dyn_size is None and axis != batch_dim_axis:
if not tag:
if axis == time_dim_axis:
tag_name = "time"
Expand All @@ -5845,7 +5840,6 @@ def _infer_dim_tags_tuple_from_shape(
# This is such that Dim.is_equal behaves as before, e.g. in Data.get_common_data.
kind=Dim.Types.Spatial)
dim_tags[axis] = tag
_create_size_placeholder(name=name, axis_wo_b=axis_wo_b, tag=tag, batch_dim=batch_dim)
dyn_size = tag.dyn_size
if tag:
# Just some sanity checks.
Expand Down
26 changes: 14 additions & 12 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,12 +2223,13 @@ def test_SplitDimsLayer_dim_tags():
feat_dim = FeatureDim("feat", 3)
config = Config({
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim],
'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}}
})
with make_scope():
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'split_dims', 'from': 'data', 'axis': time_dim, 'dims': [rem_dim, window_dim],
'out_shape': {batch_dim, rem_dim, window_dim, feat_dim}}
})


def test_SplitDimsLayer_dim_tags_expand():
Expand All @@ -2238,12 +2239,13 @@ def test_SplitDimsLayer_dim_tags_expand():
expand_dim = SpatialDim("expand_dim", 1)
config = Config({
"extern_data": {"data": {"dim_tags": [batch_dim, time_dim, feat_dim]}}})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim],
'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}}
})
with make_scope():
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'split_dims', 'from': 'data', 'axis': feat_dim, 'dims': [feat_dim, expand_dim],
'out_shape': {batch_dim, time_dim, feat_dim, expand_dim}}
})


def test_SplitDimsLayer_dim_tags_split_batch_simple():
Expand Down
16 changes: 16 additions & 0 deletions tests/test_TFUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -1243,6 +1243,22 @@ def test_Data_verify_out_shape_optional_implicit_dim():
x.verify_out_shape({time_dim, feat_dim}, allow_missing_implicit_dims=True)


def test_Data_auto_create_placeholders_same_dim_tags_as_existing():
# Came up via: https://github.com/rwth-i6/returnn/pull/1143
n_out = 3
time_tag = SpatialDim("time")
with tf.Graph().as_default() as graph, tf_compat.v1.Session(graph=graph) as session:
assert isinstance(graph, tf.Graph)
data = Data("data", dim=n_out, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True)
classes = Data("classes", dim=n_out, sparse=True, same_dim_tags_as={"t": time_tag}, auto_create_placeholders=True)
assert time_tag.dyn_size is not None # this is not so relevant and might change
seq_len = time_tag.dyn_size
assert seq_len is data.get_sequence_lengths() is classes.get_sequence_lengths()
assert seq_len.op.type == "Placeholder"
placeholder_ops = [op for op in graph.get_operations() if op.type == "Placeholder"]
assert_equal(set(placeholder_ops), {data.placeholder.op, classes.placeholder.op, time_tag.dyn_size.op})


def test_Dim_copy():
# https://github.com/rwth-i6/returnn/issues/860
import copy
Expand Down

0 comments on commit 0189da5

Please sign in to comment.