From eda13819f06947e8a0a03ce44ef590665977970e Mon Sep 17 00:00:00 2001 From: Albert Zeyer Date: Fri, 6 Oct 2023 16:43:58 +0200 Subject: [PATCH] Dim complete_dyn_size, better combine ops Fix #1410 --- returnn/tensor/_dim_extra.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/returnn/tensor/_dim_extra.py b/returnn/tensor/_dim_extra.py index c20f31c817..31650e275c 100644 --- a/returnn/tensor/_dim_extra.py +++ b/returnn/tensor/_dim_extra.py @@ -1089,18 +1089,17 @@ def _bin_op(a, b): return a if isinstance(b, _t.Tensor): return b + raise Exception(f"Dim complete_dyn_size: bin_op: expect to get one Tensor, got {a} and {b}") if kind == "add": - return _relu(a + b) + return _relu(rf.combine_bc(a, "add", b)) elif kind == "sub": - return _relu(a - b) + return _relu(rf.combine_bc(a, "sub", b)) elif kind == "mul": - return a * b + return rf.combine_bc(a, "mul", b) elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder - return a // b + return rf.combine_bc(a, "floordiv", b) elif kind == "ceildiv": - if isinstance(a, _t.Tensor): - return rf.ceil_divide(a, b) - return -(-a // b) + return rf.combine_bc(a, "ceildiv", b) else: raise ValueError("unknown op kind %r" % op.kind) @@ -1148,17 +1147,14 @@ def _relu(a): if y is None: y = x.copy(name=y_name) continue - if x.dim_tags != y.dim_tags: + if tf: common = _t.Tensor.get_common_data([x, y], allow_broadcast_all_sources=True) x_ = x.copy_compatible_to_dims(common.dims) if x.dims else x y_ = y.copy_compatible_to_dims(common.dims) if y.dims else y y = common - else: - x_, y_ = x, y - if tf: y.placeholder = _bin_op_tf(y_.placeholder, x_.placeholder) else: - y = _bin_op(y_, x_) + y = _bin_op(y, x) assert y, f"op {op}?" if self.dyn_size_ext: assert self.dyn_size_ext.dim_tags == y.dim_tags