From 66d09f93d06a530d0d168e9b607218519657a0dc Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 15 Sep 2023 12:21:01 +0200 Subject: [PATCH] TF SelectSearchSourcesLayer fix no search choices --- returnn/tf/layers/basic.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index 865b0815f4..15584548fd 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -644,21 +644,29 @@ def select_if_needed(cls, layer, search_choices): :rtype: LayerBase """ assert isinstance(layer, LayerBase) - if not search_choices: - return layer if layer.network.is_extra_internal_template_construction(): assert layer.output.placeholder is None # we expect a template return layer + if not search_choices and isinstance(layer, SelectSearchSourcesLayer): + layer = layer.sources[0] layer_search_choices = layer.get_search_choices() if layer_search_choices and layer_search_choices.keep_raw: return layer if layer_search_choices == search_choices: - assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % ( - layer.output.beam, - search_choices.get_beam_info(), - layer.network.debug_search_choices(layer) or "debug search dumped", - ) + if search_choices: + assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % ( + layer.output.beam, + search_choices.get_beam_info(), + layer.network.debug_search_choices(layer) or "debug search dumped", + ) + else: + assert not layer.output.beam, "%r should be None. %s" % ( + layer.output.beam, + layer.network.debug_search_choices(layer) or "debug search dumped", + ) return layer + if not search_choices: + raise Exception(f"Layer {layer} required without search choices.") if layer.output.batch_dim_axis is None: # e.g. VariableLayer, ConstantLayer, or so return layer layer = SelectSearchSourcesLayer(sources=[layer], search_choices_layer=search_choices.owner)