From ad600f2dad615b32aa85774bc0798bf38fa4a343 Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Thu, 14 Sep 2023 17:29:43 +0200 Subject: [PATCH] TF PositionalEncodingLayer inside rec, position state (#1390) This fixes the case when inside masked computation. --- returnn/tf/layers/rec.py | 56 ++++++++++++++++++++++++--------- tests/test_TFNetworkRecLayer.py | 13 ++++++++ 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/returnn/tf/layers/rec.py b/returnn/tf/layers/rec.py index 9cb136a5e1..64edefa689 100644 --- a/returnn/tf/layers/rec.py +++ b/returnn/tf/layers/rec.py @@ -8283,19 +8283,16 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No axis = "T" else: axis = single_step_dim + use_constant = constant > -1 if axis != single_step_dim: src_axis_int = source.get_axis_from_description(axis, allow_int=False) out_axis_int = self.output.get_axis_from_description(axis, allow_int=False) - if constant > -1: + if use_constant: position = constant * tf.ones([1] * output_templ_wo_feat.batch_ndim, tf.int32) if offset_data: position += offset_data.placeholder # (batch, len) # signal has shape (1, len) or (batch, len) or (1, 1) or more ones signal = get_positional_encoding(num_channels=self.output.dim, position=position) - if not add_to_input and not offset_data: # Need to tile the time dimension - tiles = [1] * self.output.batch_ndim - tiles[out_axis_int] = tf_util.get_shape_dim(source.placeholder, src_axis_int) - signal = tf.tile(signal, tiles) else: length = tf_util.get_shape_dim(source.placeholder, src_axis_int) position = tf.range(length) # (len,) @@ -8308,23 +8305,32 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No # signal has shape (1,len,n_out) or (batch,len,n_out) signal = get_positional_encoding(num_channels=self.output.dim, position=position) else: # single step - if constant > -1: - position = tf.convert_to_tensor([constant]) + out_axis_int = None + if use_constant: + position = tf.fill(value=constant, dims=[self.get_batch_dim()]) # [B] else: - position = tf.convert_to_tensor([self.network.get_rec_step_index()]) + position = self._rec_previous_layer.rec_vars_outputs["position"] + 1 # [B] + self.rec_vars_outputs["position"] = position if offset_data: position += offset_data.placeholder # (batch,) - signal = get_positional_encoding( - num_channels=self.output.dim, position=position - ) # (1,n_out) or (batch,n_out) + signal = get_positional_encoding(num_channels=self.output.dim, position=position) # (batch,n_out) if add_to_input: signal += source.placeholder + # No need for tiling because the source should have exactly all relevant dims. else: - # tile to batch dimension explicitly, as batch_dim=1 will not be automatically unbroadcasted + # Check whether we need to tile. tiles = [1] * self.output.batch_ndim - tiles[self.output.batch_dim_axis] = self.get_batch_dim() - signal = tf.tile(signal, tiles) + for axis_, dim in enumerate(self.output.dims[:-1]): + if offset and dim in offset.output.dims: + continue # already unbroadcasted above via offset + if axis != single_step_dim and not use_constant and out_axis_int == axis_: + continue # already handled above via time axis + if axis == single_step_dim and dim.is_batch_dim(): + continue # already handled above, state has batch dim + tiles[axis_] = self.get_batch_dim() if dim.is_batch_dim() else dim.get_dim_value() + if any([(not isinstance(t, int) or t > 1) for t in tiles]): + signal = tf.tile(signal, tiles) self.output.placeholder = signal @classmethod @@ -8359,6 +8365,28 @@ def get_out_data_from_opts(cls, name, network, add_to_input=False, sources=(), * name=name, network=network, sources=sources, **kwargs ) + # noinspection PyMethodOverriding + @classmethod + def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, network, **kwargs): + """ + :param tf.Tensor batch_dim: for this layer, might be with beam + :param returnn.tf.layers.rec.RecLayer|LayerBase|None rec_layer: for the scope + :param returnn.tf.network.TFNetwork network: + :rtype: dict[str,tf.Tensor] + """ + return {"position": tf.fill(value=-1, dims=[batch_dim])} + + # noinspection PyMethodOverriding + @classmethod + def get_rec_initial_extra_outputs_shape_invariants(cls, rec_layer, network, **kwargs): + """ + :param returnn.tf.layers.rec.RecLayer|LayerBase|None rec_layer: for the scope + :param returnn.tf.network.TFNetwork network: + :return: optional shapes for the tensors by get_rec_initial_extra_outputs + :rtype: dict[str,tf.TensorShape] + """ + return {"position": tf.TensorShape([None])} + class KenLmStateLayer(_ConcatInputLayer): """ diff --git a/tests/test_TFNetworkRecLayer.py b/tests/test_TFNetworkRecLayer.py index 3d2eb032a5..211becec41 100644 --- a/tests/test_TFNetworkRecLayer.py +++ b/tests/test_TFNetworkRecLayer.py @@ -7304,6 +7304,19 @@ def test_reclayer_optimize_out_cumsum_unrelated_axis(): ) +def test_reclayer_optimize_out_pos_enc_layer(): + feat_dim = FeatureDim("feat", dimension=11) + check_reclayer_optimize_out( + feat_dim=feat_dim, + subnet_layer_dict={ + "class": "positional_encoding", + "out_dim": feat_dim, + "from": "data:source", + "add_to_input": True, + }, + ) + + def test_reclayer_optimize_out_rel_pos_enc_layer(): # https://github.com/rwth-i6/returnn/issues/1253 time_dim = SpatialDim("time")