Skip to content

Commit

Permalink
Fix in-process main module source loading
Browse files Browse the repository at this point in the history
  • Loading branch information
schustmi committed Oct 22, 2024
1 parent 7621f44 commit fd17b85
Showing 1 changed file with 87 additions and 55 deletions.
142 changes: 87 additions & 55 deletions src/zenml/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,17 @@
from distutils.sysconfig import get_python_lib
from pathlib import Path, PurePath
from types import BuiltinFunctionType, FunctionType, ModuleType
from typing import Any, Callable, Dict, Iterator, Optional, Type, Union, cast
from typing import (
Any,
Callable,
Dict,
Iterator,
Optional,
Set,
Type,
Union,
cast,
)
from uuid import UUID

from zenml.config.source import (
Expand Down Expand Up @@ -64,6 +74,7 @@
_SHARED_TEMPDIR: Optional[str] = None
_resolved_notebook_sources: Dict[str, str] = {}
_notebook_modules: Dict[str, UUID] = {}
_resolved_main_module_sources: Set[str] = set()


def load(source: Union[Source, str]) -> Any:
Expand Down Expand Up @@ -121,7 +132,17 @@ def load(source: Union[Source, str]) -> Any:
# root in python path just to be sure
import_root = get_source_root()

module = _load_module(module_name=source.module, import_root=import_root)
if source.import_path in _resolved_main_module_sources:
# We resolved this source from the __main__ module in this process.
# If we were to load the module here, we would load the same python
# file with a different module name, which would rerun all top-level
# code. To avoid this, we instead load the source from the __main__
# module which is already loaded.
module = sys.modules["__main__"]
else:
module = _load_module(
module_name=source.module, import_root=import_root
)

if source.attribute:
obj = getattr(module, source.attribute)
Expand Down Expand Up @@ -187,73 +208,84 @@ def resolve(
"holds the object you want to resolve."
)

module_name = module.__name__
if module_name == "__main__":
module_name = _resolve_module(module)
def _resolve_helper() -> Source:
module_name = module.__name__
if module_name == "__main__":
module_name = _resolve_module(module)

source_type = get_source_type(module=module)
source_type = get_source_type(module=module)

if source_type == SourceType.USER:
from zenml.utils import code_repository_utils
if source_type == SourceType.USER:
from zenml.utils import code_repository_utils

local_repo_context = (
code_repository_utils.find_active_code_repository()
)
local_repo_context = (
code_repository_utils.find_active_code_repository()
)

if local_repo_context and not local_repo_context.has_local_changes:
module_name = _resolve_module(module)
if local_repo_context and not local_repo_context.has_local_changes:
module_name = _resolve_module(module)

source_root = get_source_root()
subdir = PurePath(source_root).relative_to(local_repo_context.root)
source_root = get_source_root()
subdir = PurePath(source_root).relative_to(
local_repo_context.root
)

return CodeRepositorySource(
repository_id=local_repo_context.code_repository_id,
commit=local_repo_context.current_commit,
subdirectory=subdir.as_posix(),
module=module_name,
attribute=attribute_name,
type=SourceType.CODE_REPOSITORY,
)
return CodeRepositorySource(
repository_id=local_repo_context.code_repository_id,
commit=local_repo_context.current_commit,
subdirectory=subdir.as_posix(),
module=module_name,
attribute=attribute_name,
type=SourceType.CODE_REPOSITORY,
)

module_name = _resolve_module(module)
elif source_type == SourceType.DISTRIBUTION_PACKAGE:
package_name = _get_package_for_module(module_name=module_name)
if package_name:
package_version = _get_package_version(package_name=package_name)
return DistributionPackageSource(
module=module_name,
module_name = _resolve_module(module)
elif source_type == SourceType.DISTRIBUTION_PACKAGE:
package_name = _get_package_for_module(module_name=module_name)
if package_name:
package_version = _get_package_version(
package_name=package_name
)
return DistributionPackageSource(
module=module_name,
attribute=attribute_name,
package_name=package_name,
version=package_version,
type=source_type,
)
else:
# Fallback to an unknown source if we can't find the package
source_type = SourceType.UNKNOWN
elif source_type == SourceType.NOTEBOOK:
source = NotebookSource(
module="__main__",
attribute=attribute_name,
package_name=package_name,
version=package_version,
type=source_type,
)
else:
# Fallback to an unknown source if we can't find the package
source_type = SourceType.UNKNOWN
elif source_type == SourceType.NOTEBOOK:
source = NotebookSource(
module="__main__",
attribute=attribute_name,
type=source_type,
)

if module_name in _notebook_modules:
source.replacement_module = module_name
source.artifact_store_id = _notebook_modules[module_name]
elif cell_code := notebook_utils.load_notebook_cell_code(obj):
replacement_module = (
notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
if module_name in _notebook_modules:
source.replacement_module = module_name
source.artifact_store_id = _notebook_modules[module_name]
elif cell_code := notebook_utils.load_notebook_cell_code(obj):
replacement_module = (
notebook_utils.compute_cell_replacement_module_name(
cell_code=cell_code
)
)
)
source.replacement_module = replacement_module
_resolved_notebook_sources[source.import_path] = cell_code
source.replacement_module = replacement_module
_resolved_notebook_sources[source.import_path] = cell_code

return source
return source

return Source(
module=module_name, attribute=attribute_name, type=source_type
)

source = _resolve_helper()
if module.__name__ == "__main__":
_resolved_main_module_sources.add(source.import_path)

return Source(
module=module_name, attribute=attribute_name, type=source_type
)
return source


def get_source_root() -> str:
Expand Down

0 comments on commit fd17b85

Please sign in to comment.