Skip to content

Commit

Permalink
TF SelectSearchSourcesLayer fix no search choices
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Sep 15, 2023
1 parent 9d4d63b commit 66d09f9
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,21 +644,29 @@ def select_if_needed(cls, layer, search_choices):
:rtype: LayerBase
"""
assert isinstance(layer, LayerBase)
if not search_choices:
return layer
if layer.network.is_extra_internal_template_construction():
assert layer.output.placeholder is None # we expect a template
return layer
if not search_choices and isinstance(layer, SelectSearchSourcesLayer):
layer = layer.sources[0]
layer_search_choices = layer.get_search_choices()
if layer_search_choices and layer_search_choices.keep_raw:
return layer
if layer_search_choices == search_choices:
assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % (
layer.output.beam,
search_choices.get_beam_info(),
layer.network.debug_search_choices(layer) or "debug search dumped",
)
if search_choices:
assert layer.output.beam == search_choices.get_beam_info(), "%r != %r. %s" % (
layer.output.beam,
search_choices.get_beam_info(),
layer.network.debug_search_choices(layer) or "debug search dumped",
)
else:
assert not layer.output.beam, "%r should be None. %s" % (
layer.output.beam,
layer.network.debug_search_choices(layer) or "debug search dumped",
)
return layer
if not search_choices:
raise Exception(f"Layer {layer} required without search choices.")
if layer.output.batch_dim_axis is None: # e.g. VariableLayer, ConstantLayer, or so
return layer
layer = SelectSearchSourcesLayer(sources=[layer], search_choices_layer=search_choices.owner)
Expand Down

0 comments on commit 66d09f9

Please sign in to comment.