Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optimization callbacks that fire on a marker function #621

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 36 additions & 22 deletions examples/jit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,30 +116,38 @@ function get_trampoline(job)
return addr
end

import GPUCompiler: deferred_codegen_jobs
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# manual version of native_job because we have a function type
source = methodinstance(F, Base.to_tuple_type(tt), world)
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# XXX: do we actually require the Julia runtime?
# with jlruntime=false, we reach an unreachable.
params = TestCompilerParams()
config = CompilerConfig(target, params; kernel=false)
job = CompilerJob(source, config, world)
# XXX: invoking GPUCompiler from a generated function is not allowed!
# for things to work, we need to forward the correct world, at least.
const runtime_cache = Dict{Any, Ptr{Cvoid}}()

function compiler(job)
JuliaContext() do _
ir, meta = GPUCompiler.compile(:llvm, job; validate=false)
# So 1. serialize the module
buf = convert(MemoryBuffer, ir)
buf, LLVM.name(meta.entry)
end
end

addr = get_trampoline(job)
trampoline = pointer(addr)
id = Base.reinterpret(Int, trampoline)
function linker(_, (buf, entry_fn))
compiler = jit[]
lljit = compiler.jit
jd = JITDylib(lljit)

deferred_codegen_jobs[id] = job
# 2. deserialize and wrap by a ThreadSafeModule
ThreadSafeContext() do ts_ctx
tsm = context!(context(ts_ctx)) do
mod = parse(LLVM.Module, buf)
ThreadSafeModule(mod)
end

quote
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
assume(ptr != C_NULL)
return ptr
LLVM.add!(lljit, jd, tsm)
end
addr = LLVM.lookup(lljit, entry_fn)
pointer(addr)
end

function GPUCompiler.var"gpuc.deferred.with"(config::GPUCompiler.CompilerConfig{<:NativeCompilerTarget}, f::F, args...) where F
source = methodinstance(F, Base.to_tuple_type(typeof(args)))
GPUCompiler.cached_compilation(runtime_cache, source, config, compiler, linker)::Ptr{Cvoid}
end

@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
Expand Down Expand Up @@ -224,8 +232,14 @@ end
@inline function call_delayed(f::F, args...) where F
tt = Tuple{map(Core.Typeof, args)...}
rt = Core.Compiler.return_type(f, tt)
world = GPUCompiler.tls_world_age()
ptr = deferred_codegen(f, Val(tt), Val(world))
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
# But that will only be needed here, and in Enzyme...
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# XXX: do we actually require the Julia runtime?
# with jlruntime=false, we reach an unreachable.
params = TestCompilerParams()
config = CompilerConfig(target, params; kernel=false)
ptr = GPUCompiler.var"gpuc.deferred.with"(config, f, args...)
abi_call(ptr, rt, tt, f, args...)
end

Expand Down
178 changes: 82 additions & 96 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,22 @@ function JuliaContext(f; kwargs...)
end


## deferred compilation

"""
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}

As if we were to call `f(args...)` but instead we are
putting down a marker and return a function pointer to later
call.
"""
function var"gpuc.deferred" end

"""
var"gpuc.deferred,with"(config::CompilerConfig, f, args...)::Ptr{Cvoid}
"""
function var"gpuc.deferred.with" end

## compiler entrypoint

export compile
Expand Down Expand Up @@ -127,33 +143,6 @@ function codegen(output::Symbol, @nospecialize(job::CompilerJob); toplevel::Bool
error("Unknown compilation output $output")
end

# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = (; ft, tt)
# don't bother looking up the method instance, as we'll do so again during codegen
# using the world age of the parent.
#
# this also works around an issue on <1.10, where we don't know the world age of
# generated functions so use the current world counter, which may be too new
# for the world we're compiling for.

quote
# TODO: add an edge to this method instance to support method redefinitions
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
end
end

const __llvm_initialized = Ref(false)

@locked function emit_llvm(@nospecialize(job::CompilerJob); toplevel::Bool,
Expand Down Expand Up @@ -183,73 +172,70 @@ const __llvm_initialized = Ref(false)
entry = finish_module!(job, ir, entry)

# deferred code generation
has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
jobs = Dict{CompilerJob, String}(job => entry_fn)
if has_deferred_jobs
dyn_marker = functions(ir)["deferred_codegen"]

# iterative compilation (non-recursive)
changed = true
while changed
changed = false

# find deferred compiler
# TODO: recover this information earlier, from the Julia IR
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
id = convert(Int, first(operands(call)))

global deferred_codegen_jobs
dyn_val = deferred_codegen_jobs[id]

# get a job in the appopriate world
dyn_job = if dyn_val isa CompilerJob
# trust that the user knows what they're doing
dyn_val
run_optimization_for_deferred = false
if haskey(functions(ir), "gpuc.lookup")
run_optimization_for_deferred = true
dyn_marker = functions(ir)["gpuc.lookup"]

# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
# target method instance from the LLVM IR
function find_base_object(val)
while true
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
opcode(val) == LLVM.API.LLVMBitCast ||
opcode(val) == LLVM.API.LLVMAddrSpaceCast)
val = first(operands(val))
elseif val isa LLVM.IntToPtrInst ||
val isa LLVM.BitCastInst ||
val isa LLVM.AddrSpaceCastInst
val = first(operands(val))
elseif val isa LLVM.LoadInst
# In 1.11+ we no longer embed integer constants directly.
gv = first(operands(val))
if gv isa LLVM.GlobalValue
val = LLVM.initializer(gv)
continue
end
break
else
ft, tt = dyn_val
dyn_src = methodinstance(ft, tt, tls_world_age())
CompilerJob(dyn_src, job.config)
break
end

push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
end
return val
end

# compile and link
for dyn_job in keys(worklist)
# cached compilation
dyn_entry_fn = get!(jobs, dyn_job) do
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
parent_job=job)
dyn_entry_fn = LLVM.name(dyn_meta.entry)
merge!(compiled, dyn_meta.compiled)
@assert context(dyn_ir) == context(ir)
link!(ir, dyn_ir)
changed = true
dyn_entry_fn
end
dyn_entry = functions(ir)[dyn_entry_fn]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_job]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_entry, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_entry, T_ptr)
else
ptrtoint!(builder, dyn_entry, T_ptr)
end
replace_uses!(call, fptr)
worklist = Dict{Any, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
dyn_mi_inst = find_base_object(operands(call)[1])
@compiler_assert isa(dyn_mi_inst, LLVM.ConstantInt) job
dyn_mi = Base.unsafe_pointer_to_objref(
convert(Ptr{Cvoid}, convert(Int, dyn_mi_inst)))
push!(get!(worklist, dyn_mi, LLVM.CallInst[]), call)
end

for dyn_mi in keys(worklist)
dyn_fn_name = compiled[dyn_mi].specfunc
dyn_fn = functions(ir)[dyn_fn_name]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_mi]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_fn, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_fn, T_ptr)
else
ptrtoint!(builder, dyn_fn, T_ptr)
end
unsafe_delete!(LLVM.parent(call), call)
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end

Expand Down Expand Up @@ -285,7 +271,7 @@ const __llvm_initialized = Ref(false)
# global variables. this makes sure that the optimizer can, e.g.,
# rewrite function signatures.
if toplevel
preserved_gvs = collect(values(jobs))
preserved_gvs = [entry_fn]
for gvar in globals(ir)
if linkage(gvar) == LLVM.API.LLVMExternalLinkage
push!(preserved_gvs, LLVM.name(gvar))
Expand Down Expand Up @@ -317,7 +303,7 @@ const __llvm_initialized = Ref(false)
# deferred codegen has some special optimization requirements,
# which also need to happen _after_ regular optimization.
# XXX: make these part of the optimizer pipeline?
if has_deferred_jobs
if run_optimization_for_deferred
@dispose pb=NewPMPassBuilder() begin
add!(pb, NewPMFunctionPassManager()) do fpm
add!(fpm, InstCombinePass())
Expand Down Expand Up @@ -353,15 +339,15 @@ const __llvm_initialized = Ref(false)
# finish the module
#
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
# during deferred code generation. Instead, process the merged module
# from all the jobs here.
if toplevel
entry = finish_ir!(job, ir, entry)

for (job′, fn′) in jobs
job′ == job && continue
finish_ir!(job′, ir, functions(ir)[fn′])
end
# for (job′, fn′) in jobs
# job′ == job && continue
# finish_ir!(job′, ir, functions(ir)[fn′])
# end
end

# replace non-entry function definitions with a declaration
Expand Down
11 changes: 11 additions & 0 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ function irgen(@nospecialize(job::CompilerJob))
compiled[job.source] =
(; compiled[job.source].ci, func, specfunc)

# Earlier we sanitize global names, this invalidates the
# func, specfunc names safed in compiled. Update the names now,
# such that when when use the compiled mappings to lookup the
# llvm function for a methodinstance (deferred codegen) we have
# valid targets.
for mi in keys(compiled)
mi == job.source && continue
ci, func, specfunc = compiled[mi]
compiled[mi] = (; ci, func=safe_name(func), specfunc=safe_name(specfunc))
end

# minimal required optimization
@timeit_debug to "rewrite" begin
if job.config.kernel && needs_byval(job)
Expand Down
Loading