Skip to content

Commit

Permalink
Fix bug rewriting sum with dims kwarg
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Sep 20, 2023
1 parent c4d711d commit 3423c53
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MutableArithmetics"
uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0"
authors = ["Gilles Peiffer", "Benoît Legat", "Sascha Timme"]
version = "1.3.2"
version = "1.3.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
26 changes: 18 additions & 8 deletions src/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ function _is_parameters(expr)
return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters)
end

function _is_kwarg(expr, kwarg::Symbol)
return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg
end

"""
_rewrite_generic(stack::Expr, expr::Expr)
Expand Down Expand Up @@ -78,14 +82,20 @@ function _rewrite_generic(stack::Expr, expr::Expr)
# come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`.
# The latter is a `:flatten` expression and needs additional handling,
# but we delay this complexity for _rewrite_generic_generator.
if Meta.isexpr(expr.args[2], :parameters, 1) &&
Meta.isexpr(expr.args[2].args[1], :kw, 2) &&
expr.args[2].args[1].args[1] == :init
# sum(iter ; init) form!
root = gensym()
init, _ = _rewrite_generic(stack, expr.args[2].args[1].args[2])
push!(stack.args, :($root = $init))
return _rewrite_generic_generator(stack, :+, expr.args[3], root)
if Meta.isexpr(expr.args[2], :parameters)
# The summation has keyword arguments. We can deal with `init`, but
# not any of the others.
p = expr.args[2]
if length(p.args) == 1 && _is_kwarg(p.args[1], :init)
# sum(iter ; init) form!
root = gensym()
init, _ = _rewrite_generic(stack, p.args[1].args[2])
push!(stack.args, :($root = $init))
return _rewrite_generic_generator(stack, :+, expr.args[3], root)
else
# We don't know how to deal with this
return esc(expr), false
end
else
# Summations use :+ as the reduction operator.
init_expr = expr.args[2].args[end]
Expand Down
53 changes: 53 additions & 0 deletions test/rewrite_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,59 @@ function test_rewrite_expression()
return
end

function test_rewrite_generic_sum_dims()
x = [1 2; 3 4]
@test ==(
MA.@rewrite(sum(x; dims = 1), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x; dims = 2), move_factors_into_sums = false),
[3; 7;;],
)
@test ==(
MA.@rewrite(sum(x; dims = 1, init = 0), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x; dims = 2, init = 0), move_factors_into_sums = false),
[3; 7;;],
)
@test ==(
MA.@rewrite(sum(x; init = 0, dims = 1), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x; init = 0, dims = 2), move_factors_into_sums = false),
[3; 7;;],
)
@test ==(
MA.@rewrite(sum(x, dims = 1), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x, dims = 2), move_factors_into_sums = false),
[3; 7;;],
)
@test ==(
MA.@rewrite(sum(x, dims = 1, init = 0), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x, dims = 2, init = 0), move_factors_into_sums = false),
[3; 7;;],
)
@test ==(
MA.@rewrite(sum(x, init = 0, dims = 1), move_factors_into_sums = false),
[4 6],
)
@test ==(
MA.@rewrite(sum(x, init = 0, dims = 2), move_factors_into_sums = false),
[3; 7;;],
)
return
end

end # module

TestRewriteGeneric.runtests()

0 comments on commit 3423c53

Please sign in to comment.