Skip to content

Commit

Permalink
test_GatherLayer_broadcast_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Zettelkasten committed Dec 17, 2021
1 parent a94072b commit d620a08
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3400,6 +3400,24 @@ def test_GatherLayer_search_beam():
"initial_output": 0}}}})


def test_GatherLayer_broadcast_dim():
from returnn.tf.util.data import batch_dim
head_dim = SpatialDim("head", 1) # previously, this dim would match all others and therefore fail.
round_dim = SpatialDim("round", 2)
chunk_dim = SpatialDim("chunk")
time_dim = SpatialDim("time")
config = Config({"extern_data": {
"source": {"dim_tags": [batch_dim, head_dim, time_dim]},
"position": {"dim_tags": [batch_dim, head_dim, round_dim, chunk_dim], "dtype": "int32"}},
"debug_print_layer_output_template": True})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'gather', 'from': 'data:source', 'position': 'data:position', 'axis': time_dim,
'out_shape': {batch_dim, head_dim, round_dim, chunk_dim}}
})


def test_SliceNdLayer():
n_batch = 5
n_time = 7
Expand Down

0 comments on commit d620a08

Please sign in to comment.