diff --git a/src/new_rewrite.jl b/src/new_rewrite.jl index 0e30373..2e0a18c 100644 --- a/src/new_rewrite.jl +++ b/src/new_rewrite.jl @@ -139,17 +139,21 @@ function _rewrite(stack::Expr, expr::Expr) rhs = if is_mutable Expr(:call, MA.operate!!, *, arg1, arg2) else - Expr(:call, MA.operate!!, MA.add_mul, MA.Zero(), arg1, arg2) + Expr(:call, *, arg1, arg2) end root = gensym() push!(stack.args, :($root = $rhs)) for i in 4:length(expr.args) arg, _ = _rewrite(stack, expr.args[i]) - rhs = Expr(:call, MA.operate!!, *, root, arg) + rhs = if is_mutable + Expr(:call, MA.operate!!, *, root, arg) + else + Expr(:call, *, root, arg) + end root = gensym() push!(stack.args, :($root = $rhs)) end - return root, true + return root, is_mutable elseif expr.args[1] == :.+ # .+(args...) => add_mul.(add_mul.(arg1, arg2), arg3) @assert length(expr.args) > 1