Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CM-7218] introduce new monkey patching to comet llm #30

Merged
merged 20 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/comet_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _muted_import_comet_ml() -> ModuleType:

comet_ml = _muted_import_comet_ml()

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import comet_ml.config as comet_ml_config


Expand Down
13 changes: 13 additions & 0 deletions src/comet_llm/import_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************
29 changes: 29 additions & 0 deletions src/comet_llm/import_hooks/callable_extenders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import dataclasses
from typing import List

from .types import AfterCallback, AfterExceptionCallback, BeforeCallback


@dataclasses.dataclass
class CallableExtenders:
before: List[BeforeCallback]
after: List[AfterCallback]
after_exception: List[AfterExceptionCallback]


def get() -> CallableExtenders:
return CallableExtenders([], [], [])
76 changes: 76 additions & 0 deletions src/comet_llm/import_hooks/callback_runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import logging
from typing import Any, Callable, Dict, List, Tuple, Union

from . import validate
from .types import AfterCallback, AfterExceptionCallback, BeforeCallback

LOGGER = logging.getLogger(__name__)

Args = Union[Tuple[Any, ...], List[Any]]
ArgsKwargs = Tuple[Args, Dict[str, Any]]


def run_before( # type: ignore
callbacks: List[BeforeCallback], original: Callable, *args, **kwargs
) -> ArgsKwargs:
for callback in callbacks:
try:
callback_return = callback(original, *args, **kwargs)

if validate.args_kwargs(callback_return):
LOGGER.debug("New args %r", callback_return)
args, kwargs = callback_return
except Exception:
LOGGER.debug(
"Exception calling before callback %r", callback, exc_info=True
)

return args, kwargs


def run_after( # type: ignore
callbacks: List[AfterCallback],
original: Callable,
return_value: Any,
*args,
**kwargs
) -> Any:
for callback in callbacks:
try:
new_return_value = callback(original, return_value, *args, **kwargs)
if new_return_value is not None:
return_value = new_return_value
except Exception:
LOGGER.debug("Exception calling after callback %r", callback, exc_info=True)

return return_value


def run_after_exception( # type: ignore
callbacks: List[AfterExceptionCallback],
original: Callable,
exception: Exception,
*args,
**kwargs
) -> None:
for callback in callbacks:
try:
callback(original, exception, *args, **kwargs)
except Exception:
LOGGER.debug(
"Exception calling after-exception callback %r", callback, exc_info=True
)
50 changes: 50 additions & 0 deletions src/comet_llm/import_hooks/finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import sys
from importlib import machinery
from types import ModuleType
from typing import List, Optional

from . import module_loader, registry


class CometFinder:
def __init__(self, extensions_registry: registry.Registry) -> None:
self._registry = extensions_registry
self._pathfinder = machinery.PathFinder()

def hook_into_import_system(self) -> None:
if self not in sys.meta_path:
sys.meta_path.insert(0, self) # type: ignore

def find_spec(
self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType]
) -> Optional[machinery.ModuleSpec]:
if fullname not in self._registry.module_names:
return None

original_spec = self._pathfinder.find_spec(fullname, path, target)

if original_spec is None:
return None

return self._wrap_spec_loader(fullname, original_spec)

def _wrap_spec_loader(
self, fullname: str, spec: machinery.ModuleSpec
) -> machinery.ModuleSpec:
module_extension = self._registry.get_extension(fullname)
spec.loader = module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore
return spec
31 changes: 31 additions & 0 deletions src/comet_llm/import_hooks/module_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

from typing import Dict

from . import callable_extenders


class ModuleExtension:
def __init__(self) -> None:
self._callables_extenders: Dict[str, callable_extenders.CallableExtenders] = {}

def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders:
if callable_name not in self._callables_extenders:
self._callables_extenders[callable_name] = callable_extenders.get()

return self._callables_extenders[callable_name]

def items(self): # type: ignore
return self._callables_extenders.items()
49 changes: 49 additions & 0 deletions src/comet_llm/import_hooks/module_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import importlib.abc
from types import ModuleType
from typing import TYPE_CHECKING, Optional

from . import module_extension, patcher

if TYPE_CHECKING: # pragma: no cover
from importlib import machinery


class CometModuleLoader(importlib.abc.Loader):
def __init__(
self,
module_name: str,
original_loader: importlib.abc.Loader,
module_extension: module_extension.ModuleExtension,
) -> None:
self._module_name = module_name
self._original_loader = original_loader
self._module_extension = module_extension

def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]:
if hasattr(self._original_loader, "create_module"):
return self._original_loader.create_module(spec)

LET_PYTHON_HANDLE_THIS = None
return LET_PYTHON_HANDLE_THIS

def exec_module(self, module: ModuleType) -> None:
if hasattr(self._original_loader, "exec_module"):
self._original_loader.exec_module(module)
else:
module = self._original_loader.load_module(self._module_name)

patcher.patch(module, self._module_extension)
64 changes: 64 additions & 0 deletions src/comet_llm/import_hooks/patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import inspect
from types import ModuleType
from typing import TYPE_CHECKING, Any

from . import wrapper

if TYPE_CHECKING: # pragma: no cover
from . import module_extension

# _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes.


def _get_object(module: ModuleType, callable_path: str) -> Any:
current_object = module

for part in callable_path:
try:
current_object = getattr(current_object, part)
except AttributeError:
return None

return current_object


def _set_object(
module: ModuleType, callable_path: str, original: Any, new_object: Any
) -> None:
object_to_patch = _get_object(module, callable_path[:-1])

original_self = getattr(original, "__self__", None)

# Support classmethod
if original_self and inspect.isclass(original_self):
new_object = classmethod(new_object)

setattr(object_to_patch, callable_path[-1], new_object)


def patch(
module: ModuleType, module_extension: "module_extension.ModuleExtension"
) -> None:
for callable_name, callable_extenders in module_extension.items():
callable_path = callable_name.split(".")
original = _get_object(module, callable_path)

if original is None:
continue

new_callable = wrapper.wrap(original, callable_extenders)
_set_object(module, callable_path, original, new_callable)
Loading