Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic wrapping of specialized methods, solve many ambiguities #813

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
Expand Down Expand Up @@ -44,6 +45,7 @@ DomainSets = "0.5"
Groebner = "0.1, 0.2"
IfElse = "0.1"
LaTeXStrings = "1.3"
LambertW = "0.4.5"
Latexify = "0.11, 0.12, 0.13, 0.14, 0.15"
MacroTools = "0.5"
NaNMath = "0.3, 1"
Expand All @@ -58,7 +60,6 @@ SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "1.1"
SymbolicUtils = "1.0.1"
TreeViews = "0.3"
LambertW = "0.4.5"
julia = "1.6"

[extras]
Expand Down
4 changes: 4 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ using MacroTools
import MacroTools: splitdef, combinedef, postwalk, striplines
include("wrapper-types.jl")

import ExprTools
export specialize_methods
include("specialize_methods.jl")

include("num.jl")
include("complex.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@ function __init__()
end

end # SymPy

specialize_methods((LinearAlgebra,))
end
45 changes: 45 additions & 0 deletions src/specialize_methods.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
specialize_methods(func, abstract_arg_types, inner_func, mods=nothing)

For any method that implements `func` with signature
fitting `abstract_arg_types`, define methods for corresponding
symbolic types that pass all arguments to `inner_func`.
`mods` is an optional list of modules to look for methods in.
"""
function specialize_methods(func, abstract_arg_types, inner_func, mods=nothing)
ms = isnothing(mods) ? methods(func, abstract_arg_types) : methods(func, abstract_arg_types, mods)
for m in ms
mod = m.module
if mod != @__MODULE__ # do not overwrite method definitions from within this module itself, else: precompilation warnings
sig = ExprTools.signature(m; extra_hygiene=true)
fname = sig[:name]
args = sig[:args]
kwargs = get(sig, :kwargs, Symbol[])
whereparams = get(sig, :whereparams, Symbol[])
args_names = expr_argname.(args)
kwargs_names = expr_kwargname.(kwargs)
body = :($(inner_func)($(args_names...); $(kwargs_names...)))
Base.eval(
@__MODULE__,
wrap_func_expr(
mod, fname, args, kwargs, args_names, kwargs_names, whereparams, body;
abstract_arg_types
)
)
end#of `mod != @__MODULE__`
end#of `for m in ms`
end

"""
specialize_methods(mods=nothing)

Define specialized methods accepting symbolic types for the following functions and
signatures found in modules `mods` via `methods(...)`:

* `Base.:(*)` for arguments of `(AbstractMatrix, AbstractVector)` to redirect to `_matvec`.
* `Base.:(*)` for arguments of `(AbstractMatrix, AbstractMetrax)` to redirect to `_matmul`.
"""
function specialize_methods(mods=nothing)
specialize_methods(Base.:(*), (AbstractMatrix, AbstractVector), _matvec, mods)
specialize_methods(Base.:(*), (AbstractMatrix, AbstractMatrix), _matmul, mods)
end
92 changes: 58 additions & 34 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,58 +57,82 @@ function wraps_type end
has_symwrapper(::Type) = false
is_wrapper_type(::Type) = false

# helper function to extract keyword argument names from expressions
function expr_kwargname(kwarg)
if kwarg isa Expr && kwarg.head == :kw
kwarg.args[1]
elseif kwarg isa Expr && kwarg.head == :(...)
kwarg.args[1]
else
kwarg
end
end

# helper function to extract argument names from expressions
function expr_argname(arg)
if arg isa Expr && (arg.head == :(::) || arg.head == :(...))
arg.args[1]
elseif arg isa Expr
error("$arg not supported as an argument")
else
arg
end
end

function wrap_func_expr(mod, expr)
@assert expr.head == :function || (expr.head == :(=) &&
expr.args[1] isa Expr &&
expr.args[1].head == :call)

def = splitdef(expr)

sig = expr.args[1]
body = def[:body]

fname = def[:name]
args = get(def, :args, [])
kwargs = get(def, :kwargs, [])
args_names = expr_argname.(args)
kwargs_names = expr_kwargname.(kwargs)

wrap_func_expr(mod, fname, args, kwargs, args_names, kwargs_names, Symbol[], body)
end

impl_name = Symbol(fname,"_", hash(string(args)*string(kwargs)))

function kwargname(kwarg)
if kwarg isa Expr && kwarg.head == :kw
kwarg.args[1]
elseif kwarg isa Expr && kwarg.head == :(...)
kwarg.args[1]
else
kwarg
end
end

function argname(arg)
if arg isa Expr && (arg.head == :(::) || arg.head == :(...))
arg.args[1]
elseif arg isa Expr
error("$arg not supported as an argument")
else
arg
end
end

names = vcat(argname.(args), kwargname.(kwargs))

function type_options(arg)
function wrap_func_expr(
mod, fname, args, kwargs, args_names, kwargs_names, whereparams, body;
abstract_arg_types=nothing
)
names = vcat(args_names, kwargs_names)

function type_options(wparams, arg, arg_ind)
pmod = parentmodule(mod)
atype = isnothing(abstract_arg_types) ? Any : abstract_arg_types[arg_ind]
if arg isa Expr && arg.head == :(::)
T = Base.eval(mod, arg.args[2])
T = Base.eval(mod, quote
let $(Symbol(pmod)) = $(pmod); # make name of parent module available in eval scope
#=
NOTE
`typeintersect` is important here for consecutive calls to `specialize_methods`
with conceptually different super types.
E.g.: Consider we first specialize `*(::AbstractMatrix, ::AbstractVector)` to
redirect to `_matvec`, and then `*(::AbstractMatrix, ::AbstractMatrix)` to
redirect to `_matmul`. If we encounter some existing method for `*` which accepts
an `AbstractMatrix` and an `VecOrMat` (type union), then we accidentally redirect
a matrix-vector-product to `_matmul` without `typeintersect`.
=#
typeintersect($(atype), $(arg.args[2]) where {$(wparams...)})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if methods are added by a third party package after @warpped has executed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then specialize_methods would have to be called again. This sort of load-order dependence is also observed in Hyperspecialize, and the only way around it seems the Cassette approach.
Actually, the code in this pull request is not too different from what Hyperspecialize can do, except that some of the Matrix or Vector types occurring in method definitions can be super special, so I had to either traverse the whole type tree and gather every possible subtype for Hyperspecialize's @concretize or inspect the method signatures (what is done now).

However, there are a few things we can do with respect to the points in your other comment:

All relevant types in Base/stdlib should be made to work with the wrapped types defined in this package

specialize_methods accepts a list of modules with argument mods. When calling it in __init__() we can restrict ourselves to LinearAlgebra and other relevant standard modules.

Taking this further if a new package wants to use its array or number types with Symbolics, wherever necessary, it should return the specializations manually or using a different macro.

What happens if methods are added by a third party package after @Warpped has executed?

Whenever a user has custom array types defined in a module, or there is a package with custom array types, specialize_methods can be called with mods set accordingly to automatically “widen”/wrap the relevant definitions to have a generic fallback. This way, the developer does not have to know that Base.:(*)(::AbstractMatrix, ::AbstractVector) should redirect to _matmul most of the time.
If, for example, I get a StaticMatrix somewhere in my computations, then I cannot use it with symbolic arrays:

using Symbolics
using StaticArrays
A = @SMatrix(rand(2,2)) # imagine this is obtained somewhere else
@variables (x::Real)[1:2]
A * x # ERROR: MethodError: *(::SMatrix{2, 2, Float64, 4}, ::Symbolics.Arr{Num, 1}) is ambiguous.

But after
Symbolics.specialize_methods(StaticArrays)
it just works.

For some widely used packages (such as StaticArrays) we could even do this ourselves with @require. Conversely, a package author could opt to place Symbolics.specialize_methods(@__MODULE__) in their __init__, though I have to think about scopes and export statements a bit more.

Going forward, I can try to benchmark the impact of specialize_methods in __init__ (for different values of mods) and maybe devise some tests.

end
end)
has_symwrapper(T) ? (T, :(SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) :
(T,:(SymbolicUtils.Symbolic{<:$T}))
(T, :(SymbolicUtils.Symbolic{<:$T}))
elseif arg isa Expr && arg.head == :(...)
Ts = type_options(arg.args[1])
Ts = type_options(wparams, arg.args[1], arg_ind)
map(x->Vararg{x},Ts)
else
(Any,)
end
end

types = map(type_options, args)
types = [type_options(whereparams, arg, arg_ind) for (arg_ind, arg)=enumerate(args)]

impl_name = Symbol(fname,"_", hash(string(args)*string(kwargs)*string(types)))

impl = :(function $impl_name($(names...))
$body
Expand Down Expand Up @@ -139,9 +163,9 @@ function wrap_func_expr(mod, expr)
quote
$impl
$(methods...)
end |> esc
end
end

macro wrapped(expr)
wrap_func_expr(__module__, expr)
esc(wrap_func_expr(__module__, expr))
end