From ae19646a6ee0fc450884da0f7764383cb26f12d3 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Fri, 22 Nov 2024 07:22:52 +0000 Subject: [PATCH] Add basic code for binding partition revalidation This adds the binding partition revalidation code from #54654. This is the last piece of that PR that hasn't been merged yet - however the TODO in that PR still stands for future work. This PR itself adds a callback that gets triggered by deleting a binding. It will then walk all code in the system and invalidate code instances of Methods whose lowered source referenced the given global. This walk is quite slow. Future work will add backedges and optimizations to make this faster, but the basic functionality should be in place with this PR. --- base/Base_compiler.jl | 1 + base/invalidation.jl | 112 ++++++++++++++++++++++++++++++++++++++++++ src/gf.c | 5 ++ src/module.c | 31 +++++++++++- test/rebinding.jl | 7 +++ 5 files changed, 154 insertions(+), 2 deletions(-) create mode 100644 base/invalidation.jl diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index 6014a6b7c9dd0..fba442bd7b3f2 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -253,6 +253,7 @@ include("ordering.jl") using .Order include("coreir.jl") +include("invalidation.jl") # For OS specific stuff # We need to strcat things here, before strings are really defined diff --git a/base/invalidation.jl b/base/invalidation.jl new file mode 100644 index 0000000000000..66ddce49d0c47 --- /dev/null +++ b/base/invalidation.jl @@ -0,0 +1,112 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +struct GlobalRefIterator + mod::Module +end +IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown() +globalrefs(mod::Module) = GlobalRefIterator(mod) + +function iterate(gri::GlobalRefIterator, i = 1) + m = gri.mod + table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m) + i == length(table) && return nothing + b = table[i] + b === nothing && return iterate(gri, i+1) + return ((b::Core.Binding).globalref, i+1) +end + +const TYPE_TYPE_MT = Type.body.name.mt +const NONFUNCTION_MT = Core.MethodTable.name.mt +function foreach_module_mtable(visit, m::Module, world::UInt) + for gb in globalrefs(m) + binding = gb.binding + bpart = lookup_binding_partition(world, binding) + if is_some_const_binding(binding_kind(bpart)) + isdefined(bpart, :restriction) || continue + v = partition_restriction(bpart) + uw = unwrap_unionall(v) + name = gb.name + if isa(uw, DataType) + tn = uw.name + if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt) + # this is the original/primary binding for the type (name/wrapper) + mt = tn.mt + if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT + @assert mt.module === m + visit(mt) || return false + end + end + elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name + # this is the original/primary binding for the submodule + foreach_module_mtable(visit, v, world) || return false + elseif isa(v, Core.MethodTable) && v.module === m && v.name === name + # this is probably an external method table here, so let's + # assume so as there is no way to precisely distinguish them + visit(v) || return false + end + end + end + return true +end + +function foreach_reachable_mtable(visit, world::UInt) + visit(TYPE_TYPE_MT) || return + visit(NONFUNCTION_MT) || return + for mod in loaded_modules_array() + foreach_module_mtable(visit, mod, world) + end +end + +function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo) + found_any = false + labelchangemap = nothing + stmts = src.code + isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name + isgr(g) = false + for i = 1:length(stmts) + stmt = stmts[i] + if isgr(stmt) + found_any = true + continue + end + for ur in Compiler.userefs(stmt) + arg = ur[] + # If any of the GlobalRefs in this stmt match the one that + # we are about, we need to move out all GlobalRefs to preserve + # effect order, in case we later invalidate a different GR + if isa(arg, GlobalRef) + if isgr(arg) + @assert !isa(stmt, PhiNode) + found_any = true + break + end + end + end + end + return found_any +end + +function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt) + valid_in_valuepos = false + foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable + for method in MethodList(mt) + if isdefined(method, :source) + src = _uncompressed_ir(method) + old_stmts = src.code + if should_invalidate_code_for_globalref(gr, src) + for mi in specializations(method) + ci = mi.cache + while true + if ci.max_world > new_max_world + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) + end + isdefined(ci, :next) || break + ci = ci.next + end + end + end + end + end + return true + end +end diff --git a/src/gf.c b/src/gf.c index bbf065a4fac0d..d736705b71ef9 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1785,6 +1785,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo JL_UNLOCK(&replaced->def->def.method->writelock); } +JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world) +{ + invalidate_code_instance(replaced, max_world, 1); +} + static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) { jl_array_t *backedges = replaced_mi->backedges; if (backedges) { diff --git a/src/module.c b/src/module.c index 38f4b980a72fd..839f4deabfa16 100644 --- a/src/module.c +++ b/src/module.c @@ -1025,6 +1025,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var jl_gc_wb(bpart, val); } +void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world) +{ + static jl_value_t *invalidate_code_for_globalref = NULL; + if (invalidate_code_for_globalref == NULL && jl_base_module != NULL) + invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!")); + if (!invalidate_code_for_globalref) + jl_error("Binding invalidation is not permitted during bootstrap."); + if (jl_generating_output()) + jl_error("Binding invalidation is not permitted during image generation."); + jl_value_t *boxed_world = jl_box_ulong(new_world); + JL_GC_PUSH1(&boxed_world); + jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world); + JL_GC_POP(); +} + extern jl_mutex_t world_counter_lock; JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) { @@ -1039,9 +1054,16 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) JL_LOCK(&world_counter_lock); jl_task_t *ct = jl_current_task; + size_t last_world = ct->world_age; size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter); - // TODO: Trigger invalidation here - (void)ct; + JL_TRY { + ct->world_age = jl_typeinf_world; + jl_invalidate_binding_refs(gr, new_max_world); + } JL_CATCH { + JL_UNLOCK(&world_counter_lock); + jl_rethrow(); + } + ct->world_age = last_world; jl_atomic_store_release(&bpart->max_world, new_max_world); jl_atomic_store_release(&jl_world_counter, new_max_world + 1); JL_UNLOCK(&world_counter_lock); @@ -1327,6 +1349,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod) jl_array_ptr_1d_push(jl_module_init_order, mod); } +JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m) +{ + return jl_atomic_load_relaxed(&m->bindings); +} + JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod) { if (!jl_generating_output() || jl_options.incremental) { diff --git a/test/rebinding.jl b/test/rebinding.jl index c93c34be7a75c..ad0ad1fc1643d 100644 --- a/test/rebinding.jl +++ b/test/rebinding.jl @@ -33,4 +33,11 @@ module Rebinding @test Base.@world(Foo, defined_world_age) == typeof(x) @test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x) @test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x) + + # Test invalidation (const -> undefined) + const delete_me = 1 + f_return_delete_me() = delete_me + @test f_return_delete_me() == 1 + Base.delete_binding(@__MODULE__, :delete_me) + @test_throws UndefVarError f_return_delete_me() end