Skip to content

Commit

Permalink
Dim complete_dyn_size, better combine ops
Browse files Browse the repository at this point in the history
Fix #1410
  • Loading branch information
albertz committed Oct 6, 2023
1 parent 206bbea commit eda1381
Showing 1 changed file with 8 additions and 12 deletions.
20 changes: 8 additions & 12 deletions returnn/tensor/_dim_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,18 +1089,17 @@ def _bin_op(a, b):
return a
if isinstance(b, _t.Tensor):
return b
raise Exception(f"Dim complete_dyn_size: bin_op: expect to get one Tensor, got {a} and {b}")
if kind == "add":
return _relu(a + b)
return _relu(rf.combine_bc(a, "add", b))
elif kind == "sub":
return _relu(a - b)
return _relu(rf.combine_bc(a, "sub", b))
elif kind == "mul":
return a * b
return rf.combine_bc(a, "mul", b)
elif kind in ("floordiv", "truediv"): # truediv assumes there is no remainder
return a // b
return rf.combine_bc(a, "floordiv", b)
elif kind == "ceildiv":
if isinstance(a, _t.Tensor):
return rf.ceil_divide(a, b)
return -(-a // b)
return rf.combine_bc(a, "ceildiv", b)
else:
raise ValueError("unknown op kind %r" % op.kind)

Expand Down Expand Up @@ -1148,17 +1147,14 @@ def _relu(a):
if y is None:
y = x.copy(name=y_name)
continue
if x.dim_tags != y.dim_tags:
if tf:
common = _t.Tensor.get_common_data([x, y], allow_broadcast_all_sources=True)
x_ = x.copy_compatible_to_dims(common.dims) if x.dims else x
y_ = y.copy_compatible_to_dims(common.dims) if y.dims else y
y = common
else:
x_, y_ = x, y
if tf:
y.placeholder = _bin_op_tf(y_.placeholder, x_.placeholder)
else:
y = _bin_op(y_, x_)
y = _bin_op(y, x)
assert y, f"op {op}?"
if self.dyn_size_ext:
assert self.dyn_size_ext.dim_tags == y.dim_tags
Expand Down

0 comments on commit eda1381

Please sign in to comment.