Skip to content

Commit

Permalink
[CM-7218] introduce new monkey patching to comet llm (#30)
Browse files Browse the repository at this point in the history
* Add prototype for import_hooks

* Add new entities

* Rename module

* Make it work on psutil example.Hooray

* Fix wrapper

* Add copyrights

* Fix lint errors

* Add docstrings to registry

* Update docstrings

* Fix lint errors

* Add new tests for patching functions

* Add tests

* Add validate.py and tests for it

* Fix lint errors

* Add tests

* Update tests

* Update tests

* Add test for child class method, refactor tests a bit

* Remove currently odd import

* Replace wraps workaround with actual wraps
  • Loading branch information
alexkuzmik committed Aug 9, 2023
1 parent 9faf8fc commit 42f7009
Show file tree
Hide file tree
Showing 17 changed files with 1,072 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/comet_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _muted_import_comet_ml() -> ModuleType:

comet_ml = _muted_import_comet_ml()

if TYPE_CHECKING:
if TYPE_CHECKING: # pragma: no cover
import comet_ml.config as comet_ml_config


Expand Down
13 changes: 13 additions & 0 deletions src/comet_llm/import_hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************
29 changes: 29 additions & 0 deletions src/comet_llm/import_hooks/callable_extenders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import dataclasses
from typing import List

from .types import AfterCallback, AfterExceptionCallback, BeforeCallback


@dataclasses.dataclass
class CallableExtenders:
before: List[BeforeCallback]
after: List[AfterCallback]
after_exception: List[AfterExceptionCallback]


def get() -> CallableExtenders:
return CallableExtenders([], [], [])
76 changes: 76 additions & 0 deletions src/comet_llm/import_hooks/callback_runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import logging
from typing import Any, Callable, Dict, List, Tuple, Union

from . import validate
from .types import AfterCallback, AfterExceptionCallback, BeforeCallback

LOGGER = logging.getLogger(__name__)

Args = Union[Tuple[Any, ...], List[Any]]
ArgsKwargs = Tuple[Args, Dict[str, Any]]


def run_before( # type: ignore
callbacks: List[BeforeCallback], original: Callable, *args, **kwargs
) -> ArgsKwargs:
for callback in callbacks:
try:
callback_return = callback(original, *args, **kwargs)

if validate.args_kwargs(callback_return):
LOGGER.debug("New args %r", callback_return)
args, kwargs = callback_return
except Exception:
LOGGER.debug(
"Exception calling before callback %r", callback, exc_info=True
)

return args, kwargs


def run_after( # type: ignore
callbacks: List[AfterCallback],
original: Callable,
return_value: Any,
*args,
**kwargs
) -> Any:
for callback in callbacks:
try:
new_return_value = callback(original, return_value, *args, **kwargs)
if new_return_value is not None:
return_value = new_return_value
except Exception:
LOGGER.debug("Exception calling after callback %r", callback, exc_info=True)

return return_value


def run_after_exception( # type: ignore
callbacks: List[AfterExceptionCallback],
original: Callable,
exception: Exception,
*args,
**kwargs
) -> None:
for callback in callbacks:
try:
callback(original, exception, *args, **kwargs)
except Exception:
LOGGER.debug(
"Exception calling after-exception callback %r", callback, exc_info=True
)
50 changes: 50 additions & 0 deletions src/comet_llm/import_hooks/finder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import sys
from importlib import machinery
from types import ModuleType
from typing import List, Optional

from . import module_loader, registry


class CometFinder:
def __init__(self, extensions_registry: registry.Registry) -> None:
self._registry = extensions_registry
self._pathfinder = machinery.PathFinder()

def hook_into_import_system(self) -> None:
if self not in sys.meta_path:
sys.meta_path.insert(0, self) # type: ignore

def find_spec(
self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType]
) -> Optional[machinery.ModuleSpec]:
if fullname not in self._registry.module_names:
return None

original_spec = self._pathfinder.find_spec(fullname, path, target)

if original_spec is None:
return None

return self._wrap_spec_loader(fullname, original_spec)

def _wrap_spec_loader(
self, fullname: str, spec: machinery.ModuleSpec
) -> machinery.ModuleSpec:
module_extension = self._registry.get_extension(fullname)
spec.loader = module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore
return spec
31 changes: 31 additions & 0 deletions src/comet_llm/import_hooks/module_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

from typing import Dict

from . import callable_extenders


class ModuleExtension:
def __init__(self) -> None:
self._callables_extenders: Dict[str, callable_extenders.CallableExtenders] = {}

def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders:
if callable_name not in self._callables_extenders:
self._callables_extenders[callable_name] = callable_extenders.get()

return self._callables_extenders[callable_name]

def items(self): # type: ignore
return self._callables_extenders.items()
49 changes: 49 additions & 0 deletions src/comet_llm/import_hooks/module_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import importlib.abc
from types import ModuleType
from typing import TYPE_CHECKING, Optional

from . import module_extension, patcher

if TYPE_CHECKING: # pragma: no cover
from importlib import machinery


class CometModuleLoader(importlib.abc.Loader):
def __init__(
self,
module_name: str,
original_loader: importlib.abc.Loader,
module_extension: module_extension.ModuleExtension,
) -> None:
self._module_name = module_name
self._original_loader = original_loader
self._module_extension = module_extension

def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]:
if hasattr(self._original_loader, "create_module"):
return self._original_loader.create_module(spec)

LET_PYTHON_HANDLE_THIS = None
return LET_PYTHON_HANDLE_THIS

def exec_module(self, module: ModuleType) -> None:
if hasattr(self._original_loader, "exec_module"):
self._original_loader.exec_module(module)
else:
module = self._original_loader.load_module(self._module_name)

patcher.patch(module, self._module_extension)
64 changes: 64 additions & 0 deletions src/comet_llm/import_hooks/patcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# *******************************************************
# ____ _ _
# / ___|___ _ __ ___ ___| |_ _ __ ___ | |
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| |
# | |__| (_) | | | | | | __/ |_ _| | | | | | |
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_|
#
# Sign up for free at https://www.comet.com
# Copyright (C) 2015-2023 Comet ML INC
# This file can not be copied and/or distributed without the express
# permission of Comet ML Inc.
# *******************************************************

import inspect
from types import ModuleType
from typing import TYPE_CHECKING, Any

from . import wrapper

if TYPE_CHECKING: # pragma: no cover
from . import module_extension

# _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes.


def _get_object(module: ModuleType, callable_path: str) -> Any:
current_object = module

for part in callable_path:
try:
current_object = getattr(current_object, part)
except AttributeError:
return None

return current_object


def _set_object(
module: ModuleType, callable_path: str, original: Any, new_object: Any
) -> None:
object_to_patch = _get_object(module, callable_path[:-1])

original_self = getattr(original, "__self__", None)

# Support classmethod
if original_self and inspect.isclass(original_self):
new_object = classmethod(new_object)

setattr(object_to_patch, callable_path[-1], new_object)


def patch(
module: ModuleType, module_extension: "module_extension.ModuleExtension"
) -> None:
for callable_name, callable_extenders in module_extension.items():
callable_path = callable_name.split(".")
original = _get_object(module, callable_path)

if original is None:
continue

new_callable = wrapper.wrap(original, callable_extenders)
_set_object(module, callable_path, original, new_callable)
Loading

0 comments on commit 42f7009

Please sign in to comment.