From 82139a89cc353f1f6ab99420b88d0af077795113 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 20 Jul 2023 16:01:30 +0200 Subject: [PATCH 01/20] Add prototype for import_hooks --- src/comet_llm/__init__.py | 1 + src/comet_llm/import_hooks/__init__.py | 33 ++++++++++++ src/comet_llm/import_hooks/callback_loader.py | 47 ++++++++++++++++ src/comet_llm/import_hooks/subverter.py | 53 +++++++++++++++++++ 4 files changed, 134 insertions(+) create mode 100644 src/comet_llm/import_hooks/__init__.py create mode 100644 src/comet_llm/import_hooks/callback_loader.py create mode 100644 src/comet_llm/import_hooks/subverter.py diff --git a/src/comet_llm/__init__.py b/src/comet_llm/__init__.py index 203873084..e80c7a042 100644 --- a/src/comet_llm/__init__.py +++ b/src/comet_llm/__init__.py @@ -12,6 +12,7 @@ # permission of Comet ML Inc. # ******************************************************* +from . import import_hooks # keep it the first one from . import app, logging from .api import log_prompt diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py new file mode 100644 index 000000000..9f520854d --- /dev/null +++ b/src/comet_llm/import_hooks/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +# type: ignore + +from . import subverter + + +def patch_psutil(psutil): + original = psutil.cpu_count + + def patched(*args, **kwargs): + print("Before psutil.cpu_count() call!") + return original(*args, **kwargs) + + psutil.cpu_count = patched + + +_subverter = subverter.Subverter() +_subverter.register_import_callback("psutil", patch_psutil) + +_subverter.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/callback_loader.py b/src/comet_llm/import_hooks/callback_loader.py new file mode 100644 index 000000000..b93dae08e --- /dev/null +++ b/src/comet_llm/import_hooks/callback_loader.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import importlib.abc +from types import ModuleType +from typing import TYPE_CHECKING, Callable, Optional + +if TYPE_CHECKING: + from importlib import machinery + + +class CallbackLoader(importlib.abc.Loader): + def __init__( + self, + module_name: str, + original_loader: importlib.abc.Loader, + alert_callback: Callable, + ) -> None: + self._module_name = module_name + self._original_loader = original_loader + self._alert_callback = alert_callback + + def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]: + if hasattr(self._original_loader, "create_module"): + return self._original_loader.create_module(spec) + + LET_PYTHON_HANDLE_THIS = None + return LET_PYTHON_HANDLE_THIS + + def exec_module(self, module: ModuleType) -> None: + if hasattr(self._original_loader, "exec_module"): + self._original_loader.exec_module(module) + else: + module = self._original_loader.load_module(self._module_name) + + self._alert_callback(module) diff --git a/src/comet_llm/import_hooks/subverter.py b/src/comet_llm/import_hooks/subverter.py new file mode 100644 index 000000000..fee7d746b --- /dev/null +++ b/src/comet_llm/import_hooks/subverter.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import sys +from importlib import machinery +from types import ModuleType +from typing import Callable, Dict, List, Optional + +from . import callback_loader + + +class Subverter: + def __init__(self) -> None: + self._alert_callbacks: Dict[str, Callable] = {} + self._pathfinder = machinery.PathFinder() + + def hook_into_import_system(self) -> None: + if self not in sys.meta_path: + sys.meta_path.insert(0, self) # type: ignore + + def find_spec( + self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType] + ) -> Optional[machinery.ModuleSpec]: + if fullname not in self._alert_callbacks: + return None + + original_spec = self._pathfinder.find_spec(fullname, path, target) + + if original_spec is None: + return None + + return self._wrap_spec_loader(fullname, original_spec) + + def _wrap_spec_loader( + self, fullname: str, spec: machinery.ModuleSpec + ) -> machinery.ModuleSpec: + callback = self._alert_callbacks[fullname] + spec.loader = callback_loader.CallbackLoader(fullname, spec.loader, callback) # type: ignore + return spec + + def register_import_callback(self, module_name: str, callback: Callable) -> None: + self._alert_callbacks[module_name] = callback From 9e38964694aa0c08a0341082b3806a55487165d1 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Wed, 2 Aug 2023 19:54:23 +0200 Subject: [PATCH 02/20] Add new entities --- src/comet_llm/import_hooks/__init__.py | 4 +- .../import_hooks/callable_extenders.py | 15 ++++ .../import_hooks/callback_runners.py | 79 +++++++++++++++++++ ...lback_loader.py => comet_module_loader.py} | 10 ++- .../import_hooks/{subverter.py => finder.py} | 19 ++--- .../import_hooks/module_extension.py | 22 ++++++ src/comet_llm/import_hooks/patcher.py | 51 ++++++++++++ src/comet_llm/import_hooks/registry.py | 41 ++++++++++ src/comet_llm/import_hooks/types.py | 7 ++ src/comet_llm/import_hooks/wrapper.py | 25 ++++++ 10 files changed, 256 insertions(+), 17 deletions(-) create mode 100644 src/comet_llm/import_hooks/callable_extenders.py create mode 100644 src/comet_llm/import_hooks/callback_runners.py rename src/comet_llm/import_hooks/{callback_loader.py => comet_module_loader.py} (85%) rename src/comet_llm/import_hooks/{subverter.py => finder.py} (71%) create mode 100644 src/comet_llm/import_hooks/module_extension.py create mode 100644 src/comet_llm/import_hooks/patcher.py create mode 100644 src/comet_llm/import_hooks/registry.py create mode 100644 src/comet_llm/import_hooks/types.py create mode 100644 src/comet_llm/import_hooks/wrapper.py diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py index 9f520854d..3673d7f54 100644 --- a/src/comet_llm/import_hooks/__init__.py +++ b/src/comet_llm/import_hooks/__init__.py @@ -14,7 +14,7 @@ # type: ignore -from . import subverter +from . import finder def patch_psutil(psutil): @@ -27,7 +27,7 @@ def patched(*args, **kwargs): psutil.cpu_count = patched -_subverter = subverter.Subverter() +_subverter = finder.Finder() _subverter.register_import_callback("psutil", patch_psutil) _subverter.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/callable_extenders.py b/src/comet_llm/import_hooks/callable_extenders.py new file mode 100644 index 000000000..1e184a805 --- /dev/null +++ b/src/comet_llm/import_hooks/callable_extenders.py @@ -0,0 +1,15 @@ +import dataclasses +from typing import List + +from .types import AfterCallback, AfterExceptionCallback, BeforeCallback + + +@dataclasses.dataclass +class CallableExtenders: + before: List[BeforeCallback] + after: List[AfterCallback] + after_exception: List[AfterExceptionCallback] + + +def get() -> CallableExtenders: + return CallableExtenders([], [], []) diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py new file mode 100644 index 000000000..2a8ef8dfe --- /dev/null +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -0,0 +1,79 @@ +import logging +from typing import Any, Callable, Dict, List, Tuple, Union + +from .types import AfterCallback, AfterExceptionCallback, BeforeCallback + +LOGGER = logging.getLogger(__name__) + +Args = Union[Tuple[Any, ...], List[Any]] +ArgsKwargs = Tuple[Args, Dict[str, Any]] + + +def run_before( + callbacks: List[BeforeCallback], original: Callable, *args, **kwargs +) -> ArgsKwargs: + for callback in callbacks: + try: + callback_return = callback(original, *args, **kwargs) + + if _valid_new_args_kwargs(callback_return): + LOGGER.debug("New args %r", callback_return) + args, kwargs = callback_return + except Exception: + LOGGER.debug( + "Exception calling before callback %r", callback, exc_info=True + ) + + return args, kwargs + + +def run_after( + callbacks: List[AfterCallback], + original: Callable, + return_value: Any, + *args, + **kwargs +) -> Any: + for callback in callbacks: + try: + new_return_value = callback(original, return_value, *args, **kwargs) + if new_return_value is not None: + return_value = new_return_value + except Exception: + LOGGER.debug("Exception calling after callback %r", callback, exc_info=True) + + return return_value + + +def run_after_exception( + callbacks: List[AfterExceptionCallback], + original: Callable, + exception: Exception, + *args, + **kwargs +) -> None: + for callback in callbacks: + try: + callback(original, exception, *args, **kwargs) + except Exception: + LOGGER.debug( + "Exception calling after-exception callback %r", callback, exc_info=True + ) + + +def _valid_new_args_kwargs(callback_return: Any) -> bool: + if callback_return is None: + return False + + try: + args, kwargs = callback_return + except (ValueError, TypeError): + return False + + if not isinstance(args, (list, tuple)): + return False + + if not isinstance(kwargs, dict): + return False + + return True diff --git a/src/comet_llm/import_hooks/callback_loader.py b/src/comet_llm/import_hooks/comet_module_loader.py similarity index 85% rename from src/comet_llm/import_hooks/callback_loader.py rename to src/comet_llm/import_hooks/comet_module_loader.py index b93dae08e..eabc25e9b 100644 --- a/src/comet_llm/import_hooks/callback_loader.py +++ b/src/comet_llm/import_hooks/comet_module_loader.py @@ -16,20 +16,22 @@ from types import ModuleType from typing import TYPE_CHECKING, Callable, Optional +from . import module_extension, patcher + if TYPE_CHECKING: from importlib import machinery -class CallbackLoader(importlib.abc.Loader): +class CometModuleLoader(importlib.abc.Loader): def __init__( self, module_name: str, original_loader: importlib.abc.Loader, - alert_callback: Callable, + module_extension: module_extension.ModuleExtension, ) -> None: self._module_name = module_name self._original_loader = original_loader - self._alert_callback = alert_callback + self._module_extension = module_extension def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]: if hasattr(self._original_loader, "create_module"): @@ -44,4 +46,4 @@ def exec_module(self, module: ModuleType) -> None: else: module = self._original_loader.load_module(self._module_name) - self._alert_callback(module) + patcher.patch(module, self._module_extension) diff --git a/src/comet_llm/import_hooks/subverter.py b/src/comet_llm/import_hooks/finder.py similarity index 71% rename from src/comet_llm/import_hooks/subverter.py rename to src/comet_llm/import_hooks/finder.py index fee7d746b..e1cee8f12 100644 --- a/src/comet_llm/import_hooks/subverter.py +++ b/src/comet_llm/import_hooks/finder.py @@ -15,14 +15,14 @@ import sys from importlib import machinery from types import ModuleType -from typing import Callable, Dict, List, Optional +from typing import List, Optional -from . import callback_loader +from . import comet_module_loader, registry -class Subverter: - def __init__(self) -> None: - self._alert_callbacks: Dict[str, Callable] = {} +class Finder: + def __init__(self, extensions_registry: registry.Registry) -> None: + self._registry = extensions_registry self._pathfinder = machinery.PathFinder() def hook_into_import_system(self) -> None: @@ -32,7 +32,7 @@ def hook_into_import_system(self) -> None: def find_spec( self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType] ) -> Optional[machinery.ModuleSpec]: - if fullname not in self._alert_callbacks: + if fullname not in self._registry.module_names: return None original_spec = self._pathfinder.find_spec(fullname, path, target) @@ -45,9 +45,6 @@ def find_spec( def _wrap_spec_loader( self, fullname: str, spec: machinery.ModuleSpec ) -> machinery.ModuleSpec: - callback = self._alert_callbacks[fullname] - spec.loader = callback_loader.CallbackLoader(fullname, spec.loader, callback) # type: ignore + module_extension = self._registry.get_extension(fullname) + spec.loader = comet_module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore return spec - - def register_import_callback(self, module_name: str, callback: Callable) -> None: - self._alert_callbacks[module_name] = callback diff --git a/src/comet_llm/import_hooks/module_extension.py b/src/comet_llm/import_hooks/module_extension.py new file mode 100644 index 000000000..53885201b --- /dev/null +++ b/src/comet_llm/import_hooks/module_extension.py @@ -0,0 +1,22 @@ +from typing import Any, Dict, List + +from . import callable_extenders + + +class ModuleExtension: + def __init__(self) -> None: + self._callable_names_extenders: Dict[ + str, callable_extenders.CallableExtenders + ] = {} + + def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders: + if callable_name not in self._callable_names_extenders: + self._callable_names_extenders[callable_name] = callable_extenders.get() + + return self._callable_names_extenders[callable_name] + + def callable_names(self) -> List[str]: + return self._callable_names_extenders.keys() + + def items(self): + return self._callable_names_extenders.items() diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py new file mode 100644 index 000000000..c0db1eacc --- /dev/null +++ b/src/comet_llm/import_hooks/patcher.py @@ -0,0 +1,51 @@ +import inspect +from types import ModuleType +from typing import TYPE_CHECKING, Any + +from . import wrapper + +if TYPE_CHECKING: + from . import module_extension + +# _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes. + + +def _get_object(module: ModuleType, object_name: str) -> Any: + object_path = object_name.split(".") + current_object = module + + for part in object_path: + try: + current_object = getattr(current_object, part) + except AttributeError: + return None + + return current_object + + +def _set_object( + module: ModuleType, object_name: str, original: Any, new_object: Any +) -> None: + object_path = object_name.split(".") + object_to_patch = _get_object(module, object_path[:-1]) + + original_self = getattr(original, "__self__", None) + + # Support classmethod + if original_self and inspect.isclass(original_self): + new_object = classmethod(new_object) + + setattr(object_to_patch, object_path[-1], new_object) + + +def patch(module: ModuleType, module_extension: "module_extension.ModuleExtension"): + for callable_name, callable_extenders in module_extension.items(): + original = _get_object(module, callable_name) + + if original is None: + continue + + new_callable = wrapper.wrap(original, callable_extenders) + _set_object(module, callable_name, original, new_callable) + + return module diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py new file mode 100644 index 000000000..8c44f60ab --- /dev/null +++ b/src/comet_llm/import_hooks/registry.py @@ -0,0 +1,41 @@ +from typing import Any, Callable, Dict, List + +from . import callable_extenders, module_extension + + +class Registry: + def __init__(self): + self._modules_extensions: Dict[str, module_extension.ModuleExtension] = {} + + @property + def module_names(self) -> List[str]: + return self._modules_extensions.keys() + + def get_extension(self, module_name: str) -> module_extension.ModuleExtension: + return self._modules_extensions[module_name] + + def _get_callable_extenders( + self, module_name: str, callable_name: str + ) -> callable_extenders.CallableExtenders: + extension = self._modules_extensions.setdefault( + module_name, module_extension.ModuleExtension() + ) + return extension[callable_name] + + def register_before( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.before.append(patcher_function) + + def register_after( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.after.append(patcher_function) + + def register_after_exception( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.after_exception.append(patcher_function) diff --git a/src/comet_llm/import_hooks/types.py b/src/comet_llm/import_hooks/types.py new file mode 100644 index 000000000..ee3e61a23 --- /dev/null +++ b/src/comet_llm/import_hooks/types.py @@ -0,0 +1,7 @@ +from typing import Callable + +# to-do: better description for callbacks signatures + +BeforeCallback = Callable +AfterCallback = Callable +AfterExceptionCallback = Callable diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py new file mode 100644 index 000000000..db404a47d --- /dev/null +++ b/src/comet_llm/import_hooks/wrapper.py @@ -0,0 +1,25 @@ +import functools +from typing import Callable + +from . import callable_extenders, callback_runners + + +def wrap(original: Callable, callbacks: callable_extenders.CallableExtenders): + def wrapped(*args, **kwargs): + args, kwargs = callback_runners.run_before(callbacks.before, original) + try: + result = original(*args, **kwargs) + except Exception as exception: + callback_runners.run_after_exception( + callbacks.after_exception, original, exception, *args, **kwargs + ) + raise exception + + callback_runners.run_after(callbacks.after, original, result, *args, **kwargs) + + # Simulate functools.wraps behavior but make it working with mocks + for attr in functools.WRAPPER_ASSIGNMENTS: + if hasattr(original, attr): + setattr(wrapped, attr, getattr(original, attr)) + + return wrapped From b87549574f3c6b444785b5b74f77cbb4880560f2 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Wed, 2 Aug 2023 19:55:48 +0200 Subject: [PATCH 03/20] Rename module --- src/comet_llm/import_hooks/finder.py | 4 ++-- .../import_hooks/{comet_module_loader.py => module_loader.py} | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename src/comet_llm/import_hooks/{comet_module_loader.py => module_loader.py} (100%) diff --git a/src/comet_llm/import_hooks/finder.py b/src/comet_llm/import_hooks/finder.py index e1cee8f12..8dce4adec 100644 --- a/src/comet_llm/import_hooks/finder.py +++ b/src/comet_llm/import_hooks/finder.py @@ -17,7 +17,7 @@ from types import ModuleType from typing import List, Optional -from . import comet_module_loader, registry +from . import module_loader, registry class Finder: @@ -46,5 +46,5 @@ def _wrap_spec_loader( self, fullname: str, spec: machinery.ModuleSpec ) -> machinery.ModuleSpec: module_extension = self._registry.get_extension(fullname) - spec.loader = comet_module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore + spec.loader = module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore return spec diff --git a/src/comet_llm/import_hooks/comet_module_loader.py b/src/comet_llm/import_hooks/module_loader.py similarity index 100% rename from src/comet_llm/import_hooks/comet_module_loader.py rename to src/comet_llm/import_hooks/module_loader.py From 29deb4f919eb3414c9ae00da0b05e3920bd81a49 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Wed, 2 Aug 2023 20:28:47 +0200 Subject: [PATCH 04/20] Make it work on psutil example.Hooray --- src/comet_llm/import_hooks/__init__.py | 17 +++++++---------- src/comet_llm/import_hooks/patcher.py | 17 ++++++++--------- src/comet_llm/import_hooks/registry.py | 2 +- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py index 3673d7f54..c58f23d53 100644 --- a/src/comet_llm/import_hooks/__init__.py +++ b/src/comet_llm/import_hooks/__init__.py @@ -14,20 +14,17 @@ # type: ignore -from . import finder +from . import finder, registry -def patch_psutil(psutil): - original = psutil.cpu_count +def print_message(original, return_value, *args, **kwargs): + print("Before psutil.cpu_count() call!") - def patched(*args, **kwargs): - print("Before psutil.cpu_count() call!") - return original(*args, **kwargs) - psutil.cpu_count = patched +_registry = registry.Registry() +_registry.register_after("psutil", "cpu_count", print_message) +_finder = finder.Finder(_registry) -_subverter = finder.Finder() -_subverter.register_import_callback("psutil", patch_psutil) -_subverter.hook_into_import_system() +_finder.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py index c0db1eacc..2f77130c4 100644 --- a/src/comet_llm/import_hooks/patcher.py +++ b/src/comet_llm/import_hooks/patcher.py @@ -10,11 +10,10 @@ # _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes. -def _get_object(module: ModuleType, object_name: str) -> Any: - object_path = object_name.split(".") +def _get_object(module: ModuleType, callable_path: str) -> Any: current_object = module - for part in object_path: + for part in callable_path: try: current_object = getattr(current_object, part) except AttributeError: @@ -24,10 +23,9 @@ def _get_object(module: ModuleType, object_name: str) -> Any: def _set_object( - module: ModuleType, object_name: str, original: Any, new_object: Any + module: ModuleType, callable_path: str, original: Any, new_object: Any ) -> None: - object_path = object_name.split(".") - object_to_patch = _get_object(module, object_path[:-1]) + object_to_patch = _get_object(module, callable_path[:-1]) original_self = getattr(original, "__self__", None) @@ -35,17 +33,18 @@ def _set_object( if original_self and inspect.isclass(original_self): new_object = classmethod(new_object) - setattr(object_to_patch, object_path[-1], new_object) + setattr(object_to_patch, callable_path[-1], new_object) def patch(module: ModuleType, module_extension: "module_extension.ModuleExtension"): for callable_name, callable_extenders in module_extension.items(): - original = _get_object(module, callable_name) + callable_path = callable_name.split(".") + original = _get_object(module, callable_path) if original is None: continue new_callable = wrapper.wrap(original, callable_extenders) - _set_object(module, callable_name, original, new_callable) + _set_object(module, callable_path, original, new_callable) return module diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index 8c44f60ab..7354f36db 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -20,7 +20,7 @@ def _get_callable_extenders( extension = self._modules_extensions.setdefault( module_name, module_extension.ModuleExtension() ) - return extension[callable_name] + return extension.extenders(callable_name) def register_before( self, module_name: str, callable_name: str, patcher_function: Callable From f6035858196d55ee90027dbda9d71278ac28f002 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 12:25:05 +0200 Subject: [PATCH 05/20] Fix wrapper --- src/comet_llm/import_hooks/__init__.py | 13 ++++++++++--- src/comet_llm/import_hooks/finder.py | 2 +- src/comet_llm/import_hooks/wrapper.py | 8 ++++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py index c58f23d53..0c8ee0921 100644 --- a/src/comet_llm/import_hooks/__init__.py +++ b/src/comet_llm/import_hooks/__init__.py @@ -17,14 +17,21 @@ from . import finder, registry -def print_message(original, return_value, *args, **kwargs): +def print_message1(original, *args, **kwargs): print("Before psutil.cpu_count() call!") +def print_message2(original, return_value, *args, **kwargs): + print("After psutil.cpu_count() call!") + + _registry = registry.Registry() -_registry.register_after("psutil", "cpu_count", print_message) -_finder = finder.Finder(_registry) +_registry.register_before("psutil", "cpu_count", print_message1) +_registry.register_after("psutil", "cpu_count", print_message2) + +_registry +_finder = finder.CometFinder(_registry) _finder.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/finder.py b/src/comet_llm/import_hooks/finder.py index 8dce4adec..6c5daea6f 100644 --- a/src/comet_llm/import_hooks/finder.py +++ b/src/comet_llm/import_hooks/finder.py @@ -20,7 +20,7 @@ from . import module_loader, registry -class Finder: +class CometFinder: def __init__(self, extensions_registry: registry.Registry) -> None: self._registry = extensions_registry self._pathfinder = machinery.PathFinder() diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index db404a47d..fb07454eb 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -6,7 +6,7 @@ def wrap(original: Callable, callbacks: callable_extenders.CallableExtenders): def wrapped(*args, **kwargs): - args, kwargs = callback_runners.run_before(callbacks.before, original) + args, kwargs = callback_runners.run_before(callbacks.before, original, *args, **kwargs) try: result = original(*args, **kwargs) except Exception as exception: @@ -15,7 +15,11 @@ def wrapped(*args, **kwargs): ) raise exception - callback_runners.run_after(callbacks.after, original, result, *args, **kwargs) + result = callback_runners.run_after( + callbacks.after, original, result, *args, **kwargs + ) + + return result # Simulate functools.wraps behavior but make it working with mocks for attr in functools.WRAPPER_ASSIGNMENTS: From 40d096875923927e6af0ed8ae0823ce8c1561803 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 13:21:21 +0200 Subject: [PATCH 06/20] Add copyrights --- .../import_hooks/callable_extenders.py | 14 ++++++++++++++ src/comet_llm/import_hooks/callback_runners.py | 14 ++++++++++++++ src/comet_llm/import_hooks/module_extension.py | 14 ++++++++++++++ src/comet_llm/import_hooks/patcher.py | 14 ++++++++++++++ src/comet_llm/import_hooks/registry.py | 14 ++++++++++++++ src/comet_llm/import_hooks/types.py | 14 ++++++++++++++ src/comet_llm/import_hooks/wrapper.py | 18 +++++++++++++++++- 7 files changed, 101 insertions(+), 1 deletion(-) diff --git a/src/comet_llm/import_hooks/callable_extenders.py b/src/comet_llm/import_hooks/callable_extenders.py index 1e184a805..eab786dab 100644 --- a/src/comet_llm/import_hooks/callable_extenders.py +++ b/src/comet_llm/import_hooks/callable_extenders.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + import dataclasses from typing import List diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py index 2a8ef8dfe..dfcd4afd6 100644 --- a/src/comet_llm/import_hooks/callback_runners.py +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + import logging from typing import Any, Callable, Dict, List, Tuple, Union diff --git a/src/comet_llm/import_hooks/module_extension.py b/src/comet_llm/import_hooks/module_extension.py index 53885201b..d61b58f02 100644 --- a/src/comet_llm/import_hooks/module_extension.py +++ b/src/comet_llm/import_hooks/module_extension.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + from typing import Any, Dict, List from . import callable_extenders diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py index 2f77130c4..ba130b071 100644 --- a/src/comet_llm/import_hooks/patcher.py +++ b/src/comet_llm/import_hooks/patcher.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + import inspect from types import ModuleType from typing import TYPE_CHECKING, Any diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index 7354f36db..858d5ef31 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + from typing import Any, Callable, Dict, List from . import callable_extenders, module_extension diff --git a/src/comet_llm/import_hooks/types.py b/src/comet_llm/import_hooks/types.py index ee3e61a23..9de78bda8 100644 --- a/src/comet_llm/import_hooks/types.py +++ b/src/comet_llm/import_hooks/types.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + from typing import Callable # to-do: better description for callbacks signatures diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index fb07454eb..5d639997c 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -1,3 +1,17 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + import functools from typing import Callable @@ -6,7 +20,9 @@ def wrap(original: Callable, callbacks: callable_extenders.CallableExtenders): def wrapped(*args, **kwargs): - args, kwargs = callback_runners.run_before(callbacks.before, original, *args, **kwargs) + args, kwargs = callback_runners.run_before( + callbacks.before, original, *args, **kwargs + ) try: result = original(*args, **kwargs) except Exception as exception: From d33293fcb3152338dd15452c77c45ed144d0f1e5 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 13:36:07 +0200 Subject: [PATCH 07/20] Fix lint errors --- src/comet_llm/import_hooks/callback_runners.py | 6 +++--- src/comet_llm/import_hooks/module_extension.py | 7 ++++--- src/comet_llm/import_hooks/patcher.py | 6 +++--- src/comet_llm/import_hooks/registry.py | 4 ++-- src/comet_llm/import_hooks/wrapper.py | 6 ++++-- 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py index dfcd4afd6..f495af252 100644 --- a/src/comet_llm/import_hooks/callback_runners.py +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -23,7 +23,7 @@ ArgsKwargs = Tuple[Args, Dict[str, Any]] -def run_before( +def run_before( # type: ignore callbacks: List[BeforeCallback], original: Callable, *args, **kwargs ) -> ArgsKwargs: for callback in callbacks: @@ -41,7 +41,7 @@ def run_before( return args, kwargs -def run_after( +def run_after( # type: ignore callbacks: List[AfterCallback], original: Callable, return_value: Any, @@ -59,7 +59,7 @@ def run_after( return return_value -def run_after_exception( +def run_after_exception( # type: ignore callbacks: List[AfterExceptionCallback], original: Callable, exception: Exception, diff --git a/src/comet_llm/import_hooks/module_extension.py b/src/comet_llm/import_hooks/module_extension.py index d61b58f02..3bc0024f5 100644 --- a/src/comet_llm/import_hooks/module_extension.py +++ b/src/comet_llm/import_hooks/module_extension.py @@ -12,7 +12,8 @@ # permission of Comet ML Inc. # ******************************************************* -from typing import Any, Dict, List +from collections.abc import ItemsView, KeysView +from typing import Dict from . import callable_extenders @@ -29,8 +30,8 @@ def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders: return self._callable_names_extenders[callable_name] - def callable_names(self) -> List[str]: + def callable_names(self): # type: ignore return self._callable_names_extenders.keys() - def items(self): + def items(self): # type: ignore return self._callable_names_extenders.items() diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py index ba130b071..70963c69d 100644 --- a/src/comet_llm/import_hooks/patcher.py +++ b/src/comet_llm/import_hooks/patcher.py @@ -50,7 +50,9 @@ def _set_object( setattr(object_to_patch, callable_path[-1], new_object) -def patch(module: ModuleType, module_extension: "module_extension.ModuleExtension"): +def patch( + module: ModuleType, module_extension: "module_extension.ModuleExtension" +) -> None: for callable_name, callable_extenders in module_extension.items(): callable_path = callable_name.split(".") original = _get_object(module, callable_path) @@ -60,5 +62,3 @@ def patch(module: ModuleType, module_extension: "module_extension.ModuleExtensio new_callable = wrapper.wrap(original, callable_extenders) _set_object(module, callable_path, original, new_callable) - - return module diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index 858d5ef31..27ef7184a 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -18,11 +18,11 @@ class Registry: - def __init__(self): + def __init__(self) -> None: self._modules_extensions: Dict[str, module_extension.ModuleExtension] = {} @property - def module_names(self) -> List[str]: + def module_names(self): # type: ignore return self._modules_extensions.keys() def get_extension(self, module_name: str) -> module_extension.ModuleExtension: diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index 5d639997c..3922ad158 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -18,8 +18,10 @@ from . import callable_extenders, callback_runners -def wrap(original: Callable, callbacks: callable_extenders.CallableExtenders): - def wrapped(*args, **kwargs): +def wrap( + original: Callable, callbacks: callable_extenders.CallableExtenders +) -> Callable: + def wrapped(*args, **kwargs): # type: ignore args, kwargs = callback_runners.run_before( callbacks.before, original, *args, **kwargs ) From 49fb232743c43d3e00afaf148a5830ecd98a1359 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 17:25:21 +0200 Subject: [PATCH 08/20] Add docstrings to registry --- src/comet_llm/import_hooks/__init__.py | 24 +++++++++++--------- src/comet_llm/import_hooks/registry.py | 31 ++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py index 0c8ee0921..0e2aa5a23 100644 --- a/src/comet_llm/import_hooks/__init__.py +++ b/src/comet_llm/import_hooks/__init__.py @@ -16,22 +16,26 @@ from . import finder, registry +# def extract_cpu_count_args(logical): +# return logical -def print_message1(original, *args, **kwargs): - print("Before psutil.cpu_count() call!") +# def print_message1(original, *args, **kwargs): +# logical = extract_cpu_count_args(*args, **kwargs) +# print(f"Before psutil.cpu_count(logical={logical}) call!") -def print_message2(original, return_value, *args, **kwargs): - print("After psutil.cpu_count() call!") +# def print_message2(original, return_value, *args, **kwargs): +# logical = extract_cpu_count_args(*args, **kwargs) +# print(f"After psutil.cpu_count(logical={logical}) call!") -_registry = registry.Registry() +# _registry = registry.Registry() -_registry.register_before("psutil", "cpu_count", print_message1) -_registry.register_after("psutil", "cpu_count", print_message2) +# _registry.register_before("psutil", "cpu_count", print_message1) +# _registry.register_after("psutil", "cpu_count", print_message2) -_registry -_finder = finder.CometFinder(_registry) +# _registry +# _finder = finder.CometFinder(_registry) -_finder.hook_into_import_system() +# _finder.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index 27ef7184a..d86709bdd 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -39,17 +39,48 @@ def _get_callable_extenders( def register_before( self, module_name: str, callable_name: str, patcher_function: Callable ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, # original callable to patch + *args, + **kwargs + ) + Return value of patcher function is expected to be either None + or [Args,Kwargs] tuple to overwrite original args and kwargs + """ extenders = self._get_callable_extenders(module_name, callable_name) extenders.before.append(patcher_function) def register_after( self, module_name: str, callable_name: str, patcher_function: Callable ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, # original callable to patch + return_value, # value returned by original callable + *args, + **kwargs + ) + Return value of patcher function will overwrite return_value of + patched function if not None + """ extenders = self._get_callable_extenders(module_name, callable_name) extenders.after.append(patcher_function) def register_after_exception( self, module_name: str, callable_name: str, patcher_function: Callable ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, # original callable to patch + exception, # exception thrown from original callable + *args, + **kwargs + ) + Expected to return None. + """ extenders = self._get_callable_extenders(module_name, callable_name) extenders.after_exception.append(patcher_function) From 22e22545170d50b784148f9f901266596a8efa0c Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 18:06:22 +0200 Subject: [PATCH 09/20] Update docstrings --- src/comet_llm/import_hooks/registry.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index d86709bdd..fe18a968f 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -42,10 +42,12 @@ def register_before( """ patcher_function: Callable with the following signature func( - original, # original callable to patch + original, *args, **kwargs ) + original - original callable to patch + Return value of patcher function is expected to be either None or [Args,Kwargs] tuple to overwrite original args and kwargs """ @@ -58,11 +60,14 @@ def register_after( """ patcher_function: Callable with the following signature func( - original, # original callable to patch - return_value, # value returned by original callable + original, + return_value, *args, **kwargs ) + original - original callable to patch + return_value - value returned by original callable + Return value of patcher function will overwrite return_value of patched function if not None """ @@ -75,11 +80,14 @@ def register_after_exception( """ patcher_function: Callable with the following signature func( - original, # original callable to patch - exception, # exception thrown from original callable + original, + exception, *args, **kwargs ) + original - original callable to patch + exception - exception thrown from original callable + Expected to return None. """ extenders = self._get_callable_extenders(module_name, callable_name) From 6987aeeebfd33cff147d31f35f14ef1f7889ab29 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Thu, 3 Aug 2023 18:36:48 +0200 Subject: [PATCH 10/20] Fix lint errors --- src/comet_llm/import_hooks/__init__.py | 28 ------ src/comet_llm/import_hooks/registry.py | 2 +- tests/unit/import_hooks/__init__.py | 0 .../import_hooks/fake_package/__init__.py | 0 .../import_hooks/fake_package/fake_module.py | 44 +++++++++ tests/unit/import_hooks/test_import_hooks.py | 89 +++++++++++++++++++ 6 files changed, 134 insertions(+), 29 deletions(-) create mode 100644 tests/unit/import_hooks/__init__.py create mode 100644 tests/unit/import_hooks/fake_package/__init__.py create mode 100644 tests/unit/import_hooks/fake_package/fake_module.py create mode 100644 tests/unit/import_hooks/test_import_hooks.py diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py index 0e2aa5a23..95c64cfdf 100644 --- a/src/comet_llm/import_hooks/__init__.py +++ b/src/comet_llm/import_hooks/__init__.py @@ -11,31 +11,3 @@ # This file can not be copied and/or distributed without the express # permission of Comet ML Inc. # ******************************************************* - -# type: ignore - -from . import finder, registry - -# def extract_cpu_count_args(logical): -# return logical - -# def print_message1(original, *args, **kwargs): -# logical = extract_cpu_count_args(*args, **kwargs) -# print(f"Before psutil.cpu_count(logical={logical}) call!") - - -# def print_message2(original, return_value, *args, **kwargs): -# logical = extract_cpu_count_args(*args, **kwargs) -# print(f"After psutil.cpu_count(logical={logical}) call!") - - -# _registry = registry.Registry() - -# _registry.register_before("psutil", "cpu_count", print_message1) -# _registry.register_after("psutil", "cpu_count", print_message2) - -# _registry -# _finder = finder.CometFinder(_registry) - - -# _finder.hook_into_import_system() diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py index fe18a968f..e268fb9d0 100644 --- a/src/comet_llm/import_hooks/registry.py +++ b/src/comet_llm/import_hooks/registry.py @@ -47,7 +47,7 @@ def register_before( **kwargs ) original - original callable to patch - + Return value of patcher function is expected to be either None or [Args,Kwargs] tuple to overwrite original args and kwargs """ diff --git a/tests/unit/import_hooks/__init__.py b/tests/unit/import_hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/import_hooks/fake_package/__init__.py b/tests/unit/import_hooks/fake_package/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/import_hooks/fake_package/fake_module.py b/tests/unit/import_hooks/fake_package/fake_module.py new file mode 100644 index 000000000..f53c94416 --- /dev/null +++ b/tests/unit/import_hooks/fake_package/fake_module.py @@ -0,0 +1,44 @@ +# Fake module + +from unittest.mock import Mock + +FUNCTION_1_MOCK = Mock(return_value="function-1-return-value") +FUNCTION_2_MOCK = Mock(return_value="function-2-return-value") + +STATIC_METHOD_MOCK = Mock(return_value="static-method-return-value") + + +def function1(*args, **kwargs): + return FUNCTION_1_MOCK(*args, **kwargs) + + +def function2(*args, **kwargs): + return FUNCTION_2_MOCK(*args, **kwargs) + + +class Klass: + clsmethodmock = Mock() + + def __init__(self): + self.mock = Mock() + self.mock2 = Mock() + + def method(self, *args, **kwargs): + return self.mock(*args, **kwargs) + + def method2(self, *args, **kwargs): + return self.mock2(*args, **kwargs) + + @classmethod + def clsmethod(cls, *args, **kwargs): + print("Locals", locals()) + return cls.clsmethodmock(*args, **kwargs) + + @staticmethod + def statikmethod(*args, **kwargs): + return STATIC_METHOD_MOCK(*args, **kwargs) + + +class Child(Klass): + def method(self, *args, **kwargs): + return super(Child, self).method(*args, **kwargs) diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py new file mode 100644 index 000000000..d2e3dd5f9 --- /dev/null +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -0,0 +1,89 @@ +from unittest import mock + +import pytest + +from comet_llm.import_hooks import finder, registry + + +@pytest.fixture +def fake_module_path(): + parent_name = '.'.join(__name__.split('.')[:-1]) + + return f"{parent_name}.fake_package.fake_module" + +@pytest.mark.forked +def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock1 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function1", mock1) + + mock2 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function2", mock2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1() + fake_module.function2() + + # Check function1 + mock1.assert_called_once_with(mock.ANY, "function-1-return-value") + original = mock1.call_args[0][0] + assert original.__name__ == "function1" + + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock2.assert_called_once_with(mock.ANY, "function-2-return-value") + original = mock2.call_args[0][0] + assert original.__name__ == "function2" + + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_functions_in_module__register_before__without_arguments(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock1 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock1) + + mock2 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function2", mock2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1() + fake_module.function2() + + # Check function1 + mock1.assert_called_once_with(mock.ANY) + original = mock1.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock1.assert_called_once_with(mock.ANY) + original = mock2.call_args[0][0] + assert original.__name__ == "function2" + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() From d2f19e38cb47fc8867886e81122448198a1607a1 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 00:24:00 +0200 Subject: [PATCH 11/20] Add new tests for patching functions --- .../import_hooks/fake_package/fake_module.py | 5 + tests/unit/import_hooks/test_import_hooks.py | 293 +++++++++++++++++- 2 files changed, 281 insertions(+), 17 deletions(-) diff --git a/tests/unit/import_hooks/fake_package/fake_module.py b/tests/unit/import_hooks/fake_package/fake_module.py index f53c94416..3014ef751 100644 --- a/tests/unit/import_hooks/fake_package/fake_module.py +++ b/tests/unit/import_hooks/fake_package/fake_module.py @@ -4,6 +4,7 @@ FUNCTION_1_MOCK = Mock(return_value="function-1-return-value") FUNCTION_2_MOCK = Mock(return_value="function-2-return-value") +FUNCTION_3_MOCK = Mock() STATIC_METHOD_MOCK = Mock(return_value="static-method-return-value") @@ -15,6 +16,10 @@ def function1(*args, **kwargs): def function2(*args, **kwargs): return FUNCTION_2_MOCK(*args, **kwargs) +def function3(*args, **kwargs): + FUNCTION_3_MOCK(*args, **kwargs) + raise Exception("raising-function-exception-message") + class Klass: clsmethodmock = Mock() diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index d2e3dd5f9..946afdfad 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -11,16 +11,19 @@ def fake_module_path(): return f"{parent_name}.fake_package.fake_module" +ORIGINAL = mock.ANY +EXCEPTION = mock.ANY + @pytest.mark.forked def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): extensions_registry = registry.Registry() # Prepare - mock1 = mock.Mock() - extensions_registry.register_after(fake_module_path, "function1", mock1) + mock_callback1 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function1", mock_callback1) - mock2 = mock.Mock() - extensions_registry.register_after(fake_module_path, "function2", mock2) + mock_callback2 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function2", mock_callback2) comet_finder = finder.CometFinder(extensions_registry) comet_finder.hook_into_import_system() @@ -33,8 +36,8 @@ def test_patch_functions_in_module__register_after__without_arguments(fake_modul fake_module.function2() # Check function1 - mock1.assert_called_once_with(mock.ANY, "function-1-return-value") - original = mock1.call_args[0][0] + mock_callback1.assert_called_once_with(ORIGINAL, "function-1-return-value") + original = mock_callback1.call_args[0][0] assert original.__name__ == "function1" assert original is not fake_module.function1 @@ -42,10 +45,9 @@ def test_patch_functions_in_module__register_after__without_arguments(fake_modul fake_module.FUNCTION_1_MOCK.assert_called_once_with() # Check function2 - mock2.assert_called_once_with(mock.ANY, "function-2-return-value") - original = mock2.call_args[0][0] + mock_callback2.assert_called_once_with(ORIGINAL, "function-2-return-value") + original = mock_callback2.call_args[0][0] assert original.__name__ == "function2" - assert original is not fake_module.function2 fake_module.FUNCTION_2_MOCK.assert_called_once_with() @@ -56,11 +58,11 @@ def test_patch_functions_in_module__register_before__without_arguments(fake_modu extensions_registry = registry.Registry() # Prepare - mock1 = mock.Mock() - extensions_registry.register_before(fake_module_path, "function1", mock1) + mock_callback1 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback1) - mock2 = mock.Mock() - extensions_registry.register_before(fake_module_path, "function2", mock2) + mock_callback2 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function2", mock_callback2) comet_finder = finder.CometFinder(extensions_registry) comet_finder.hook_into_import_system() @@ -73,17 +75,274 @@ def test_patch_functions_in_module__register_before__without_arguments(fake_modu fake_module.function2() # Check function1 - mock1.assert_called_once_with(mock.ANY) - original = mock1.call_args[0][0] + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback1.call_args[0][0] assert original.__name__ == "function1" assert original is not fake_module.function1 fake_module.FUNCTION_1_MOCK.assert_called_once_with() # Check function2 - mock1.assert_called_once_with(mock.ANY) - original = mock2.call_args[0][0] + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback2.call_args[0][0] assert original.__name__ == "function2" assert original is not fake_module.function2 fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_functions_in_module__register_before_and_after__without_arguments(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback1 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback1) + + mock_callback2 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function2", mock_callback2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1() + fake_module.function2() + + # Check function1 + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback1.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock_callback2.assert_called_once_with(ORIGINAL, "function-2-return-value") + original = mock_callback2.call_args[0][0] + assert original.__name__ == "function2" + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__happyflow(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__callback_changes_input_arguments(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock( + return_value=( + ("new-arg-1", "new-arg-2"), + {"kwarg1":"new-kwarg-1", "kwarg2":"new-kwarg-2"} + ) + ) + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("new-arg-1", "new-arg-2", kwarg1="new-kwarg-1", kwarg2="new-kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__error_in_callback__original_function_worked(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__happyflow(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock() + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__error_in_callback__original_function_worked(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__callback_changes_return_value(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock(return_value="new-return-value") + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + result = fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + assert result == "new-return-value" + + +@pytest.mark.forked +def test_patch_raising_function_in_module__register_after_exception__happyflow(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock() + extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + with pytest.raises(Exception): + fake_module.function3("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + ORIGINAL = mock.ANY + EXCEPTION = mock.ANY + mock_callback.assert_called_once_with(ORIGINAL, EXCEPTION, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function3" + assert original is not fake_module.function3 + + fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_raising_function_in_module__register_after_exception__error_in_callback__original_function_worked(fake_module_path): + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + with pytest.raises(Exception): + fake_module.function3("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, EXCEPTION, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function3" + assert original is not fake_module.function3 + + fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") From 599dd4aefc5bfbc7758a5f748e1b12f77f98cca9 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 00:55:24 +0200 Subject: [PATCH 12/20] Add tests --- .../import_hooks/module_extension.py | 16 +++------ tests/unit/import_hooks/test_import_hooks.py | 33 +++++++++++++++++++ 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/src/comet_llm/import_hooks/module_extension.py b/src/comet_llm/import_hooks/module_extension.py index 3bc0024f5..1d109e8ee 100644 --- a/src/comet_llm/import_hooks/module_extension.py +++ b/src/comet_llm/import_hooks/module_extension.py @@ -12,7 +12,6 @@ # permission of Comet ML Inc. # ******************************************************* -from collections.abc import ItemsView, KeysView from typing import Dict from . import callable_extenders @@ -20,18 +19,13 @@ class ModuleExtension: def __init__(self) -> None: - self._callable_names_extenders: Dict[ - str, callable_extenders.CallableExtenders - ] = {} + self._callables_extenders: Dict[str, callable_extenders.CallableExtenders] = {} def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders: - if callable_name not in self._callable_names_extenders: - self._callable_names_extenders[callable_name] = callable_extenders.get() + if callable_name not in self._callables_extenders: + self._callables_extenders[callable_name] = callable_extenders.get() - return self._callable_names_extenders[callable_name] - - def callable_names(self): # type: ignore - return self._callable_names_extenders.keys() + return self._callables_extenders[callable_name] def items(self): # type: ignore - return self._callable_names_extenders.items() + return self._callables_extenders.items() diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index 946afdfad..c9b1075e6 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -13,6 +13,7 @@ def fake_module_path(): ORIGINAL = mock.ANY EXCEPTION = mock.ANY +SELF = mock.ANY @pytest.mark.forked def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): @@ -346,3 +347,35 @@ def test_patch_raising_function_in_module__register_after_exception__error_in_ca assert original is not fake_module.function3 fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +def test_patch_method_in_module__happyflow(fake_module_path): + # Prepare + extensions_registry = registry.Registry() + + # Prepare + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + instance = fake_module.Klass() + + # Set the return + instance.mock.return_value = mock.sentinel.METHOD + + assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "method" + + assert original is not fake_module.Klass.method + + instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") From 028c10d1596190436ec7267da0685a1a0f7e872c Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 01:15:05 +0200 Subject: [PATCH 13/20] Add validate.py and tests for it --- .../import_hooks/callback_runners.py | 21 ++--------------- src/comet_llm/import_hooks/validate.py | 18 +++++++++++++++ tests/unit/import_hooks/test_validate.py | 23 +++++++++++++++++++ 3 files changed, 43 insertions(+), 19 deletions(-) create mode 100644 src/comet_llm/import_hooks/validate.py create mode 100644 tests/unit/import_hooks/test_validate.py diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py index f495af252..3bd784f2a 100644 --- a/src/comet_llm/import_hooks/callback_runners.py +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Tuple, Union from .types import AfterCallback, AfterExceptionCallback, BeforeCallback +from . import validate LOGGER = logging.getLogger(__name__) @@ -30,7 +31,7 @@ def run_before( # type: ignore try: callback_return = callback(original, *args, **kwargs) - if _valid_new_args_kwargs(callback_return): + if validate.args_kwargs(callback_return): LOGGER.debug("New args %r", callback_return) args, kwargs = callback_return except Exception: @@ -73,21 +74,3 @@ def run_after_exception( # type: ignore LOGGER.debug( "Exception calling after-exception callback %r", callback, exc_info=True ) - - -def _valid_new_args_kwargs(callback_return: Any) -> bool: - if callback_return is None: - return False - - try: - args, kwargs = callback_return - except (ValueError, TypeError): - return False - - if not isinstance(args, (list, tuple)): - return False - - if not isinstance(kwargs, dict): - return False - - return True diff --git a/src/comet_llm/import_hooks/validate.py b/src/comet_llm/import_hooks/validate.py new file mode 100644 index 000000000..1290c38f7 --- /dev/null +++ b/src/comet_llm/import_hooks/validate.py @@ -0,0 +1,18 @@ +from typing import Any + +def args_kwargs(obj: Any) -> bool: + if obj is None: + return False + + try: + args, kwargs = obj + except (ValueError, TypeError): + return False + + if not isinstance(args, (list, tuple)): + return False + + if not isinstance(kwargs, dict): + return False + + return True diff --git a/tests/unit/import_hooks/test_validate.py b/tests/unit/import_hooks/test_validate.py new file mode 100644 index 000000000..e49edc071 --- /dev/null +++ b/tests/unit/import_hooks/test_validate.py @@ -0,0 +1,23 @@ +from comet_llm.import_hooks import validate + +def test_args_kwargs__happyflow(): + args_kwargs = ([1], {"foo": "bar"}) + assert validate.args_kwargs(args_kwargs) is True + + +def test_args_kwargs__input_is_None__return_False(): + assert validate.args_kwargs(None) is False + + +def test_args_kwargs__input_is_not_tuple_or_list_of_length_2__return_False(): + assert validate.args_kwargs(42) is False + + +def test_args_kwargs__args_cant_be_parsed__return_False(): + args_kwargs = (42, {}) + assert validate.args_kwargs(args_kwargs) is False + + +def test_args_kwargs__kwargs_cant_be_parsed__return_False(): + args_kwargs = ([1], 42) + assert validate.args_kwargs(args_kwargs) is False \ No newline at end of file From 183fa29aecbc8f998ad4b0dfbe60005fb222d2f9 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 01:15:51 +0200 Subject: [PATCH 14/20] Fix lint errors --- src/comet_llm/import_hooks/callback_runners.py | 2 +- src/comet_llm/import_hooks/types.py | 2 +- src/comet_llm/import_hooks/validate.py | 15 +++++++++++++++ tests/unit/import_hooks/test_validate.py | 1 + 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py index 3bd784f2a..155537a04 100644 --- a/src/comet_llm/import_hooks/callback_runners.py +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -15,8 +15,8 @@ import logging from typing import Any, Callable, Dict, List, Tuple, Union -from .types import AfterCallback, AfterExceptionCallback, BeforeCallback from . import validate +from .types import AfterCallback, AfterExceptionCallback, BeforeCallback LOGGER = logging.getLogger(__name__) diff --git a/src/comet_llm/import_hooks/types.py b/src/comet_llm/import_hooks/types.py index 9de78bda8..612f8cb68 100644 --- a/src/comet_llm/import_hooks/types.py +++ b/src/comet_llm/import_hooks/types.py @@ -14,7 +14,7 @@ from typing import Callable -# to-do: better description for callbacks signatures +# to-do: better description for callbacks signatures? BeforeCallback = Callable AfterCallback = Callable diff --git a/src/comet_llm/import_hooks/validate.py b/src/comet_llm/import_hooks/validate.py index 1290c38f7..936306354 100644 --- a/src/comet_llm/import_hooks/validate.py +++ b/src/comet_llm/import_hooks/validate.py @@ -1,5 +1,20 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + from typing import Any + def args_kwargs(obj: Any) -> bool: if obj is None: return False diff --git a/tests/unit/import_hooks/test_validate.py b/tests/unit/import_hooks/test_validate.py index e49edc071..dfbfeebf3 100644 --- a/tests/unit/import_hooks/test_validate.py +++ b/tests/unit/import_hooks/test_validate.py @@ -1,5 +1,6 @@ from comet_llm.import_hooks import validate + def test_args_kwargs__happyflow(): args_kwargs = ([1], {"foo": "bar"}) assert validate.args_kwargs(args_kwargs) is True From a6147b3b6970a02cfbfcaadf2946bef0d41f5a47 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 01:33:45 +0200 Subject: [PATCH 15/20] Add tests --- tests/unit/import_hooks/test_import_hooks.py | 50 +++++++++++--------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index c9b1075e6..1ffc3b6d9 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -13,13 +13,27 @@ def fake_module_path(): ORIGINAL = mock.ANY EXCEPTION = mock.ANY -SELF = mock.ANY + @pytest.mark.forked -def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): +def test_patch_function_in_module__name_to_patch_not_found__no_failure(fake_module_path): + # Prepare extensions_registry = registry.Registry() + extensions_registry.register_after(fake_module_path, "non_existing_function", "any-callback") + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + # Import + from .fake_package import fake_module + + assert hasattr(fake_module, "non_existing_function") is False + +@pytest.mark.forked +def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): # Prepare + extensions_registry = registry.Registry() + mock_callback1 = mock.Mock() extensions_registry.register_after(fake_module_path, "function1", mock_callback1) @@ -56,9 +70,9 @@ def test_patch_functions_in_module__register_after__without_arguments(fake_modul @pytest.mark.forked def test_patch_functions_in_module__register_before__without_arguments(fake_module_path): + # Prepare extensions_registry = registry.Registry() - # Prepare mock_callback1 = mock.Mock() extensions_registry.register_before(fake_module_path, "function1", mock_callback1) @@ -94,9 +108,9 @@ def test_patch_functions_in_module__register_before__without_arguments(fake_modu @pytest.mark.forked def test_patch_functions_in_module__register_before_and_after__without_arguments(fake_module_path): + # Prepare extensions_registry = registry.Registry() - # Prepare mock_callback1 = mock.Mock() extensions_registry.register_before(fake_module_path, "function1", mock_callback1) @@ -132,9 +146,8 @@ def test_patch_functions_in_module__register_before_and_after__without_arguments @pytest.mark.forked def test_patch_function_in_module__register_before__happyflow(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "function1", mock_callback) @@ -158,9 +171,8 @@ def test_patch_function_in_module__register_before__happyflow(fake_module_path): @pytest.mark.forked def test_patch_function_in_module__register_before__callback_changes_input_arguments(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock( return_value=( ("new-arg-1", "new-arg-2"), @@ -189,9 +201,8 @@ def test_patch_function_in_module__register_before__callback_changes_input_argum @pytest.mark.forked def test_patch_function_in_module__register_before__error_in_callback__original_function_worked(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_before(fake_module_path, "function1", mock_callback) @@ -216,9 +227,8 @@ def test_patch_function_in_module__register_before__error_in_callback__original_ @pytest.mark.forked def test_patch_function_in_module__register_after__happyflow(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -242,9 +252,8 @@ def test_patch_function_in_module__register_after__happyflow(fake_module_path): @pytest.mark.forked def test_patch_function_in_module__register_after__error_in_callback__original_function_worked(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -268,9 +277,8 @@ def test_patch_function_in_module__register_after__error_in_callback__original_f @pytest.mark.forked def test_patch_function_in_module__register_after__callback_changes_return_value(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock(return_value="new-return-value") extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -295,9 +303,8 @@ def test_patch_function_in_module__register_after__callback_changes_return_value @pytest.mark.forked def test_patch_raising_function_in_module__register_after_exception__happyflow(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) @@ -324,9 +331,8 @@ def test_patch_raising_function_in_module__register_after_exception__happyflow(f @pytest.mark.forked def test_patch_raising_function_in_module__register_after_exception__error_in_callback__original_function_worked(fake_module_path): - extensions_registry = registry.Registry() - # Prepare + extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) @@ -352,8 +358,6 @@ def test_patch_raising_function_in_module__register_after_exception__error_in_ca def test_patch_method_in_module__happyflow(fake_module_path): # Prepare extensions_registry = registry.Registry() - - # Prepare mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) @@ -379,3 +383,5 @@ def test_patch_method_in_module__happyflow(fake_module_path): assert original is not fake_module.Klass.method instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + From 471eafa3afef0d7e61e8ecb0c5f2eddbdb807b37 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 14:06:35 +0200 Subject: [PATCH 16/20] Update tests --- src/comet_llm/import_hooks/wrapper.py | 11 ++- .../import_hooks/fake_package/fake_module.py | 8 +- tests/unit/import_hooks/test_import_hooks.py | 95 ++++++++++++++++--- 3 files changed, 95 insertions(+), 19 deletions(-) diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index 3922ad158..6e5e42b6d 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -12,7 +12,7 @@ # permission of Comet ML Inc. # ******************************************************* -import functools +import functools, inspect from typing import Callable from . import callable_extenders, callback_runners @@ -21,6 +21,8 @@ def wrap( original: Callable, callbacks: callable_extenders.CallableExtenders ) -> Callable: + original = _unbound_if_classmethod(original) + def wrapped(*args, **kwargs): # type: ignore args, kwargs = callback_runners.run_before( callbacks.before, original, *args, **kwargs @@ -45,3 +47,10 @@ def wrapped(*args, **kwargs): # type: ignore setattr(wrapped, attr, getattr(original, attr)) return wrapped + + +def _unbound_if_classmethod(original: Callable) -> Callable: + if hasattr(original, "__self__") and inspect.isclass(original.__self__): + original = original.__func__ + + return original \ No newline at end of file diff --git a/tests/unit/import_hooks/fake_package/fake_module.py b/tests/unit/import_hooks/fake_package/fake_module.py index 3014ef751..8490ff5e3 100644 --- a/tests/unit/import_hooks/fake_package/fake_module.py +++ b/tests/unit/import_hooks/fake_package/fake_module.py @@ -2,11 +2,11 @@ from unittest.mock import Mock -FUNCTION_1_MOCK = Mock(return_value="function-1-return-value") -FUNCTION_2_MOCK = Mock(return_value="function-2-return-value") +FUNCTION_1_MOCK = Mock() +FUNCTION_2_MOCK = Mock() FUNCTION_3_MOCK = Mock() -STATIC_METHOD_MOCK = Mock(return_value="static-method-return-value") +STATIC_METHOD_MOCK = Mock() def function1(*args, **kwargs): @@ -18,7 +18,7 @@ def function2(*args, **kwargs): def function3(*args, **kwargs): FUNCTION_3_MOCK(*args, **kwargs) - raise Exception("raising-function-exception-message") + raise Exception() class Klass: diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index 1ffc3b6d9..40fec3650 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -47,6 +47,8 @@ def test_patch_functions_in_module__register_after__without_arguments(fake_modul from .fake_package import fake_module # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" fake_module.function1() fake_module.function2() @@ -86,8 +88,10 @@ def test_patch_functions_in_module__register_before__without_arguments(fake_modu from .fake_package import fake_module # Call - fake_module.function1() - fake_module.function2() + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" + assert fake_module.function1() == "function-1-return-value" + assert fake_module.function2() == "function-2-return-value" # Check function1 mock_callback1.assert_called_once_with(ORIGINAL) @@ -114,7 +118,7 @@ def test_patch_functions_in_module__register_before_and_after__without_arguments mock_callback1 = mock.Mock() extensions_registry.register_before(fake_module_path, "function1", mock_callback1) - mock_callback2 = mock.Mock() + mock_callback2 = mock.Mock(return_value=None) extensions_registry.register_after(fake_module_path, "function2", mock_callback2) comet_finder = finder.CometFinder(extensions_registry) @@ -124,8 +128,10 @@ def test_patch_functions_in_module__register_before_and_after__without_arguments from .fake_package import fake_module # Call - fake_module.function1() - fake_module.function2() + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" + assert fake_module.function1() == "function-1-return-value" + assert fake_module.function2() == "function-2-return-value" # Check function1 mock_callback1.assert_called_once_with(ORIGINAL) @@ -158,7 +164,8 @@ def test_patch_function_in_module__register_before__happyflow(fake_module_path): from .fake_package import fake_module # Call - fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -188,7 +195,8 @@ def test_patch_function_in_module__register_before__callback_changes_input_argum from .fake_package import fake_module # Call - fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -213,7 +221,8 @@ def test_patch_function_in_module__register_before__error_in_callback__original_ from .fake_package import fake_module # Call - fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -229,7 +238,7 @@ def test_patch_function_in_module__register_before__error_in_callback__original_ def test_patch_function_in_module__register_after__happyflow(fake_module_path): # Prepare extensions_registry = registry.Registry() - mock_callback = mock.Mock() + mock_callback = mock.Mock(return_value=None) extensions_registry.register_after(fake_module_path, "function1", mock_callback) comet_finder = finder.CometFinder(extensions_registry) @@ -239,7 +248,8 @@ def test_patch_function_in_module__register_after__happyflow(fake_module_path): from .fake_package import fake_module # Call - fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -264,7 +274,8 @@ def test_patch_function_in_module__register_after__error_in_callback__original_f from .fake_package import fake_module # Call - fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -289,7 +300,8 @@ def test_patch_function_in_module__register_after__callback_changes_return_value from .fake_package import fake_module # Call - result = fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "new-return-value" # Check mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") @@ -298,7 +310,6 @@ def test_patch_function_in_module__register_after__callback_changes_return_value assert original is not fake_module.function1 fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") - assert result == "new-return-value" @pytest.mark.forked @@ -355,6 +366,7 @@ def test_patch_raising_function_in_module__register_after_exception__error_in_ca fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") +@pytest.mark.forked def test_patch_method_in_module__happyflow(fake_module_path): # Prepare extensions_registry = registry.Registry() @@ -373,15 +385,70 @@ def test_patch_method_in_module__happyflow(fake_module_path): # Set the return instance.mock.return_value = mock.sentinel.METHOD + # Call assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD # Check method mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") original = mock_callback.call_args[0][0] assert original.__name__ == "method" - assert original is not fake_module.Klass.method instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") +@pytest.mark.forked +def test_patch_class_method_in_module__happyflow(fake_module_path): + # Prepare + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.clsmethod", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Set the return + fake_module.Klass.clsmethodmock.return_value = mock.sentinel.METHOD + + # Call + assert fake_module.Klass.clsmethod("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, fake_module.Klass, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "clsmethod" + + assert original is not fake_module.Klass.clsmethod + + fake_module.Klass.clsmethodmock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + +@pytest.mark.forked +def test_patch_static_method_in_module__happyflow(fake_module_path): + # Prepare + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.statikmethod", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Set the return + fake_module.STATIC_METHOD_MOCK.return_value = mock.sentinel.METHOD + + # Call + assert fake_module.Klass.statikmethod("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "statikmethod" + assert original is not fake_module.Klass.statikmethod + + fake_module.STATIC_METHOD_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") \ No newline at end of file From 19d88c34a4399bffabb0d6fa53d56f2d482cc6e6 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 14:12:06 +0200 Subject: [PATCH 17/20] Update tests --- src/comet_llm/import_hooks/wrapper.py | 12 +++++++----- tests/unit/import_hooks/test_import_hooks.py | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index 6e5e42b6d..00950738d 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -12,8 +12,9 @@ # permission of Comet ML Inc. # ******************************************************* -import functools, inspect -from typing import Callable +import functools +import inspect +from typing import Any, Callable from . import callable_extenders, callback_runners @@ -22,7 +23,7 @@ def wrap( original: Callable, callbacks: callable_extenders.CallableExtenders ) -> Callable: original = _unbound_if_classmethod(original) - + def wrapped(*args, **kwargs): # type: ignore args, kwargs = callback_runners.run_before( callbacks.before, original, *args, **kwargs @@ -51,6 +52,7 @@ def wrapped(*args, **kwargs): # type: ignore def _unbound_if_classmethod(original: Callable) -> Callable: if hasattr(original, "__self__") and inspect.isclass(original.__self__): - original = original.__func__ + # when original is classmethod, mypy doesn't consider it as a callable. + original = original.__func__ # type: ignore - return original \ No newline at end of file + return original diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index 40fec3650..108012ed2 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -33,7 +33,7 @@ def test_patch_function_in_module__name_to_patch_not_found__no_failure(fake_modu def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): # Prepare extensions_registry = registry.Registry() - + mock_callback1 = mock.Mock() extensions_registry.register_after(fake_module_path, "function1", mock_callback1) From f913fa677478eb056610a2a13cef59e347165785 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 14:26:16 +0200 Subject: [PATCH 18/20] Add test for child class method, refactor tests a bit --- src/comet_llm/config.py | 2 +- src/comet_llm/import_hooks/module_loader.py | 4 +- src/comet_llm/import_hooks/patcher.py | 2 +- tests/unit/import_hooks/test_import_hooks.py | 76 ++++++++++++++------ 4 files changed, 57 insertions(+), 27 deletions(-) diff --git a/src/comet_llm/config.py b/src/comet_llm/config.py index 33adb53e6..1c09bc3c0 100644 --- a/src/comet_llm/config.py +++ b/src/comet_llm/config.py @@ -30,7 +30,7 @@ def _muted_import_comet_ml() -> ModuleType: comet_ml = _muted_import_comet_ml() -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import comet_ml.config as comet_ml_config diff --git a/src/comet_llm/import_hooks/module_loader.py b/src/comet_llm/import_hooks/module_loader.py index eabc25e9b..d3871f62e 100644 --- a/src/comet_llm/import_hooks/module_loader.py +++ b/src/comet_llm/import_hooks/module_loader.py @@ -14,11 +14,11 @@ import importlib.abc from types import ModuleType -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional from . import module_extension, patcher -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from importlib import machinery diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py index 70963c69d..5cc4e5a49 100644 --- a/src/comet_llm/import_hooks/patcher.py +++ b/src/comet_llm/import_hooks/patcher.py @@ -18,7 +18,7 @@ from . import wrapper -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover from . import module_extension # _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes. diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py index 108012ed2..b8071522f 100644 --- a/tests/unit/import_hooks/test_import_hooks.py +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -17,7 +17,7 @@ def fake_module_path(): @pytest.mark.forked def test_patch_function_in_module__name_to_patch_not_found__no_failure(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() extensions_registry.register_after(fake_module_path, "non_existing_function", "any-callback") @@ -31,7 +31,7 @@ def test_patch_function_in_module__name_to_patch_not_found__no_failure(fake_modu @pytest.mark.forked def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback1 = mock.Mock() @@ -72,7 +72,7 @@ def test_patch_functions_in_module__register_after__without_arguments(fake_modul @pytest.mark.forked def test_patch_functions_in_module__register_before__without_arguments(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback1 = mock.Mock() @@ -112,7 +112,7 @@ def test_patch_functions_in_module__register_before__without_arguments(fake_modu @pytest.mark.forked def test_patch_functions_in_module__register_before_and_after__without_arguments(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback1 = mock.Mock() @@ -152,7 +152,7 @@ def test_patch_functions_in_module__register_before_and_after__without_arguments @pytest.mark.forked def test_patch_function_in_module__register_before__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "function1", mock_callback) @@ -178,7 +178,7 @@ def test_patch_function_in_module__register_before__happyflow(fake_module_path): @pytest.mark.forked def test_patch_function_in_module__register_before__callback_changes_input_arguments(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock( return_value=( @@ -209,7 +209,7 @@ def test_patch_function_in_module__register_before__callback_changes_input_argum @pytest.mark.forked def test_patch_function_in_module__register_before__error_in_callback__original_function_worked(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_before(fake_module_path, "function1", mock_callback) @@ -236,7 +236,7 @@ def test_patch_function_in_module__register_before__error_in_callback__original_ @pytest.mark.forked def test_patch_function_in_module__register_after__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock(return_value=None) extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -262,7 +262,7 @@ def test_patch_function_in_module__register_after__happyflow(fake_module_path): @pytest.mark.forked def test_patch_function_in_module__register_after__error_in_callback__original_function_worked(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -288,7 +288,7 @@ def test_patch_function_in_module__register_after__error_in_callback__original_f @pytest.mark.forked def test_patch_function_in_module__register_after__callback_changes_return_value(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock(return_value="new-return-value") extensions_registry.register_after(fake_module_path, "function1", mock_callback) @@ -314,7 +314,7 @@ def test_patch_function_in_module__register_after__callback_changes_return_value @pytest.mark.forked def test_patch_raising_function_in_module__register_after_exception__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) @@ -342,7 +342,7 @@ def test_patch_raising_function_in_module__register_after_exception__happyflow(f @pytest.mark.forked def test_patch_raising_function_in_module__register_after_exception__error_in_callback__original_function_worked(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock(side_effect=Exception) extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) @@ -368,7 +368,7 @@ def test_patch_raising_function_in_module__register_after_exception__error_in_ca @pytest.mark.forked def test_patch_method_in_module__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) @@ -383,10 +383,9 @@ def test_patch_method_in_module__happyflow(fake_module_path): instance = fake_module.Klass() # Set the return - instance.mock.return_value = mock.sentinel.METHOD - + instance.mock.return_value = "method-return-value" # Call - assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == "method-return-value" # Check method mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") @@ -399,7 +398,7 @@ def test_patch_method_in_module__happyflow(fake_module_path): @pytest.mark.forked def test_patch_class_method_in_module__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "Klass.clsmethod", mock_callback) @@ -411,10 +410,10 @@ def test_patch_class_method_in_module__happyflow(fake_module_path): from .fake_package import fake_module # Set the return - fake_module.Klass.clsmethodmock.return_value = mock.sentinel.METHOD + fake_module.Klass.clsmethodmock.return_value = "classmethod-return-value" # Call - assert fake_module.Klass.clsmethod("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + assert fake_module.Klass.clsmethod("arg-1", "arg-2", kwarg1="kwarg-1") == "classmethod-return-value" # Check method mock_callback.assert_called_once_with(ORIGINAL, fake_module.Klass, "arg-1", "arg-2", kwarg1="kwarg-1") @@ -428,7 +427,7 @@ def test_patch_class_method_in_module__happyflow(fake_module_path): @pytest.mark.forked def test_patch_static_method_in_module__happyflow(fake_module_path): - # Prepare + # Prepare hooks extensions_registry = registry.Registry() mock_callback = mock.Mock() extensions_registry.register_before(fake_module_path, "Klass.statikmethod", mock_callback) @@ -440,10 +439,10 @@ def test_patch_static_method_in_module__happyflow(fake_module_path): from .fake_package import fake_module # Set the return - fake_module.STATIC_METHOD_MOCK.return_value = mock.sentinel.METHOD + fake_module.STATIC_METHOD_MOCK.return_value = "staticmethod-return-value" # Call - assert fake_module.Klass.statikmethod("arg-1", "arg-2", kwarg1="kwarg-1") == mock.sentinel.METHOD + assert fake_module.Klass.statikmethod("arg-1", "arg-2", kwarg1="kwarg-1") == "staticmethod-return-value" # Check method mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1") @@ -451,4 +450,35 @@ def test_patch_static_method_in_module__happyflow(fake_module_path): assert original.__name__ == "statikmethod" assert original is not fake_module.Klass.statikmethod - fake_module.STATIC_METHOD_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") \ No newline at end of file + fake_module.STATIC_METHOD_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + +@pytest.mark.forked +def test_patch_subclass_method_in_module__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + instance = fake_module.Child() + + # Set the return + instance.mock.return_value = "method-return-value" + + # Call + assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == "method-return-value" + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "method" + assert original is not fake_module.Klass.method + + instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") \ No newline at end of file From 287bc4867c848fe7c0b9ddd307c653740a1fb193 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 14:36:25 +0200 Subject: [PATCH 19/20] Remove currently odd import --- src/comet_llm/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/comet_llm/__init__.py b/src/comet_llm/__init__.py index e80c7a042..203873084 100644 --- a/src/comet_llm/__init__.py +++ b/src/comet_llm/__init__.py @@ -12,7 +12,6 @@ # permission of Comet ML Inc. # ******************************************************* -from . import import_hooks # keep it the first one from . import app, logging from .api import log_prompt From a0d84a939aa89eecef712c2cb68f45fee7ed4f13 Mon Sep 17 00:00:00 2001 From: Alexander Kuzmik Date: Fri, 4 Aug 2023 15:04:27 +0200 Subject: [PATCH 20/20] Replace wraps workaround with actual wraps --- src/comet_llm/import_hooks/wrapper.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py index 00950738d..e1b261304 100644 --- a/src/comet_llm/import_hooks/wrapper.py +++ b/src/comet_llm/import_hooks/wrapper.py @@ -14,7 +14,7 @@ import functools import inspect -from typing import Any, Callable +from typing import Callable from . import callable_extenders, callback_runners @@ -24,6 +24,7 @@ def wrap( ) -> Callable: original = _unbound_if_classmethod(original) + @functools.wraps(original) def wrapped(*args, **kwargs): # type: ignore args, kwargs = callback_runners.run_before( callbacks.before, original, *args, **kwargs @@ -42,11 +43,6 @@ def wrapped(*args, **kwargs): # type: ignore return result - # Simulate functools.wraps behavior but make it working with mocks - for attr in functools.WRAPPER_ASSIGNMENTS: - if hasattr(original, attr): - setattr(wrapped, attr, getattr(original, attr)) - return wrapped