diff --git a/Project.toml b/Project.toml index 4bc2bb891..7bed4b959 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "2.0.0" +version = "2.0.1" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" diff --git a/src/interface.jl b/src/interface.jl index 687a802a8..355137ecb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -78,5 +78,7 @@ with `head` as the head and `args` as the arguments, `type` as the symtype and `metadata` as the metadata. By default this will execute `head(args...)`. `x` parameter can also be a `Type`. The `exprhead` keyword argument is useful when manipulating `Expr`s. + +`similarterm` is deprecated see help for `maketerm` instead. """ function similarterm end diff --git a/src/polyform.jl b/src/polyform.jl index 21e04ac9b..88019d5ce 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -118,10 +118,11 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) # create a new symbol to store this y = if recurse - similarterm(x, - op, - map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), - args), symtype(x)) + maketerm(typeof(x), + op, + map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args), + symtype(x), + metadata(x)) else x end @@ -175,11 +176,11 @@ isexpr(x::PolyForm) = true iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true -function similarterm(t::PolyForm, f, args, symtype; metadata=nothing) - basic_similarterm(t, f, args, symtype; metadata=metadata) +function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata) + basicsymbolic(t, f, args, symtype, metadata) end -function similarterm(::PolyForm, f::Union{typeof(*), typeof(+), typeof(^)}, - args, symtype; metadata=nothing) +function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, + args, symtype, metadata) f(args...) end @@ -248,8 +249,10 @@ multivariate polynomials implementation. expand(expr) = unpolyize(PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true)) function unpolyize(x) - simterm(x, f, args; kw...) = similarterm(x, f, args, symtype(x); kw...) - Postwalk(identity, similarterm=simterm)(x) + # we need a special makterm here because the default one used in Postwalk will call + # promote_symtype to get the new type, but we just want to forward that in case + # promote_symtype is not defined for some of the expressions here. + Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x) end function toterm(x::PolyForm) @@ -301,7 +304,7 @@ function add_divs(x, y) end end -function frac_similarterm(x, f, args; kw...) +function frac_maketerm(T, f, args, stype, metadata) if f in (*, /, \, +, -) f(args...) elseif f == (^) @@ -311,7 +314,7 @@ function frac_similarterm(x, f, args; kw...) args[1]^args[2] end else - similarterm(x, f, args; kw...) + maketerm(T, f, args, stype, metadata) end end @@ -333,8 +336,8 @@ function simplify_fractions(x; polyform=false) sdiv(a) = isdiv(a) ? simplify_div(a) : a expr = Postwalk(sdiv ∘ quick_cancel, - similarterm=frac_similarterm)(Postwalk(add_with_div, - similarterm=frac_similarterm)(x)) + maketerm=frac_maketerm)(Postwalk(add_with_div, + maketerm=frac_maketerm)(x)) polyform ? expr : unpolyize(expr) end diff --git a/src/rewriters.jl b/src/rewriters.jl index 81ae2dbe0..3b3bba5e5 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -33,7 +33,7 @@ module Rewriters using SymbolicUtils: @timer using TermInterface -import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count +import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough # Cache of printed rules to speed up @timer @@ -167,24 +167,41 @@ end struct Walk{ord, C, F, threaded} rw::C thread_cutoff::Int - similarterm::F + maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function + # that behaves like `similarterm` here, we use `compatmaker` to wrap + # maketerm-like input to do this, with a warning if similarterm provided + # we need this workaround to deprecate because similarterm takes value + # but maketerm only knows the type. end function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} irw = instrument(x.rw, f) - Walk{ord, typeof(irw), typeof(x.similarterm), threaded}(irw, + Walk{ord, typeof(irw), typeof(x.maketerm), threaded}(irw, x.thread_cutoff, - x.similarterm) + x.maketerm) end using .Threads -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function compatmaker(similarterm, maketerm) + # XXX: delete this and only use maketerm in a future release. + if similarterm isa Nothing + function (x, f, args, type=_promote_symtype(f, args); metadata) + maketerm(typeof(x), f, args, type, metadata) + end + else + Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm) + similarterm + end +end +function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) + maker = compatmaker(similarterm, maketerm) + Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm) - Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm) +function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) + maker = compatmaker(similarterm, maketerm) + Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) end struct PassThrough{C} @@ -202,8 +219,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} x = p.rw(x) end - if istree(x) - x = p.similarterm(x, operation(x), map(PassThrough(p), + if iscall(x) + x = p.maketerm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)), metadata=metadata(x)) end @@ -228,7 +245,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} end end args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.similarterm(x, operation(x), args, metadata=metadata(x)) + t = p.maketerm(x, operation(x), args, metadata=metadata(x)) end return ord === :post ? p.rw(t) : t else diff --git a/src/rule.jl b/src/rule.jl index 05941b764..89b1242bd 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -408,7 +408,7 @@ function (acr::ACRule)(term) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result - return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term)) + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term), metadata(term)) end end end diff --git a/src/substitute.jl b/src/substitute.jl index 73ea7659d..99ac134a0 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -31,11 +31,11 @@ function substitute(expr, dict; fold=true) args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr)) end - similarterm(expr, - op, - args, - symtype(expr); - metadata=metadata(expr)) + maketerm(typeof(expr), + op, + args, + symtype(expr), + metadata(expr)) else expr end diff --git a/src/utils.jl b/src/utils.jl index 90d7c407c..69b6e8e2d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -53,7 +53,7 @@ function fold(t) # evaluate it return operation(t)(tt...) else - return similarterm(t, operation(t), tt) + return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t)) end else return t @@ -147,19 +147,19 @@ function flatten_term(⋆, x) push!(flattened_args, t) end end - similarterm(x, ⋆, flattened_args) + maketerm(typeof(x), ⋆, flattened_args, symtype(x), metadata(x)) end function sort_args(f, t) args = arguments(t) if length(args) < 2 - return similarterm(t, f, args) + return maketerm(typeof(t), f, args, symtype(t), metadata(t)) elseif length(args) == 2 x, y = args - return similarterm(t, f, x <ₑ y ? [x,y] : [y,x]) + return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t)) end args = args isa Tuple ? [args...] : args - similarterm(t, f, sort(args, lt=<ₑ)) + maketerm(typeof(t), f, sort(args, lt=<ₑ), symtype(t), metadata(t)) end # Linked List interface @@ -225,7 +225,7 @@ macro matchable(expr) SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)] Base.length(x::$name) = $(length(fields) + 1) - SymbolicUtils.similarterm(x::$name, f, args, type; kw...) = f(args...) + SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...) end |> esc end diff --git a/test/basics.jl b/test/basics.jl index cc3464eab..36228324c 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -214,20 +214,20 @@ end @test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14))) end -@testset "similarterm" begin +@testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1)) - @test isequal(SymbolicUtils.similarterm(b^2, ^, [b^2, 1//2]), b) + @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b) - # test that similarterm doesn't hard-code BasicSymbolic subtype + # test that maketerm doesn't hard-code BasicSymbolic subtype # and is consistent with BasicSymbolic arithmetic operations - @test isequal(SymbolicUtils.similarterm(a / b, *, [a / b, c]), (a / b) * c) - @test isequal(SymbolicUtils.similarterm(a * b, *, [0, c]), 0) - @test isequal(SymbolicUtils.similarterm(a^b, ^, [a * b, 3]), (a * b)^3) + @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c) + @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0) + @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3) - # test that similarterm sets metadata correctly + # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") - s = SymbolicUtils.similarterm(a^b, ^, [a * b, 3]; metadata = metadata) + s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata) @test hasmetadata(s, Ctx1) @test getmetadata(s, Ctx1) == "meta_1" end diff --git a/test/runtests.jl b/test/runtests.jl index 3098331ab..ad533ae15 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,6 @@ else include("cse.jl") include("interface.jl") # Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed - # include("fuzz.jl") + include("fuzz.jl") include("adjoints.jl") end