Skip to content

Commit

Permalink
Improve way of addon detection
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer committed Sep 19, 2023
1 parent bdd45c3 commit 4203582
Show file tree
Hide file tree
Showing 4 changed files with 225 additions and 76 deletions.
76 changes: 72 additions & 4 deletions copulas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

import contextlib
import importlib
import sys
import warnings
from copy import deepcopy
from operator import attrgetter

import numpy as np
import pandas as pd

from copulas._addons import _find_addons

_find_addons(group='copulas_modules', parent_globals=globals())
from pkg_resources import iter_entry_points

EPSILON = np.finfo(np.float32).eps

Expand Down Expand Up @@ -262,3 +262,71 @@ def decorated(self, X, *args, **kwargs):
return function(self, X, *args, **kwargs)

return decorated


def _get_addon_target(addon_path_name):
"""Find the target object for the add-on.
Args:
addon_path_name (str):
The add-on's name. The add-on's name should be the full path of valid Python
identifiers (i.e. importable.module:object.attr).
Returns:
tuple:
* object:
The base module or object the add-on should be added to.
* str:
The name the add-on should be added to under the module or object.
"""
module_path, _, object_path = addon_path_name.partition(':')
module_path = module_path.split('.')

if module_path[0] != __name__:
msg = f"expected base module to be '{__name__}', found '{module_path[0]}'"
raise AttributeError(msg)

target_base = sys.modules[__name__]
for submodule in module_path[1:-1]:
target_base = getattr(target_base, submodule)

addon_name = module_path[-1]
if object_path:
if len(module_path) > 1 and not hasattr(target_base, module_path[-1]):
msg = f"cannot add '{object_path}' to unknown submodule '{'.'.join(module_path)}'"
raise AttributeError(msg)

if len(module_path) > 1:
target_base = getattr(target_base, module_path[-1])

split_object = object_path.split('.')
addon_name = split_object[-1]

if len(split_object) > 1:
target_base = attrgetter('.'.join(split_object[:-1]))(target_base)

return target_base, addon_name


def _find_addons():
"""Find and load all copulas add-ons."""
group = 'copulas_modules'
for entry_point in iter_entry_points(group=group):
try:
addon = entry_point.load()
except Exception: # pylint: disable=broad-exception-caught
msg = f'Failed to load "{entry_point.name}" from "{entry_point.module_name}".'
warnings.warn(msg)
continue

try:
addon_target, addon_name = _get_addon_target(entry_point.name)
except AttributeError as error:
msg = f"Failed to set '{entry_point.name}': {error}."
warnings.warn(msg)
continue

setattr(addon_target, addon_name, addon)


_find_addons()
26 changes: 0 additions & 26 deletions copulas/_addons.py

This file was deleted.

154 changes: 153 additions & 1 deletion tests/unit/test___init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from unittest import TestCase
from unittest.mock import MagicMock, call, patch

Expand All @@ -6,8 +7,10 @@
import pytest
from numpy.testing import assert_array_equal

import copulas
from copulas import (
check_valid_values, get_instance, random_state, scalarize, validate_random_state, vectorize)
_find_addons, check_valid_values, get_instance, random_state, scalarize, validate_random_state,
vectorize)
from copulas.multivariate import GaussianMultivariate


Expand Down Expand Up @@ -421,3 +424,152 @@ def test_get_instance_with_kwargs(self):
assert not instance.fitted
assert isinstance(instance, GaussianMultivariate)
assert instance.distribution == 'copulas.univariate.truncnorm.TruncNorm'


@pytest.fixture()
def mock_copulas():
copulas_module = sys.modules['copulas']
copulas_mock = MagicMock()
sys.modules['copulas'] = copulas_mock
yield copulas_mock
sys.modules['copulas'] = copulas_module


@patch.object(copulas, 'iter_entry_points')
def test__find_addons_module(entry_points_mock, mock_copulas):
"""Test loading an add-on."""
# Setup
entry_point = MagicMock()
entry_point.name = 'copulas.submodule.entry_name'
entry_point.load.return_value = 'entry_point'
entry_points_mock.return_value = [entry_point]

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
assert mock_copulas.submodule.entry_name == 'entry_point'


@patch.object(copulas, 'iter_entry_points')
def test__find_addons_object(entry_points_mock, mock_copulas):
"""Test loading an add-on."""
# Setup
entry_point = MagicMock()
entry_point.name = 'copulas.submodule:entry_object.entry_method'
entry_point.load.return_value = 'new_method'
entry_points_mock.return_value = [entry_point]

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
assert mock_copulas.submodule.entry_object.entry_method == 'new_method'


@patch('warnings.warn')
@patch('copulas.iter_entry_points')
def test__find_addons_bad_addon(entry_points_mock, warning_mock):
"""Test failing to load an add-on generates a warning."""
# Setup
def entry_point_error():
raise ValueError()

bad_entry_point = MagicMock()
bad_entry_point.name = 'bad_entry_point'
bad_entry_point.module_name = 'bad_module'
bad_entry_point.load.side_effect = entry_point_error
entry_points_mock.return_value = [bad_entry_point]
msg = 'Failed to load "bad_entry_point" from "bad_module".'

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
warning_mock.assert_called_once_with(msg)


@patch('warnings.warn')
@patch('copulas.iter_entry_points')
def test__find_addons_wrong_base(entry_points_mock, warning_mock):
"""Test incorrect add-on name generates a warning."""
# Setup
bad_entry_point = MagicMock()
bad_entry_point.name = 'bad_base.bad_entry_point'
entry_points_mock.return_value = [bad_entry_point]
msg = (
"Failed to set 'bad_base.bad_entry_point': expected base module to be 'copulas', found "
"'bad_base'."
)

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
warning_mock.assert_called_once_with(msg)


@patch('warnings.warn')
@patch('copulas.iter_entry_points')
def test__find_addons_missing_submodule(entry_points_mock, warning_mock):
"""Test incorrect add-on name generates a warning."""
# Setup
bad_entry_point = MagicMock()
bad_entry_point.name = 'copulas.missing_submodule.new_submodule'
entry_points_mock.return_value = [bad_entry_point]
msg = (
"Failed to set 'copulas.missing_submodule.new_submodule': module 'copulas' has no "
"attribute 'missing_submodule'."
)

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
warning_mock.assert_called_once_with(msg)


@patch('warnings.warn')
@patch('copulas.iter_entry_points')
def test__find_addons_module_and_object(entry_points_mock, warning_mock):
"""Test incorrect add-on name generates a warning."""
# Setup
bad_entry_point = MagicMock()
bad_entry_point.name = 'copulas.missing_submodule:new_object'
entry_points_mock.return_value = [bad_entry_point]
msg = (
"Failed to set 'copulas.missing_submodule:new_object': cannot add 'new_object' to unknown "
"submodule 'copulas.missing_submodule'."
)

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
warning_mock.assert_called_once_with(msg)


@patch('warnings.warn')
@patch.object(copulas, 'iter_entry_points')
def test__find_addons_missing_object(entry_points_mock, warning_mock, mock_copulas):
"""Test incorrect add-on name generates a warning."""
# Setup
bad_entry_point = MagicMock()
bad_entry_point.name = 'copulas.submodule:missing_object.new_method'
entry_points_mock.return_value = [bad_entry_point]
msg = ("Failed to set 'copulas.submodule:missing_object.new_method': missing_object.")

del mock_copulas.submodule.missing_object

# Run
_find_addons()

# Assert
entry_points_mock.assert_called_once_with(group='copulas_modules')
warning_mock.assert_called_once_with(msg)
45 changes: 0 additions & 45 deletions tests/unit/test__addons.py

This file was deleted.

0 comments on commit 4203582

Please sign in to comment.