Skip to content

Commit

Permalink
Merge pull request #603 from JuliaSymbolics/s/fix-depwarn
Browse files Browse the repository at this point in the history
similarterm -> maketerm in Walk and PolyForm
  • Loading branch information
ChrisRackauckas authored May 30, 2024
2 parents 9e3b9de + 0520c78 commit 33b274b
Show file tree
Hide file tree
Showing 9 changed files with 70 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicUtils"
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
authors = ["Shashi Gowda"]
version = "2.0.0"
version = "2.0.1"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down
2 changes: 2 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,7 @@ with `head` as the head and `args` as the arguments, `type` as the symtype
and `metadata` as the metadata. By default this will execute `head(args...)`.
`x` parameter can also be a `Type`. The `exprhead` keyword argument is useful
when manipulating `Expr`s.
`similarterm` is deprecated see help for `maketerm` instead.
"""
function similarterm end
31 changes: 17 additions & 14 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
# create a new symbol to store this

y = if recurse
similarterm(x,
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse),
args), symtype(x))
maketerm(typeof(x),
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args),
symtype(x),
metadata(x))
else
x
end
Expand Down Expand Up @@ -175,11 +176,11 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true

function similarterm(t::PolyForm, f, args, symtype; metadata=nothing)
basic_similarterm(t, f, args, symtype; metadata=metadata)
function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata)
basicsymbolic(t, f, args, symtype, metadata)
end
function similarterm(::PolyForm, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype; metadata=nothing)
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype, metadata)
f(args...)
end

Expand Down Expand Up @@ -248,8 +249,10 @@ multivariate polynomials implementation.
expand(expr) = unpolyize(PolyForm(expr, Fs=Union{typeof(+), typeof(*), typeof(^)}, recurse=true))

function unpolyize(x)
simterm(x, f, args; kw...) = similarterm(x, f, args, symtype(x); kw...)
Postwalk(identity, similarterm=simterm)(x)
# we need a special makterm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x)
end

function toterm(x::PolyForm)
Expand Down Expand Up @@ -301,7 +304,7 @@ function add_divs(x, y)
end
end

function frac_similarterm(x, f, args; kw...)
function frac_maketerm(T, f, args, stype, metadata)
if f in (*, /, \, +, -)
f(args...)
elseif f == (^)
Expand All @@ -311,7 +314,7 @@ function frac_similarterm(x, f, args; kw...)
args[1]^args[2]
end
else
similarterm(x, f, args; kw...)
maketerm(T, f, args, stype, metadata)
end
end

Expand All @@ -333,8 +336,8 @@ function simplify_fractions(x; polyform=false)
sdiv(a) = isdiv(a) ? simplify_div(a) : a

expr = Postwalk(sdiv quick_cancel,
similarterm=frac_similarterm)(Postwalk(add_with_div,
similarterm=frac_similarterm)(x))
maketerm=frac_maketerm)(Postwalk(add_with_div,
maketerm=frac_maketerm)(x))

polyform ? expr : unpolyize(expr)
end
Expand Down
39 changes: 28 additions & 11 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ module Rewriters
using SymbolicUtils: @timer
using TermInterface

import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count
import SymbolicUtils: iscall, operation, arguments, unsorted_arguments, metadata, node_count, _promote_symtype
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough

# Cache of printed rules to speed up @timer
Expand Down Expand Up @@ -167,24 +167,41 @@ end
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
similarterm::F
maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function
# that behaves like `similarterm` here, we use `compatmaker` to wrap
# maketerm-like input to do this, with a warning if similarterm provided
# we need this workaround to deprecate because similarterm takes value
# but maketerm only knows the type.
end

function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded}
irw = instrument(x.rw, f)
Walk{ord, typeof(irw), typeof(x.similarterm), threaded}(irw,
Walk{ord, typeof(irw), typeof(x.maketerm), threaded}(irw,
x.thread_cutoff,
x.similarterm)
x.maketerm)
end

using .Threads

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:post, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
function compatmaker(similarterm, maketerm)
# XXX: delete this and only use maketerm in a future release.
if similarterm isa Nothing
function (x, f, args, type=_promote_symtype(f, args); metadata)
maketerm(typeof(x), f, args, type, metadata)
end
else
Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm)
similarterm
end
end
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, similarterm=similarterm)
Walk{:pre, typeof(rw), typeof(similarterm), threaded}(rw, thread_cutoff, similarterm)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
end

struct PassThrough{C}
Expand All @@ -202,8 +219,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
x = p.rw(x)
end

if istree(x)
x = p.similarterm(x, operation(x), map(PassThrough(p),
if iscall(x)
x = p.maketerm(x, operation(x), map(PassThrough(p),
unsorted_arguments(x)), metadata=metadata(x))
end

Expand All @@ -228,7 +245,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = p.similarterm(x, operation(x), args, metadata=metadata(x))
t = p.maketerm(x, operation(x), args, metadata=metadata(x))
end
return ord === :post ? p.rw(t) : t
else
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ function (acr::ACRule)(term)
if result !== nothing
# Assumption: inds are unique
length(args) == length(inds) && return result
return similarterm(term, f, [result, (args[i] for i in eachindex(args) if i inds)...], symtype(term))
return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i inds)...], symtype(term), metadata(term))
end
end
end
Expand Down
10 changes: 5 additions & 5 deletions src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ function substitute(expr, dict; fold=true)
args = map(x->substitute(x, dict, fold=fold), unsorted_arguments(expr))
end

similarterm(expr,
op,
args,
symtype(expr);
metadata=metadata(expr))
maketerm(typeof(expr),
op,
args,
symtype(expr),
metadata(expr))
else
expr
end
Expand Down
12 changes: 6 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function fold(t)
# evaluate it
return operation(t)(tt...)
else
return similarterm(t, operation(t), tt)
return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t))
end
else
return t
Expand Down Expand Up @@ -147,19 +147,19 @@ function flatten_term(⋆, x)
push!(flattened_args, t)
end
end
similarterm(x, , flattened_args)
maketerm(typeof(x), , flattened_args, symtype(x), metadata(x))
end

function sort_args(f, t)
args = arguments(t)
if length(args) < 2
return similarterm(t, f, args)
return maketerm(typeof(t), f, args, symtype(t), metadata(t))
elseif length(args) == 2
x, y = args
return similarterm(t, f, x <ₑ y ? [x,y] : [y,x])
return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t))
end
args = args isa Tuple ? [args...] : args
similarterm(t, f, sort(args, lt=<ₑ))
maketerm(typeof(t), f, sort(args, lt=<), symtype(t), metadata(t))
end

# Linked List interface
Expand Down Expand Up @@ -225,7 +225,7 @@ macro matchable(expr)
SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)]
Base.length(x::$name) = $(length(fields) + 1)
SymbolicUtils.similarterm(x::$name, f, args, type; kw...) = f(args...)
SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...)
end |> esc
end

Expand Down
18 changes: 9 additions & 9 deletions test/basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,20 +214,20 @@ end
@test_reference "inspect_output/sub14.txt" sprint(io->SymbolicUtils.inspect(io, SymbolicUtils.pluck(ex, 14)))
end

@testset "similarterm" begin
@testset "maketerm" begin
@syms a b c
@test isequal(SymbolicUtils.similarterm((b + c), +, [a, (b+c)]).dict, Dict(a=>1,b=>1,c=>1))
@test isequal(SymbolicUtils.similarterm(b^2, ^, [b^2, 1//2]), b)
@test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1))
@test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b)

# test that similarterm doesn't hard-code BasicSymbolic subtype
# test that maketerm doesn't hard-code BasicSymbolic subtype
# and is consistent with BasicSymbolic arithmetic operations
@test isequal(SymbolicUtils.similarterm(a / b, *, [a / b, c]), (a / b) * c)
@test isequal(SymbolicUtils.similarterm(a * b, *, [0, c]), 0)
@test isequal(SymbolicUtils.similarterm(a^b, ^, [a * b, 3]), (a * b)^3)
@test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c)
@test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0)
@test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3)

# test that similarterm sets metadata correctly
# test that maketerm sets metadata correctly
metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1")
s = SymbolicUtils.similarterm(a^b, ^, [a * b, 3]; metadata = metadata)
s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata)
@test hasmetadata(s, Ctx1)
@test getmetadata(s, Ctx1) == "meta_1"
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ else
include("cse.jl")
include("interface.jl")
# Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed
# include("fuzz.jl")
include("fuzz.jl")
include("adjoints.jl")
end

0 comments on commit 33b274b

Please sign in to comment.