Skip to content

Commit

Permalink
TF PositionalEncodingLayer inside rec, position state (#1390)
Browse files Browse the repository at this point in the history
This fixes the case when inside masked computation.
  • Loading branch information
albertz authored Sep 14, 2023
1 parent 9f30825 commit ad600f2
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 14 deletions.
56 changes: 42 additions & 14 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8283,19 +8283,16 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No
axis = "T"
else:
axis = single_step_dim
use_constant = constant > -1
if axis != single_step_dim:
src_axis_int = source.get_axis_from_description(axis, allow_int=False)
out_axis_int = self.output.get_axis_from_description(axis, allow_int=False)
if constant > -1:
if use_constant:
position = constant * tf.ones([1] * output_templ_wo_feat.batch_ndim, tf.int32)
if offset_data:
position += offset_data.placeholder # (batch, len)
# signal has shape (1, len) or (batch, len) or (1, 1) or more ones
signal = get_positional_encoding(num_channels=self.output.dim, position=position)
if not add_to_input and not offset_data: # Need to tile the time dimension
tiles = [1] * self.output.batch_ndim
tiles[out_axis_int] = tf_util.get_shape_dim(source.placeholder, src_axis_int)
signal = tf.tile(signal, tiles)
else:
length = tf_util.get_shape_dim(source.placeholder, src_axis_int)
position = tf.range(length) # (len,)
Expand All @@ -8308,23 +8305,32 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No
# signal has shape (1,len,n_out) or (batch,len,n_out)
signal = get_positional_encoding(num_channels=self.output.dim, position=position)
else: # single step
if constant > -1:
position = tf.convert_to_tensor([constant])
out_axis_int = None
if use_constant:
position = tf.fill(value=constant, dims=[self.get_batch_dim()]) # [B]
else:
position = tf.convert_to_tensor([self.network.get_rec_step_index()])
position = self._rec_previous_layer.rec_vars_outputs["position"] + 1 # [B]
self.rec_vars_outputs["position"] = position
if offset_data:
position += offset_data.placeholder # (batch,)
signal = get_positional_encoding(
num_channels=self.output.dim, position=position
) # (1,n_out) or (batch,n_out)
signal = get_positional_encoding(num_channels=self.output.dim, position=position) # (batch,n_out)

if add_to_input:
signal += source.placeholder
# No need for tiling because the source should have exactly all relevant dims.
else:
# tile to batch dimension explicitly, as batch_dim=1 will not be automatically unbroadcasted
# Check whether we need to tile.
tiles = [1] * self.output.batch_ndim
tiles[self.output.batch_dim_axis] = self.get_batch_dim()
signal = tf.tile(signal, tiles)
for axis_, dim in enumerate(self.output.dims[:-1]):
if offset and dim in offset.output.dims:
continue # already unbroadcasted above via offset
if axis != single_step_dim and not use_constant and out_axis_int == axis_:
continue # already handled above via time axis
if axis == single_step_dim and dim.is_batch_dim():
continue # already handled above, state has batch dim
tiles[axis_] = self.get_batch_dim() if dim.is_batch_dim() else dim.get_dim_value()
if any([(not isinstance(t, int) or t > 1) for t in tiles]):
signal = tf.tile(signal, tiles)
self.output.placeholder = signal

@classmethod
Expand Down Expand Up @@ -8359,6 +8365,28 @@ def get_out_data_from_opts(cls, name, network, add_to_input=False, sources=(), *
name=name, network=network, sources=sources, **kwargs
)

# noinspection PyMethodOverriding
@classmethod
def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, network, **kwargs):
"""
:param tf.Tensor batch_dim: for this layer, might be with beam
:param returnn.tf.layers.rec.RecLayer|LayerBase|None rec_layer: for the scope
:param returnn.tf.network.TFNetwork network:
:rtype: dict[str,tf.Tensor]
"""
return {"position": tf.fill(value=-1, dims=[batch_dim])}

# noinspection PyMethodOverriding
@classmethod
def get_rec_initial_extra_outputs_shape_invariants(cls, rec_layer, network, **kwargs):
"""
:param returnn.tf.layers.rec.RecLayer|LayerBase|None rec_layer: for the scope
:param returnn.tf.network.TFNetwork network:
:return: optional shapes for the tensors by get_rec_initial_extra_outputs
:rtype: dict[str,tf.TensorShape]
"""
return {"position": tf.TensorShape([None])}


class KenLmStateLayer(_ConcatInputLayer):
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7304,6 +7304,19 @@ def test_reclayer_optimize_out_cumsum_unrelated_axis():
)


def test_reclayer_optimize_out_pos_enc_layer():
feat_dim = FeatureDim("feat", dimension=11)
check_reclayer_optimize_out(
feat_dim=feat_dim,
subnet_layer_dict={
"class": "positional_encoding",
"out_dim": feat_dim,
"from": "data:source",
"add_to_input": True,
},
)


def test_reclayer_optimize_out_rel_pos_enc_layer():
# https://github.com/rwth-i6/returnn/issues/1253
time_dim = SpatialDim("time")
Expand Down

0 comments on commit ad600f2

Please sign in to comment.