Skip to content

Commit

Permalink
TF PositionalEncodingLayer, compat for masked comp
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 13, 2023
1 parent 816492a commit ce10646
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8309,11 +8309,10 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No
signal = get_positional_encoding(num_channels=self.output.dim, position=position)
else: # single step
if constant > -1:
position = tf.convert_to_tensor([constant]) # (1,)
position = tf.convert_to_tensor([constant]) # [1]
else:
position = self._rec_previous_layer.rec_vars_outputs["position"] + 1
position = self._rec_previous_layer.rec_vars_outputs["position"] + 1 # [B]
self.rec_vars_outputs["position"] = position
position = tf.expand_dims(position, axis=0) # (1,)
if offset_data:
position += offset_data.placeholder # (batch,)
signal = get_positional_encoding(
Expand Down Expand Up @@ -8370,7 +8369,7 @@ def get_rec_initial_extra_outputs(cls, batch_dim, rec_layer, network, **kwargs):
:param returnn.tf.network.TFNetwork network:
:rtype: dict[str,tf.Tensor]
"""
return {"position": tf.constant(-1, shape=(), dtype=tf.int32)}
return {"position": tf.fill(value=-1, dims=[batch_dim])}

# noinspection PyMethodOverriding
@classmethod
Expand All @@ -8381,7 +8380,7 @@ def get_rec_initial_extra_outputs_shape_invariants(cls, rec_layer, network, **kw
:return: optional shapes for the tensors by get_rec_initial_extra_outputs
:rtype: dict[str,tf.TensorShape]
"""
return {"position": tf.TensorShape(())}
return {"position": tf.TensorShape([None])}


class KenLmStateLayer(_ConcatInputLayer):
Expand Down

0 comments on commit ce10646

Please sign in to comment.