Skip to content

Commit

Permalink
TF PositionalEncodingLayer better, small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 13, 2023
1 parent 3b01151 commit 3784a39
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions returnn/tf/layers/rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8309,10 +8309,11 @@ def __init__(self, axis=NotSpecified, add_to_input=False, constant=-1, offset=No
signal = get_positional_encoding(num_channels=self.output.dim, position=position)
else: # single step
if constant > -1:
position = tf.convert_to_tensor([constant])
position = tf.convert_to_tensor([constant]) # (1,)
else:
position = self._rec_previous_layer.rec_vars_outputs["position"] + 1
self.rec_vars_outputs["position"] = position
position = tf.expand_dims(position, axis=0) # (1,)
if offset_data:
position += offset_data.placeholder # (batch,)
signal = get_positional_encoding(
Expand All @@ -8337,7 +8338,7 @@ def transform_config_dict(cls, d, network, get_layer):
"""
if d.get("from", None) is None:
if network.is_inside_rec_layer():
d["from"] = []
d["from"] = [":i"]
else:
d["from"] = ["data"]
super(PositionalEncodingLayer, cls).transform_config_dict(d, network=network, get_layer=get_layer)
Expand Down

0 comments on commit 3784a39

Please sign in to comment.