Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add basic code for binding partition revalidation #56649

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ include("ordering.jl")
using .Order

include("coreir.jl")
include("invalidation.jl")

# For OS specific stuff
# We need to strcat things here, before strings are really defined
Expand Down
112 changes: 112 additions & 0 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct GlobalRefIterator
mod::Module
end
IteratorSize(::Type{GlobalRefIterator}) = SizeUnknown()
globalrefs(mod::Module) = GlobalRefIterator(mod)

function iterate(gri::GlobalRefIterator, i = 1)
m = gri.mod
table = ccall(:jl_module_get_bindings, Ref{SimpleVector}, (Any,), m)
i == length(table) && return nothing
b = table[i]
b === nothing && return iterate(gri, i+1)
return ((b::Core.Binding).globalref, i+1)
end

const TYPE_TYPE_MT = Type.body.name.mt
const NONFUNCTION_MT = Core.MethodTable.name.mt
function foreach_module_mtable(visit, m::Module, world::UInt)
for gb in globalrefs(m)
binding = gb.binding
bpart = lookup_binding_partition(world, binding)
if is_some_const_binding(binding_kind(bpart))
isdefined(bpart, :restriction) || continue
v = partition_restriction(bpart)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to CI, apparently this function is unsound currently? (may return some invalid reference, possibly a NULL)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, there's a bit of an awkward state where you have a something declared constant but the value is undefined. Originally that was what the isdefined check on the line before this was for, but on current master, it doesn't work, because restriction is the magic pointer-value union. I'm planning to fix it up by making that case its own binding_kind that's considered a guard kind, not a constant kind, so that constant kinds always have values.

uw = unwrap_unionall(v)
name = gb.name
if isa(uw, DataType)
tn = uw.name
if tn.module === m && tn.name === name && tn.wrapper === v && isdefined(tn, :mt)
# this is the original/primary binding for the type (name/wrapper)
mt = tn.mt
if mt !== nothing && mt !== TYPE_TYPE_MT && mt !== NONFUNCTION_MT
@assert mt.module === m
visit(mt) || return false
end
end
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
# this is the original/primary binding for the submodule
foreach_module_mtable(visit, v, world) || return false
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
# this is probably an external method table here, so let's
# assume so as there is no way to precisely distinguish them
visit(v) || return false
end
end
end
return true
end

function foreach_reachable_mtable(visit, world::UInt)
visit(TYPE_TYPE_MT) || return
visit(NONFUNCTION_MT) || return
for mod in loaded_modules_array()
foreach_module_mtable(visit, mod, world)
end
end

function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
found_any = false
labelchangemap = nothing
stmts = src.code
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
isgr(g) = false
for i = 1:length(stmts)
stmt = stmts[i]
if isgr(stmt)
found_any = true
continue
end
for ur in Compiler.userefs(stmt)
arg = ur[]
# If any of the GlobalRefs in this stmt match the one that
# we are about, we need to move out all GlobalRefs to preserve
# effect order, in case we later invalidate a different GR
if isa(arg, GlobalRef)
if isgr(arg)
@assert !isa(stmt, PhiNode)
found_any = true
break
end
end
end
end
return found_any
end

function invalidate_code_for_globalref!(gr::GlobalRef, new_max_world::UInt)
valid_in_valuepos = false
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
for method in MethodList(mt)
if isdefined(method, :source)
src = _uncompressed_ir(method)
old_stmts = src.code
if should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
ci = mi.cache
while true
if ci.max_world > new_max_world
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end
end
end
return true
end
end
5 changes: 5 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1785,6 +1785,11 @@ static void invalidate_code_instance(jl_code_instance_t *replaced, size_t max_wo
JL_UNLOCK(&replaced->def->def.method->writelock);
}

JL_DLLEXPORT void jl_invalidate_code_instance(jl_code_instance_t *replaced, size_t max_world)
{
invalidate_code_instance(replaced, max_world, 1);
}

static void _invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_world, int depth) {
jl_array_t *backedges = replaced_mi->backedges;
if (backedges) {
Expand Down
31 changes: 29 additions & 2 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,21 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var
jl_gc_wb(bpart, val);
}

void jl_invalidate_binding_refs(jl_globalref_t *ref, size_t new_world)
{
static jl_value_t *invalidate_code_for_globalref = NULL;
if (invalidate_code_for_globalref == NULL && jl_base_module != NULL)
invalidate_code_for_globalref = jl_get_global(jl_base_module, jl_symbol("invalidate_code_for_globalref!"));
if (!invalidate_code_for_globalref)
jl_error("Binding invalidation is not permitted during bootstrap.");
if (jl_generating_output())
jl_error("Binding invalidation is not permitted during image generation.");
jl_value_t *boxed_world = jl_box_ulong(new_world);
JL_GC_PUSH1(&boxed_world);
jl_call2((jl_function_t*)invalidate_code_for_globalref, (jl_value_t*)ref, boxed_world);
JL_GC_POP();
}

extern jl_mutex_t world_counter_lock;
JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
{
Expand All @@ -1039,9 +1054,16 @@ JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)

JL_LOCK(&world_counter_lock);
jl_task_t *ct = jl_current_task;
size_t last_world = ct->world_age;
size_t new_max_world = jl_atomic_load_acquire(&jl_world_counter);
// TODO: Trigger invalidation here
(void)ct;
JL_TRY {
ct->world_age = jl_typeinf_world;
jl_invalidate_binding_refs(gr, new_max_world);
} JL_CATCH {
JL_UNLOCK(&world_counter_lock);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are forbidden from unlocking a lock that you didn't acquire in the same scope

Suggested change
JL_UNLOCK(&world_counter_lock);

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the try/catch will automatically unlock it?

jl_rethrow();
}
ct->world_age = last_world;
jl_atomic_store_release(&bpart->max_world, new_max_world);
jl_atomic_store_release(&jl_world_counter, new_max_world + 1);
JL_UNLOCK(&world_counter_lock);
Expand Down Expand Up @@ -1327,6 +1349,11 @@ JL_DLLEXPORT void jl_add_to_module_init_list(jl_value_t *mod)
jl_array_ptr_1d_push(jl_module_init_order, mod);
}

JL_DLLEXPORT jl_svec_t *jl_module_get_bindings(jl_module_t *m)
{
return jl_atomic_load_relaxed(&m->bindings);
}

JL_DLLEXPORT void jl_init_restored_module(jl_value_t *mod)
{
if (!jl_generating_output() || jl_options.incremental) {
Expand Down
7 changes: 7 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ module Rebinding
@test Base.@world(Foo, defined_world_age) == typeof(x)
@test Base.@world(Rebinding.Foo, defined_world_age) == typeof(x)
@test Base.@world((@__MODULE__).Foo, defined_world_age) == typeof(x)

# Test invalidation (const -> undefined)
const delete_me = 1
f_return_delete_me() = delete_me
@test f_return_delete_me() == 1
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_return_delete_me()
end
Loading