Skip to content

Commit

Permalink
remove old deferred implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Aug 9, 2024
1 parent 2fa7871 commit b54b5e4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 135 deletions.
55 changes: 28 additions & 27 deletions examples/jit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,31 @@ function get_trampoline(job)
return addr
end

import GPUCompiler: deferred_codegen_jobs
@generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# manual version of native_job because we have a function type
source = methodinstance(F, Base.to_tuple_type(tt), world)
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# XXX: do we actually require the Julia runtime?
# with jlruntime=false, we reach an unreachable.
params = TestCompilerParams()
config = CompilerConfig(target, params; kernel=false)
job = CompilerJob(source, config, world)
# XXX: invoking GPUCompiler from a generated function is not allowed!
# for things to work, we need to forward the correct world, at least.

addr = get_trampoline(job)
trampoline = pointer(addr)
id = Base.reinterpret(Int, trampoline)

deferred_codegen_jobs[id] = job

quote
ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
assume(ptr != C_NULL)
return ptr
end
end
# import GPUCompiler: deferred_codegen_jobs
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# # manual version of native_job because we have a function type
# source = methodinstance(F, Base.to_tuple_type(tt), world)
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# # XXX: do we actually require the Julia runtime?
# # with jlruntime=false, we reach an unreachable.
# params = TestCompilerParams()
# config = CompilerConfig(target, params; kernel=false)
# job = CompilerJob(source, config, world)
# # XXX: invoking GPUCompiler from a generated function is not allowed!
# # for things to work, we need to forward the correct world, at least.

# addr = get_trampoline(job)
# trampoline = pointer(addr)
# id = Base.reinterpret(Int, trampoline)

# deferred_codegen_jobs[id] = job

# quote
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
# assume(ptr != C_NULL)
# return ptr
# end
# end

@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
argtt = tt.parameters[1]
Expand Down Expand Up @@ -224,8 +224,9 @@ end
@inline function call_delayed(f::F, args...) where F
tt = Tuple{map(Core.Typeof, args)...}
rt = Core.Compiler.return_type(f, tt)
world = GPUCompiler.tls_world_age()
ptr = deferred_codegen(f, Val(tt), Val(world))
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
# But that will only be needed here, and in Enzyme...
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
abi_call(ptr, rt, tt, f, args...)
end

Expand Down
110 changes: 2 additions & 108 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,37 +43,6 @@ end

function var"gpuc.deferred" end

# old, deprecated mechanism slated for removal once Enzyme is updated to the new intrinsic
begin
# primitive mechanism for deferred compilation, for implementing CUDA dynamic parallelism.
# this could both be generalized (e.g. supporting actual function calls, instead of
# returning a function pointer), and be integrated with the nonrecursive codegen.
const deferred_codegen_jobs = Dict{Int, Any}()

# We make this function explicitly callable so that we can drive OrcJIT's
# lazy compilation from, while also enabling recursive compilation.
Base.@ccallable Ptr{Cvoid} function deferred_codegen(ptr::Ptr{Cvoid})
ptr
end

@generated function deferred_codegen(::Val{ft}, ::Val{tt}) where {ft,tt}
id = length(deferred_codegen_jobs) + 1
deferred_codegen_jobs[id] = (; ft, tt)
# don't bother looking up the method instance, as we'll do so again during codegen
# using the world age of the parent.
#
# this also works around an issue on <1.10, where we don't know the world age of
# generated functions so use the current world counter, which may be too new
# for the world we're compiling for.

quote
# TODO: add an edge to this method instance to support method redefinitions
ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Int,), $id)
end
end
end


## compiler entrypoint

export compile
Expand Down Expand Up @@ -198,7 +167,6 @@ const __llvm_initialized = Ref(false)

# gpuc.deferred is lowered to a gpuc.lookup foreigncall, so we need to extract the
# target method instance from the LLVM IR
# TODO: drive deferred compilation from the Julia IR instead
function find_base_object(val)
while true
if val isa ConstantExpr && (opcode(val) == LLVM.API.LLVMIntToPtr ||
Expand Down Expand Up @@ -263,80 +231,6 @@ const __llvm_initialized = Ref(false)
@compiler_assert isempty(uses(dyn_marker)) job
unsafe_delete!(ir, dyn_marker)
end
## old, deprecated implementation
jobs = Dict{CompilerJob, String}(job => entry_fn)
if toplevel && !only_entry && haskey(functions(ir), "deferred_codegen")
run_optimization_for_deferred = true
dyn_marker = functions(ir)["deferred_codegen"]

# iterative compilation (non-recursive)
changed = true
while changed
changed = false

# find deferred compiler
worklist = Dict{CompilerJob, Vector{LLVM.CallInst}}()
for use in uses(dyn_marker)
# decode the call
call = user(use)::LLVM.CallInst
id = convert(Int, first(operands(call)))

global deferred_codegen_jobs
dyn_val = deferred_codegen_jobs[id]

# get a job in the appopriate world
dyn_job = if dyn_val isa CompilerJob
# trust that the user knows what they're doing
dyn_val
else
ft, tt = dyn_val
dyn_src = methodinstance(ft, tt, tls_world_age())
CompilerJob(dyn_src, job.config)
end

push!(get!(worklist, dyn_job, LLVM.CallInst[]), call)
end

# compile and link
for dyn_job in keys(worklist)
# cached compilation
dyn_entry_fn = get!(jobs, dyn_job) do
dyn_ir, dyn_meta = codegen(:llvm, dyn_job; toplevel=false,
parent_job=job)
dyn_entry_fn = LLVM.name(dyn_meta.entry)
merge!(compiled, dyn_meta.compiled)
@assert context(dyn_ir) == context(ir)
link!(ir, dyn_ir)
changed = true
dyn_entry_fn
end
dyn_entry = functions(ir)[dyn_entry_fn]

# insert a pointer to the function everywhere the entry is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_job]
@dispose builder=IRBuilder() begin
position!(builder, call)
fptr = if LLVM.version() >= v"17"
T_ptr = LLVM.PointerType()
bitcast!(builder, dyn_entry, T_ptr)
elseif VERSION >= v"1.12.0-DEV.225"
T_ptr = LLVM.PointerType(LLVM.Int8Type())
bitcast!(builder, dyn_entry, T_ptr)
else
ptrtoint!(builder, dyn_entry, T_ptr)
end
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end
end

# all deferred compilations should have been resolved
@compiler_assert isempty(uses(dyn_marker)) job
unsafe_delete!(ir, dyn_marker)
end

if libraries
# load the runtime outside of a timing block (because it recurses into the compiler)
Expand Down Expand Up @@ -433,8 +327,8 @@ const __llvm_initialized = Ref(false)
# finish the module
#
# we want to finish the module after optimization, so we cannot do so
# during deferred code generation. instead, process the deferred jobs
# here.
# during deferred code generation. Instead, process the merged module
# from all the jobs here.
if toplevel
entry = finish_ir!(job, ir, entry)

Expand Down

0 comments on commit b54b5e4

Please sign in to comment.