Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify and fix intermittent names import #60

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/frompackage/code_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function modify_package_using!(ex::Expr, loc, package_dict::Dict, eval_module::M
extracted_package_name = first(package_expr_args)
if extracted_package_name === package_name
# We modify the specific using expression to point to the correct module path
prepend!(package_expr_args, modname_path(fromparent_module[]))
prepend!(package_expr_args, temp_module_path())
end
end
return true
Expand Down
94 changes: 51 additions & 43 deletions src/frompackage/helpers.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import ..PlutoDevMacros: hide_this_log

function get_temp_module()
@assert isassigned(fromparent_module) "You have to assing the parent module by calling `maybe_create_module` with a Pluto workspace module as input before you can use `get_temp_module`"
fromparent_module[]
isdefined(Main, TEMP_MODULE_NAME) || return nothing
return getproperty(Main, TEMP_MODULE_NAME)::Module
end

# Extract the module that is the target in dict
get_target_module(dict) = get_target_module(Symbol(dict["name"]))
get_target_module(mod_name::Symbol) = getfield(get_temp_module(), mod_name)
get_target_module(dict) = dict["Created Module"]

function get_target_uuid(dict)
uuid = get(dict, "uuid", nothing)
Expand Down Expand Up @@ -118,7 +117,7 @@ function get_package_data(packagepath::AbstractString)
return package_data
end

## getfirst
# Get the first element in itr that satisfies predicate p, or nothing if itr is empty or no elements satisfy p
function getfirst(p, itr)
for el in itr
p(el) && return el
Expand All @@ -127,45 +126,48 @@ function getfirst(p, itr)
end
getfirst(itr) = getfirst(x -> true, itr)

## filterednames
function filterednames(m::Module, caller_module = nothing; all = true, imported = true, explicit_names = nothing, package_dict = nothing)
## Similar to names but allows to exclude names and add explicit ones. It also filter names based on whether they are defined already in the caller module
function filterednames(m::Module; all = true, imported = true, explicit_names = Set{Symbol}(), caller_module::Module)
excluded = (:eval, :include, :_fromparent_dict_, Symbol("@bind"))
mod_names = names(m;all, imported)
filter_args = if explicit_names isa Set{Symbol}
for name in mod_names
push!(explicit_names, name)
end
collect(explicit_names)
else
mod_names
end
filter_func = filterednames_filter_func(m; excluded, caller_module, package_dict)
mod_names = names(m; all, imported)
filter_args = union(mod_names, explicit_names)
filter_func = filterednames_filter_func(;excluded, caller_module)
filter(filter_func, filter_args)
end

function filterednames_filter_func(m; excluded, caller_module, package_dict)
f(s) = let excluded = excluded, caller_module = caller_module, package_dict = package_dict
function has_ancestor_module(target::Module, ancestor_name::Symbol; previous = nothing, only_rootmodule = false)
has_ancestor_module(target, (ancestor_name,); previous, only_rootmodule)
end
function has_ancestor_module(target::Module, ancestor_names; previous = nothing, only_rootmodule = false)
nm = nameof(target)
ancestor_found = nm in ancestor_names
!only_rootmodule && ancestor_found && return true # Ancestor found, and no check on only_rootmodule
nm === previous && return ancestor_found # The target is the same as previous, so we reached a top-level module. We return whether the ancestor was found and is a parent of itself
return has_ancestor_module(parentmodule(target), ancestor_names; previous = nm, only_rootmodule)
end

# This returns two flags: whether the name can be included and whether a warning should be generated
function can_import_in_caller(name::Symbol, caller::Module)
isdefined(caller, name) || return true, false # If is not defined we can surely import it
owner = which(caller, name)
# Skip (and do not warn) for things defined in Base or Core
invalid_ancestor = has_ancestor_module(owner, (:Base, :Core, :Markdown, :InteractiveUtils))
invalid_ancestor && return false, false
# We check if the name is inside the list of symbols imported by the previous module
in_previous = name in PREVIOUS_CATCHALL_NAMES
return in_previous, !in_previous
end

function filterednames_filter_func(;excluded, caller_module)
f(s) = let excluded = excluded, caller = caller_module
Base.isgensym(s) && return false
s in excluded && return false
if caller_module isa Module
previous_target_module = get_stored_module(package_dict)
# We check and avoid putting in scope symbols which are already in the caller module
isdefined(caller_module, s) || return true
# Here we have to extract the symbols to compare them
mod_val = getfield(m, s)
caller_val = getfield(caller_module, s)
if caller_val !== mod_val
if isdefined(previous_target_module, s) && caller_val === getfield(previous_target_module, s)
# We are just replacing the previous implementation of this call's target package, so we want to overwrite
return true
else
@warn "Symbol `:$s`, is already defined in the caller module and points to a different object. Skipping"
end
end
return false
else # We don't check for names clashes with a caller module
return true
should_include, should_warn = can_import_in_caller(s, caller)
if should_warn
owner = which(caller, s)
@warn "The name `$s`, defined in $owner, is already present in the caller module and will not be imported."
end
return should_include
end
return f
end
Expand Down Expand Up @@ -280,7 +282,7 @@ end
# This relies on Base internals (and even the C API) but will allow make the loaded module behave more like if we simply did `using TargetPackage` in the REPL
function register_target_module_as_root(package_dict)
name_str = package_dict["name"]
m = get_target_module(Symbol(name_str))
m = get_target_module(package_dict)
id = get_target_pkgid(package_dict)
uuid = id.uuid
entry_point = package_dict["file"]
Expand Down Expand Up @@ -317,12 +319,18 @@ function try_load_extensions(package_dict::Dict)
end

# This function will get the module stored in the created_modules dict based on the entry point
get_stored_module(package_dict) = get_stored_module(package_dict["uuid"])
get_stored_module(key::String) = get(created_modules, key, nothing)
get_stored_module() = STORED_MODULE[]
# This will store in it
update_stored_module(key::String, m::Module) = created_modules[key] = m
update_stored_module(m::Module) = STORED_MODULE[] = m
function update_stored_module(package_dict::Dict)
uuid = package_dict["uuid"]
m = get_target_module(package_dict)
update_stored_module(uuid, m)
update_stored_module(m)
end

overwrite_imported_symbols(package_dict::Dict) = overwrite_imported_symbols(get(Set{Symbol}, package_dict, "Catchall Imported Symbols"))
# This overwrites the PREVIOUSLY_IMPORTED_SYMBOLS with the contents of new_symbols
function overwrite_imported_symbols(new_symbols)
empty!(PREVIOUS_CATCHALL_NAMES)
union!(PREVIOUS_CATCHALL_NAMES, new_symbols)
nothing
end
31 changes: 10 additions & 21 deletions src/frompackage/input_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,20 +145,8 @@ function import_type(args, dict)
error("The provided import expression is not supported, please look at @frompackage documentation to see the supported imports")
end

# Get the full path of the module as array of Symbols starting from Main
function modname_path(m::Module)
args = [nameof(m)]
m_old = m
_m = parentmodule(m)
while _m !== m_old && _m !== Main
m_old = _m
pushfirst!(args, nameof(_m))
_m = parentmodule(_m)
end
_m === Main || error("modname_path did not reach Main, this is not expected")
pushfirst!(args, nameof(_m))
return args
end
# This is the path of the temp module which will be prepended to the package name
temp_module_path() = (:Main, TEMP_MODULE_NAME)

## process outside pluto
# We parse all the expressions in the block provided as input to @fromparent
Expand Down Expand Up @@ -259,12 +247,12 @@ function process_imported_nameargs!(args, dict)
end
## Per-type versions
function process_imported_nameargs!(args, dict, t::FromPackageImport)
name_init = modname_path(fromparent_module[])
name_init = temp_module_path()
args[1] = t.mod_name
prepend!(args, name_init)
end
function process_imported_nameargs!(args, dict, t::Union{FromParentImport, RelativeImport})
name_init = modname_path(fromparent_module[])
name_init = temp_module_path()
# Here transform the relative module name to the one based on the full loaded module path
target_path = get(dict, "Target Path", []) |> reverse
isempty(target_path) && error("The current file was not found included in the loaded module $(t.mod_name), so you can't use relative path imports")
Expand All @@ -285,7 +273,7 @@ function process_imported_nameargs!(args, dict, t::FromDepsImport)
maybe_add_loaded_module(t.id)
deps_module_name = :_LoadedModules_
args[1] = deps_module_name # We replace `>` with _LoadedModules_
name_init = modname_path(fromparent_module[])
name_init = temp_module_path()
prepend!(args, name_init)
return nothing
end
Expand Down Expand Up @@ -317,7 +305,7 @@ function should_include_using_names!(ex)
end

## parseinput
function parseinput(ex, package_dict; caller_module = nothing)
function parseinput(ex, package_dict; caller_module)
include_using = should_include_using_names!(ex)
# We get the module
modname_expr, importednames_exprs = extract_import_args(ex)
Expand Down Expand Up @@ -346,10 +334,11 @@ function parseinput(ex, package_dict; caller_module = nothing)
explicit_names = if include_using
package_dict["Using Names"].explicit_names
else
nothing
Set{Symbol}()
end
# We extract the imported names either due to catchall or due to the standard using
imported_names = filterednames(_mod, caller_module; all = catchall, imported = catchall, explicit_names, package_dict)
imported_names = filterednames(_mod; all = catchall, imported = catchall, explicit_names, caller_module)
# We add the imported names to the set for tracking names imported by this macrocall
union!(get!(Set{Symbol}, package_dict, "Catchall Imported Symbols"), imported_names)
# At this point we have all the names and we just have to create the final expression
importednames_exprs = map(n -> Expr(:., n), imported_names)
return reconstruct_import_expr(modname_expr, importednames_exprs)
Expand Down
33 changes: 20 additions & 13 deletions src/frompackage/loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,17 +124,16 @@ function eval_module_expr(parent_module, ex, dict)
return out isa StopEval ? out : nothing
end

function maybe_create_module(m::Module)
if !isassigned(fromparent_module)
fromparent_m = Core.eval(m, :(module $(gensym(:frompackage))
end))
# We create the dummy module where all the direct dependencies will be loaded
Core.eval(fromparent_m, :(module _DirectDeps_ end))
# We also set a reference to LoadedModules for access from the notebook
Core.eval(fromparent_m, :(const _LoadedModules_ = $LoadedModules))
fromparent_module[] = fromparent_m
end
return fromparent_module[]
function maybe_create_module()
m = get_temp_module()
isnothing(m) || return m
fromparent_m = Core.eval(Main, :(module $TEMP_MODULE_NAME
end))
# We create the dummy module where all the direct dependencies will be loaded
Core.eval(fromparent_m, :(module _DirectDeps_ end))
# We also set a reference to LoadedModules for access from the notebook
Core.eval(fromparent_m, :(const _LoadedModules_ = $LoadedModules))
return fromparent_m
end

# This will explicitly import each direct dependency of the package inside the LoadedModules module. Loading all of the direct dependencies will help make every dependency available even if not directly loaded in the target source code.
Expand Down Expand Up @@ -168,9 +167,15 @@ function load_module_in_caller(mod_exp::Expr, package_dict::Dict, caller_module)
target_file = package_dict["target"]
ecg = default_ecg()
# If the module Reference inside fromparent_module is not assigned, we create the module in the calling workspace and assign it
_MODULE_ = maybe_create_module(caller_module)
_MODULE_ = maybe_create_module()
# We reset the module path in case it was not cleaned
mod_name = mod_exp.args[2]
# We reset the list of symbols if we loaded a different module
stored_module = get_stored_module()
if !isnothing(stored_module) && nameof(stored_module) !== mod_name
# We reset the list of previous symbols
empty!(PREVIOUS_CATCHALL_NAMES)
end
# We inject the project in the LOAD_PATH if it is not present already
add_loadpath(ecg; should_prepend = Settings.get_setting(package_dict, :SHOULD_PREPEND_LOAD_PATH))
# We start by loading each of the direct dependencies in the LoadedModules submodule
Expand All @@ -184,7 +189,9 @@ function load_module_in_caller(mod_exp::Expr, package_dict::Dict, caller_module)
rethrow(e)
end
# Get the moduleof the parent package
__module = getfield(_MODULE_, mod_name)
__module = getproperty(_MODULE_, mod_name)::Module
# We put the module in the dict
package_dict["Created Module"] = __module
# We put the dict inside the loaded module
Core.eval(__module, :(_fromparent_dict_ = $package_dict))
# Register this module as root module.
Expand Down
2 changes: 2 additions & 0 deletions src/frompackage/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ function frompackage(ex, target_file, caller, caller_module; macroname)
end
# We update th stored root module
update_stored_module(package_dict)
# We put the included names in PREVIOUS_CATCHALL_NAMES
overwrite_imported_symbols(package_dict)
# We call at runtime the function to trigger extensions loading
push!(args, :($try_load_extensions($package_dict)))
# We wrap the import expressions inside a try-catch, as those also correctly work from there.
Expand Down
4 changes: 3 additions & 1 deletion src/frompackage/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ const _stdlibs = first.(values(Pkg.Types.stdlibs()))

const default_pkg_io = Ref{IO}(devnull)

const fromparent_module = Ref{Module}()
const TEMP_MODULE_NAME = :_FromPackage_TempModule_
const STORED_MODULE = Ref{Union{Module, Nothing}}(nothing)
const PREVIOUS_CATCHALL_NAMES = Set{Symbol}()
const macro_cell = Ref("undefined")
const manifest_names = ("JuliaManifest.toml", "Manifest.toml")

Expand Down
1 change: 1 addition & 0 deletions test/TestUsingNames/src/TestUsingNames.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Base64
export top_level_func
top_level_func() = 1
clash_name = 5
rand_variable = rand()

module Test1
using Example
Expand Down
6 changes: 6 additions & 0 deletions test/TestUsingNames/test_notebook2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,14 @@ isdefined(@__MODULE__, :base64encode) || error("base64encode from Base64 should
clash_name === 0 || error("The clashed name was not handled correctly")
╠═╡ =#

# ╔═╡ 4492d516-2b23-45b7-bf76-7458e7352fea
#=╠═╡
rand_variable # This should change at every re-run
╠═╡ =#

# ╔═╡ Cell order:
# ╠═4f8def86-f90b-4f74-ac47-93fe6e437cee
# ╠═ac3d261a-86c9-453f-9d86-23a8f30ca583
# ╠═dd3f662f-e2ce-422d-a91a-487a4da359cc
# ╠═c72f2544-eb2e-4ed6-a89b-495ead20b5f6
# ╠═4492d516-2b23-45b7-bf76-7458e7352fea
Loading
Loading