diff --git a/doc/introduction.rst b/doc/introduction.rst index 7d9f52d0..e9ff0015 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -123,26 +123,3 @@ to the following resources: Survey `_ (2012) - **Book:** `Metric Learning `_ (2015) - -.. Methods [TO MOVE TO SUPERVISED/WEAK SECTIONS] -.. ============================================= - -.. Currently, each metric learning algorithm supports the following methods: - -.. - ``fit(...)``, which learns the model. -.. - ``get_mahalanobis_matrix()``, which returns a Mahalanobis matrix -.. - ``get_metric()``, which returns a function that takes as input two 1D - arrays and outputs the learned metric score on these two points -.. :math:`M = L^{\top}L` such that distance between vectors ``x`` and -.. ``y`` can be computed as :math:`\sqrt{\left(x-y\right)M\left(x-y\right)}`. -.. - ``components_from_metric(metric)``, which returns a transformation matrix -.. :math:`L \in \mathbb{R}^{D \times d}`, which can be used to convert a -.. data matrix :math:`X \in \mathbb{R}^{n \times d}` to the -.. :math:`D`-dimensional learned metric space :math:`X L^{\top}`, -.. in which standard Euclidean distances may be used. -.. - ``transform(X)``, which applies the aforementioned transformation. -.. - ``score_pairs(pairs)`` which returns the distance between pairs of -.. points. ``pairs`` should be a 3D array-like of pairs of shape ``(n_pairs, -.. 2, n_features)``, or it can be a 2D array-like of pairs indicators of -.. shape ``(n_pairs, 2)`` (see section :ref:`preprocessor_section` for more -.. details). \ No newline at end of file diff --git a/doc/supervised.rst b/doc/supervised.rst index c6d8b68b..e27b58ec 100644 --- a/doc/supervised.rst +++ b/doc/supervised.rst @@ -69,10 +69,10 @@ Also, as explained before, our metric learners has learn a distance between points. You can use this distance in two main ways: - You can either return the distance between pairs of points using the - `score_pairs` function: + `pair_distance` function: ->>> nca.score_pairs([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]]]) -array([0.49627072, 3.65287282]) +>>> nca.pair_distance([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) +array([0.49627072, 3.65287282, 6.06079877]) - Or you can return a function that will return the distance (in the new space) between two 1D arrays (the coordinates of the points in the original @@ -82,6 +82,18 @@ array([0.49627072, 3.65287282]) >>> metric_fun([3.5, 3.6], [5.6, 2.4]) 0.4962707194621285 +- Alternatively, you can use `pair_score` to return the **score** between + pairs of points (the larger the score, the more similar the pair). + For Mahalanobis learners, it is equal to the opposite of the distance. + +>>> score = nca.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) +>>> score +array([-0.49627072, -3.65287282, -6.06079877]) + +This is useful because `pair_score` matches the **score** semantic of +scikit-learn's `Classification metrics +`_. + .. note:: If the metric learner that you use learns a :ref:`Mahalanobis distance @@ -93,7 +105,6 @@ array([0.49627072, 3.65287282]) array([[0.43680409, 0.89169412], [0.89169412, 1.9542479 ]]) -.. TODO: remove the "like it is the case etc..." if it's not the case anymore Scikit-learn compatibility -------------------------- @@ -105,6 +116,7 @@ All supervised algorithms are scikit-learn estimators scikit-learn model selection routines (`sklearn.model_selection.cross_val_score`, `sklearn.model_selection.GridSearchCV`, etc). +You can also use some of the scoring functions from `sklearn.metrics`. Algorithms ========== diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index 174210b8..02ea4ef6 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -160,9 +160,9 @@ Also, as explained before, our metric learner has learned a distance between points. You can use this distance in two main ways: - You can either return the distance between pairs of points using the - `score_pairs` function: + `pair_distance` function: ->>> mmc.score_pairs([[[3.5, 3.6, 5.2], [5.6, 2.4, 6.7]], +>>> mmc.pair_distance([[[3.5, 3.6, 5.2], [5.6, 2.4, 6.7]], ... [[1.2, 4.2, 7.7], [2.1, 6.4, 0.9]]]) array([7.27607365, 0.88853014]) @@ -175,6 +175,18 @@ array([7.27607365, 0.88853014]) >>> metric_fun([3.5, 3.6, 5.2], [5.6, 2.4, 6.7]) 7.276073646278203 +- Alternatively, you can use `pair_score` to return the **score** between + pairs of points (the larger the score, the more similar the pair). + For Mahalanobis learners, it is equal to the opposite of the distance. + +>>> score = mmc.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) +>>> score +array([-0.49627072, -3.65287282, -6.06079877]) + + This is useful because `pair_score` matches the **score** semantic of + scikit-learn's `Classification metrics + `_. + .. note:: If the metric learner that you use learns a :ref:`Mahalanobis distance @@ -187,8 +199,6 @@ array([[ 0.58603894, -5.69883982, -1.66614919], [-5.69883982, 55.41743549, 16.20219519], [-1.66614919, 16.20219519, 4.73697721]]) -.. TODO: remove the "like it is the case etc..." if it's not the case anymore - .. _sklearn_compat_ws: Prediction and scoring @@ -344,8 +354,8 @@ returns the `sklearn.metrics.roc_auc_score` (which is threshold-independent). .. note:: See :ref:`fit_ws` for more details on metric learners functions that are - not specific to learning on pairs, like `transform`, `score_pairs`, - `get_metric` and `get_mahalanobis_matrix`. + not specific to learning on pairs, like `transform`, `pair_distance`, + `pair_score`, `get_metric` and `get_mahalanobis_matrix`. Algorithms ---------- @@ -691,8 +701,8 @@ of triplets that have the right predicted ordering. .. note:: See :ref:`fit_ws` for more details on metric learners functions that are - not specific to learning on pairs, like `transform`, `score_pairs`, - `get_metric` and `get_mahalanobis_matrix`. + not specific to learning on pairs, like `transform`, `pair_distance`, + `pair_score`, `get_metric` and `get_mahalanobis_matrix`. @@ -859,8 +869,8 @@ of quadruplets have the right predicted ordering. .. note:: See :ref:`fit_ws` for more details on metric learners functions that are - not specific to learning on pairs, like `transform`, `score_pairs`, - `get_metric` and `get_mahalanobis_matrix`. + not specific to learning on pairs, like `transform`, `pair_distance`, + `pair_score`, `get_metric` and `get_mahalanobis_matrix`. diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 21506011..e7dbd608 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -9,6 +9,7 @@ import numpy as np from abc import ABCMeta, abstractmethod from ._util import ArrayIndexer, check_input, validate_vector +import warnings class BaseMetricLearner(BaseEstimator, metaclass=ABCMeta): @@ -27,13 +28,54 @@ def __init__(self, preprocessor=None): @abstractmethod def score_pairs(self, pairs): - """Returns the score between pairs + """ + .. deprecated:: 0.7.0 Refer to `pair_distance` and `pair_score`. + + .. warning:: + This method will be removed in 0.8.0. Please refer to `pair_distance` + or `pair_score`. This change will occur in order to add learners + that don't necessarily learn a Mahalanobis distance. + + Returns the score between pairs (can be a similarity, or a distance/metric depending on the algorithm) Parameters ---------- - pairs : `numpy.ndarray`, shape=(n_samples, 2, n_features) - 3D array of pairs. + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The score of every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the metric between + two points. The difference between `score_pairs` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. + """ + + @abstractmethod + def pair_score(self, pairs): + """ + .. versionadded:: 0.7.0 Compute the similarity score between pairs + + Returns the similarity score between pairs of points (the larger the score, + the more similar the pair). For metric learners that learn a distance, + the score is simply the opposite of the distance between pairs. All + learners have access to this method. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. Returns ------- @@ -43,10 +85,40 @@ def score_pairs(self, pairs): See Also -------- get_metric : a method that returns a function to compute the metric between - two points. The difference with `score_pairs` is that it works on two 1D - arrays and cannot use a preprocessor. Besides, the returned function is - independent of the metric learner and hence is not modified if the metric - learner is. + two points. The difference with `pair_score` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. + """ + + @abstractmethod + def pair_distance(self, pairs): + """ + .. versionadded:: 0.7.0 Compute the distance between pairs + + Returns the (pseudo) distance between pairs, when available. For metric + learners that do not learn a (pseudo) distance, an error is thrown + instead. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs for which to compute the distance, with each + row corresponding to two points, for 2D array of indices of pairs + if the metric learner uses a preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The distance between every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the metric between + two points. The difference with `pair_distance` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. """ def _check_preprocessor(self): @@ -102,8 +174,10 @@ def _prepare_inputs(self, X, y=None, type_of_inputs='classic', @abstractmethod def get_metric(self): - """Returns a function that takes as input two 1D arrays and outputs the - learned metric score on these two points. + """Returns a function that takes as input two 1D arrays and outputs + the value of the learned metric on these two points. Depending on the + algorithm, it can return a distance or a similarity function between + pairs. This function will be independent from the metric learner that learned it (it will not be modified if the initial metric learner is modified), @@ -136,10 +210,17 @@ def get_metric(self): See Also -------- - score_pairs : a method that returns the metric score between several pairs - of points. Unlike `get_metric`, this is a method of the metric learner - and therefore can change if the metric learner changes. Besides, it can - use the metric learner's preprocessor, and works on concatenated arrays. + pair_distance : a method that returns the distance between several + pairs of points. Unlike `get_metric`, this is a method of the metric + learner and therefore can change if the metric learner changes. Besides, + it can use the metric learner's preprocessor, and works on concatenated + arrays. + + pair_score : a method that returns the similarity score between + several pairs of points. Unlike `get_metric`, this is a method of the + metric learner and therefore can change if the metric learner changes. + Besides, it can use the metric learner's preprocessor, and works on + concatenated arrays. """ @@ -182,13 +263,92 @@ class MahalanobisMixin(BaseMetricLearner, MetricTransformer, """ def score_pairs(self, pairs): - r"""Returns the learned Mahalanobis distance between pairs. + r""" + .. deprecated:: 0.7.0 + This method is deprecated. Please use `pair_distance` instead. + + .. warning:: + This method will be removed in 0.8.0. Please refer to `pair_distance` + or `pair_score`. This change will occur in order to add learners + that don't necessarily learn a Mahalanobis distance. + + Returns the learned Mahalanobis distance between pairs. + + This distance is defined as: :math:`d_M(x, x') = \\sqrt{(x-x')^T M (x-x')}` + where ``M`` is the learned Mahalanobis matrix, for every pair of points + ``x`` and ``x'``. This corresponds to the euclidean distance between + embeddings of the points in a new space, obtained through a linear + transformation. Indeed, we have also: :math:`d_M(x, x') = \\sqrt{(x_e - + x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See + :class:`MahalanobisMixin`). + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The learned Mahalanobis distance for every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the metric between + two points. The difference with `score_pairs` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. + + :ref:`mahalanobis_distances` : The section of the project documentation + that describes Mahalanobis Distances. + """ + dpr_msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + warnings.warn(dpr_msg, category=FutureWarning) + return self.pair_distance(pairs) + + def pair_score(self, pairs): + """ + Returns the opposite of the learned Mahalanobis distance between pairs. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The opposite of the learned Mahalanobis distance for every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the metric between + two points. The difference with `pair_score` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. + + :ref:`mahalanobis_distances` : The section of the project documentation + that describes Mahalanobis Distances. + """ + return -1 * self.pair_distance(pairs) + + def pair_distance(self, pairs): + """ + Returns the learned Mahalanobis distance between pairs. - This distance is defined as: :math:`d_M(x, x') = \sqrt{(x-x')^T M (x-x')}` + This distance is defined as: :math:`d_M(x, x') = \\sqrt{(x-x')^T M (x-x')}` where ``M`` is the learned Mahalanobis matrix, for every pair of points ``x`` and ``x'``. This corresponds to the euclidean distance between embeddings of the points in a new space, obtained through a linear - transformation. Indeed, we have also: :math:`d_M(x, x') = \sqrt{(x_e - + transformation. Indeed, we have also: :math:`d_M(x, x') = \\sqrt{(x_e - x_e')^T (x_e- x_e')}`, with :math:`x_e = L x` (See :class:`MahalanobisMixin`). @@ -207,10 +367,10 @@ def score_pairs(self, pairs): See Also -------- get_metric : a method that returns a function to compute the metric between - two points. The difference with `score_pairs` is that it works on two 1D - arrays and cannot use a preprocessor. Besides, the returned function is - independent of the metric learner and hence is not modified if the metric - learner is. + two points. The difference with `pair_distance` is that it works on two + 1D arrays and cannot use a preprocessor. Besides, the returned function + is independent of the metric learner and hence is not modified if the + metric learner is. :ref:`mahalanobis_distances` : The section of the project documentation that describes Mahalanobis Distances. @@ -361,7 +521,7 @@ def decision_function(self, pairs): pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) - return - self.score_pairs(pairs) + return self.pair_score(pairs) def score(self, pairs, y): """Computes score of pairs similarity prediction. @@ -631,8 +791,8 @@ def decision_function(self, triplets): triplets = check_input(triplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) - return (self.score_pairs(triplets[:, [0, 2]]) - - self.score_pairs(triplets[:, :2])) + return (self.pair_score(triplets[:, :2]) - + self.pair_score(triplets[:, [0, 2]])) def score(self, triplets): """Computes score on input triplets. @@ -716,8 +876,8 @@ def decision_function(self, quadruplets): quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) - return (self.score_pairs(quadruplets[:, 2:]) - - self.score_pairs(quadruplets[:, :2])) + return (self.pair_score(quadruplets[:, :2]) - + self.pair_score(quadruplets[:, 2:])) def score(self, quadruplets): """Computes score on input quadruplets diff --git a/test/test_base_metric.py b/test/test_base_metric.py index 67f9b6a0..baa585b9 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -1,3 +1,4 @@ +from numpy.core.numeric import array_equal import pytest import re import unittest @@ -274,5 +275,28 @@ def test_n_components(estimator, build_dataset): 'Invalid n_components, must be in [1, {}]'.format(X.shape[1])) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_score_pairs_warning(estimator, build_dataset): + """Tests that score_pairs returns a FutureWarning regarding deprecation. + Also that score_pairs and pair_distance have the same behaviour""" + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + + # We fit the metric learner on it and then we call score_pairs on some + # points + model.fit(*remove_y(model, input_data, labels)) + + msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + with pytest.warns(FutureWarning) as raised_warning: + score = model.score_pairs([[X[0], X[1]], ]) + dist = model.pair_distance([[X[0], X[1]], ]) + assert array_equal(score, dist) + assert any([str(warning.message) == msg for warning in raised_warning]) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index e3d981a4..e2aa1e4d 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -3,7 +3,8 @@ import pytest import numpy as np from numpy.linalg import LinAlgError -from numpy.testing import assert_array_almost_equal, assert_allclose +from numpy.testing import assert_array_almost_equal, assert_allclose, \ + assert_array_equal from scipy.spatial.distance import pdist, squareform, mahalanobis from scipy.stats import ortho_group from sklearn import clone @@ -27,7 +28,27 @@ @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) -def test_score_pairs_pairwise(estimator, build_dataset): +def test_pair_distance_pair_score_equivalent(estimator, build_dataset): + """ + For Mahalanobis learners, pair_score should be equivalent to the + opposite of the pair_distance result. + """ + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + + distances = model.pair_distance(np.array(list(product(X, X)))) + scores = model.pair_score(np.array(list(product(X, X)))) + + assert_array_equal(distances, -1 * scores) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_pair_distance_pairwise(estimator, build_dataset): # Computing pairwise scores should return a euclidean distance matrix. input_data, labels, _, X = build_dataset() n_samples = 20 @@ -36,7 +57,7 @@ def test_score_pairs_pairwise(estimator, build_dataset): set_random_state(model) model.fit(*remove_y(estimator, input_data, labels)) - pairwise = model.score_pairs(np.array(list(product(X, X))))\ + pairwise = model.pair_distance(np.array(list(product(X, X))))\ .reshape(n_samples, n_samples) check_is_distance_matrix(pairwise) @@ -51,8 +72,8 @@ def test_score_pairs_pairwise(estimator, build_dataset): @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) -def test_score_pairs_toy_example(estimator, build_dataset): - # Checks that score_pairs works on a toy example +def test_pair_distance_toy_example(estimator, build_dataset): + # Checks that pair_distance works on a toy example input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] @@ -64,24 +85,24 @@ def test_score_pairs_toy_example(estimator, build_dataset): distances = np.sqrt(np.sum((embedded_pairs[:, 1] - embedded_pairs[:, 0])**2, axis=-1)) - assert_array_almost_equal(model.score_pairs(pairs), distances) + assert_array_almost_equal(model.pair_distance(pairs), distances) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) -def test_score_pairs_finite(estimator, build_dataset): +def test_pair_distance_finite(estimator, build_dataset): # tests that the score is finite input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) model.fit(*remove_y(estimator, input_data, labels)) pairs = np.array(list(product(X, X))) - assert np.isfinite(model.score_pairs(pairs)).all() + assert np.isfinite(model.pair_distance(pairs)).all() @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) -def test_score_pairs_dim(estimator, build_dataset): +def test_pair_distance_dim(estimator, build_dataset): # scoring of 3D arrays should return 1D array (several tuples), # and scoring of 2D arrays (one tuple) should return an error (like # scikit-learn's error when scoring 1D arrays) @@ -90,13 +111,13 @@ def test_score_pairs_dim(estimator, build_dataset): set_random_state(model) model.fit(*remove_y(estimator, input_data, labels)) tuples = np.array(list(product(X, X))) - assert model.score_pairs(tuples).shape == (tuples.shape[0],) + assert model.pair_distance(tuples).shape == (tuples.shape[0],) context = make_context(estimator) msg = ("3D array of formed tuples expected{}. Found 2D array " "instead:\ninput={}. Reshape your data and/or use a preprocessor.\n" .format(context, tuples[1])) with pytest.raises(ValueError) as raised_error: - model.score_pairs(tuples[1]) + model.pair_distance(tuples[1]) assert str(raised_error.value) == msg @@ -140,7 +161,7 @@ def test_embed_dim(estimator, build_dataset): "instead:\ninput={}. Reshape your data and/or use a " "preprocessor.\n".format(context, X[0])) with pytest.raises(ValueError) as raised_error: - model.score_pairs(model.transform(X[0, :])) + model.pair_distance(model.transform(X[0, :])) assert str(raised_error.value) == err_msg # we test that the shape is also OK when doing dimensionality reduction if hasattr(model, 'n_components'): @@ -625,7 +646,7 @@ def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): 'preprocessing step.') with pytest.warns(UserWarning) as raised_warning: model.fit(input_data, labels) - assert np.any([str(warning.message) == msg for warning in raised_warning]) + assert any([str(warning.message) == msg for warning in raised_warning]) M, _ = _initialize_metric_mahalanobis(X, init='covariance', random_state=RNG, return_inverse=True, diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 824bb622..714cbd08 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -49,14 +49,14 @@ def test_predict_monotonous(estimator, build_dataset, pairs_train, pairs_test, y_train, y_test = train_test_split(input_data, labels) estimator.fit(pairs_train, y_train) - distances = estimator.score_pairs(pairs_test) + scores = estimator.pair_score(pairs_test) predictions = estimator.predict(pairs_test) - min_dissimilar = np.min(distances[predictions == -1]) - max_similar = np.max(distances[predictions == 1]) - assert max_similar <= min_dissimilar - separator = np.mean([min_dissimilar, max_similar]) - assert (predictions[distances > separator] == -1).all() - assert (predictions[distances < separator] == 1).all() + max_dissimilar = np.max(scores[predictions == -1]) + min_similar = np.min(scores[predictions == 1]) + assert max_dissimilar <= min_similar + separator = np.mean([max_dissimilar, min_similar]) + assert (predictions[scores < separator] == -1).all() + assert (predictions[scores > separator] == 1).all() @pytest.mark.parametrize('with_preprocessor', [True, False]) @@ -65,15 +65,17 @@ def test_predict_monotonous(estimator, build_dataset, def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): """Test that a NotFittedError is raised if someone tries to use - score_pairs, decision_function, get_metric, transform or + pair_score, score_pairs, decision_function, get_metric, transform or get_mahalanobis_matrix on input data and the metric learner has not been fitted.""" input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) - with pytest.raises(NotFittedError): + with pytest.raises(NotFittedError): # Remove in 0.8.0 estimator.score_pairs(input_data) + with pytest.raises(NotFittedError): + estimator.pair_score(input_data) with pytest.raises(NotFittedError): estimator.decision_function(input_data) with pytest.raises(NotFittedError): diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 3ad69712..b08fcf25 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -147,8 +147,19 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): pairs = np.array([[X[0], X[1]], [X[0], X[2]]]) pairs_variants, _ = generate_array_like(pairs) + + not_implemented_msg = "" + # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says + # "This learner does not have pair_distance" + for pairs_variant in pairs_variants: - estimator.score_pairs(pairs_variant) + estimator.pair_score(pairs_variant) # All learners have pair_score + + # But not all of them will have pair_distance + try: + estimator.pair_distance(pairs_variant) + except Exception as raised_exception: + assert raised_exception.value.args[0] == not_implemented_msg @pytest.mark.parametrize('with_preprocessor', [True, False]) diff --git a/test/test_utils.py b/test/test_utils.py index 072b94c5..83bdd86a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -834,9 +834,9 @@ def test_error_message_tuple_size(estimator, _): @pytest.mark.parametrize('estimator, _', metric_learners, ids=ids_metric_learners) -def test_error_message_t_score_pairs(estimator, _): - """tests that if you want to score_pairs on triplets for instance, it returns - the right error message +def test_error_message_t_pair_distance_or_score(estimator, _): + """Tests that if you want to pair_distance or pair_score on triplets + for instance, it returns the right error message """ estimator = clone(estimator) set_random_state(estimator) @@ -844,12 +844,22 @@ def test_error_message_t_score_pairs(estimator, _): triplets = np.array([[[1.3, 6.3], [3., 6.8], [6.5, 4.4]], [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) with pytest.raises(ValueError) as raised_err: - estimator.score_pairs(triplets) + estimator.pair_score(triplets) expected_msg = ("Tuples of 2 element(s) expected{}. Got tuples of 3 " "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" .format(make_context(estimator), triplets)) assert str(raised_err.value) == expected_msg + not_implemented_msg = "" + # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says + # "This learner does not have pair_distance" + + # One exception will trigger for sure + with pytest.raises(Exception) as raised_exception: + estimator.pair_distance(triplets) + err_value = raised_exception.value.args[0] + assert err_value == expected_msg or err_value == not_implemented_msg + def test_preprocess_tuples_simple_example(): """Test the preprocessor on a very simple example of tuples to ensure the @@ -930,32 +940,59 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): method)(formed_test) assert np.array(output_with_prep == output_with_prep_formed).all() - # test score_pairs + # Test pair_score, all learners have it. idx1 = np.array([[0, 2], [5, 3]], dtype=int) - output_with_prep = estimator_with_preprocessor.score_pairs( + output_with_prep = estimator_with_preprocessor.pair_score( indicators_to_transform[idx1]) - output_without_prep = estimator_without_preprocessor.score_pairs( + output_without_prep = estimator_without_preprocessor.pair_score( formed_points_to_transform[idx1]) assert np.array(output_with_prep == output_without_prep).all() - output_with_prep = estimator_with_preprocessor.score_pairs( + output_with_prep = estimator_with_preprocessor.pair_score( indicators_to_transform[idx1]) - output_without_prep = estimator_with_prep_formed.score_pairs( + output_without_prep = estimator_with_prep_formed.pair_score( formed_points_to_transform[idx1]) assert np.array(output_with_prep == output_without_prep).all() - # test transform - output_with_prep = estimator_with_preprocessor.transform( - indicators_to_transform) - output_without_prep = estimator_without_preprocessor.transform( - formed_points_to_transform) - assert np.array(output_with_prep == output_without_prep).all() - - output_with_prep = estimator_with_preprocessor.transform( - indicators_to_transform) - output_without_prep = estimator_with_prep_formed.transform( - formed_points_to_transform) - assert np.array(output_with_prep == output_without_prep).all() + # Test pair_distance + not_implemented_msg = "" + # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says + # "This learner does not have pair_distance" + try: + output_with_prep = estimator_with_preprocessor.pair_distance( + indicators_to_transform[idx1]) + output_without_prep = estimator_without_preprocessor.pair_distance( + formed_points_to_transform[idx1]) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.pair_distance( + indicators_to_transform[idx1]) + output_without_prep = estimator_with_prep_formed.pair_distance( + formed_points_to_transform[idx1]) + assert np.array(output_with_prep == output_without_prep).all() + + except Exception as raised_exception: + assert raised_exception.value.args[0] == not_implemented_msg + + # Test transform + not_implemented_msg = "" + # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says + # "This learner does not have transform" + try: + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_without_preprocessor.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.transform( + indicators_to_transform) + output_without_prep = estimator_with_prep_formed.transform( + formed_points_to_transform) + assert np.array(output_with_prep == output_without_prep).all() + + except Exception as raised_exception: + assert raised_exception.value.args[0] == not_implemented_msg def test_check_collapsed_pairs_raises_no_error():