Skip to content

Commit

Permalink
test_GatherLayer_broadcast_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
Zettelkasten committed Dec 16, 2021
1 parent d05d1da commit 995ec58
Showing 1 changed file with 16 additions and 21 deletions.
37 changes: 16 additions & 21 deletions tests/test_TFNetworkLayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3377,27 +3377,22 @@ def test_GatherLayer_constant_position():
np.testing.assert_almost_equal(out_seqs[b, f1, f2], values_seqs[b, f1, position, f2])


def test_GatherLayer_search_beam():
from returnn.tf.network import TFNetwork
from returnn.config import Config
with make_scope() as session:
n_out = 5
config = Config({
"debug_print_layer_output_template": True,
"extern_data": {
"data": {"dim": n_out},
"classes": {"dim": n_out, "sparse": True}
}})
net = TFNetwork(config=config, search_flag=True)
net.construct_from_dict({
"output": {
"class": "rec", "from": "data:data", "unit": {
"position": {"class": "reinterpret_data", "from": "prev:output", "set_sparse": False},
"gather": {"class": "gather", "from": "base:data:data", "position": "position", "axis": "t"}, # [B,T,slice,D]
"prob": {"class": "softmax", "from": "gather", "target": "classes", "loss": "ce"},
'output': {
'class': 'choice', 'target': "classes", 'beam_size': 3, 'from': "prob", "input_type": "prob",
"initial_output": 0}}}})
def test_GatherLayer_broadcast_dim():
from returnn.tf.util.data import batch_dim
head_dim = SpatialDim("head", 1) # previously, this dim would match all others and therefore fail.
round_dim = SpatialDim("round", 2)
chunk_dim = SpatialDim("chunk")
time_dim = SpatialDim("time")
config = Config({"extern_data": {
"source": {"dim_tags": [batch_dim, head_dim, time_dim]},
"position": {"dim_tags": [batch_dim, head_dim, round_dim, chunk_dim], "dtype": "int32"}},
"debug_print_layer_output_template": True})
net = TFNetwork(config=config)
net.construct_from_dict({
"output": {
'class': 'gather', 'from': 'data:source', 'position': 'data:position', 'axis': time_dim,
'out_shape': {batch_dim, head_dim, round_dim, chunk_dim}}
})


def test_SliceNdLayer():
Expand Down

0 comments on commit 995ec58

Please sign in to comment.