Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
shashi committed Feb 4, 2021
2 parents 7a968a1 + 423b640 commit 9c09e7e
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 21 deletions.
8 changes: 4 additions & 4 deletions page/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ for `simplify` to work. Other required methods are `operation` and `istree`

In addition, the methods for `Base.hash` and `Base.isequal` should also be implemented by the types for the purposes of substitution and equality matching respectively.

### Optional

#### `similarterm(t::MyType, f, args)`
#### `similarterm(t::MyType, f, args[, T])`

Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation).
Construct a new term with the operation `f` and arguments `args`, the term should be similar to `t` in type. if `t` is a `Term` object a new Term is created with the same symtype as `t`. If not, the result is computed as `f(args...)`. Defining this method for your term type will reduce any performance loss in performing `f(args...)` (esp. the splatting, and redundant type computation). T is the symtype of the output term. You can use `promote_symtype` to infer this type.

The below two functions are internal to SymbolicUtils

### Optional

#### `symtype(x)`

The supposed type of values in the domain of x. Tracing tools can use this type to
Expand Down
14 changes: 8 additions & 6 deletions src/abstractalgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ function labels!(dicts, t)
return t
elseif istree(t) && (operation(t) == (*) || operation(t) == (+) || operation(t) == (-))
tt = arguments(t)
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt))
return similarterm(t, operation(t), map(x->labels!(dicts, x), tt), symtype(t))
elseif istree(t) && operation(t) == (^) && length(arguments(t)) > 1 && isnonnegint(arguments(t)[2])
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)))
return similarterm(t, operation(t), map(x->labels!(dicts, x), arguments(t)), symtype(t))
else
sym2term, term2sym = dicts
if haskey(term2sym, t)
Expand All @@ -36,7 +36,8 @@ function labels!(dicts, t)
sym = Sym{symtype(t)}(gensym(nameof(operation(t))))
dicts2 = _dicts(dicts[2])
sym2term[sym] = similarterm(t, operation(t),
map(x->to_mpoly(x, dicts)[1], arguments(t)))
map(x->to_mpoly(x, dicts)[1], arguments(t)),
symtype(t))
else
sym = Sym{symtype(t)}(gensym("literal"))
sym2term[sym] = t
Expand Down Expand Up @@ -110,7 +111,7 @@ function _to_term(reference, x::MPoly, dict, syms)
elseif length(monics) == 0
return 1
else
return similarterm(reference, *, monics)
return similarterm(reference, *, monics, symtype(reference))
end
end

Expand All @@ -123,15 +124,16 @@ function _to_term(reference, x::MPoly, dict, syms)
t = similarterm(reference,
+,
map((x,y)->isone(y) ? x : Int(y)*x,
monoms, x.coeffs[1:length(monoms)]))
monoms, x.coeffs[1:length(monoms)]),
symtype(reference))
end

substitute(t, dict, fold=false)
end

function _to_term(reference, x, dict, vars)
if istree(x)
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)))
t=similarterm(x, operation(x), _to_term.((reference,), arguments(x), (dict,), (vars,)), symtype(x))
else
if haskey(dict, x)
return dict[x]
Expand Down
31 changes: 21 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ end
function to_symbolic(x)
Base.depwarn("`to_symbolic(x)` is deprecated, define the interface for your " *
"symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " *
"and `similarterm(::YourType, f, args)`", :to_symbolic, force=true)
"and `similarterm(::YourType, f, args, symtype)`", :to_symbolic, force=true)

x
end
Expand Down Expand Up @@ -319,13 +319,23 @@ function term(f, args...; type = nothing)
end

"""
similarterm(t, f, args)
similarterm(t, f, args, symtype)
Create a term that is similar in type to `t` such that `symtype(similarterm(f,
args...)) === symtype(f(args...))`.
Create a term that is similar in type to `t`. Extending this function allows packages
using their own expression types with SymbolicUtils to define how new terms should
be created.
## Arguments
- `t` the reference term to use to create similar terms
- `f` is the operation of the term
- `args` is the arguments
- The `symtype` of the resulting term. Best effort will be made to set the symtype of the
resulting similar term to this type.
"""
similarterm(t, f, args) = f(args...)
similarterm(::Term, f, args) = term(f, args...)
similarterm(t, f, args, symtype) = f(args...)
similarterm(t, f, args) = similarterm(t, f, args, _promote_symtype(f, args))
similarterm(::Term, f, args, symtype=nothing) = term(f, args...; type=symtype)

node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init=0) + 1 : 1

Expand Down Expand Up @@ -757,15 +767,16 @@ function mapvalues(f, d1::AbstractDict)
d
end

function similarterm(p::Union{Mul, Add, Pow}, f, args)
if f === (+)
function similarterm(p::Union{Mul, Add, Pow}, f, args, T=nothing)
if T === nothing
T = _promote_symtype(f, args)
end
if f === (+)
Add(T, makeadd(1, 0, args...)...)
elseif f == (*)
T = _promote_symtype(f, args)
Mul(T, makemul(1, args...)...)
elseif f == (^) && length(args) == 2
Pow(args...)
Pow{T, typeof.(args)...}(args...)
else
f(args...)
end
Expand Down
7 changes: 6 additions & 1 deletion test/nf.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
using SymbolicUtils, Test
using SymbolicUtils: polynormalize, Term
using SymbolicUtils: polynormalize, Term, symtype
@testset "polyform" begin
@syms a b c d
@test polynormalize(a * (b + -1 * c) + -1 * (b * a + -1 * c * a)) == 0
@eqtest polynormalize(sin(a+b)+sin(c+d)) == sin(a+b) + sin(c+d)
@eqtest simplify(polynormalize(sin((a+b)^2)^2)) == simplify(sin(a^2+2*(b*a)+b^2)^2)
@test simplify(polynormalize(sin((a+b)^2)^2 + cos((a+b)^2)^2)) == 1
@syms x1::Real f(::Real)::Real

# issue 193
@test isequal(polynormalize(f(x1 + 2.0)), f(2.0 + x1))
@test symtype(polynormalize(f(x1 + 2.0))) == Real

# cleanup rules
@test polynormalize(Term{Number}(identity, 0)) == 0
Expand Down

0 comments on commit 9c09e7e

Please sign in to comment.