diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index e358578210..88344d1630 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -1094,6 +1094,8 @@ def _bin_op_tf(a, b): raise ValueError("unknown op kind %r" % op.kind) def _bin_op(a, b): + if b is None: + return None if a is None: if isinstance(b, int): if not template_only and backend and not tf: @@ -1110,13 +1112,11 @@ def _bin_op(a, b): return b.copy(name=y_name) else: raise TypeError(f"complete_dyn_size: _bin_op: unexpected type {type(b)}") - if b is None: - return None assert isinstance(a, _t.Tensor) if template_only or not backend: if isinstance(b, _t.Tensor): return _t.Tensor.get_common_data([a, b], allow_broadcast_all_sources=True) - return a + return a.copy_template() if tf: if isinstance(b, _t.Tensor): res = _t.Tensor.get_common_data([a, b], allow_broadcast_all_sources=True)