diff --git a/returnn/tf/layers/basic.py b/returnn/tf/layers/basic.py index d42078ca6d..3639c38536 100644 --- a/returnn/tf/layers/basic.py +++ b/returnn/tf/layers/basic.py @@ -1634,7 +1634,10 @@ def __init__(self, position, position_axis, output_dim_via_time_from=None, out_s pos_ndim = position.output.batch_ndim assert 0 <= replace_common_axis < pos_ndim pos_shape = [position.output.get_dim(i) for i in range(pos_ndim)] - output_dim = output_dim_via_time_from.output.time_dimension() + if output_dim_via_time_from: + output_dim = output_dim_via_time_from.output.time_dimension() + else: + output_dim = out_spatial_dim.get_dim_value() input_shape = pos_shape + [self.input_data.get_dim(i) for i in input_extra_axes] input_expanded = self.input_data.copy_compatible_to(common, unbroadcast=True) input_v = input_expanded.placeholder