Skip to content

Commit

Permalink
test_GenericAttentionLayer_extra_spatial_multi_head fix dims
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Nov 16, 2022
1 parent 8b6b82d commit 277a946
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions tests/test_TFNetworkRecLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7013,8 +7013,8 @@ def test_GenericAttentionLayer_extra_spatial_multi_head():
net = TFNetwork(extern_data=ExternData(), config=Config({"debug_print_layer_output_template": True}))
dec_time = SpatialDim("dec time")
enc_time = SpatialDim("enc time")
num_heads = 8
heads_dim = FeatureDim("heads", dimension=num_heads)
heads_dim = FeatureDim("heads", dimension=8)
feat_dim = FeatureDim("feat", dimension=2048)
kwargs = dict(
name="att", network=net,
weights=InternalLayer(
Expand All @@ -7024,14 +7024,13 @@ def test_GenericAttentionLayer_extra_spatial_multi_head():
base=InternalLayer(
name="enc_value", network=net,
output=Data(
name='enc_value_output', shape=(None, num_heads, 2048), batch_dim_axis=1, auto_create_placeholders=True,
same_dim_tags_as={"t": enc_time})))
name='enc_value_output', dim_tags=[enc_time, batch_dim, heads_dim, feat_dim], auto_create_placeholders=True)))
print("GenericAttentionLayer kwargs:")
pprint(kwargs)
kwargs["output"] = GenericAttentionLayer.get_out_data_from_opts(**kwargs)
layer = GenericAttentionLayer(**kwargs)
layer.output.sanity_check()
assert layer.output.shape == (num_heads, None, 2048) and layer.output.have_time_axis()
assert layer.output.shape == (heads_dim.dimension, None, feat_dim.dimension) and layer.output.have_time_axis()
assert len(layer.output.size_placeholder) == 1
assert list(layer.output.size_placeholder.values())[0] is layer.weights.output.size_placeholder[0]

Expand Down

0 comments on commit 277a946

Please sign in to comment.