diff --git a/Project.toml b/Project.toml index a157523..7618fce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MutableArithmetics" uuid = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" authors = ["Gilles Peiffer", "BenoƮt Legat", "Sascha Timme"] -version = "1.3.2" +version = "1.3.3" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/src/rewrite_generic.jl b/src/rewrite_generic.jl index 55dcf6f..08012ad 100644 --- a/src/rewrite_generic.jl +++ b/src/rewrite_generic.jl @@ -51,6 +51,10 @@ function _is_parameters(expr) return Meta.isexpr(expr, :call, 3) && Meta.isexpr(expr.args[2], :parameters) end +function _is_kwarg(expr, kwarg::Symbol) + return Meta.isexpr(expr, :kw) && expr.args[1] == kwarg +end + """ _rewrite_generic(stack::Expr, expr::Expr) @@ -78,14 +82,20 @@ function _rewrite_generic(stack::Expr, expr::Expr) # come in two forms: `sum(i for i=I, j=J)` or `sum(i for i=I for j=J)`. # The latter is a `:flatten` expression and needs additional handling, # but we delay this complexity for _rewrite_generic_generator. - if Meta.isexpr(expr.args[2], :parameters, 1) && - Meta.isexpr(expr.args[2].args[1], :kw, 2) && - expr.args[2].args[1].args[1] == :init - # sum(iter ; init) form! - root = gensym() - init, _ = _rewrite_generic(stack, expr.args[2].args[1].args[2]) - push!(stack.args, :($root = $init)) - return _rewrite_generic_generator(stack, :+, expr.args[3], root) + if Meta.isexpr(expr.args[2], :parameters) + # The summation has keyword arguments. We can deal with `init`, but + # not any of the others. + p = expr.args[2] + if length(p.args) == 1 && _is_kwarg(p.args[1], :init) + # sum(iter ; init) form! + root = gensym() + init, _ = _rewrite_generic(stack, p.args[1].args[2]) + push!(stack.args, :($root = $init)) + return _rewrite_generic_generator(stack, :+, expr.args[3], root) + else + # We don't know how to deal with this + return esc(expr), false + end else # Summations use :+ as the reduction operator. init_expr = expr.args[2].args[end] diff --git a/test/rewrite_generic.jl b/test/rewrite_generic.jl index d6d32eb..0b752a7 100644 --- a/test/rewrite_generic.jl +++ b/test/rewrite_generic.jl @@ -344,6 +344,22 @@ function test_rewrite_expression() return end +function test_rewrite_generic_sum_dims() + @test_rewrite sum([1 2; 3 4]; dims = 1) + @test_rewrite sum([1 2; 3 4]; dims = 2) + @test_rewrite sum([1 2; 3 4]; dims = 1, init = 0) + @test_rewrite sum([1 2; 3 4]; dims = 2, init = 0) + @test_rewrite sum([1 2; 3 4]; init = 0, dims = 1) + @test_rewrite sum([1 2; 3 4]; init = 0, dims = 2) + @test_rewrite sum([1 2; 3 4], dims = 1) + @test_rewrite sum([1 2; 3 4], dims = 2) + @test_rewrite sum([1 2; 3 4], dims = 1, init = 0) + @test_rewrite sum([1 2; 3 4], dims = 2, init = 0) + @test_rewrite sum([1 2; 3 4], init = 0, dims = 1) + @test_rewrite sum([1 2; 3 4], init = 0, dims = 2) + return +end + end # module TestRewriteGeneric.runtests()