Skip to content

Commit

Permalink
Extend invoke to accept CodeInstance (JuliaLang#56660)
Browse files Browse the repository at this point in the history
This is an alternative mechanism to JuliaLang#56650 that largely achieves the
same result, but by hooking into `invoke` rather than a generated
function. They are orthogonal mechanisms, and its possible we want both.
However, in JuliaLang#56650, both Jameson and Valentin were skeptical of the
generated function signature bottleneck. This PR is sort of a hybrid of
mechanism in JuliaLang#52964 and what I proposed in
JuliaLang#56650 (comment).

In particular, this PR:

1. Extends `invoke` to support a CodeInstance in place of its usual
`types` argument.

2. Adds a new `typeinf` optimized generic. The semantics of this
optimized generic allow the compiler to instead call a companion
`typeinf_edge` function, allowing a mid-inference interpreter switch
(like JuliaLang#52964), without being forced through a concrete signature
bottleneck. However, if calling `typeinf_edge` does not work (e.g.
because the compiler version is mismatched), this still has well defined
semantics, you just don't get inference support.

The additional benefit of the `typeinf` optimized generic is that it
lets custom cache owners tell the runtime how to "cure" code instances
that have lost their native code. Currently the runtime only knows how
to do that for `owner == nothing` CodeInstances (by re-running
inference). This extension is not implemented, but the idea is that the
runtime would be permitted to call the `typeinf` optimized generic on
the dead CodeInstance's `owner` and `def` fields to obtain a cured
CodeInstance (or a user-actionable error from the plugin).

This PR includes an implementation of `with_new_compiler` from JuliaLang#56650.
This PR includes just enough compiler support to make the compiler
optimize this to the same code that JuliaLang#56650 produced:

```
julia> @code_typed with_new_compiler(sin, 1.0)
CodeInfo(
1 ─      $(Expr(:foreigncall, :(:jl_get_tls_world_age), UInt64, svec(), 0, :(:ccall)))::UInt64
│   %2 =   builtin Core.getfield(args, 1)::Float64
│   %3 =    invoke sin(%2::Float64)::Float64
└──      return %3
) => Float64
```

However, the implementation here is extremely incomplete. I'm putting it
up only as a directional sketch to see if people prefer it over JuliaLang#56650.
If so, I would prepare a cleaned up version of this PR that has the
optimized generics as well as the curing support, but not the full
inference integration (which needs a fair bit more work).
  • Loading branch information
Keno authored Dec 3, 2024
1 parent f1b0b01 commit efa917e
Show file tree
Hide file tree
Showing 14 changed files with 242 additions and 15 deletions.
15 changes: 15 additions & 0 deletions Compiler/extras/CompilerDevTools/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.12.0-DEV"
manifest_format = "2.0"
project_hash = "84f495a1bf065c95f732a48af36dd0cd2cefb9d5"

[[deps.Compiler]]
path = "../.."
uuid = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
version = "0.0.2"

[[deps.CompilerDevTools]]
path = "."
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"
version = "0.0.0"
5 changes: 5 additions & 0 deletions Compiler/extras/CompilerDevTools/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
name = "CompilerDevTools"
uuid = "92b2d91f-d2bd-4c05-9214-4609ac33433f"

[deps]
Compiler = "807dbc54-b67e-4c79-8afb-eafe4df6f2e1"
56 changes: 56 additions & 0 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module CompilerDevTools

using Compiler
using Core.IR

struct SplitCacheOwner; end
struct SplitCacheInterp <: Compiler.AbstractInterpreter
world::UInt
inf_params::Compiler.InferenceParams
opt_params::Compiler.OptimizationParams
inf_cache::Vector{Compiler.InferenceResult}
function SplitCacheInterp(;
world::UInt = Base.get_world_counter(),
inf_params::Compiler.InferenceParams = Compiler.InferenceParams(),
opt_params::Compiler.OptimizationParams = Compiler.OptimizationParams(),
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
new(world, inf_params, opt_params, inf_cache)
end
end

Compiler.InferenceParams(interp::SplitCacheInterp) = interp.inf_params
Compiler.OptimizationParams(interp::SplitCacheInterp) = interp.opt_params
Compiler.get_inference_world(interp::SplitCacheInterp) = interp.world
Compiler.get_inference_cache(interp::SplitCacheInterp) = interp.inf_cache
Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()

import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)

@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
# TODO: This isn't quite right, we're just sketching things for now
interp = SplitCacheInterp(; world)
Compiler.typeinf_edge(interp, mi.def, mi.specTypes, Core.svec(), parent_frame, false, false)
end

# TODO: This needs special compiler support to properly case split for multiple
# method matches, etc.
@noinline function mi_for_tt(tt, world=Base.tls_world_age())
interp = SplitCacheInterp(; world)
match, _ = Compiler.findsup(tt, Compiler.method_table(interp))
Base.specialize_method(match)
end

function with_new_compiler(f, args...)
tt = Base.signature_type(f, typeof(args))
world = Base.tls_world_age()
new_compiler_ci = Core.OptimizedGenerics.CompilerPlugins.typeinf(
SplitCacheOwner(), mi_for_tt(tt), Compiler.SOURCE_MODE_ABI
)
invoke(f, new_compiler_ci, args...)
end

export with_new_compiler

end
48 changes: 39 additions & 9 deletions Compiler/src/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2218,16 +2218,46 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
ft = widenconst(ft′)
ft === Bottom && return Future(CallMeta(Bottom, Any, EFFECTS_THROWS, NoCallInfo()))
types = argtype_by_index(argtypes, 3)
if types isa Const && types.val isa Method
method = types.val::Method
types = method # argument value
lookupsig = method.sig # edge kind
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
nargtype = typeintersect(lookupsig, argtype)
nargtype === Bottom && return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
nargtype isa DataType || return Future(CallMeta(Any, Any, Effects(), NoCallInfo())) # other cases are not implemented below
if types isa Const && types.val isa Union{Method, CodeInstance}
method_or_ci = types.val
if isa(method_or_ci, CodeInstance)
our_world = sv.world.this
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
sig = method_or_ci.def.specTypes
exct = method_or_ci.exctype
if !hasintersect(argtype, sig)
return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
elseif !(argtype <: sig)
exct = Union{exct, TypeError}
end
callee_valid_range = WorldRange(method_or_ci.min_world, method_or_ci.max_world)
if !(our_world in callee_valid_range)
if our_world < first(callee_valid_range)
update_valid_age!(sv, WorldRange(first(sv.world.valid_worlds), first(callee_valid_range)-1))
else
update_valid_age!(sv, WorldRange(last(callee_valid_range)+1, last(sv.world.valid_worlds)))
end
return Future(CallMeta(Bottom, ErrorException, EFFECTS_THROWS, NoCallInfo()))
end
# TODO: When we add curing, we may want to assume this is nothrow
if (method_or_ci.owner === Nothing && method_ir_ci.def.def isa Method)
exct = Union{exct, ErrorException}
end
update_valid_age!(sv, callee_valid_range)
return Future(CallMeta(method_or_ci.rettype, exct, Effects(decode_effects(method_or_ci.ipo_purity_bits), nothrow=(exct===Bottom)),
InvokeCICallInfo(method_or_ci)))
else
method = method_or_ci::Method
types = method # argument value
lookupsig = method.sig # edge kind
argtype = argtypes_to_type(pushfirst!(argtype_tail(argtypes, 4), ft))
nargtype = typeintersect(lookupsig, argtype)
nargtype === Bottom && return Future(CallMeta(Bottom, TypeError, EFFECTS_THROWS, NoCallInfo()))
nargtype isa DataType || return Future(CallMeta(Any, Any, Effects(), NoCallInfo())) # other cases are not implemented below
# Fall through to generic invoke handling
end
else
widenconst(types) >: Method && return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
widenconst(types) >: Union{Method, CodeInstance} && return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3), false)
isexact || return Future(CallMeta(Any, Any, Effects(), NoCallInfo()))
unwrapped = unwrap_unionall(types)
Expand Down
2 changes: 1 addition & 1 deletion Compiler/src/abstractlattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ end
if isa(t, Const)
# don't consider mutable values useful constants
val = t.val
return isa(val, Symbol) || isa(val, Type) || isa(val, Method) || !ismutable(val)
return isa(val, Symbol) || isa(val, Type) || isa(val, Method) || isa(val, CodeInstance) || !ismutable(val)
end
isa(t, PartialTypeVar) && return false # this isn't forwardable
return is_const_prop_profitable_arg(widenlattice(𝕃), t)
Expand Down
10 changes: 9 additions & 1 deletion Compiler/src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,15 @@
# especially try to make sure any recursive and leaf functions have concrete signatures,
# since we won't be able to specialize & infer them at runtime

activate_codegen!() = ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
function activate_codegen!()
ccall(:jl_set_typeinf_func, Cvoid, (Any,), typeinf_ext_toplevel)
Core.eval(Compiler, quote
let typeinf_world_age = Base.tls_world_age()
@eval Core.OptimizedGenerics.CompilerPlugins.typeinf(::Nothing, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world($(Expr(:$, :typeinf_world_age)), typeinf_ext_toplevel, mi, Base.tls_world_age(), source_mode)
end
end)
end

function bootstrap!()
let time() = ccall(:jl_clock_now, Float64, ())
Expand Down
11 changes: 11 additions & 0 deletions Compiler/src/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,17 @@ end
add_edges_impl(edges::Vector{Any}, info::UnionSplitApplyCallInfo) =
for split in info.infos; add_edges!(edges, split); end

"""
info::InvokeCICallInfo
Represents a resolved call to `Core.invoke` targeting a `Core.CodeInstance`
"""
struct InvokeCICallInfo <: CallInfo
edge::CodeInstance
end
add_edges_impl(edges::Vector{Any}, info::InvokeCICallInfo) =
add_one_edge!(edges, info.edge)

"""
info::InvokeCallInfo
Expand Down
4 changes: 2 additions & 2 deletions Compiler/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ function count_const_size(@nospecialize(x), count_self::Bool = true)
# No definite size
(isa(x, GenericMemory) || isa(x, String) || isa(x, SimpleVector)) &&
return MAX_INLINE_CONST_SIZE + 1
if isa(x, Module) || isa(x, Method)
# We allow modules and methods, because we already assume they are externally
if isa(x, Module) || isa(x, Method) || isa(x, CodeInstance)
# We allow modules, methods and CodeInstance, because we already assume they are externally
# rooted, so we count their contents as 0 size.
return sizeof(Ptr{Cvoid})
end
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ New library features
* New `ltruncate`, `rtruncate` and `ctruncate` functions for truncating strings to text width, accounting for char widths ([#55351])
* `isless` (and thus `cmp`, sorting, etc.) is now supported for zero-dimensional `AbstractArray`s ([#55772])
* `invoke` now supports passing a Method instead of a type signature making this interface somewhat more flexible for certain uncommon use cases ([#56692]).
* `invoke` now supports passing a CodeInstance instead of a type, which can enable
certain compiler plugin workflows ([#56660]).

Standard library changes
------------------------
Expand Down
17 changes: 17 additions & 0 deletions base/docs/basedocs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,7 @@ applicable
"""
invoke(f, argtypes::Type, args...; kwargs...)
invoke(f, argtypes::Method, args...; kwargs...)
invoke(f, argtypes::CodeInstance, args...; kwargs...)
Invoke a method for the given generic function `f` matching the specified types `argtypes` on the
specified arguments `args` and passing the keyword arguments `kwargs`. The arguments `args` must
Expand All @@ -2056,6 +2057,22 @@ Note in particular that the specified `Method` may be entirely unreachable from
If the method is part of the ordinary method table, this call behaves similar
to `invoke(f, method.sig, args...)`.
!!! compat "Julia 1.12"
Passing a `Method` requires Julia 1.12.
# Passing a `CodeInstance` instead of a signature
The `argtypes` argument may be a `CodeInstance`, bypassing both method lookup and specialization.
The semantics of this invocation are similar to a function pointer call of the `CodeInstance`'s
`invoke` pointer. It is an error to invoke a `CodeInstance` with arguments that do not match its
parent MethodInstance or from a world age not included in the `min_world`/`max_world` range.
It is undefined behavior to invoke a CodeInstance whose behavior does not match the constraints
specified in its fields. For some code instances with `owner !== nothing` (i.e. those generated
by external compilers), it may be an error to invoke them after passing through precompilation.
This is an advanced interface intended for use with external compiler plugins.
!!! compat "Julia 1.12"
Passing a `CodeInstance` requires Julia 1.12.
# Examples
```jldoctest
julia> f(x::Real) = x^2;
Expand Down
27 changes: 27 additions & 0 deletions base/optimized_generics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,31 @@ module KeyValue
function get end
end

# Compiler-recognized intrinsics for compiler plugins
"""
module CompilerPlugins
Implements a pair of functions `typeinf`/`typeinf_edge`. When the optimizer sees
a call to `typeinf`, it has license to instead call `typeinf_edge`, supplying the
current inference stack in `parent_frame` (but otherwise supplying the arguments
to `typeinf`). typeinf_edge will return the `CodeInstance` that `typeinf` would
have returned at runtime. The optimizer may perform a non-IPO replacement of
the call to `typeinf` by the result of `typeinf_edge`. In addition, the IPO-safe
fields of the `CodeInstance` may be propagated in IPO mode.
"""
module CompilerPlugins
"""
typeinf(owner, mi, source_mode)::CodeInstance
Return a `CodeInstance` for the given `mi` whose valid results include at
the least current tls world and satisfies the requirements of `source_mode`.
"""
function typeinf end

"""
typeinf_edge(owner, mi, parent_frame, world, abi_mode)::CodeInstance
"""
function typeinf_edge end
end

end
22 changes: 22 additions & 0 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1587,6 +1587,28 @@ JL_CALLABLE(jl_f_invoke)
if (!jl_tuple1_isa(args[0], &args[2], nargs - 1, (jl_datatype_t*)m->sig))
jl_type_error("invoke: argument type error", argtypes, arg_tuple(args[0], &args[2], nargs - 1));
return jl_gf_invoke_by_method(m, args[0], &args[2], nargs - 1);
} else if (jl_is_code_instance(argtypes)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)args[1];
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (jl_tuple1_isa(args[0], &args[2], nargs - 2, (jl_datatype_t*)codeinst->def->specTypes)) {
jl_type_error("invoke: argument type error", codeinst->def->specTypes, arg_tuple(args[0], &args[2], nargs - 2));
}
if (jl_atomic_load_relaxed(&codeinst->min_world) > jl_current_task->world_age ||
jl_current_task->world_age > jl_atomic_load_relaxed(&codeinst->max_world)) {
jl_error("invoke: CodeInstance not valid for this world");
}
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
return invoke(args[0], &args[2], nargs - 2, codeinst);
} else {
if (codeinst->owner != jl_nothing || !jl_is_method(codeinst->def->def.value)) {
jl_error("Failed to invoke or compile external codeinst");
}
return jl_gf_invoke_by_method(codeinst->def->def.method, args[0], &args[2], nargs - 1);
}
}
if (!jl_is_tuple_type(jl_unwrap_unionall(argtypes)))
jl_type_error("invoke", (jl_value_t*)jl_anytuple_type_type, argtypes);
Expand Down
24 changes: 22 additions & 2 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,28 @@ static jl_value_t *do_invoke(jl_value_t **args, size_t nargs, interpreter_state
argv[i-1] = eval_value(args[i], s);
jl_value_t *c = args[0];
assert(jl_is_code_instance(c) || jl_is_method_instance(c));
jl_method_instance_t *meth = jl_is_method_instance(c) ? (jl_method_instance_t*)c : ((jl_code_instance_t*)c)->def;
jl_value_t *result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, meth);
jl_value_t *result = NULL;
if (jl_is_code_instance(c)) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)c;
assert(jl_atomic_load_relaxed(&codeinst->min_world) <= jl_current_task->world_age &&
jl_current_task->world_age <= jl_atomic_load_relaxed(&codeinst->max_world));
jl_callptr_t invoke = jl_atomic_load_acquire(&codeinst->invoke);
if (!invoke) {
jl_compile_codeinst(codeinst);
invoke = jl_atomic_load_acquire(&codeinst->invoke);
}
if (invoke) {
result = invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst);

} else {
if (codeinst->owner != jl_nothing) {
jl_error("Failed to invoke or compile external codeinst");
}
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, codeinst->def);
}
} else {
result = jl_invoke(argv[0], nargs == 2 ? NULL : &argv[1], nargs - 2, (jl_method_instance_t*)c);
}
JL_GC_POP();
return result;
}
Expand Down
14 changes: 14 additions & 0 deletions test/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8353,9 +8353,23 @@ end
@test eval(Expr(:toplevel, :(@define_call(f_macro_defined1)))) == 1
@test @define_call(f_macro_defined2) == 1

# `invoke` of `Method`
let m = which(+, (Int, Int))
@eval f56692(i) = invoke(+, $m, i, 4)
global g56692() = f56692(5) == 9 ? "true" : false
end
@test @inferred(f56692(3)) == 7
@test @inferred(g56692()) == "true"

# `invoke` of `CodeInstance`
f_invalidate_me() = return 1
f_invoke_me() = return f_invalidate_me()
@test f_invoke_me() == 1
const f_invoke_me_ci = Base.specialize_method(Base._which(Tuple{typeof(f_invoke_me)})).cache
f_call_me() = invoke(f_invoke_me, f_invoke_me_ci)
@test invoke(f_invoke_me, f_invoke_me_ci) == 1
@test f_call_me() == 1
@test_throws TypeError invoke(f_invoke_me, f_invoke_me_ci, 1)
f_invalidate_me() = 2
@test_throws ErrorException invoke(f_invoke_me, f_invoke_me_ci)
@test_throws ErrorException f_call_me()

0 comments on commit efa917e

Please sign in to comment.