-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[CM-7218] introduce new monkey patching to comet llm (#30)
* Add prototype for import_hooks * Add new entities * Rename module * Make it work on psutil example.Hooray * Fix wrapper * Add copyrights * Fix lint errors * Add docstrings to registry * Update docstrings * Fix lint errors * Add new tests for patching functions * Add tests * Add validate.py and tests for it * Fix lint errors * Add tests * Update tests * Update tests * Add test for child class method, refactor tests a bit * Remove currently odd import * Replace wraps workaround with actual wraps
- Loading branch information
1 parent
9faf8fc
commit 42f7009
Showing
17 changed files
with
1,072 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
import dataclasses | ||
from typing import List | ||
|
||
from .types import AfterCallback, AfterExceptionCallback, BeforeCallback | ||
|
||
|
||
@dataclasses.dataclass | ||
class CallableExtenders: | ||
before: List[BeforeCallback] | ||
after: List[AfterCallback] | ||
after_exception: List[AfterExceptionCallback] | ||
|
||
|
||
def get() -> CallableExtenders: | ||
return CallableExtenders([], [], []) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
import logging | ||
from typing import Any, Callable, Dict, List, Tuple, Union | ||
|
||
from . import validate | ||
from .types import AfterCallback, AfterExceptionCallback, BeforeCallback | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
Args = Union[Tuple[Any, ...], List[Any]] | ||
ArgsKwargs = Tuple[Args, Dict[str, Any]] | ||
|
||
|
||
def run_before( # type: ignore | ||
callbacks: List[BeforeCallback], original: Callable, *args, **kwargs | ||
) -> ArgsKwargs: | ||
for callback in callbacks: | ||
try: | ||
callback_return = callback(original, *args, **kwargs) | ||
|
||
if validate.args_kwargs(callback_return): | ||
LOGGER.debug("New args %r", callback_return) | ||
args, kwargs = callback_return | ||
except Exception: | ||
LOGGER.debug( | ||
"Exception calling before callback %r", callback, exc_info=True | ||
) | ||
|
||
return args, kwargs | ||
|
||
|
||
def run_after( # type: ignore | ||
callbacks: List[AfterCallback], | ||
original: Callable, | ||
return_value: Any, | ||
*args, | ||
**kwargs | ||
) -> Any: | ||
for callback in callbacks: | ||
try: | ||
new_return_value = callback(original, return_value, *args, **kwargs) | ||
if new_return_value is not None: | ||
return_value = new_return_value | ||
except Exception: | ||
LOGGER.debug("Exception calling after callback %r", callback, exc_info=True) | ||
|
||
return return_value | ||
|
||
|
||
def run_after_exception( # type: ignore | ||
callbacks: List[AfterExceptionCallback], | ||
original: Callable, | ||
exception: Exception, | ||
*args, | ||
**kwargs | ||
) -> None: | ||
for callback in callbacks: | ||
try: | ||
callback(original, exception, *args, **kwargs) | ||
except Exception: | ||
LOGGER.debug( | ||
"Exception calling after-exception callback %r", callback, exc_info=True | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
import sys | ||
from importlib import machinery | ||
from types import ModuleType | ||
from typing import List, Optional | ||
|
||
from . import module_loader, registry | ||
|
||
|
||
class CometFinder: | ||
def __init__(self, extensions_registry: registry.Registry) -> None: | ||
self._registry = extensions_registry | ||
self._pathfinder = machinery.PathFinder() | ||
|
||
def hook_into_import_system(self) -> None: | ||
if self not in sys.meta_path: | ||
sys.meta_path.insert(0, self) # type: ignore | ||
|
||
def find_spec( | ||
self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType] | ||
) -> Optional[machinery.ModuleSpec]: | ||
if fullname not in self._registry.module_names: | ||
return None | ||
|
||
original_spec = self._pathfinder.find_spec(fullname, path, target) | ||
|
||
if original_spec is None: | ||
return None | ||
|
||
return self._wrap_spec_loader(fullname, original_spec) | ||
|
||
def _wrap_spec_loader( | ||
self, fullname: str, spec: machinery.ModuleSpec | ||
) -> machinery.ModuleSpec: | ||
module_extension = self._registry.get_extension(fullname) | ||
spec.loader = module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore | ||
return spec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
from typing import Dict | ||
|
||
from . import callable_extenders | ||
|
||
|
||
class ModuleExtension: | ||
def __init__(self) -> None: | ||
self._callables_extenders: Dict[str, callable_extenders.CallableExtenders] = {} | ||
|
||
def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders: | ||
if callable_name not in self._callables_extenders: | ||
self._callables_extenders[callable_name] = callable_extenders.get() | ||
|
||
return self._callables_extenders[callable_name] | ||
|
||
def items(self): # type: ignore | ||
return self._callables_extenders.items() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
import importlib.abc | ||
from types import ModuleType | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from . import module_extension, patcher | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from importlib import machinery | ||
|
||
|
||
class CometModuleLoader(importlib.abc.Loader): | ||
def __init__( | ||
self, | ||
module_name: str, | ||
original_loader: importlib.abc.Loader, | ||
module_extension: module_extension.ModuleExtension, | ||
) -> None: | ||
self._module_name = module_name | ||
self._original_loader = original_loader | ||
self._module_extension = module_extension | ||
|
||
def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]: | ||
if hasattr(self._original_loader, "create_module"): | ||
return self._original_loader.create_module(spec) | ||
|
||
LET_PYTHON_HANDLE_THIS = None | ||
return LET_PYTHON_HANDLE_THIS | ||
|
||
def exec_module(self, module: ModuleType) -> None: | ||
if hasattr(self._original_loader, "exec_module"): | ||
self._original_loader.exec_module(module) | ||
else: | ||
module = self._original_loader.load_module(self._module_name) | ||
|
||
patcher.patch(module, self._module_extension) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# -*- coding: utf-8 -*- | ||
# ******************************************************* | ||
# ____ _ _ | ||
# / ___|___ _ __ ___ ___| |_ _ __ ___ | | | ||
# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | | ||
# | |__| (_) | | | | | | __/ |_ _| | | | | | | | ||
# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| | ||
# | ||
# Sign up for free at https://www.comet.com | ||
# Copyright (C) 2015-2023 Comet ML INC | ||
# This file can not be copied and/or distributed without the express | ||
# permission of Comet ML Inc. | ||
# ******************************************************* | ||
|
||
import inspect | ||
from types import ModuleType | ||
from typing import TYPE_CHECKING, Any | ||
|
||
from . import wrapper | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
from . import module_extension | ||
|
||
# _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes. | ||
|
||
|
||
def _get_object(module: ModuleType, callable_path: str) -> Any: | ||
current_object = module | ||
|
||
for part in callable_path: | ||
try: | ||
current_object = getattr(current_object, part) | ||
except AttributeError: | ||
return None | ||
|
||
return current_object | ||
|
||
|
||
def _set_object( | ||
module: ModuleType, callable_path: str, original: Any, new_object: Any | ||
) -> None: | ||
object_to_patch = _get_object(module, callable_path[:-1]) | ||
|
||
original_self = getattr(original, "__self__", None) | ||
|
||
# Support classmethod | ||
if original_self and inspect.isclass(original_self): | ||
new_object = classmethod(new_object) | ||
|
||
setattr(object_to_patch, callable_path[-1], new_object) | ||
|
||
|
||
def patch( | ||
module: ModuleType, module_extension: "module_extension.ModuleExtension" | ||
) -> None: | ||
for callable_name, callable_extenders in module_extension.items(): | ||
callable_path = callable_name.split(".") | ||
original = _get_object(module, callable_path) | ||
|
||
if original is None: | ||
continue | ||
|
||
new_callable = wrapper.wrap(original, callable_extenders) | ||
_set_object(module, callable_path, original, new_callable) |
Oops, something went wrong.