diff --git a/src/comet_llm/config.py b/src/comet_llm/config.py index 33adb53e6..1c09bc3c0 100644 --- a/src/comet_llm/config.py +++ b/src/comet_llm/config.py @@ -30,7 +30,7 @@ def _muted_import_comet_ml() -> ModuleType: comet_ml = _muted_import_comet_ml() -if TYPE_CHECKING: +if TYPE_CHECKING: # pragma: no cover import comet_ml.config as comet_ml_config diff --git a/src/comet_llm/import_hooks/__init__.py b/src/comet_llm/import_hooks/__init__.py new file mode 100644 index 000000000..95c64cfdf --- /dev/null +++ b/src/comet_llm/import_hooks/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* diff --git a/src/comet_llm/import_hooks/callable_extenders.py b/src/comet_llm/import_hooks/callable_extenders.py new file mode 100644 index 000000000..eab786dab --- /dev/null +++ b/src/comet_llm/import_hooks/callable_extenders.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import dataclasses +from typing import List + +from .types import AfterCallback, AfterExceptionCallback, BeforeCallback + + +@dataclasses.dataclass +class CallableExtenders: + before: List[BeforeCallback] + after: List[AfterCallback] + after_exception: List[AfterExceptionCallback] + + +def get() -> CallableExtenders: + return CallableExtenders([], [], []) diff --git a/src/comet_llm/import_hooks/callback_runners.py b/src/comet_llm/import_hooks/callback_runners.py new file mode 100644 index 000000000..155537a04 --- /dev/null +++ b/src/comet_llm/import_hooks/callback_runners.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import logging +from typing import Any, Callable, Dict, List, Tuple, Union + +from . import validate +from .types import AfterCallback, AfterExceptionCallback, BeforeCallback + +LOGGER = logging.getLogger(__name__) + +Args = Union[Tuple[Any, ...], List[Any]] +ArgsKwargs = Tuple[Args, Dict[str, Any]] + + +def run_before( # type: ignore + callbacks: List[BeforeCallback], original: Callable, *args, **kwargs +) -> ArgsKwargs: + for callback in callbacks: + try: + callback_return = callback(original, *args, **kwargs) + + if validate.args_kwargs(callback_return): + LOGGER.debug("New args %r", callback_return) + args, kwargs = callback_return + except Exception: + LOGGER.debug( + "Exception calling before callback %r", callback, exc_info=True + ) + + return args, kwargs + + +def run_after( # type: ignore + callbacks: List[AfterCallback], + original: Callable, + return_value: Any, + *args, + **kwargs +) -> Any: + for callback in callbacks: + try: + new_return_value = callback(original, return_value, *args, **kwargs) + if new_return_value is not None: + return_value = new_return_value + except Exception: + LOGGER.debug("Exception calling after callback %r", callback, exc_info=True) + + return return_value + + +def run_after_exception( # type: ignore + callbacks: List[AfterExceptionCallback], + original: Callable, + exception: Exception, + *args, + **kwargs +) -> None: + for callback in callbacks: + try: + callback(original, exception, *args, **kwargs) + except Exception: + LOGGER.debug( + "Exception calling after-exception callback %r", callback, exc_info=True + ) diff --git a/src/comet_llm/import_hooks/finder.py b/src/comet_llm/import_hooks/finder.py new file mode 100644 index 000000000..6c5daea6f --- /dev/null +++ b/src/comet_llm/import_hooks/finder.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import sys +from importlib import machinery +from types import ModuleType +from typing import List, Optional + +from . import module_loader, registry + + +class CometFinder: + def __init__(self, extensions_registry: registry.Registry) -> None: + self._registry = extensions_registry + self._pathfinder = machinery.PathFinder() + + def hook_into_import_system(self) -> None: + if self not in sys.meta_path: + sys.meta_path.insert(0, self) # type: ignore + + def find_spec( + self, fullname: str, path: Optional[List[str]], target: Optional[ModuleType] + ) -> Optional[machinery.ModuleSpec]: + if fullname not in self._registry.module_names: + return None + + original_spec = self._pathfinder.find_spec(fullname, path, target) + + if original_spec is None: + return None + + return self._wrap_spec_loader(fullname, original_spec) + + def _wrap_spec_loader( + self, fullname: str, spec: machinery.ModuleSpec + ) -> machinery.ModuleSpec: + module_extension = self._registry.get_extension(fullname) + spec.loader = module_loader.CometModuleLoader(fullname, spec.loader, module_extension) # type: ignore + return spec diff --git a/src/comet_llm/import_hooks/module_extension.py b/src/comet_llm/import_hooks/module_extension.py new file mode 100644 index 000000000..1d109e8ee --- /dev/null +++ b/src/comet_llm/import_hooks/module_extension.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +from typing import Dict + +from . import callable_extenders + + +class ModuleExtension: + def __init__(self) -> None: + self._callables_extenders: Dict[str, callable_extenders.CallableExtenders] = {} + + def extenders(self, callable_name: str) -> callable_extenders.CallableExtenders: + if callable_name not in self._callables_extenders: + self._callables_extenders[callable_name] = callable_extenders.get() + + return self._callables_extenders[callable_name] + + def items(self): # type: ignore + return self._callables_extenders.items() diff --git a/src/comet_llm/import_hooks/module_loader.py b/src/comet_llm/import_hooks/module_loader.py new file mode 100644 index 000000000..d3871f62e --- /dev/null +++ b/src/comet_llm/import_hooks/module_loader.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import importlib.abc +from types import ModuleType +from typing import TYPE_CHECKING, Optional + +from . import module_extension, patcher + +if TYPE_CHECKING: # pragma: no cover + from importlib import machinery + + +class CometModuleLoader(importlib.abc.Loader): + def __init__( + self, + module_name: str, + original_loader: importlib.abc.Loader, + module_extension: module_extension.ModuleExtension, + ) -> None: + self._module_name = module_name + self._original_loader = original_loader + self._module_extension = module_extension + + def create_module(self, spec: "machinery.ModuleSpec") -> Optional[ModuleType]: + if hasattr(self._original_loader, "create_module"): + return self._original_loader.create_module(spec) + + LET_PYTHON_HANDLE_THIS = None + return LET_PYTHON_HANDLE_THIS + + def exec_module(self, module: ModuleType) -> None: + if hasattr(self._original_loader, "exec_module"): + self._original_loader.exec_module(module) + else: + module = self._original_loader.load_module(self._module_name) + + patcher.patch(module, self._module_extension) diff --git a/src/comet_llm/import_hooks/patcher.py b/src/comet_llm/import_hooks/patcher.py new file mode 100644 index 000000000..5cc4e5a49 --- /dev/null +++ b/src/comet_llm/import_hooks/patcher.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import inspect +from types import ModuleType +from typing import TYPE_CHECKING, Any + +from . import wrapper + +if TYPE_CHECKING: # pragma: no cover + from . import module_extension + +# _get_object and _set_object copied from comet_ml.monkeypatching almost without any changes. + + +def _get_object(module: ModuleType, callable_path: str) -> Any: + current_object = module + + for part in callable_path: + try: + current_object = getattr(current_object, part) + except AttributeError: + return None + + return current_object + + +def _set_object( + module: ModuleType, callable_path: str, original: Any, new_object: Any +) -> None: + object_to_patch = _get_object(module, callable_path[:-1]) + + original_self = getattr(original, "__self__", None) + + # Support classmethod + if original_self and inspect.isclass(original_self): + new_object = classmethod(new_object) + + setattr(object_to_patch, callable_path[-1], new_object) + + +def patch( + module: ModuleType, module_extension: "module_extension.ModuleExtension" +) -> None: + for callable_name, callable_extenders in module_extension.items(): + callable_path = callable_name.split(".") + original = _get_object(module, callable_path) + + if original is None: + continue + + new_callable = wrapper.wrap(original, callable_extenders) + _set_object(module, callable_path, original, new_callable) diff --git a/src/comet_llm/import_hooks/registry.py b/src/comet_llm/import_hooks/registry.py new file mode 100644 index 000000000..e268fb9d0 --- /dev/null +++ b/src/comet_llm/import_hooks/registry.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +from typing import Any, Callable, Dict, List + +from . import callable_extenders, module_extension + + +class Registry: + def __init__(self) -> None: + self._modules_extensions: Dict[str, module_extension.ModuleExtension] = {} + + @property + def module_names(self): # type: ignore + return self._modules_extensions.keys() + + def get_extension(self, module_name: str) -> module_extension.ModuleExtension: + return self._modules_extensions[module_name] + + def _get_callable_extenders( + self, module_name: str, callable_name: str + ) -> callable_extenders.CallableExtenders: + extension = self._modules_extensions.setdefault( + module_name, module_extension.ModuleExtension() + ) + return extension.extenders(callable_name) + + def register_before( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, + *args, + **kwargs + ) + original - original callable to patch + + Return value of patcher function is expected to be either None + or [Args,Kwargs] tuple to overwrite original args and kwargs + """ + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.before.append(patcher_function) + + def register_after( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, + return_value, + *args, + **kwargs + ) + original - original callable to patch + return_value - value returned by original callable + + Return value of patcher function will overwrite return_value of + patched function if not None + """ + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.after.append(patcher_function) + + def register_after_exception( + self, module_name: str, callable_name: str, patcher_function: Callable + ) -> None: + """ + patcher_function: Callable with the following signature + func( + original, + exception, + *args, + **kwargs + ) + original - original callable to patch + exception - exception thrown from original callable + + Expected to return None. + """ + extenders = self._get_callable_extenders(module_name, callable_name) + extenders.after_exception.append(patcher_function) diff --git a/src/comet_llm/import_hooks/types.py b/src/comet_llm/import_hooks/types.py new file mode 100644 index 000000000..612f8cb68 --- /dev/null +++ b/src/comet_llm/import_hooks/types.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +from typing import Callable + +# to-do: better description for callbacks signatures? + +BeforeCallback = Callable +AfterCallback = Callable +AfterExceptionCallback = Callable diff --git a/src/comet_llm/import_hooks/validate.py b/src/comet_llm/import_hooks/validate.py new file mode 100644 index 000000000..936306354 --- /dev/null +++ b/src/comet_llm/import_hooks/validate.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +from typing import Any + + +def args_kwargs(obj: Any) -> bool: + if obj is None: + return False + + try: + args, kwargs = obj + except (ValueError, TypeError): + return False + + if not isinstance(args, (list, tuple)): + return False + + if not isinstance(kwargs, dict): + return False + + return True diff --git a/src/comet_llm/import_hooks/wrapper.py b/src/comet_llm/import_hooks/wrapper.py new file mode 100644 index 000000000..e1b261304 --- /dev/null +++ b/src/comet_llm/import_hooks/wrapper.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# ******************************************************* +# ____ _ _ +# / ___|___ _ __ ___ ___| |_ _ __ ___ | | +# | | / _ \| '_ ` _ \ / _ \ __| | '_ ` _ \| | +# | |__| (_) | | | | | | __/ |_ _| | | | | | | +# \____\___/|_| |_| |_|\___|\__(_)_| |_| |_|_| +# +# Sign up for free at https://www.comet.com +# Copyright (C) 2015-2023 Comet ML INC +# This file can not be copied and/or distributed without the express +# permission of Comet ML Inc. +# ******************************************************* + +import functools +import inspect +from typing import Callable + +from . import callable_extenders, callback_runners + + +def wrap( + original: Callable, callbacks: callable_extenders.CallableExtenders +) -> Callable: + original = _unbound_if_classmethod(original) + + @functools.wraps(original) + def wrapped(*args, **kwargs): # type: ignore + args, kwargs = callback_runners.run_before( + callbacks.before, original, *args, **kwargs + ) + try: + result = original(*args, **kwargs) + except Exception as exception: + callback_runners.run_after_exception( + callbacks.after_exception, original, exception, *args, **kwargs + ) + raise exception + + result = callback_runners.run_after( + callbacks.after, original, result, *args, **kwargs + ) + + return result + + return wrapped + + +def _unbound_if_classmethod(original: Callable) -> Callable: + if hasattr(original, "__self__") and inspect.isclass(original.__self__): + # when original is classmethod, mypy doesn't consider it as a callable. + original = original.__func__ # type: ignore + + return original diff --git a/tests/unit/import_hooks/__init__.py b/tests/unit/import_hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/import_hooks/fake_package/__init__.py b/tests/unit/import_hooks/fake_package/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/import_hooks/fake_package/fake_module.py b/tests/unit/import_hooks/fake_package/fake_module.py new file mode 100644 index 000000000..8490ff5e3 --- /dev/null +++ b/tests/unit/import_hooks/fake_package/fake_module.py @@ -0,0 +1,49 @@ +# Fake module + +from unittest.mock import Mock + +FUNCTION_1_MOCK = Mock() +FUNCTION_2_MOCK = Mock() +FUNCTION_3_MOCK = Mock() + +STATIC_METHOD_MOCK = Mock() + + +def function1(*args, **kwargs): + return FUNCTION_1_MOCK(*args, **kwargs) + + +def function2(*args, **kwargs): + return FUNCTION_2_MOCK(*args, **kwargs) + +def function3(*args, **kwargs): + FUNCTION_3_MOCK(*args, **kwargs) + raise Exception() + + +class Klass: + clsmethodmock = Mock() + + def __init__(self): + self.mock = Mock() + self.mock2 = Mock() + + def method(self, *args, **kwargs): + return self.mock(*args, **kwargs) + + def method2(self, *args, **kwargs): + return self.mock2(*args, **kwargs) + + @classmethod + def clsmethod(cls, *args, **kwargs): + print("Locals", locals()) + return cls.clsmethodmock(*args, **kwargs) + + @staticmethod + def statikmethod(*args, **kwargs): + return STATIC_METHOD_MOCK(*args, **kwargs) + + +class Child(Klass): + def method(self, *args, **kwargs): + return super(Child, self).method(*args, **kwargs) diff --git a/tests/unit/import_hooks/test_import_hooks.py b/tests/unit/import_hooks/test_import_hooks.py new file mode 100644 index 000000000..b8071522f --- /dev/null +++ b/tests/unit/import_hooks/test_import_hooks.py @@ -0,0 +1,484 @@ +from unittest import mock + +import pytest + +from comet_llm.import_hooks import finder, registry + + +@pytest.fixture +def fake_module_path(): + parent_name = '.'.join(__name__.split('.')[:-1]) + + return f"{parent_name}.fake_package.fake_module" + +ORIGINAL = mock.ANY +EXCEPTION = mock.ANY + + +@pytest.mark.forked +def test_patch_function_in_module__name_to_patch_not_found__no_failure(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + extensions_registry.register_after(fake_module_path, "non_existing_function", "any-callback") + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + assert hasattr(fake_module, "non_existing_function") is False + +@pytest.mark.forked +def test_patch_functions_in_module__register_after__without_arguments(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + + mock_callback1 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function1", mock_callback1) + + mock_callback2 = mock.Mock() + extensions_registry.register_after(fake_module_path, "function2", mock_callback2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" + fake_module.function1() + fake_module.function2() + + # Check function1 + mock_callback1.assert_called_once_with(ORIGINAL, "function-1-return-value") + original = mock_callback1.call_args[0][0] + assert original.__name__ == "function1" + + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock_callback2.assert_called_once_with(ORIGINAL, "function-2-return-value") + original = mock_callback2.call_args[0][0] + assert original.__name__ == "function2" + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_functions_in_module__register_before__without_arguments(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + + mock_callback1 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback1) + + mock_callback2 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function2", mock_callback2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" + assert fake_module.function1() == "function-1-return-value" + assert fake_module.function2() == "function-2-return-value" + + # Check function1 + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback1.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback2.call_args[0][0] + assert original.__name__ == "function2" + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_functions_in_module__register_before_and_after__without_arguments(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + + mock_callback1 = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback1) + + mock_callback2 = mock.Mock(return_value=None) + extensions_registry.register_after(fake_module_path, "function2", mock_callback2) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + fake_module.FUNCTION_2_MOCK.return_value = "function-2-return-value" + assert fake_module.function1() == "function-1-return-value" + assert fake_module.function2() == "function-2-return-value" + + # Check function1 + mock_callback1.assert_called_once_with(ORIGINAL) + original = mock_callback1.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with() + + # Check function2 + mock_callback2.assert_called_once_with(ORIGINAL, "function-2-return-value") + original = mock_callback2.call_args[0][0] + assert original.__name__ == "function2" + assert original is not fake_module.function2 + + fake_module.FUNCTION_2_MOCK.assert_called_once_with() + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__callback_changes_input_arguments(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock( + return_value=( + ("new-arg-1", "new-arg-2"), + {"kwarg1":"new-kwarg-1", "kwarg2":"new-kwarg-2"} + ) + ) + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("new-arg-1", "new-arg-2", kwarg1="new-kwarg-1", kwarg2="new-kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_before__error_in_callback__original_function_worked(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_before(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock(return_value=None) + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__error_in_callback__original_function_worked(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "function-1-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_function_in_module__register_after__callback_changes_return_value(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock(return_value="new-return-value") + extensions_registry.register_after(fake_module_path, "function1", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + fake_module.FUNCTION_1_MOCK.return_value = "function-1-return-value" + assert fake_module.function1("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") == "new-return-value" + + # Check + mock_callback.assert_called_once_with(ORIGINAL, "function-1-return-value", "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function1" + assert original is not fake_module.function1 + + fake_module.FUNCTION_1_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_raising_function_in_module__register_after_exception__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + with pytest.raises(Exception): + fake_module.function3("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + ORIGINAL = mock.ANY + EXCEPTION = mock.ANY + mock_callback.assert_called_once_with(ORIGINAL, EXCEPTION, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function3" + assert original is not fake_module.function3 + + fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_raising_function_in_module__register_after_exception__error_in_callback__original_function_worked(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock(side_effect=Exception) + extensions_registry.register_after_exception(fake_module_path, "function3", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + with pytest.raises(Exception): + fake_module.function3("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + # Check + mock_callback.assert_called_once_with(ORIGINAL, EXCEPTION, "arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + original = mock_callback.call_args[0][0] + assert original.__name__ == "function3" + assert original is not fake_module.function3 + + fake_module.FUNCTION_3_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1", kwarg2="kwarg-2") + + +@pytest.mark.forked +def test_patch_method_in_module__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + instance = fake_module.Klass() + + # Set the return + instance.mock.return_value = "method-return-value" + # Call + assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == "method-return-value" + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "method" + assert original is not fake_module.Klass.method + + instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + +@pytest.mark.forked +def test_patch_class_method_in_module__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.clsmethod", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Set the return + fake_module.Klass.clsmethodmock.return_value = "classmethod-return-value" + + # Call + assert fake_module.Klass.clsmethod("arg-1", "arg-2", kwarg1="kwarg-1") == "classmethod-return-value" + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, fake_module.Klass, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "clsmethod" + + assert original is not fake_module.Klass.clsmethod + + fake_module.Klass.clsmethodmock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + +@pytest.mark.forked +def test_patch_static_method_in_module__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.statikmethod", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Set the return + fake_module.STATIC_METHOD_MOCK.return_value = "staticmethod-return-value" + + # Call + assert fake_module.Klass.statikmethod("arg-1", "arg-2", kwarg1="kwarg-1") == "staticmethod-return-value" + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "statikmethod" + assert original is not fake_module.Klass.statikmethod + + fake_module.STATIC_METHOD_MOCK.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") + + +@pytest.mark.forked +def test_patch_subclass_method_in_module__happyflow(fake_module_path): + # Prepare hooks + extensions_registry = registry.Registry() + mock_callback = mock.Mock() + extensions_registry.register_before(fake_module_path, "Klass.method", mock_callback) + + comet_finder = finder.CometFinder(extensions_registry) + comet_finder.hook_into_import_system() + + # Import + from .fake_package import fake_module + + # Call + instance = fake_module.Child() + + # Set the return + instance.mock.return_value = "method-return-value" + + # Call + assert instance.method("arg-1", "arg-2", kwarg1="kwarg-1") == "method-return-value" + + # Check method + mock_callback.assert_called_once_with(ORIGINAL, instance, "arg-1", "arg-2", kwarg1="kwarg-1") + original = mock_callback.call_args[0][0] + assert original.__name__ == "method" + assert original is not fake_module.Klass.method + + instance.mock.assert_called_once_with("arg-1", "arg-2", kwarg1="kwarg-1") \ No newline at end of file diff --git a/tests/unit/import_hooks/test_validate.py b/tests/unit/import_hooks/test_validate.py new file mode 100644 index 000000000..dfbfeebf3 --- /dev/null +++ b/tests/unit/import_hooks/test_validate.py @@ -0,0 +1,24 @@ +from comet_llm.import_hooks import validate + + +def test_args_kwargs__happyflow(): + args_kwargs = ([1], {"foo": "bar"}) + assert validate.args_kwargs(args_kwargs) is True + + +def test_args_kwargs__input_is_None__return_False(): + assert validate.args_kwargs(None) is False + + +def test_args_kwargs__input_is_not_tuple_or_list_of_length_2__return_False(): + assert validate.args_kwargs(42) is False + + +def test_args_kwargs__args_cant_be_parsed__return_False(): + args_kwargs = (42, {}) + assert validate.args_kwargs(args_kwargs) is False + + +def test_args_kwargs__kwargs_cant_be_parsed__return_False(): + args_kwargs = ([1], 42) + assert validate.args_kwargs(args_kwargs) is False \ No newline at end of file