diff --git a/doc/introduction.rst b/doc/introduction.rst index e9ff0015..8c8e40fb 100644 --- a/doc/introduction.rst +++ b/doc/introduction.rst @@ -4,15 +4,14 @@ What is Metric Learning? ======================== -Many approaches in machine learning require a measure of distance between data -points. Traditionally, practitioners would choose a standard distance metric +Many approaches in machine learning require a measure of distance (or similarity) +between data points. Traditionally, practitioners would choose a standard metric (Euclidean, City-Block, Cosine, etc.) using a priori knowledge of the domain. However, it is often difficult to design metrics that are well-suited to the particular data and task of interest. -Distance metric learning (or simply, metric learning) aims at -automatically constructing task-specific distance metrics from (weakly) -supervised data, in a machine learning manner. The learned distance metric can +Metric learning aims at automatically constructing task-specific metrics from +(weakly) supervised data, in a machine learning manner. The learned metric can then be used to perform various tasks (e.g., k-NN classification, clustering, information retrieval). @@ -25,19 +24,19 @@ of supervision available about the training data: - :doc:`Supervised learning `: the algorithm has access to a set of data points, each of them belonging to a class (label) as in a standard classification problem. - Broadly speaking, the goal in this setting is to learn a distance metric + Broadly speaking, the goal in this setting is to learn a metric that puts points with the same label close together while pushing away points with different labels. - :doc:`Weakly supervised learning `: the algorithm has access to a set of data points with supervision only at the tuple level (typically pairs, triplets, or quadruplets of data points). A classic example of such weaker supervision is a set of - positive and negative pairs: in this case, the goal is to learn a distance + positive and negative pairs: in this case, the goal is to learn a metric that puts positive pairs close together and negative pairs far away. Based on the above (weakly) supervised data, the metric learning problem is generally formulated as an optimization problem where one seeks to find the -parameters of a distance function that optimize some objective function +parameters of a metric that optimize some objective function measuring the agreement with the training data. .. _mahalanobis_distances: @@ -45,7 +44,7 @@ measuring the agreement with the training data. Mahalanobis Distances ===================== -In the metric-learn package, all algorithms currently implemented learn +In the metric-learn package, most algorithms currently implemented learn so-called Mahalanobis distances. Given a real-valued parameter matrix :math:`L` of shape ``(num_dims, n_features)`` where ``n_features`` is the number features describing the data, the Mahalanobis distance associated with @@ -60,8 +59,8 @@ Mahalanobis distance metric learning can thus be seen as learning a new embedding space of dimension ``num_dims``. Note that when ``num_dims`` is smaller than ``n_features``, this achieves dimensionality reduction. -Strictly speaking, Mahalanobis distances are "pseudo-metrics": they satisfy -three of the `properties of a metric `_ (non-negativity, symmetry, triangle inequality) but not necessarily the identity of indiscernibles. @@ -79,6 +78,35 @@ necessarily the identity of indiscernibles. parameterizations are equivalent. In practice, an algorithm may thus solve the metric learning problem with respect to either :math:`M` or :math:`L`. +.. __bilinear_similarities: + +Bilinear Similarities +===================== + +Some algorithms in the package learn bilinear similarity functions. These +similarity functions are not pseudo-distances: they simply output real values +such that the larger the similarity value, the more similar the two examples. +Given a real-valued parameter matrix :math:`W` of shape +``(n_features, n_features)`` where ``n_features`` is the number features +describing the data, the bilinear similarity associated with :math:`W` is +defined as follows: + +.. math:: S_W(x, x') = x^\top W x' + +The matrix :math:`W` is not required to be positive semi-definite (PSD) or +even symmetric, so the distance properties (nonnegativity, identity of +indiscernibles, symmetry and triangle inequality) do not hold in general. + +This allows some algorithms to optimize :math:`S_W` in an online manner using a +simple and efficient procedure, and thus can be applied to problems with +millions of training instances and achieves state-of-the-art performance +on an image search task using :math:`k`-NN. + +The absence of PSD constraint can enable the design of more efficient +algorithms. It is also relevant in applications where the underlying notion +of similarity does not satisfy the triangle inequality, as known to be the +case for visual judgments. + .. _use_cases: Use-cases @@ -99,9 +127,9 @@ examples (for code illustrating some of these use-cases, see the elements of a database that are semantically closest to a query element. - Dimensionality reduction: metric learning may be seen as a way to reduce the data dimension in a (weakly) supervised setting. -- More generally, the learned transformation :math:`L` can be used to project - the data into a new embedding space before feeding it into another machine - learning algorithm. +- More generally with Mahalanobis distances, the learned transformation :math:`L` + can be used to project the data into a new embedding space before feeding it + into another machine learning algorithm. The API of metric-learn is compatible with `scikit-learn `_, the leading library for machine diff --git a/doc/supervised.rst b/doc/supervised.rst index a847a33c..d576554c 100644 --- a/doc/supervised.rst +++ b/doc/supervised.rst @@ -41,11 +41,13 @@ two numbers. Fit, transform, and so on ------------------------- -The goal of supervised metric-learning algorithms is to transform -points in a new space, in which the distance between two points from the -same class will be small, and the distance between two points from different -classes will be large. To do so, we fit the metric learner (example: -`NCA`). +The goal of supervised metric learning algorithms is to learn a (distance or +similarity) metric such that two points from the same class will be similar +(e.g., have small distance) and points from different classes will be dissimilar +(e.g., have large distance). + +To do so, we first need to fit the supervised metric learner on a labeled dataset, +as in the example below with ``NCA``. >>> from metric_learn import NCA >>> nca = NCA(random_state=42) @@ -53,58 +55,79 @@ classes will be large. To do so, we fit the metric learner (example: NCA(init='auto', max_iter=100, n_components=None, preprocessor=None, random_state=42, tol=None, verbose=False) - Now that the estimator is fitted, you can use it on new data for several purposes. -First, you can transform the data in the learned space, using `transform`: -Here we transform two points in the new embedding space. +We can use the learned metric to **score** new pairs of points with ``pair_score`` +(the larger the score, the more similar the pair). For Mahalanobis learners, +it is equal to the opposite of the distance. ->>> X_new = np.array([[9.4, 4.1], [2.1, 4.4]]) ->>> nca.transform(X_new) -array([[ 5.91884732, 10.25406973], - [ 3.1545886 , 6.80350083]]) +>>> score = nca.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) +>>> score +array([-0.49627072, -3.65287282, -6.06079877]) -Also, as explained before, our metric learners has learn a distance between -points. You can use this distance in two main ways: +This is useful because ``pair_score`` matches the **score** semantic of +scikit-learn's `Classification metrics +`_. -- You can either return the distance between pairs of points using the - `pair_distance` function: +For metric learners that learn a distance metric, there is also the ``pair_distance`` +method. >>> nca.pair_distance([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) array([0.49627072, 3.65287282, 6.06079877]) -- Or you can return a function that will return the distance (in the new - space) between two 1D arrays (the coordinates of the points in the original - space), similarly to distance functions in `scipy.spatial.distance`. +.. warning:: + + If you try to use ``pair_distance`` with a bilinear similarity learner, an error + will be thrown, as it does not learn a distance. + +You can also return a function that will return the metric learned. It can +compute the metric between two 1D arrays, similarly to distance functions in +`scipy.spatial.distance`. To do that, use the ``get_metric`` method. >>> metric_fun = nca.get_metric() >>> metric_fun([3.5, 3.6], [5.6, 2.4]) 0.4962707194621285 -- Alternatively, you can use `pair_score` to return the **score** between - pairs of points (the larger the score, the more similar the pair). - For Mahalanobis learners, it is equal to the opposite of the distance. +You can also call ``get_metric`` with bilinear similarity learners, and you will get +a function that will return the similarity between 1D arrays. ->>> score = nca.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) ->>> score -array([-0.49627072, -3.65287282, -6.06079877]) +>>> similarity_fun = algorithm.get_metric() +>>> similarity_fun([3.5, 3.6], [5.6, 2.4]) +-0.04752 -This is useful because `pair_score` matches the **score** semantic of -scikit-learn's `Classification metrics -`_. +Finally, as explained in :ref:`mahalanobis_distances`, these are equivalent to the Euclidean +distance in a transformed space, and can thus be used to transform data points in +a new embedding space. You can use ``transform`` to do so. + +>>> X_new = np.array([[9.4, 4.1], [2.1, 4.4]]) +>>> nca.transform(X_new) +array([[ 5.91884732, 10.25406973], + [ 3.1545886 , 6.80350083]]) + +.. warning:: + + If you try to use ``transform`` with a bilinear similarity learner, an error will + be thrown, as you cannot transform the data using them. .. note:: If the metric learner that you use learns a :ref:`Mahalanobis distance - ` (like it is the case for all algorithms - currently in metric-learn), you can get the plain learned Mahalanobis - matrix using `get_mahalanobis_matrix`. + `, you can get the learned Mahalanobis + matrix :math:`M` using `get_mahalanobis_matrix`. >>> nca.get_mahalanobis_matrix() array([[0.43680409, 0.89169412], [0.89169412, 1.9542479 ]]) + If the metric learner that you use learns a :ref:`bilinear similarity + <_bilinear_similarities>`, you can get the plain learned Bilinear + matrix :math:`W` using `get_bilinear_matrix`. + + >>> algorithm.get_bilinear_matrix() + array([[-0.72680409, -0.153213], + [1.45542269, 7.8135546 ]]) + Scikit-learn compatibility -------------------------- @@ -116,7 +139,7 @@ All supervised algorithms are scikit-learn estimators scikit-learn model selection routines (`sklearn.model_selection.cross_val_score`, `sklearn.model_selection.GridSearchCV`, etc). -You can also use some of the scoring functions from `sklearn.metrics`. +You can also use some scoring functions from `sklearn.metrics`. Algorithms ========== @@ -249,12 +272,12 @@ the sum of probability of being correctly classified: Local Fisher Discriminant Analysis (:py:class:`LFDA `) `LFDA` is a linear supervised dimensionality reduction method which effectively combines the ideas of `Linear Discriminant Analysis ` and Locality-Preserving Projection . It is -particularly useful when dealing with multi-modality, where one ore more classes +particularly useful when dealing with multi-modality, where one or more classes consist of separate clusters in input space. The core optimization problem of LFDA is solved as a generalized eigenvalue problem. -The algorithm define the Fisher local within-/between-class scatter matrix +The algorithm defines the Fisher local within-/between-class scatter matrix :math:`\mathbf{S}^{(w)}/ \mathbf{S}^{(b)}` in a pairwise fashion: .. math:: @@ -410,7 +433,7 @@ method will look at all the samples from a different class and sample randomly a pair among them. The method will try to build `n_constraints` positive pairs and `n_constraints` negative pairs, but sometimes it cannot find enough of one of those, so forcing `same_length=True` will return both times the -minimum of the two lenghts. +minimum of the two lengths. For using quadruplets learners (see :ref:`learning_on_quadruplets`) in a supervised way, positive and negative pairs are sampled as above and diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index 76f7c14e..e4c6a7f4 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -80,11 +80,13 @@ Here is an artificial dataset of 4 pairs of 2 points of 3 features each: >>> [-2.16, +0.11, -0.02]]]) # same as tuples[1, 0, :] >>> y = np.array([-1, 1, 1, -1]) -.. warning:: This way of specifying pairs is not recommended for a large number - of tuples, as it is redundant (see the comments in the example) and hence - takes a lot of memory. Indeed each feature vector of a point will be - replicated as many times as a point is involved in a tuple. The second way - to specify pairs is more efficient +.. warning:: + + This way of specifying pairs is not recommended for a large number + of tuples, as it is redundant (see the comments in the example) and hence + takes a lot of memory. Indeed, each feature vector of a point will be + replicated as many times as a point is involved in a tuple. The second way + to specify pairs is more efficient 2D array of indicators + preprocessor @@ -130,9 +132,12 @@ through the argument `preprocessor` (see below :ref:`fit_ws`) Fit, transform, and so on ------------------------- -The goal of weakly-supervised metric-learning algorithms is to transform -points in a new space, in which the tuple-wise constraints between points -are respected. +The goal of weakly supervised metric learning algorithms is to learn a (distance +or similarity) metric such that the tuple-wise constraints between points are +respected. + +To do so, we first need to fit the weakly supervised metric learner on a dataset +of tuples, as in the example below with ``MMC``. >>> from metric_learn import MMC >>> mmc = MMC(random_state=42) @@ -145,62 +150,82 @@ Or alternatively (using a preprocessor): >>> from metric_learn import MMC >>> mmc = MMC(preprocessor=X, random_state=42) ->>> mmc.fit(pairs_indice, y) - +>>> mmc.fit(pairs_indices, y) Now that the estimator is fitted, you can use it on new data for several purposes. -First, you can transform the data in the learned space, using `transform`: -Here we transform two points in the new embedding space. +We can use the learned metric to **score** new pairs of points with ``pair_score`` +(the larger the score, the more similar the pair). For Mahalanobis learners, +it is equal to the opposite of the distance. ->>> X_new = np.array([[9.4, 4.1, 4.2], [2.1, 4.4, 2.3]]) ->>> mmc.transform(X_new) -array([[-3.24667162e+01, 4.62622348e-07, 3.88325421e-08], - [-3.61531114e+01, 4.86778289e-07, 2.12654397e-08]]) +>>> score = mmc.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) +>>> score +array([-0.49627072, -3.65287282, -6.06079877]) -Also, as explained before, our metric learner has learned a distance between -points. You can use this distance in two main ways: +This is useful because ``pair_score`` matches the **score** semantic of +scikit-learn's `Classification metrics +`_. -- You can either return the distance between pairs of points using the - `pair_distance` function: +For metric learners that learn a distance metric, there is also the ``pair_distance`` +method. >>> mmc.pair_distance([[[3.5, 3.6, 5.2], [5.6, 2.4, 6.7]], ... [[1.2, 4.2, 7.7], [2.1, 6.4, 0.9]]]) array([7.27607365, 0.88853014]) -- Or you can return a function that will return the distance - (in the new space) between two 1D arrays (the coordinates of the points in - the original space), similarly to distance functions in - `scipy.spatial.distance`. To do that, use the `get_metric` method. +.. warning:: + + If you try to use ``pair_distance`` with a bilinear similarity learner, an error + will be thrown, as it does not learn a distance. + +You can also return a function that will return the metric learned. It can +compute the metric between two 1D arrays, similarly to distance functions in +`scipy.spatial.distance`. To do that, use the ``get_metric`` method. >>> metric_fun = mmc.get_metric() >>> metric_fun([3.5, 3.6, 5.2], [5.6, 2.4, 6.7]) 7.276073646278203 -- Alternatively, you can use `pair_score` to return the **score** between - pairs of points (the larger the score, the more similar the pair). - For Mahalanobis learners, it is equal to the opposite of the distance. +You can also call ``get_metric``` with bilinear similarity learners, and you will get +a function that will return the similarity between 1D arrays. ->>> score = mmc.pair_score([[[3.5, 3.6], [5.6, 2.4]], [[1.2, 4.2], [2.1, 6.4]], [[3.3, 7.8], [10.9, 0.1]]]) ->>> score -array([-0.49627072, -3.65287282, -6.06079877]) +>>> similarity_fun = algorithm.get_metric() +>>> similarity_fun([3.5, 3.6], [5.6, 2.4]) +-0.04752 + +Finally, as explained in :ref:`mahalanobis_distances`, these are equivalent to the Euclidean +distance in a transformed space, and can thus be used to transform data points in +a new embedding space. You can use ``transform`` to do so. + +>>> X_new = np.array([[9.4, 4.1, 4.2], [2.1, 4.4, 2.3]]) +>>> mmc.transform(X_new) +array([[-3.24667162e+01, 4.62622348e-07, 3.88325421e-08], + [-3.61531114e+01, 4.86778289e-07, 2.12654397e-08]]) - This is useful because `pair_score` matches the **score** semantic of - scikit-learn's `Classification metrics - `_. +.. warning:: + + If you try to use ``transform`` with a bilinear similarity learner, an error will + be thrown, as you cannot transform the data using them. .. note:: If the metric learner that you use learns a :ref:`Mahalanobis distance - ` (like it is the case for all algorithms - currently in metric-learn), you can get the plain Mahalanobis matrix using - `get_mahalanobis_matrix`. + `, you can get the plain learned Mahalanobis + matrix :math:`M` using `get_mahalanobis_matrix`. + + >>> mmc.get_mahalanobis_matrix() + array([[ 0.58603894, -5.69883982, -1.66614919], + [-5.69883982, 55.41743549, 16.20219519], + [-1.66614919, 16.20219519, 4.73697721]]) + + If the metric learner that you use learns a :ref:`bilinear similarity + <_bilinear_similarities>`, you can get the learned bilinear + matrix :math:`W` using `get_bilinear_matrix`. ->>> mmc.get_mahalanobis_matrix() -array([[ 0.58603894, -5.69883982, -1.66614919], - [-5.69883982, 55.41743549, 16.20219519], - [-1.66614919, 16.20219519, 4.73697721]]) + >>> algorithm.get_bilinear_matrix() + array([[-0.72680409, -0.153213], + [1.45542269, 7.8135546 ]]) .. _sklearn_compat_ws: @@ -457,7 +482,7 @@ Mahalanobis matrix :math:`\mathbf{M}`, and a log-determinant divergence between or :math:`\mathbf{\Omega}^{-1}`, where :math:`\mathbf{\Omega}` is the covariance matrix). -The formulated optimization on the semidefinite matrix :math:`\mathbf{M}` +The formulated optimization on the semi-definite matrix :math:`\mathbf{M}` is convex: .. math:: diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index d0ba1ef9..0a5f115b 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -464,6 +464,156 @@ def get_mahalanobis_matrix(self): return self.components_.T.dot(self.components_) +class BilinearMixin(BaseMetricLearner, metaclass=ABCMeta): + r"""Bilinear similarity learning algorithms. + + Algorithm that learns a bilinear similarity :math:`s_W(x, x')`, + defined between two column vectors :math:`x` and :math:`x'` by: :math: + `s_W(x, x') = x W x'`, where :math:`W` is a learned matrix. This matrix + is not guaranteed to be symmetric nor positive semi-definite (PSD). Thus + it cannot be seen as learning a linear transformation of the original + space like Mahalanobis learning algorithms. + + Attributes + ---------- + components_ : `numpy.ndarray`, shape=(n_components, n_features) + The learned bilinear matrix ``W``. + """ + + def score_pairs(self, pairs): + r""" + .. deprecated:: 0.7.0 + This method is deprecated. + + .. warning:: + This method will be removed in 0.8.0. Please refer to `pair_distance` + or `pair_score`. This change will occur in order to add learners + that don't necessarily learn a Mahalanobis distance. + + Returns the learned bilinear similarity between pairs. + + This similarity is defined as: :math:`s_W(x, x') = x^T W x'` + where ``W`` is the learned bilinear matrix, for every pair of points + ``x`` and ``x'``. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the metric learner uses a + preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The learned bilinear similarity for every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the similarity + between two points. The difference with `pair_score` is that it + works on two 1D arrays and cannot use a preprocessor. Besides, the + returned function is independent of the similarity learner and hence + is not modified if the similarity learner is. + + :ref:`_bilinear_similarities` : The section of the project documentation + that describes bilinear similarity. + """ + dpr_msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + warnings.warn(dpr_msg, category=FutureWarning) + return self.pair_score(pairs) + + def pair_distance(self, pairs): + """ + Returns an error, as bilinear similarity learners do not learn a + pseudo-distance nor a distance. In consecuence, the additive inverse + of the bilinear similarity cannot be used as distance by construction. + """ + msg = ("This learner does not learn a distance, thus ", + "this method is not implemented. Use pair_score instead") + raise Exception(msg) + + def pair_score(self, pairs): + r"""Returns the learned bilinear similarity between pairs. + + This similarity is defined as: :math:`s_W(x, x') = x^T W x'` + where ``W`` is the learned bilinear matrix, for every pair of points + ``x`` and ``x'``. + + Parameters + ---------- + pairs : array-like, shape=(n_pairs, 2, n_features) or (n_pairs, 2) + 3D Array of pairs to score, with each row corresponding to two points, + for 2D array of indices of pairs if the similarity learner uses a + preprocessor. + + Returns + ------- + scores : `numpy.ndarray` of shape=(n_pairs,) + The learned bilinear similarity for every pair. + + See Also + -------- + get_metric : a method that returns a function to compute the similarity + between two points. The difference with `pair_score` is that it + works on two 1D arrays and cannot use a preprocessor. Besides, the + returned function is independent of the similarity learner and hence + is not modified if the similarity learner is. + + :ref:`_bilinear_similarities` : The section of the project documentation + that describes bilinear similarity. + """ + check_is_fitted(self, ['preprocessor_']) + pairs = check_input(pairs, type_of_inputs='tuples', + preprocessor=self.preprocessor_, + estimator=self, tuple_size=2) + # Note: For bilinear order matters, dist(a,b) != dist(b,a) + # We always choose first pair first, then second pair + # (In contrast with Mahalanobis implementation) + return np.sum(np.dot(pairs[:, 0, :], self.components_) * pairs[:, 1, :], + axis=-1) + + def get_metric(self): + check_is_fitted(self, 'components_') + components = self.components_.copy() + + def similarity_fun(u, v): + """This function computes the bilinear similarity between u and v, + according to the previously learned bilinear similarity. + + Parameters + ---------- + u : array-like, shape=(n_features,) + The first point involved in the similarity computation. + + v : array-like, shape=(n_features,) + The second point involved in the similarity computation. + + Returns + ------- + similarity : float + The similarity between u and v according to the new similarity. + """ + u = validate_vector(u) + v = validate_vector(v) + return np.dot(np.dot(u.T, components), v) + + return similarity_fun + + def get_bilinear_matrix(self): + """Returns a copy of the bilinear matrix learned by the similarity learner. + + Returns + ------- + M : `numpy.ndarray`, shape=(n_features, n_features) + The copy of the learned bilinear matrix. + """ + check_is_fitted(self, 'components_') + return self.components_ + + class _PairsClassifierMixin(BaseMetricLearner): """Base class for pairs learners. diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index a39c7b3c..79400ae5 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1,3 +1,6 @@ +""" +Tests that are specific for each learner. +""" import unittest import re import pytest diff --git a/test/test_base_metric.py b/test/test_base_metric.py index fa641526..c9b098f7 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -1,12 +1,19 @@ -from numpy.core.numeric import array_equal +""" +Tests general things from the API: String parsing, methods like get_metric, +and deprecation warnings. +""" import pytest import re import unittest import metric_learn import numpy as np +from numpy.testing import assert_array_equal +from itertools import product from sklearn import clone from test.test_utils import ids_metric_learners, metric_learners, remove_y from metric_learn.sklearn_shims import set_random_state, SKLEARN_AT_LEAST_0_22 +from metric_learn._util import make_context +from metric_learn.base_metric import MahalanobisMixin, BilinearMixin def remove_spaces(s): @@ -279,25 +286,71 @@ def test_n_components(estimator, build_dataset): @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) def test_score_pairs_warning(estimator, build_dataset): - """Tests that score_pairs returns a FutureWarning regarding deprecation. - Also that score_pairs and pair_distance have the same behaviour""" + """Tests that score_pairs returns a FutureWarning regarding + deprecation for all learners""" input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - - # We fit the metric learner on it and then we call score_pairs on some - # points model.fit(*remove_y(model, input_data, labels)) msg = ("score_pairs will be deprecated in release 0.7.0. " "Use pair_score to compute similarity scores, or " "pair_distances to compute distances.") with pytest.warns(FutureWarning) as raised_warning: - score = model.score_pairs([[X[0], X[1]], ]) - dist = model.pair_distance([[X[0], X[1]], ]) - assert array_equal(score, dist) + _ = model.score_pairs([[X[0], X[1]], ]) assert any([str(warning.message) == msg for warning in raised_warning]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_pair_score_dim(estimator, build_dataset): + """ + Scoring of 3D arrays should return 1D array (several tuples), + and scoring of 2D arrays (one tuple) should return an error (like + scikit-learn's error when scoring 1D arrays) + """ + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + tuples = np.array(list(product(X, X))) + assert model.pair_score(tuples).shape == (tuples.shape[0],) + context = make_context(model) + msg = ("3D array of formed tuples expected{}. Found 2D array " + "instead:\ninput={}. Reshape your data and/or use a preprocessor.\n" + .format(context, tuples[1])) + with pytest.raises(ValueError) as raised_error: + model.pair_score(tuples[1]) + assert str(raised_error.value) == msg + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_deprecated_score_pairs_same_result(estimator, build_dataset): + """ + Test that `pair_distance` gives the same result as `score_pairs` for + Mahalanobis learnes, and the same for `pair_score` and `score_pairs` + for Bilinear learners. It also checks that the deprecation warning of + `score_pairs` is being shown. + """ + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(model, input_data, labels)) + random_pairs = np.array(list(product(X, X))) + + msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + with pytest.warns(FutureWarning) as raised_warnings: + s1 = model.score_pairs(random_pairs) + if isinstance(model, BilinearMixin): + s2 = model.pair_score(random_pairs) + elif isinstance(model, MahalanobisMixin): + s2 = model.pair_distance(random_pairs) + assert_array_equal(s1, s2) + assert any(str(w.message) == msg for w in raised_warnings) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_bilinear_mixin.py b/test/test_bilinear_mixin.py new file mode 100644 index 00000000..0053a631 --- /dev/null +++ b/test/test_bilinear_mixin.py @@ -0,0 +1,152 @@ +""" +Tests all functionality for Bilinear learners. Correctness, use cases, +warnings, etc. +""" +from itertools import product +import numpy as np +from numpy.testing import assert_array_almost_equal +import pytest +from sklearn import clone +from sklearn.datasets import make_spd_matrix +from sklearn.utils import check_random_state +from metric_learn.sklearn_shims import set_random_state +from test.test_utils import metric_learners_b, ids_metric_learners_b, \ + remove_y, IdentityBilinearLearner, build_classification + +RNG = check_random_state(0) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_b, + ids=ids_metric_learners_b) +def test_same_similarity_with_two_methods(estimator, build_dataset): + """" + Tests that pair_score() and get_metric() give consistent results. + In both cases, the results must match for the same input. + Tests it for 'n_pairs' sampled from 'n' d-dimentional arrays. + """ + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + random_pairs = np.array(list(product(X, X))) + + dist1 = model.pair_score(random_pairs) + dist2 = [model.get_metric()(p[0], p[1]) for p in random_pairs] + + assert_array_almost_equal(dist1, dist2) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_b, + ids=ids_metric_learners_b) +def test_check_correctness_similarity(estimator, build_dataset): + """ + Tests the correctness of the results made from socre_paris(), + get_metric() and get_bilinear_matrix. Results are compared with + the real bilinear similarity calculated in-place. + """ + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + random_pairs = np.array(list(product(X, X))) + + dist1 = model.pair_score(random_pairs) + dist2 = [model.get_metric()(p[0], p[1]) for p in random_pairs] + dist3 = [np.dot(np.dot(p[0].T, model.get_bilinear_matrix()), p[1]) + for p in random_pairs] + desired = [np.dot(np.dot(p[0].T, model.components_), p[1]) + for p in random_pairs] + + assert_array_almost_equal(dist1, desired) # pair_score + assert_array_almost_equal(dist2, desired) # get_metric + assert_array_almost_equal(dist3, desired) # get_metric + + +# This is a `hardcoded` handmade tests, to make sure the computation +# made at BilinearMixin is correct. +def test_check_handmade_example(): + """ + Checks that pair_score() result is correct comparing it with a + handmade example. + """ + u = np.array([0, 1, 2]) + v = np.array([3, 4, 5]) + mixin = IdentityBilinearLearner() + mixin.fit([u, v], [0, 0]) # Identity fit + c = np.array([[2, 4, 6], [6, 4, 2], [1, 2, 3]]) + mixin.components_ = c # Force components_ + dists = mixin.pair_score([[u, v], [v, u]]) + assert_array_almost_equal(dists, [96, 120]) + + +# Note: This test needs to be `hardcoded` as the similarity martix must +# be symmetric. Running on all Bilinear learners will throw an error as +# the matrix can be non-symmetric. +def test_check_handmade_symmetric_example(): + """ + When the bilinear matrix is the identity. The similarity + between two arrays must be equal: S(u,v) = S(v,u). Also + checks the random case: when the matrix is spd and symetric. + """ + input_data, labels, _, X = build_classification() + n_samples = 20 + X = X[:n_samples] + model = clone(IdentityBilinearLearner()) # Identity matrix + set_random_state(model) + model.fit(*remove_y(IdentityBilinearLearner(), input_data, labels)) + random_pairs = np.array(list(product(X, X))) + + pairs_reverse = [[p[1], p[0]] for p in random_pairs] + dist1 = model.pair_score(random_pairs) + dist2 = model.pair_score(pairs_reverse) + assert_array_almost_equal(dist1, dist2) + + # Random pairs for M = spd Matrix + spd_matrix = make_spd_matrix(X[0].shape[-1], random_state=RNG) + model.components_ = spd_matrix + dist1 = model.pair_score(random_pairs) + dist2 = model.pair_score(pairs_reverse) + assert_array_almost_equal(dist1, dist2) + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_b, + ids=ids_metric_learners_b) +def test_pair_score_finite(estimator, build_dataset): + """ + Checks for 'n' pair_score() of 'd' dimentions, that all + similarities are finite numbers: not NaN, +inf or -inf. + Considers a random M for bilinear similarity. + """ + input_data, labels, _, X = build_dataset() + n_samples = 20 + X = X[:n_samples] + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + random_pairs = np.array(list(product(X, X))) + dist1 = model.pair_score(random_pairs) + assert np.isfinite(dist1).all() + + +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_b, + ids=ids_metric_learners_b) +def test_check_error_with_pair_distance(estimator, build_dataset): + """ + Check that calling `pair_distance` is not possible with a Bilinear learner. + An Exception must be shown instead. + """ + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(model, input_data, labels)) + random_pairs = np.array(list(product(X, X))) + + msg = ("This learner does not learn a distance, thus ", + "this method is not implemented. Use pair_score instead") + with pytest.raises(Exception) as e: + _ = model.pair_distance(random_pairs) + assert e.value.args[0] == msg diff --git a/test/test_components_metric_conversion.py b/test/test_components_metric_conversion.py index c6113957..d2c66e49 100644 --- a/test/test_components_metric_conversion.py +++ b/test/test_components_metric_conversion.py @@ -1,3 +1,7 @@ +""" +Tests for Mahalanobis learners, that the transormation matrix (L) squared +is equivalent to the Mahalanobis matrix, even in edge cases. +""" import unittest import numpy as np import pytest diff --git a/test/test_constraints.py b/test/test_constraints.py index 3429d9cc..228702d1 100644 --- a/test/test_constraints.py +++ b/test/test_constraints.py @@ -1,3 +1,7 @@ +""" +Test Contrains generation for positive_negative_pairs and knn_triplets. +Also tests warnings. +""" import pytest import numpy as np from sklearn.utils import shuffle diff --git a/test/test_fit_transform.py b/test/test_fit_transform.py index 246223b0..41f7498e 100644 --- a/test/test_fit_transform.py +++ b/test/test_fit_transform.py @@ -1,3 +1,7 @@ +""" +For each lerner that has `fit` and `transform`, checks that calling them +sequeatially is the same as calling fit_transform from scikit-learn. +""" import unittest import numpy as np from sklearn.datasets import load_iris diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index b5dbc248..5dffbb65 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -1,3 +1,7 @@ +""" +Tests all functionality for Mahalanobis Learners. Correctness, use cases, +warnings, distance properties, transform, dimentions, init, etc. +""" from itertools import product import pytest @@ -8,7 +12,6 @@ from scipy.spatial.distance import pdist, squareform, mahalanobis from scipy.stats import ortho_group from sklearn import clone -from sklearn.cluster import DBSCAN from sklearn.datasets import make_spd_matrix, make_blobs from sklearn.utils import check_random_state, shuffle from sklearn.utils.multiclass import type_of_target @@ -20,14 +23,16 @@ _PairsClassifierMixin) from metric_learn.exceptions import NonPSDError -from test.test_utils import (ids_metric_learners, metric_learners, - remove_y, ids_classifiers) +from test.test_utils import (ids_metric_learners_m, metric_learners_m, + remove_y, ids_classifiers_m, + pairs_learners_m, ids_pairs_learners_m) +from sklearn.exceptions import NotFittedError RNG = check_random_state(0) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_pair_distance_pair_score_equivalent(estimator, build_dataset): """ For Mahalanobis learners, pair_score should be equivalent to the @@ -46,10 +51,11 @@ def test_pair_distance_pair_score_equivalent(estimator, build_dataset): assert_array_equal(distances, -1 * scores) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_pair_distance_pairwise(estimator, build_dataset): - # Computing pairwise scores should return a euclidean distance matrix. + """Computing pairwise scores should return a euclidean distance + matrix.""" input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] @@ -70,10 +76,10 @@ def test_pair_distance_pairwise(estimator, build_dataset): assert_array_almost_equal(squareform(pairwise), pdist(model.transform(X))) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_pair_distance_toy_example(estimator, build_dataset): - # Checks that pair_distance works on a toy example + """Checks that `pair_distance` works on a toy example.""" input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] @@ -88,10 +94,10 @@ def test_pair_distance_toy_example(estimator, build_dataset): assert_array_almost_equal(model.pair_distance(pairs), distances) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_pair_distance_finite(estimator, build_dataset): - # tests that the score is finite + """Tests that the distance from `pair_distance` is finite""" input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) @@ -100,28 +106,9 @@ def test_pair_distance_finite(estimator, build_dataset): assert np.isfinite(model.pair_distance(pairs)).all() -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) -def test_pair_distance_dim(estimator, build_dataset): - # scoring of 3D arrays should return 1D array (several tuples), - # and scoring of 2D arrays (one tuple) should return an error (like - # scikit-learn's error when scoring 1D arrays) - input_data, labels, _, X = build_dataset() - model = clone(estimator) - set_random_state(model) - model.fit(*remove_y(estimator, input_data, labels)) - tuples = np.array(list(product(X, X))) - assert model.pair_distance(tuples).shape == (tuples.shape[0],) - context = make_context(estimator) - msg = ("3D array of formed tuples expected{}. Found 2D array " - "instead:\ninput={}. Reshape your data and/or use a preprocessor.\n" - .format(context, tuples[1])) - with pytest.raises(ValueError) as raised_error: - model.pair_distance(tuples[1]) - assert str(raised_error.value) == msg - - def check_is_distance_matrix(pairwise): + """Returns True if the matrix is positive, symmetrc, the diagonal is zero, + and if it fullfills the triangular inequality for all pairs""" assert (pairwise >= 0).all() # positivity assert np.array_equal(pairwise, pairwise.T) # symmetry assert (pairwise.diagonal() == 0).all() # identity @@ -131,10 +118,11 @@ def check_is_distance_matrix(pairwise): pairwise[:, np.newaxis, :] + tol).all() -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_embed_toy_example(estimator, build_dataset): - # Checks that embed works on a toy example + """Checks that embed works on a toy example. That using `transform` + is equivalent to manually multiplying Lx""" input_data, labels, _, X = build_dataset() n_samples = 20 X = X[:n_samples] @@ -145,10 +133,10 @@ def test_embed_toy_example(estimator, build_dataset): assert_array_almost_equal(model.transform(X), embedded_points) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_embed_dim(estimator, build_dataset): - # Checks that the the dimension of the output space is as expected + """Checks that the the dimension of the output space is as expected""" input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) @@ -174,10 +162,10 @@ def test_embed_dim(estimator, build_dataset): assert str(raised_error.value) == err_msg -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_embed_finite(estimator, build_dataset): - # Checks that embed returns vectors with finite values + """Checks that embed (transform) returns vectors with finite values""" input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) @@ -185,10 +173,11 @@ def test_embed_finite(estimator, build_dataset): assert np.isfinite(model.transform(X)).all() -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_embed_is_linear(estimator, build_dataset): - # Checks that the embedding is linear + """Checks that the embedding is linear, i.e. linear properties of + using `tranform`""" input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) @@ -200,8 +189,8 @@ def test_embed_is_linear(estimator, build_dataset): 5 * model.transform(X[:10])) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_get_metric_equivalent_to_explicit_mahalanobis(estimator, build_dataset): """Tests that using the get_metric method of mahalanobis metric learners is @@ -220,8 +209,8 @@ def test_get_metric_equivalent_to_explicit_mahalanobis(estimator, assert_allclose(metric(a, b), expected_dist, rtol=1e-13) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_get_metric_is_pseudo_metric(estimator, build_dataset): """Tests that the get_metric method of mahalanobis metric learners returns a pseudo-metric (metric but without one side of the equivalence of @@ -247,21 +236,8 @@ def test_get_metric_is_pseudo_metric(estimator, build_dataset): np.isclose(metric(a, c), metric(a, b) + metric(b, c), rtol=1e-20)) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) -def test_get_metric_compatible_with_scikit_learn(estimator, build_dataset): - """Check that the metric returned by get_metric is compatible with - scikit-learn's algorithms using a custom metric, DBSCAN for instance""" - input_data, labels, _, X = build_dataset() - model = clone(estimator) - set_random_state(model) - model.fit(*remove_y(estimator, input_data, labels)) - clustering = DBSCAN(metric=model.get_metric()) - clustering.fit(X) - - -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_get_squared_metric(estimator, build_dataset): """Test that the squared metric returned is indeed the square of the metric""" @@ -280,8 +256,8 @@ def test_get_squared_metric(estimator, build_dataset): rtol=1e-15) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_components_is_2D(estimator, build_dataset): """Tests that the transformation matrix of metric learners is 2D""" input_data, labels, _, X = build_dataset() @@ -318,13 +294,13 @@ def test_components_is_2D(estimator, build_dataset): @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if hasattr(ml, 'n_components') and hasattr(ml, 'init')], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if hasattr(ml, 'n_components') and hasattr(ml, 'init')]) def test_init_transformation(estimator, build_dataset): @@ -411,13 +387,13 @@ def test_init_transformation(estimator, build_dataset): @pytest.mark.parametrize('n_components', [3, 5, 7, 11]) @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if hasattr(ml, 'n_components') and hasattr(ml, 'init')], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if hasattr(ml, 'n_components') and hasattr(ml, 'init')]) def test_auto_init_transformation(n_samples, n_features, n_classes, @@ -460,7 +436,7 @@ def test_auto_init_transformation(n_samples, n_features, n_classes, input_data = input_data[:n_samples, ..., :n_features] assert input_data.shape[0] == n_samples assert input_data.shape[-1] == n_features - has_classes = model_base.__class__.__name__ in ids_classifiers + has_classes = model_base.__class__.__name__ in ids_classifiers_m if has_classes: labels = np.tile(range(n_classes), n_samples // n_classes + 1)[:n_samples] @@ -481,13 +457,13 @@ def test_auto_init_transformation(n_samples, n_features, n_classes, @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if not hasattr(ml, 'n_components') and hasattr(ml, 'init')], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if not hasattr(ml, 'n_components') and hasattr(ml, 'init')]) def test_init_mahalanobis(estimator, build_dataset): @@ -571,12 +547,12 @@ def test_init_mahalanobis(estimator, build_dataset): @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:4] in ['ITML', 'SDML', 'LSML']], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:4] in ['ITML', 'SDML', 'LSML']]) def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset): """Tests that when using the 'covariance' init or prior, it returns the @@ -615,12 +591,12 @@ def test_singular_covariance_init_or_prior_strictpd(estimator, build_dataset): @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:3] in ['MMC']], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:3] in ['MMC']]) def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): """Tests that when using the 'covariance' init or prior, it returns the @@ -657,12 +633,12 @@ def test_singular_covariance_init_of_non_strict_pd(estimator, build_dataset): @pytest.mark.integration @pytest.mark.parametrize('estimator, build_dataset', [(ml, bd) for idml, (ml, bd) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:4] in ['ITML', 'SDML', 'LSML']], ids=[idml for idml, (ml, _) - in zip(ids_metric_learners, - metric_learners) + in zip(ids_metric_learners_m, + metric_learners_m) if idml[:4] in ['ITML', 'SDML', 'LSML']]) @pytest.mark.parametrize('w0', [1e-20, 0., -1e-20]) def test_singular_array_init_or_prior_strictpd(estimator, build_dataset, w0): @@ -731,8 +707,8 @@ def test_singular_array_init_of_non_strict_pd(w0): @pytest.mark.integration -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) def test_deterministic_initialization(estimator, build_dataset): """Test that estimators that have a prior or an init are deterministic when it is set to to random and when the random_state is fixed.""" @@ -750,3 +726,35 @@ def test_deterministic_initialization(estimator, build_dataset): model2 = model2.fit(*remove_y(model, input_data, labels)) np.testing.assert_allclose(model1.get_mahalanobis_matrix(), model2.get_mahalanobis_matrix()) + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners_m, + ids=ids_pairs_learners_m) +def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, + with_preprocessor): + """Test that a NotFittedError is raised if someone tries to use + pair_score, pair_distance, score_pairs, get_metric, transform or + get_mahalanobis_matrix on input data and the metric learner + has not been fitted.""" + input_data, _, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + with pytest.raises(NotFittedError): # TODO: Remove in 0.8.0 + msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + with pytest.warns(FutureWarning) as raised_warning: + estimator.score_pairs(input_data) + assert any([str(warning.message) == msg for warning in raised_warning]) + with pytest.raises(NotFittedError): + estimator.pair_score(input_data) + with pytest.raises(NotFittedError): + estimator.pair_distance(input_data) + with pytest.raises(NotFittedError): + estimator.get_metric() + with pytest.raises(NotFittedError): + estimator.get_mahalanobis_matrix() + with pytest.raises(NotFittedError): + estimator.transform(input_data) diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 6a725f23..2aac2b3d 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -1,3 +1,7 @@ +""" +Tests all functionality for PairClassifiers. Methods, threshold, calibration, +warnings, correctness, use cases, etc. +""" from functools import partial import pytest @@ -10,7 +14,8 @@ precision_score) from sklearn.model_selection import train_test_split -from test.test_utils import pairs_learners, ids_pairs_learners +from test.test_utils import pairs_learners, ids_pairs_learners, \ + pairs_learners_m, ids_pairs_learners_m from metric_learn.sklearn_shims import set_random_state from sklearn import clone import numpy as np @@ -40,7 +45,7 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, ids=ids_pairs_learners) def test_predict_monotonous(estimator, build_dataset, with_preprocessor): - """Test that there is a threshold distance separating points labeled as + """Test that there is a threshold value separating points labeled as similar and points labeled as dissimilar """ input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) @@ -65,32 +70,22 @@ def test_predict_monotonous(estimator, build_dataset, def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): """Test that a NotFittedError is raised if someone tries to use - pair_score, score_pairs, decision_function, get_metric, transform or - get_mahalanobis_matrix on input data and the metric learner - has not been fitted.""" + decision_function, calibrate_threshold, set_threshold, predict + on input data and the metric learner has not been fitted.""" input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) - with pytest.raises(NotFittedError): # Remove in 0.8.0 - estimator.score_pairs(input_data) with pytest.raises(NotFittedError): - estimator.pair_score(input_data) + estimator.predict(input_data) with pytest.raises(NotFittedError): estimator.decision_function(input_data) with pytest.raises(NotFittedError): - estimator.get_metric() - with pytest.raises(NotFittedError): - estimator.transform(input_data) - with pytest.raises(NotFittedError): - estimator.get_mahalanobis_matrix() - with pytest.raises(NotFittedError): - estimator.calibrate_threshold(input_data, labels) - + estimator.score(input_data, labels) with pytest.raises(NotFittedError): estimator.set_threshold(0.5) with pytest.raises(NotFittedError): - estimator.predict(input_data) + estimator.calibrate_threshold(input_data, labels) @pytest.mark.parametrize('calibration_params', @@ -130,7 +125,7 @@ def test_fit_with_valid_threshold_params(estimator, build_dataset, ids=ids_pairs_learners) def test_threshold_different_scores_is_finite(estimator, build_dataset, with_preprocessor, kwargs): - # test that calibrating the threshold works for every metric learner + """Test that calibrating the threshold works for every metric learner""" input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) @@ -171,7 +166,7 @@ def test_unset_threshold(): def test_set_threshold(): - # test that set_threshold indeed sets the threshold + """Test that set_threshold indeed sets the threshold""" identity_pairs_classifier = IdentityPairsClassifier() pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]]) y = np.array([1, 1, -1, -1]) @@ -200,8 +195,8 @@ def test_set_wrong_type_threshold(value): def test_f_beta_1_is_f_1(): - # test that putting beta to 1 indeed finds the best threshold to optimize - # the f1_score + """Test that putting beta to 1 indeed finds the best threshold to optimize + the f1_score""" rng = np.random.RandomState(42) n_samples = 100 pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) @@ -266,8 +261,8 @@ def tnr_threshold(y_true, y_pred, tpr_threshold=0.): for t in [0., 0.1, 0.5, 0.8, 1.]], ) def test_found_score_is_best_score(kwargs, scoring): - # test that when we use calibrate threshold, it will indeed be the - # threshold that have the best score + """Test that when we use calibrate threshold, it will indeed be the + threshold that have the best score""" rng = np.random.RandomState(42) n_samples = 50 pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) @@ -305,11 +300,11 @@ def test_found_score_is_best_score(kwargs, scoring): for t in [0., 0.1, 0.5, 0.8, 1.]] ) def test_found_score_is_best_score_duplicates(kwargs, scoring): - # test that when we use calibrate threshold, it will indeed be the - # threshold that have the best score. It's the same as the previous test - # except this time we test that the scores are coherent even if there are - # duplicates (i.e. points that have the same score returned by - # `decision_function`). + """Test that when we use calibrate threshold, it will indeed be the + threshold that have the best score. It's the same as the previous test + except this time we test that the scores are coherent even if there are + duplicates (i.e. points that have the same score returned by + `decision_function`).""" rng = np.random.RandomState(42) n_samples = 50 pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) @@ -353,8 +348,8 @@ def test_found_score_is_best_score_duplicates(kwargs, scoring): ) def test_calibrate_threshold_invalid_parameters_right_error(invalid_args, expected_msg): - # test that the right error message is returned if invalid arguments are - # given to calibrate_threshold + """Test that the right error message is returned if invalid arguments are + given to `calibrate_threshold`""" rng = np.random.RandomState(42) pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) pairs_learner = IdentityPairsClassifier() @@ -377,8 +372,8 @@ def test_calibrate_threshold_invalid_parameters_right_error(invalid_args, # to do that) ) def test_calibrate_threshold_valid_parameters(valid_args): - # test that no warning message is returned if valid arguments are given to - # calibrate threshold + """Test that no warning message is returned if valid arguments are given to + `calibrate threshold`""" rng = np.random.RandomState(42) pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) pairs_learner = IdentityPairsClassifier() @@ -390,8 +385,7 @@ def test_calibrate_threshold_valid_parameters(valid_args): def test_calibrate_threshold_extreme(): """Test that in the (rare) case where we should accept all points or - reject all points, this is effectively what - is done""" + reject all points, this is effectively what is done""" class MockBadPairsClassifier(MahalanobisMixin, _PairsClassifierMixin): """A pairs classifier that returns bad scores (i.e. in the inverse order @@ -489,9 +483,9 @@ def decision_function(self, pairs): ) def test_validate_calibration_params_invalid_parameters_right_error( estimator, _, invalid_args, expected_msg): - # test that the right error message is returned if invalid arguments are - # given to _validate_calibration_params, for all pairs metric learners as - # well as a mocking general identity pairs classifier and the class itself + """Test that the right error message is returned if invalid arguments are + given to `_validate_calibration_params`, for all pairs metric learners as + well as a mocking general identity pairs classifier and the class itself""" with pytest.raises(ValueError) as raised_error: estimator._validate_calibration_params(**invalid_args) assert str(raised_error.value) == expected_msg @@ -515,9 +509,9 @@ def test_validate_calibration_params_invalid_parameters_right_error( ) def test_validate_calibration_params_valid_parameters( estimator, _, valid_args): - # test that no warning message is returned if valid arguments are given to - # _validate_calibration_params for all pairs metric learners, as well as - # a mocking example, and the class itself + """Test that no warning message is returned if valid arguments are given to + `_validate_calibration_params` for all pairs metric learners, as well as + a mocking example, and the class itself""" with pytest.warns(None) as record: estimator._validate_calibration_params(**valid_args) assert len(record) == 0 @@ -528,7 +522,7 @@ def test_validate_calibration_params_valid_parameters( ids=ids_pairs_learners) def test_validate_calibration_params_invalid_parameters_error_before__fit( estimator, build_dataset): - """For all pairs metric learners (which currently all have a _fit method), + """For all pairs metric learners (which currently all have a `_fit` method), make sure that calibration parameters are validated before fitting""" estimator = clone(estimator) input_data, labels, _, _ = build_dataset() @@ -545,11 +539,12 @@ def breaking_fun(**args): # a function that fails so that we will miss assert str(raised_error.value) == expected_msg -@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, - ids=ids_pairs_learners) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners_m, + ids=ids_pairs_learners_m) def test_accuracy_toy_example(estimator, build_dataset): """Test that the accuracy works on some toy example (hence that the - prediction is OK)""" + prediction is OK). This test is designed for Mahalanobis learners only, + as the toy example uses the notion of distance.""" input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) diff --git a/test/test_quadruplets_classifiers.py b/test/test_quadruplets_classifiers.py index a8319961..65aa9538 100644 --- a/test/test_quadruplets_classifiers.py +++ b/test/test_quadruplets_classifiers.py @@ -1,8 +1,13 @@ +""" +Tests all functionality for QuadrupletsClassifiers. Methods, warrnings, +correctness, use cases, etc. +""" import pytest from sklearn.exceptions import NotFittedError from sklearn.model_selection import train_test_split -from test.test_utils import quadruplets_learners, ids_quadruplets_learners +from test.test_utils import quadruplets_learners, ids_quadruplets_learners, \ + quadruplets_learners_m, ids_quadruplets_learners_m from metric_learn.sklearn_shims import set_random_state from sklearn import clone import numpy as np @@ -31,21 +36,27 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, ids=ids_quadruplets_learners) def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): - """Test that a NotFittedError is raised if someone tries to predict and - the metric learner has not been fitted.""" + """Test that a NotFittedError is raised if someone tries to use the + methods: predict, decision_function and score when the metric learner + has not been fitted.""" input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) with pytest.raises(NotFittedError): estimator.predict(input_data) + with pytest.raises(NotFittedError): + estimator.decision_function(input_data) + with pytest.raises(NotFittedError): + estimator.score(input_data) -@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, - ids=ids_quadruplets_learners) +@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners_m, + ids=ids_quadruplets_learners_m) def test_accuracy_toy_example(estimator, build_dataset): """Test that the default scoring for quadruplets (accuracy) works on some - toy example""" + toy example. This test is designed for Mahalanobis learners only, + as the toy example uses the notion of distance.""" input_data, labels, preprocessor, X = build_dataset(with_preprocessor=False) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index 798d9036..1c9ae6cd 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -12,10 +12,12 @@ MMC_Supervised, RCA_Supervised, SDML_Supervised, SCML_Supervised) from sklearn import clone +from sklearn.cluster import DBSCAN import numpy as np from sklearn.model_selection import (cross_val_score, cross_val_predict, train_test_split, KFold) from test.test_utils import (metric_learners, ids_metric_learners, + metric_learners_m, ids_metric_learners_m, mock_preprocessor, tuples_learners, ids_tuples_learners, pairs_learners, ids_pairs_learners, remove_y, @@ -110,6 +112,54 @@ def generate_array_like(input_data, labels=None): return input_data_changed, labels_changed +# TODO: Find a better way to run this test and the next one, to avoid +# duplicated code. +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners_m, + ids=ids_metric_learners_m) +def test_array_like_inputs_mahalanobis(estimator, build_dataset, + with_preprocessor): + """Test that metric-learners can have as input any array-like object. + This in particular tests `transform` and `pair_distance` for Mahalanobis + learners.""" + input_data, labels, preprocessor, X = build_dataset(with_preprocessor) + # we subsample the data for the test to be more efficient + input_data, _, labels, _ = train_test_split(input_data, labels, + train_size=40, + random_state=42) + X = X[:10] + + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + input_variants, label_variants = generate_array_like(input_data, labels) + for input_variant in input_variants: + for label_variant in label_variants: + estimator.fit(*remove_y(estimator, input_variant, label_variant)) + if hasattr(estimator, "predict"): + estimator.predict(input_variant) + if hasattr(estimator, "predict_proba"): + estimator.predict_proba(input_variant) # anticipation in case some + # time we have that, or if ppl want to contribute with new algorithms + # it will be checked automatically + if hasattr(estimator, "decision_function"): + estimator.decision_function(input_variant) + if hasattr(estimator, "score"): + for label_variant in label_variants: + estimator.score(*remove_y(estimator, input_variant, label_variant)) + + # Transform + X_variants, _ = generate_array_like(X) + for X_variant in X_variants: + estimator.transform(X_variant) + + # Pair distance + pairs = np.array([[X[0], X[1]], [X[0], X[2]]]) + pairs_variants, _ = generate_array_like(pairs) + for pairs_variant in pairs_variants: + estimator.pair_distance(pairs_variant) + + @pytest.mark.integration @pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, @@ -144,25 +194,12 @@ def test_array_like_inputs(estimator, build_dataset, with_preprocessor): for label_variant in label_variants: estimator.score(*remove_y(estimator, input_variant, label_variant)) - X_variants, _ = generate_array_like(X) - for X_variant in X_variants: - estimator.transform(X_variant) - pairs = np.array([[X[0], X[1]], [X[0], X[2]]]) pairs_variants, _ = generate_array_like(pairs) - not_implemented_msg = "" - # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says - # "This learner does not have pair_distance" - + # Pair score for pairs_variant in pairs_variants: - estimator.pair_score(pairs_variant) # All learners have pair_score - - # But not all of them will have pair_distance - try: - estimator.pair_distance(pairs_variant) - except Exception as raised_exception: - assert raised_exception.value.args[0] == not_implemented_msg + estimator.pair_score(pairs_variant) @pytest.mark.parametrize('with_preprocessor', [True, False]) @@ -461,5 +498,18 @@ def test_dont_overwrite_parameters(estimator, build_dataset, " %s changed" % ', '.join(attrs_changed_by_fit)) +@pytest.mark.parametrize('estimator, build_dataset', metric_learners, + ids=ids_metric_learners) +def test_get_metric_compatible_with_scikit_learn(estimator, build_dataset): + """Check that the metric returned by get_metric is compatible with + scikit-learn's algorithms using a custom metric, DBSCAN for instance""" + input_data, labels, _, X = build_dataset() + model = clone(estimator) + set_random_state(model) + model.fit(*remove_y(estimator, input_data, labels)) + clustering = DBSCAN(metric=model.get_metric()) + clustering.fit(X) + + if __name__ == '__main__': unittest.main() diff --git a/test/test_triplets_classifiers.py b/test/test_triplets_classifiers.py index 515a0a33..684a49cc 100644 --- a/test/test_triplets_classifiers.py +++ b/test/test_triplets_classifiers.py @@ -1,3 +1,7 @@ +""" +Tests all functionality for TripletsClassifiers. Methods, warrnings, +correctness, use cases, etc. +""" import pytest from sklearn.exceptions import NotFittedError from sklearn.model_selection import train_test_split @@ -5,7 +9,9 @@ from metric_learn import SCML from test.test_utils import ( triplets_learners, + triplets_learners_m, ids_triplets_learners, + ids_triplets_learners_m, build_triplets ) from metric_learn.sklearn_shims import set_random_state @@ -32,13 +38,14 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, assert len(not_valid) == 0 -@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, - ids=ids_triplets_learners) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners_m, + ids=ids_triplets_learners_m) def test_no_zero_prediction(estimator, build_dataset): """ Test that all predicted values are not zero, even when the distance d(x,y) and d(x,z) is the same for a triplet of the - form (x, y, z). i.e border cases. + form (x, y, z). i.e border cases for Mahalanobis distance + learners. """ triplets, _, _, X = build_dataset(with_preprocessor=False) # Force 3 dimentions only, to use cross product and get easy orthogonal vec. @@ -61,7 +68,7 @@ def test_no_zero_prediction(estimator, build_dataset): assert_array_equal(X[1], x) with pytest.raises(AssertionError): assert_array_equal(X[1], y) - # Assert the distance is the same for both + # Assert the distance is the same for both -> Wont work for b. similarity assert estimator.get_metric()(X[1], x) == estimator.get_metric()(X[1], y) # Form the three scenarios where predict() gives 0 with numpy.sign @@ -80,21 +87,27 @@ def test_no_zero_prediction(estimator, build_dataset): ids=ids_triplets_learners) def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, with_preprocessor): - """Test that a NotFittedError is raised if someone tries to predict and - the metric learner has not been fitted.""" + """Test that a NotFittedError is raised if someone tries to use the + methods: predict, decision_function and score when the metric learner + has not been fitted.""" input_data, _, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) with pytest.raises(NotFittedError): estimator.predict(input_data) + with pytest.raises(NotFittedError): + estimator.decision_function(input_data) + with pytest.raises(NotFittedError): + estimator.score(input_data) -@pytest.mark.parametrize('estimator, build_dataset', triplets_learners, - ids=ids_triplets_learners) +@pytest.mark.parametrize('estimator, build_dataset', triplets_learners_m, + ids=ids_triplets_learners_m) def test_accuracy_toy_example(estimator, build_dataset): """Test that the default scoring for triplets (accuracy) works on some - toy example""" + toy example. This test is designed for Mahalanobis learners only, + as the toy example uses the notion of distance.""" triplets, _, _, X = build_dataset(with_preprocessor=False) estimator = clone(estimator) set_random_state(estimator) diff --git a/test/test_utils.py b/test/test_utils.py index 43d67111..908ae50d 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,8 @@ +""" +Tests preprocesor, warnings, errors. Also made util functions to build datasets +in a general way for each learner. Here is also the list of learners of each +kind that are used as a parameters in tests in other files. Util functions. +""" import pytest from scipy.linalg import eigh, pinvh from collections import namedtuple @@ -18,6 +23,7 @@ MMC_Supervised, RCA_Supervised, SDML_Supervised, SCML, SCML_Supervised, Constraints) from metric_learn.base_metric import (ArrayIndexer, MahalanobisMixin, + BilinearMixin, _PairsClassifierMixin, _TripletsClassifierMixin, _QuadrupletsClassifierMixin) @@ -28,6 +34,92 @@ SEED = 42 RNG = check_random_state(SEED) + +# -------------------- Mock classes for testing ------------------------ + + +class RandomBilinearLearner(BilinearMixin): + """A simple Random bilinear mixin that returns an random matrix + M as learned. Class for testing purposes. + """ + def __init__(self, preprocessor=None, random_state=33): + super().__init__(preprocessor=preprocessor) + self.random_state = random_state + + def fit(self, X, y): + """ + Checks input's format. A random (d,d) matrix is set. + """ + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) + self.d_ = np.shape(X[0])[-1] + rng = check_random_state(self.random_state) + self.components_ = rng.rand(self.d_, self.d_) + return self + + +class IdentityBilinearLearner(BilinearMixin): + """A simple Identity bilinear mixin that returns an identity matrix + M as learned. Class for testing purposes. + """ + def __init__(self, preprocessor=None): + super().__init__(preprocessor=preprocessor) + + def fit(self, X, y): + """ + Checks input's format. Sets M matrix to identity of shape (d,d) + where d is the dimension of the input. + """ + X, y = self._prepare_inputs(X, y, ensure_min_samples=2) + self.d_ = np.shape(X[0])[-1] + self.components_ = np.identity(self.d_) + return self + + +class MockPairIdentityBilinearLearner(BilinearMixin, + _PairsClassifierMixin): + + def __init__(self, preprocessor=None): + super().__init__(preprocessor=preprocessor) + + def fit(self, pairs, y, calibration_params=None): + calibration_params = (calibration_params if calibration_params is not + None else dict()) + self._validate_calibration_params(**calibration_params) + pairs = self._prepare_inputs(pairs, type_of_inputs='tuples') + self.d_ = np.shape(pairs[0][0])[-1] + self.components_ = np.identity(self.d_) + self.calibrate_threshold(pairs, y, **calibration_params) + return self + + +class MockTripletsIdentityBilinearLearner(BilinearMixin, + _TripletsClassifierMixin): + + def __init__(self, preprocessor=None): + super().__init__(preprocessor=preprocessor) + + def fit(self, triplets): + triplets = self._prepare_inputs(triplets, type_of_inputs='tuples') + self.d_ = np.shape(triplets[0][0])[-1] + self.components_ = np.identity(self.d_) + return self + + +class MockQuadrpletsIdentityBilinearLearner(BilinearMixin, + _QuadrupletsClassifierMixin): + + def __init__(self, preprocessor=None): + super().__init__(preprocessor=preprocessor) + + def fit(self, quadruplets): + quadruplets = self._prepare_inputs(quadruplets, type_of_inputs='tuples') + self.d_ = np.shape(quadruplets[0][0])[-1] + self.components_ = np.identity(self.d_) + return self + + +# ------------------ Building dummy data for learners ------------------ + Dataset = namedtuple('Dataset', ('data target preprocessor to_transform')) # Data and target are what we will fit on. Preprocessor is the additional # data if we use a preprocessor (which should be the default ArrayIndexer), @@ -35,7 +127,16 @@ def build_classification(with_preprocessor=False): - """Basic array for testing when using a preprocessor""" + """ + Basic array 'X, y' for testing when using a preprocessor, for instance, + for clustering. For supervised learners. + + If no preprocesor: 'data' are raw points, 'target' are dummy labels, + 'preprocesor' is None, and 'to_transform' are points. + + If preprocessor: 'data' are point indices, 'target' are dummy labels, + 'preprocessor' are unique points, 'to_transform' are points. + """ X, y = shuffle(*make_blobs(random_state=SEED), random_state=SEED) indices = shuffle(np.arange(X.shape[0]), random_state=SEED).astype(int) @@ -46,7 +147,16 @@ def build_classification(with_preprocessor=False): def build_regression(with_preprocessor=False): - """Basic array for testing when using a preprocessor""" + """ + Basic array 'X, y' for testing when using a preprocessor, for regression. + For supervised learners. + + If no preprocesor: 'data' are raw points, 'target' are dummy labels, + 'preprocesor' is None, and 'to_transform' are points. + + If preprocessor: 'data' are point indices, 'target' are dummy labels, + 'preprocessor' are unique points, 'to_transform' are points. + """ X, y = shuffle(*make_regression(n_samples=100, n_features=5, random_state=SEED), random_state=SEED) @@ -58,6 +168,8 @@ def build_regression(with_preprocessor=False): def build_data(): + """Aux function: Returns 'X, pairs' taken from the iris dataset, where + pairs are positive and negative pairs for PairClassifiers.""" input_data, labels = load_iris(return_X_y=True) X, y = shuffle(input_data, labels, random_state=SEED) n_constraints = 50 @@ -70,7 +182,17 @@ def build_data(): def build_pairs(with_preprocessor=False): - # builds a toy pairs problem + """ + For all pair weakly-supervised learners. + + Returns: data, target, preprocessor, to_transform. + + If no preprocesor: 'data' are raw pairs, 'target' are dummy labels, + 'preprocesor' is None, and 'to_transform' are points. + + If preprocessor: 'data' are pair indices, 'target' are dummy labels, + 'preprocessor' are unique points, 'to_transform' are points. + """ X, indices = build_data() c = np.vstack([np.column_stack(indices[:2]), np.column_stack(indices[2:])]) target = np.concatenate([np.ones(indices[0].shape[0]), @@ -85,6 +207,17 @@ def build_pairs(with_preprocessor=False): def build_triplets(with_preprocessor=False): + """ + For all triplet weakly-supervised learners. + + Returns: data, target, preprocessor, to_transform. + + If no preprocesor: 'data' are raw triplets, 'target' are dummy labels, + 'preprocesor' is None, and 'to_transform' are points. + + If preprocessor: 'data' are triplets indices, 'target' are dummy labels, + 'preprocessor' are unique points, 'to_transform' are points. + """ input_data, labels = load_iris(return_X_y=True) X, y = shuffle(input_data, labels, random_state=SEED) constraints = Constraints(y) @@ -98,7 +231,17 @@ def build_triplets(with_preprocessor=False): def build_quadruplets(with_preprocessor=False): - # builds a toy quadruplets problem + """ + For all Quadruplets weakly-supervised learners. + + Returns: data, target, preprocessor, to_transform. + + If no preprocesor: 'data' are raw quadruplets, 'target' are dummy labels, + 'preprocesor' is None, and 'to_transform' are points. + + If preprocessor: 'data' are quadruplets indices, 'target' are dummy labels, + 'preprocessor' are unique points, 'to_transform' are points. + """ X, indices = build_data() c = np.column_stack(indices) target = np.ones(c.shape[0]) # quadruplets targets are not used @@ -112,59 +255,130 @@ def build_quadruplets(with_preprocessor=False): return Dataset(X[c], target, None, X[c[:, 0]]) -quadruplets_learners = [(LSML(), build_quadruplets)] -ids_quadruplets_learners = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - quadruplets_learners])) +# ------------- List of learners, separating them by kind ------------- -triplets_learners = [(SCML(n_basis=320), build_triplets)] -ids_triplets_learners = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - triplets_learners])) - -pairs_learners = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be faster - (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster - (SDML(prior='identity', balance_param=1e-5), build_pairs)] -ids_pairs_learners = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - pairs_learners])) - -classifiers = [(Covariance(), build_classification), - (LFDA(), build_classification), - (LMNN(), build_classification), - (NCA(), build_classification), - (RCA(), build_classification), - (ITML_Supervised(max_iter=5), build_classification), - (LSML_Supervised(), build_classification), - (MMC_Supervised(max_iter=5), build_classification), - (RCA_Supervised(n_chunks=5), build_classification), - (SDML_Supervised(prior='identity', balance_param=1e-5), - build_classification), - (SCML_Supervised(n_basis=80), build_classification)] -ids_classifiers = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in - classifiers])) - -regressors = [(MLKR(init='pca'), build_regression)] -ids_regressors = list(map(lambda x: x.__class__.__name__, - [learner for (learner, _) in regressors])) +# Mahalanobis learners +# -- Weakly Supervised +quadruplets_learners_m = [(LSML(), build_quadruplets)] +ids_quadruplets_learners_m = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + quadruplets_learners_m])) + +triplets_learners_m = [(SCML(n_basis=320), build_triplets)] +ids_triplets_learners_m = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + triplets_learners_m])) +pairs_learners_m = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be faster + (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster + (SDML(prior='identity', balance_param=1e-5), build_pairs)] +ids_pairs_learners_m = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + pairs_learners_m])) + +# -- Supervised +classifiers_m = [(Covariance(), build_classification), + (LFDA(), build_classification), + (LMNN(), build_classification), + (NCA(), build_classification), + (RCA(), build_classification), + (ITML_Supervised(max_iter=5), build_classification), + (LSML_Supervised(), build_classification), + (MMC_Supervised(max_iter=5), build_classification), + (RCA_Supervised(n_chunks=5), build_classification), + (SDML_Supervised(prior='identity', balance_param=1e-5), + build_classification), + (SCML_Supervised(n_basis=80), build_classification)] +ids_classifiers_m = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + classifiers_m])) + +regressors_m = [(MLKR(init='pca'), build_regression)] +ids_regressors_m = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in regressors_m])) + +# -- Mahalanobis sets +tuples_learners_m = pairs_learners_m + triplets_learners_m + \ + quadruplets_learners_m +ids_tuples_learners_m = ids_pairs_learners_m + ids_triplets_learners_m \ + + ids_quadruplets_learners_m + +supervised_learners_m = classifiers_m + regressors_m +ids_supervised_learners_m = ids_classifiers_m + ids_regressors_m + +metric_learners_m = tuples_learners_m + supervised_learners_m +ids_metric_learners_m = ids_tuples_learners_m + ids_supervised_learners_m + +# Bilinear learners +# -- Weakly Supervised +quadruplets_learners_b = [(MockQuadrpletsIdentityBilinearLearner(), + build_quadruplets)] +ids_quadruplets_learners_b = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + quadruplets_learners_b])) + +triplets_learners_b = [(MockTripletsIdentityBilinearLearner(), build_triplets)] +ids_triplets_learners_b = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + triplets_learners_b])) + +pairs_learners_b = [(MockPairIdentityBilinearLearner(), build_pairs)] +ids_pairs_learners_b = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + pairs_learners_b])) +# -- Supervised +classifiers_b = [(RandomBilinearLearner(), build_classification), + (IdentityBilinearLearner(), build_classification)] +ids_classifiers_b = list(map(lambda x: x.__class__.__name__, + [learner for (learner, _) in + classifiers_b])) +# -- Bilinear sets +tuples_learners_b = pairs_learners_b + triplets_learners_b + \ + quadruplets_learners_b +ids_tuples_learners_b = ids_pairs_learners_b + ids_triplets_learners_b \ + + ids_quadruplets_learners_b + +supervised_learners_b = classifiers_b +ids_supervised_learners_b = ids_classifiers_b + +metric_learners_b = tuples_learners_b + supervised_learners_b +ids_metric_learners_b = ids_tuples_learners_b + ids_supervised_learners_b + +# General sets (Mahalanobis + Bilinear) +# -- Weakly Supervised learners individually +pairs_learners = pairs_learners_m + pairs_learners_b +ids_pairs_learners = ids_pairs_learners_m + ids_pairs_learners_b +triplets_learners = triplets_learners_m + triplets_learners_b +ids_triplets_learners = ids_triplets_learners_m + ids_triplets_learners_b +quadruplets_learners = quadruplets_learners_m + quadruplets_learners_b +ids_quadruplets_learners = ids_quadruplets_learners_m + \ + ids_quadruplets_learners_b + +# -- All weakly supervised learners +tuples_learners = tuples_learners_m + tuples_learners_b +ids_tuples_learners = ids_tuples_learners_m + ids_tuples_learners_b + +# -- Supervised learners +supervised_learners = supervised_learners_m + supervised_learners_b +ids_supervised_learners = ids_supervised_learners_m + ids_supervised_learners_b + +# -- Weakly Supervised + Supervised learners +metric_learners = metric_learners_m + metric_learners_b +ids_metric_learners = ids_metric_learners_m + ids_metric_learners_b + +# -- For sklearn pipeline: Pair + Supervised learners +metric_learners_pipeline = pairs_learners_m + pairs_learners_b + \ + supervised_learners_m + supervised_learners_b +ids_metric_learners_pipeline = ids_pairs_learners_m + ids_pairs_learners_b +\ + ids_supervised_learners_m + \ + ids_supervised_learners_b + +# Not used WeaklySupervisedClasses = (_PairsClassifierMixin, _TripletsClassifierMixin, _QuadrupletsClassifierMixin) -tuples_learners = pairs_learners + triplets_learners + quadruplets_learners -ids_tuples_learners = ids_pairs_learners + ids_triplets_learners \ - + ids_quadruplets_learners - -supervised_learners = classifiers + regressors -ids_supervised_learners = ids_classifiers + ids_regressors - -metric_learners = tuples_learners + supervised_learners -ids_metric_learners = ids_tuples_learners + ids_supervised_learners - -metric_learners_pipeline = pairs_learners + supervised_learners -ids_metric_learners_pipeline = ids_pairs_learners + ids_supervised_learners +# ------------- Useful methods ------------- def remove_y(estimator, X, y): @@ -850,15 +1064,14 @@ def test_error_message_t_pair_distance_or_score(estimator, _): .format(make_context(estimator), triplets)) assert str(raised_err.value) == expected_msg - not_implemented_msg = "" - # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says - # "This learner does not have pair_distance" + msg = ("This learner does not learn a distance, thus ", + "this method is not implemented. Use pair_score instead") # One exception will trigger for sure with pytest.raises(Exception) as raised_exception: estimator.pair_distance(triplets) err_value = raised_exception.value.args[0] - assert err_value == expected_msg or err_value == not_implemented_msg + assert err_value == expected_msg or err_value == msg def test_preprocess_tuples_simple_example(): @@ -897,7 +1110,8 @@ def fun(row): ids=ids_metric_learners) def test_same_with_or_without_preprocessor(estimator, build_dataset): """Test that algorithms using a preprocessor behave consistently -# with their no-preprocessor equivalent + with their no-preprocessor equivalent. Methods `pair_score`, + `score_pairs` (deprecated), `predict` and `decision_function`. """ dataset_indices = build_dataset(with_preprocessor=True) dataset_formed = build_dataset(with_preprocessor=False) @@ -926,7 +1140,7 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): estimator_with_prep_formed.set_params(preprocessor=X) estimator_with_prep_formed.fit(*remove_y(estimator, indices_train, y_train)) - # test prediction methods + # Test prediction methods for Weakly supervised algorithms. for method in ["predict", "decision_function"]: if hasattr(estimator, method): output_with_prep = getattr(estimator_with_preprocessor, @@ -940,8 +1154,9 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): method)(formed_test) assert np.array(output_with_prep == output_with_prep_formed).all() - # Test pair_score, all learners have it. - idx1 = np.array([[0, 2], [5, 3]], dtype=int) + idx1 = np.array([[0, 2], [5, 3]], dtype=int) # Sample + + # Pair score output_with_prep = estimator_with_preprocessor.pair_score( indicators_to_transform[idx1]) output_without_prep = estimator_without_preprocessor.pair_score( @@ -954,11 +1169,26 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): formed_points_to_transform[idx1]) assert np.array(output_with_prep == output_without_prep).all() - # Test pair_distance - not_implemented_msg = "" - # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says - # "This learner does not have pair_distance" - try: + # Score pairs. TODO: Delete in 0.8.0 + msg = ("score_pairs will be deprecated in release 0.7.0. " + "Use pair_score to compute similarity scores, or " + "pair_distances to compute distances.") + with pytest.warns(FutureWarning) as raised_warning: + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[idx1]) + output_without_prep = estimator_without_preprocessor.score_pairs( + formed_points_to_transform[idx1]) + assert np.array(output_with_prep == output_without_prep).all() + + output_with_prep = estimator_with_preprocessor.score_pairs( + indicators_to_transform[idx1]) + output_without_prep = estimator_with_prep_formed.score_pairs( + formed_points_to_transform[idx1]) + assert np.array(output_with_prep == output_without_prep).all() + assert any([str(warning.message) == msg for warning in raised_warning]) + + if isinstance(estimator, MahalanobisMixin): + # Pair distance output_with_prep = estimator_with_preprocessor.pair_distance( indicators_to_transform[idx1]) output_without_prep = estimator_without_preprocessor.pair_distance( @@ -971,14 +1201,7 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): formed_points_to_transform[idx1]) assert np.array(output_with_prep == output_without_prep).all() - except Exception as raised_exception: - assert raised_exception.value.args[0] == not_implemented_msg - - # Test transform - not_implemented_msg = "" - # Todo in 0.7.0: Change 'not_implemented_msg' for the message that says - # "This learner does not have transform" - try: + # Transform output_with_prep = estimator_with_preprocessor.transform( indicators_to_transform) output_without_prep = estimator_without_preprocessor.transform( @@ -991,9 +1214,6 @@ def test_same_with_or_without_preprocessor(estimator, build_dataset): formed_points_to_transform) assert np.array(output_with_prep == output_without_prep).all() - except Exception as raised_exception: - assert raised_exception.value.args[0] == not_implemented_msg - def test_check_collapsed_pairs_raises_no_error(): """Checks that check_collapsed_pairs raises no error if no collapsed pairs