diff --git a/page/interface.md b/page/interface.md index 6e48f6769..894ddd821 100644 --- a/page/interface.md +++ b/page/interface.md @@ -40,14 +40,14 @@ for `simplify` to work. Other required methods are `operation` and `istree` In addition, the methods for `Base.hash` and `Base.isequal` should also be implemented by the types for the purposes of substitution and equality matching respectively. -### Optional - -#### `similarterm(t::MyType, f, args)` +#### `similarterm(t::MyType, f, args[, T])` -Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation). +Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation). T is the symtype of the output term. You can use `promote_symtype` to infer this type. The below two functions are internal to SymbolicUtils +### Optional + #### `symtype(x)` The supposed type of values in the domain of x. Tracing tools can use this type to diff --git a/src/abstractalgebra.jl b/src/abstractalgebra.jl index 2b3411e3b..412fd095f 100644 --- a/src/abstractalgebra.jl +++ b/src/abstractalgebra.jl @@ -23,9 +23,9 @@ function labels!(dicts, t) return t elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-)) tt = arguments(t) - return similarterm(t, operation(t), map(x->labels!(dicts, x), tt)) + return similarterm(t, operation(t), map(x->labels!(dicts, x), tt), symtype(t)) elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2]) - return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t))) + return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)), symtype(t)) else sym2term, term2sym = dicts if haskey(term2sym, t) @@ -36,7 +36,8 @@ function labels!(dicts, t) sym = Sym{symtype(t)}(gensym(nameof(operation(t)))) dicts2 = _dicts(dicts[2]) sym2term[sym] = similarterm(t, operation(t), - map(x->to_mpoly(x, dicts)[1], arguments(t))) + map(x->to_mpoly(x, dicts)[1], arguments(t)), + symtype(t)) else sym = Sym{symtype(t)}(gensym("literal")) sym2term[sym] = t @@ -110,7 +111,7 @@ function _to_term(reference, x::MPoly, dict, syms) elseif length(monics) == 0 return 1 else - return similarterm(reference, *, monics) + return similarterm(reference, *, monics, symtype(reference)) end end @@ -123,7 +124,8 @@ function _to_term(reference, x::MPoly, dict, syms) t = similarterm(reference, +, map((x,y)->isone(y) ? x : Int(y)*x, - monoms, x.coeffs[1:length(monoms)])) + monoms, x.coeffs[1:length(monoms)]), + symtype(reference)) end substitute(t, dict, fold=false) @@ -131,7 +133,7 @@ end function _to_term(reference, x, dict, vars) if istree(x) - t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,))) + t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)), symtype(x)) else if haskey(dict, x) return dict[x] diff --git a/src/types.jl b/src/types.jl index ede1d8084..dd188613a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -74,7 +74,7 @@ end function to_symbolic(x) Base.depwarn("`to_symbolic(x)` is deprecated, define the interface for your " * "symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " * - "and `similarterm(::YourType, f, args)`", :to_symbolic, force=true) + "and `similarterm(::YourType, f, args, symtype)`", :to_symbolic, force=true) x end @@ -319,13 +319,23 @@ function term(f, args...; type = nothing) end """ - similarterm(t, f, args) + similarterm(t, f, args, symtype) -Create a term that is similar in type to `t` such that `symtype(similarterm(f, -args...)) === symtype(f(args...))`. +Create a term that is similar in type to `t`. Extending this function allows packages +using their own expression types with SymbolicUtils to define how new terms should +be created. + +## Arguments + +- `t` the reference term to use to create similar terms +- `f` is the operation of the term +- `args` is the arguments +- The `symtype` of the resulting term. Best effort will be made to set the symtype of the + resulting similar term to this type. """ -similarterm(t, f, args) = f(args...) -similarterm(::Term, f, args) = term(f, args...) +similarterm(t, f, args, symtype) = f(args...) +similarterm(t, f, args) = similarterm(t, f, args, _promote_symtype(f, args)) +similarterm(::Term, f, args, symtype=nothing) = term(f, args...; type=symtype) node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1 @@ -757,15 +767,16 @@ function mapvalues(f, d1::AbstractDict) d end -function similarterm(p::Union{Mul, Add, Pow}, f, args) - if f === (+) +function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing) + if T === nothing T = _promote_symtype(f, args) + end + if f === (+) Add(T, makeadd(1, 0, args...)...) elseif f == (*) - T = _promote_symtype(f, args) Mul(T, makemul(1, args...)...) elseif f == (^) && length(args) == 2 - Pow(args...) + Pow{T, typeof.(args)...}(args...) else f(args...) end diff --git a/test/nf.jl b/test/nf.jl index b0be1d1bb..ba1feb2ca 100644 --- a/test/nf.jl +++ b/test/nf.jl @@ -1,11 +1,16 @@ using SymbolicUtils, Test -using SymbolicUtils: polynormalize, Term +using SymbolicUtils: polynormalize, Term, symtype @testset "polyform" begin @syms a b c d @test polynormalize(a * (b + -1 * c) + -1 * (b * a + -1 * c * a)) == 0 @eqtest polynormalize(sin(a+b)+sin(c+d)) == sin(a+b) + sin(c+d) @eqtest simplify(polynormalize(sin((a+b)^2)^2)) == simplify(sin(a^2+2*(b*a)+b^2)^2) @test simplify(polynormalize(sin((a+b)^2)^2 + cos((a+b)^2)^2)) == 1 + @syms x1::Real f(::Real)::Real + + # issue 193 + @test isequal(polynormalize(f(x1 + 2.0)), f(2.0 + x1)) + @test symtype(polynormalize(f(x1 + 2.0))) == Real # cleanup rules @test polynormalize(Term{Number}(identity, 0)) == 0