diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index 6014a6b7c9dd0..fba442bd7b3f2 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -253,6 +253,7 @@ include("ordering.jl") using .Order include("coreir.jl") +include("invalidation.jl") # For OS specific stuff # We need to strcat things here, before strings are really defined diff --git a/base/invalidation.jl b/base/invalidation.jl new file mode 100644 index 0000000000000..66ddce49d0c47 --- /dev/null +++ b/base/invalidation.jl @@ -0,0 +1,112 @@ +# This file is a part of Julia. License is MIT: https://julialang.org/license + +struct GlobalRefIterator + mod::Module +end +IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown() +globalrefs(mod::Module) = GlobalRefIterator(mod) + +function iterate(gri::GlobalRefIterator, i = 1) + m = gri.mod + table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m) + i == length(table) && return nothing + b = table[i] + b === nothing && return iterate(gri, i+1) + return ((b::Core.Binding).globalref, i+1) +end + +const TYPE_TYPE_MT = Type.body.name.mt +const NONFUNCTION_MT = Core.MethodTable.name.mt +function foreach_module_mtable(visit, m::Module, world::UInt) + for gb in globalrefs(m) + binding = gb.binding + bpart = lookup_binding_partition(world, binding) + if is_some_const_binding(binding_kind(bpart)) + isdefined(bpart, :restriction) || continue + v = partition_restriction(bpart) + uw = unwrap_unionall(v) + name = gb.name + if isa(uw, DataType) + tn = uw.name + if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt) + # this is the original/primary binding for the type (name/wrapper) + mt = tn.mt + if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT + @assert mt.module === m + visit(mt) || return false + end + end + elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name + # this is the original/primary binding for the submodule + foreach_module_mtable(visit, v, world) || return false + elseif isa(v, Core.MethodTable) && v.module === m && v.name === name + # this is probably an external method table here, so let's + # assume so as there is no way to precisely distinguish them + visit(v) || return false + end + end + end + return true +end + +function foreach_reachable_mtable(visit, world::UInt) + visit(TYPE_TYPE_MT) || return + visit(NONFUNCTION_MT) || return + for mod in loaded_modules_array() + foreach_module_mtable(visit, mod, world) + end +end + +function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo) + found_any = false + labelchangemap = nothing + stmts = src.code + isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name + isgr(g) = false + for i = 1:length(stmts) + stmt = stmts[i] + if isgr(stmt) + found_any = true + continue + end + for ur in Compiler.userefs(stmt) + arg = ur[] + # If any of the GlobalRefs in this stmt match the one that + # we are about, we need to move out all GlobalRefs to preserve + # effect order, in case we later invalidate a different GR + if isa(arg, GlobalRef) + if isgr(arg) + @assert !isa(stmt, PhiNode) + found_any = true + break + end + end + end + end + return found_any +end + +function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt) + valid_in_valuepos = false + foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable + for method in MethodList(mt) + if isdefined(method, :source) + src = _uncompressed_ir(method) + old_stmts = src.code + if should_invalidate_code_for_globalref(gr, src) + for mi in specializations(method) + ci = mi.cache + while true + if ci.max_world > new_max_world + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) + end + isdefined(ci, :next) || break + ci = ci.next + end + end + end + end + end + return true + end +end diff --git a/src/gf.c b/src/gf.c index bbf065a4fac0d..d736705b71ef9 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1785,6 +1785,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo JL_UNLOCK(&replaced->def->def.method->writelock); } +JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world) +{ + invalidate_code_instance(replaced, max_world, 1); +} + static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) { jl_array_t *backedges = replaced_mi->backedges; if (backedges) { diff --git a/src/module.c b/src/module.c index 38f4b980a72fd..839f4deabfa16 100644 --- a/src/module.c +++ b/src/module.c @@ -1025,6 +1025,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var jl_gc_wb(bpart, val); } +void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world) +{ + static jl_value_t *invalidate_code_for_globalref = NULL; + if (invalidate_code_for_globalref == NULL && jl_base_module != NULL) + invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!")); + if (!invalidate_code_for_globalref) + jl_error("Binding invalidation is not permitted during bootstrap."); + if (jl_generating_output()) + jl_error("Binding invalidation is not permitted during image generation."); + jl_value_t *boxed_world = jl_box_ulong(new_world); + JL_GC_PUSH1(&boxed_world); + jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world); + JL_GC_POP(); +} + extern jl_mutex_t world_counter_lock; JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) { @@ -1039,9 +1054,16 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) JL_LOCK(&world_counter_lock); jl_task_t *ct = jl_current_task; + size_t last_world = ct->world_age; size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter); - // TODO: Trigger invalidation here - (void)ct; + JL_TRY { + ct->world_age = jl_typeinf_world; + jl_invalidate_binding_refs(gr, new_max_world); + } JL_CATCH { + JL_UNLOCK(&world_counter_lock); + jl_rethrow(); + } + ct->world_age = last_world; jl_atomic_store_release(&bpart->max_world, new_max_world); jl_atomic_store_release(&jl_world_counter, new_max_world + 1); JL_UNLOCK(&world_counter_lock); @@ -1327,6 +1349,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod) jl_array_ptr_1d_push(jl_module_init_order, mod); } +JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m) +{ + return jl_atomic_load_relaxed(&m->bindings); +} + JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod) { if (!jl_generating_output() || jl_options.incremental) { diff --git a/test/rebinding.jl b/test/rebinding.jl index c93c34be7a75c..ad0ad1fc1643d 100644 --- a/test/rebinding.jl +++ b/test/rebinding.jl @@ -33,4 +33,11 @@ module Rebinding @test Base.@world(Foo, defined_world_age) == typeof(x) @test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x) @test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x) + + # Test invalidation (const -> undefined) + const delete_me = 1 + f_return_delete_me() = delete_me + @test f_return_delete_me() == 1 + Base.delete_binding(@__MODULE__, :delete_me) + @test_throws UndefVarError f_return_delete_me() end