Skip to content

Commit

Permalink
ScatterNdLayer, small fix for out_spatial_dim
Browse files Browse the repository at this point in the history
  • Loading branch information
albertz committed Feb 7, 2022
1 parent 562fec4 commit 4b500fd
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion returnn/tf/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,10 @@ def __init__(self, position, position_axis, output_dim_via_time_from=None, out_s
pos_ndim = position.output.batch_ndim
assert 0 <= replace_common_axis < pos_ndim
pos_shape = [position.output.get_dim(i) for i in range(pos_ndim)]
output_dim = output_dim_via_time_from.output.time_dimension()
if output_dim_via_time_from:
output_dim = output_dim_via_time_from.output.time_dimension()
else:
output_dim = out_spatial_dim.get_dim_value()
input_shape = pos_shape + [self.input_data.get_dim(i) for i in input_extra_axes]
input_expanded = self.input_data.copy_compatible_to(common, unbroadcast=True)
input_v = input_expanded.placeholder
Expand Down

0 comments on commit 4b500fd

Please sign in to comment.