From edad55d731f9da14b1b931a6c884b8d3c3c77c67 Mon Sep 17 00:00:00 2001 From: William de Vazelhes <31916524+wdevazelhes@users.noreply.github.com> Date: Mon, 15 Apr 2019 11:04:41 +0200 Subject: [PATCH] [MRG+1] Threshold for pairs learners (#168) * add some tests for testing that different scores work using the scoring function * ENH: Add tests and basic threshold implementation * Add support for LSML and more generally quadruplets * Make CalibratedClassifierCV work (for preprocessor case) thanks to classes_ * Fix some tests and PEP8 errors * change the sign in decision function * Add docstring for threshold_ and classes_ in the base _PairsClassifier class * remove quadruplets from the test with scikit learn custom scorings * Remove argument y in quadruplets learners and lsml * FIX fix docstrings of decision functions * FIX the threshold by taking the opposite (to be adapted to the decision function) * Fix tests to have no y for quadruplets' estimator fit * Remove isin to be compatible with old numpy versions * Fix threshold so that it has a positive value and add small test * Fix threshold for itml * FEAT: Add calibrate_threshold and tests * MAINT: remove starred syntax for compatibility with older versions of python * Remove debugging prints and make tests for ITML pass, while waiting for #175 to be solved * FIX: from __future__ import division to pass tests for python 2.7 * Add some documentation for calibration * DOC: fix style * Address most comments from aurelien's reviews * Remove classes_ attribute and test for CalibratedClassifierCV * Rename make_args_inc_quadruplets into remove_y_quadruplets * TST: Fix remaining threshold into min_rate * Remove default_threshold and put calibrate_threshold instead * Use calibrate_threshold for ITML, and remove description * ENH: use calibrate_threshold by default and display its parameters from the fit method * Add a small test to test automatic calibration * Update documentation of the default threshold * Inverse sense for threshold comparison to be more intuitive * Address remaining review comments * MAINT: Rename threshold_params into calibration_params * TST: Add test for extreme cases * MAINT: rename threshold_params into calibration_params * MAINT: rename threshold_params into calibration_params * FIX: Make tests work, and add the right threshold (mean between lowest accepted value and highest rejected value), and max + 1 or min - 1 for extreme points * Go back to previous version of finding the threshold * Extract method for validating calibration parameters * Validate calibration params before fit * Address https://github.com/metric-learn/metric-learn/pull/168#discussion_r268109180 --- doc/weakly_supervised.rst | 117 +++++-- metric_learn/base_metric.py | 235 ++++++++++++- metric_learn/itml.py | 21 +- metric_learn/lsml.py | 2 +- metric_learn/mmc.py | 26 +- metric_learn/sdml.py | 25 +- test/test_mahalanobis_mixin.py | 38 ++- test/test_pairs_classifiers.py | 491 +++++++++++++++++++++++++++ test/test_quadruplets_classifiers.py | 42 +++ test/test_sklearn_compat.py | 182 ++++++---- test/test_utils.py | 35 +- 11 files changed, 1066 insertions(+), 148 deletions(-) create mode 100644 test/test_pairs_classifiers.py create mode 100644 test/test_quadruplets_classifiers.py diff --git a/doc/weakly_supervised.rst b/doc/weakly_supervised.rst index deae9b40..6bf6f993 100644 --- a/doc/weakly_supervised.rst +++ b/doc/weakly_supervised.rst @@ -148,8 +148,47 @@ tuples you're working with (pairs, triplets...). See the docstring of the `score` method of the estimator you use. +Learning on pairs +================= + +Some metric learning algorithms learn on pairs of samples. In this case, one +should provide the algorithm with ``n_samples`` pairs of points, with a +corresponding target containing ``n_samples`` values being either +1 or -1. +These values indicate whether the given pairs are similar points or +dissimilar points. + + +.. _calibration: + +Thresholding +------------ +In order to predict whether a new pair represents similar or dissimilar +samples, we need to set a distance threshold, so that points closer (in the +learned space) than this threshold are predicted as similar, and points further +away are predicted as dissimilar. Several methods are possible for this +thresholding. + +- **At fit time**: The threshold is set with `calibrate_threshold` (see + below) on the trainset. You can specify the calibration parameters directly + in the `fit` method with the `threshold_params` parameter (see the + documentation of the `fit` method of any metric learner that learns on pairs + of points for more information). This method can cause a little bit of + overfitting. If you want to avoid that, calibrate the threshold after + fitting, on a validation set. + +- **Manual**: calling `set_threshold` will set the threshold to a + particular value. + +- **Calibration**: calling `calibrate_threshold` will calibrate the + threshold to achieve a particular score on a validation set, the score + being among the classical scores for classification (accuracy, f1 score...). + + +See also: `sklearn.calibration`. + + Algorithms -================== +========== ITML ---- @@ -192,39 +231,6 @@ programming. .. [2] Adapted from Matlab code at http://www.cs.utexas.edu/users/pjain/ itml/ - -LSML ----- - -`LSML`: Metric Learning from Relative Comparisons by Minimizing Squared -Residual - -.. topic:: Example Code: - -:: - - from metric_learn import LSML - - quadruplets = [[[1.2, 7.5], [1.3, 1.5], [6.4, 2.6], [6.2, 9.7]], - [[1.3, 4.5], [3.2, 4.6], [6.2, 5.5], [5.4, 5.4]], - [[3.2, 7.5], [3.3, 1.5], [8.4, 2.6], [8.2, 9.7]], - [[3.3, 4.5], [5.2, 4.6], [8.2, 5.5], [7.4, 5.4]]] - - # we want to make closer points where the first feature is close, and - # further if the second feature is close - - lsml = LSML() - lsml.fit(quadruplets) - -.. topic:: References: - - .. [1] Liu et al. - "Metric Learning from Relative Comparisons by Minimizing Squared - Residual". ICDM 2012. http://www.cs.ucla.edu/~weiwang/paper/ICDM12.pdf - - .. [2] Adapted from https://gist.github.com/kcarnold/5439917 - - SDML ---- @@ -343,3 +349,46 @@ method. However, it is one of the earliest and a still often cited technique. -with-side-information.pdf>`_ Xing, Jordan, Russell, Ng. .. [2] Adapted from Matlab code `here `_. + +Learning on quadruplets +======================= + +A type of information even weaker than pairs is information about relative +comparisons between pairs. The user should provide the algorithm with a +quadruplet of points, where the two first points are closer than the two +last points. No target vector (``y``) is needed, since the supervision is +already in the order that points are given in the quadruplet. + +Algorithms +========== + +LSML +---- + +`LSML`: Metric Learning from Relative Comparisons by Minimizing Squared +Residual + +.. topic:: Example Code: + +:: + + from metric_learn import LSML + + quadruplets = [[[1.2, 7.5], [1.3, 1.5], [6.4, 2.6], [6.2, 9.7]], + [[1.3, 4.5], [3.2, 4.6], [6.2, 5.5], [5.4, 5.4]], + [[3.2, 7.5], [3.3, 1.5], [8.4, 2.6], [8.2, 9.7]], + [[3.3, 4.5], [5.2, 4.6], [8.2, 5.5], [7.4, 5.4]]] + + # we want to make closer points where the first feature is close, and + # further if the second feature is close + + lsml = LSML() + lsml.fit(quadruplets) + +.. topic:: References: + + .. [1] Liu et al. + "Metric Learning from Relative Comparisons by Minimizing Squared + Residual". ICDM 2012. http://www.cs.ucla.edu/~weiwang/paper/ICDM12.pdf + + .. [2] Adapted from https://gist.github.com/kcarnold/5439917 diff --git a/metric_learn/base_metric.py b/metric_learn/base_metric.py index 58b8cc5d..9f127f58 100644 --- a/metric_learn/base_metric.py +++ b/metric_learn/base_metric.py @@ -1,8 +1,7 @@ -from numpy.linalg import cholesky -from scipy.spatial.distance import euclidean from sklearn.base import BaseEstimator -from sklearn.utils.validation import _is_arraylike -from sklearn.metrics import roc_auc_score +from sklearn.utils.extmath import stable_cumsum +from sklearn.utils.validation import _is_arraylike, check_is_fitted +from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_curve import numpy as np from abc import ABCMeta, abstractmethod import six @@ -138,6 +137,7 @@ def get_metric(self): use the metric learner's preprocessor, and works on concatenated arrays. """ + class MetricTransformer(six.with_metaclass(ABCMeta)): @abstractmethod @@ -295,6 +295,14 @@ def get_mahalanobis_matrix(self): class _PairsClassifierMixin(BaseMetricLearner): + """ + Attributes + ---------- + threshold_ : `float` + If the distance metric between two points is lower than this threshold, + points will be classified as similar, otherwise they will be + classified as dissimilar. + """ _tuple_size = 2 # number of points in a tuple, 2 for pairs @@ -317,13 +325,17 @@ def predict(self, pairs): y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) The predicted learned metric value between samples in every pair. """ - return self.decision_function(pairs) + check_is_fitted(self, ['threshold_', 'transformer_']) + return 2 * (- self.decision_function(pairs) <= self.threshold_) - 1 def decision_function(self, pairs): - """Returns the learned metric between input pairs. + """Returns the decision function used to classify the pairs. - Returns the learned metric value between samples in every pair. It should - ideally be low for similar samples and high for dissimilar samples. + Returns the opposite of the learned metric value between samples in every + pair, to be consistent with scikit-learn conventions. Hence it should + ideally be low for dissimilar samples and high for similar samples. + This is the decision function that is used to classify pairs as similar + (+1), or dissimilar (-1). Parameters ---------- @@ -335,12 +347,12 @@ def decision_function(self, pairs): Returns ------- y_predicted : `numpy.ndarray` of floats, shape=(n_constraints,) - The predicted learned metric value between samples in every pair. + The predicted decision function value for each pair. """ pairs = check_input(pairs, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) - return self.score_pairs(pairs) + return - self.score_pairs(pairs) def score(self, pairs, y): """Computes score of pairs similarity prediction. @@ -369,6 +381,190 @@ def score(self, pairs, y): """ return roc_auc_score(y, self.decision_function(pairs)) + def set_threshold(self, threshold): + """Sets the threshold of the metric learner to the given value `threshold`. + + See more in the :ref:`User Guide `. + + Parameters + ---------- + threshold : float + The threshold value we want to set. It is the value to which the + predicted distance for test pairs will be compared. If they are superior + to the threshold they will be classified as similar (+1), + and dissimilar (-1) if not. + + Returns + ------- + self : `_PairsClassifier` + The pairs classifier with the new threshold set. + """ + self.threshold_ = threshold + return self + + def calibrate_threshold(self, pairs_valid, y_valid, strategy='accuracy', + min_rate=None, beta=1.): + """Decision threshold calibration for pairwise binary classification + + Method that calibrates the decision threshold (cutoff point) of the metric + learner. This threshold will then be used when calling the method + `predict`. The methods for picking cutoff points make use of traditional + binary classification evaluation statistics such as the true positive and + true negative rates and F-scores. The threshold will be found to maximize + the chosen score on the validation set ``(pairs_valid, y_valid)``. + + See more in the :ref:`User Guide `. + + Parameters + ---------- + strategy : str, optional (default='accuracy') + The strategy to use for choosing the cutoff threshold. + + 'accuracy' + Selects a decision threshold that maximizes the accuracy. + 'f_beta' + Selects a decision threshold that maximizes the f_beta score, + with beta given by the parameter `beta`. + 'max_tpr' + Selects a decision threshold that yields the highest true positive + rate with true negative rate at least equal to the value of the + parameter `min_rate`. + 'max_tnr' + Selects a decision threshold that yields the highest true negative + rate with true positive rate at least equal to the value of the + parameter `min_rate`. + + beta : float in [0, 1], optional (default=None) + Beta value to be used in case strategy == 'f_beta'. + + min_rate : float in [0, 1] or None, (default=None) + In case strategy is 'max_tpr' or 'max_tnr' this parameter must be set + to specify the minimal value for the true negative rate or true positive + rate respectively that needs to be achieved. + + pairs_valid : array-like, shape=(n_pairs_valid, 2, n_features) + The validation set of pairs to use to set the threshold. + + y_valid : array-like, shape=(n_pairs_valid,) + The labels of the pairs of the validation set to use to set the + threshold. They must be +1 for positive pairs and -1 for negative pairs. + + References + ---------- + .. [1] Receiver-operating characteristic (ROC) plots: a fundamental + evaluation tool in clinical medicine, MH Zweig, G Campbell - + Clinical chemistry, 1993 + + .. [2] most of the code of this function is from scikit-learn's PR #10117 + + See Also + -------- + sklearn.calibration : scikit-learn's module for calibrating classifiers + """ + + self._validate_calibration_params(strategy, min_rate, beta) + + pairs_valid, y_valid = self._prepare_inputs(pairs_valid, y_valid, + type_of_inputs='tuples') + + n_samples = pairs_valid.shape[0] + if strategy == 'accuracy': + scores = self.decision_function(pairs_valid) + scores_sorted_idces = np.argsort(scores)[::-1] + scores_sorted = scores[scores_sorted_idces] + # true labels ordered by decision_function value: (higher first) + y_ordered = y_valid[scores_sorted_idces] + # we need to add a threshold that will reject all points + scores_sorted = np.concatenate([[scores_sorted[0] + 1], scores_sorted]) + + # finds the threshold that maximizes the accuracy: + cum_tp = stable_cumsum(y_ordered == 1) # cumulative number of true + # positives + # we need to add the point where all samples are rejected: + cum_tp = np.concatenate([[0.], cum_tp]) + cum_tn_inverted = stable_cumsum(y_ordered[::-1] == -1) + cum_tn = np.concatenate([[0.], cum_tn_inverted])[::-1] + cum_accuracy = (cum_tp + cum_tn) / n_samples + imax = np.argmax(cum_accuracy) + # we set the threshold to the lowest accepted score + # note: we are working with negative distances but we want the threshold + # to be with respect to the actual distances so we take minus sign + self.threshold_ = - scores_sorted[imax] + # note: if the best is to reject all points it's already one of the + # thresholds (scores_sorted[0]) + return self + + if strategy == 'f_beta': + precision, recall, thresholds = precision_recall_curve( + y_valid, self.decision_function(pairs_valid), pos_label=1) + + # here the thresholds are decreasing + # We ignore the warnings here, in the same taste as + # https://github.com/scikit-learn/scikit-learn/blob/62d205980446a1abc1065 + # f4332fd74eee57fcf73/sklearn/metrics/classification.py#L1284 + with np.errstate(divide='ignore', invalid='ignore'): + f_beta = ((1 + beta**2) * (precision * recall) / + (beta**2 * precision + recall)) + # We need to set nans to zero otherwise they will be considered higher + # than the others (also discussed in https://github.com/scikit-learn/ + # scikit-learn/pull/10117/files#r262115773) + f_beta[np.isnan(f_beta)] = 0. + imax = np.argmax(f_beta) + # we set the threshold to the lowest accepted score + # note: we are working with negative distances but we want the threshold + # to be with respect to the actual distances so we take minus sign + self.threshold_ = - thresholds[imax] + # Note: we don't need to deal with rejecting all points (i.e. threshold = + # max_scores + 1), since this can never happen to be optimal + # (see a more detailed discussion in test_calibrate_threshold_extreme) + return self + + fpr, tpr, thresholds = roc_curve(y_valid, + self.decision_function(pairs_valid), + pos_label=1) + # here the thresholds are decreasing + fpr, tpr, thresholds = fpr, tpr, thresholds + + if strategy in ['max_tpr', 'max_tnr']: + if strategy == 'max_tpr': + indices = np.where(1 - fpr >= min_rate)[0] + imax = np.argmax(tpr[indices]) + + if strategy == 'max_tnr': + indices = np.where(tpr >= min_rate)[0] + imax = np.argmax(1 - fpr[indices]) + + imax_valid = indices[imax] + # note: we are working with negative distances but we want the threshold + # to be with respect to the actual distances so we take minus sign + if indices[imax] == len(thresholds): # we want to accept everything + self.threshold_ = - (thresholds[imax_valid] - 1) + else: + # thanks to roc_curve, the first point will always be max_scores + # + 1, see: https://github.com/scikit-learn/scikit-learn/pull/13523 + self.threshold_ = - thresholds[imax_valid] + return self + + @staticmethod + def _validate_calibration_params(strategy='accuracy', min_rate=None, + beta=1.): + """Ensure that calibration parameters have allowed values""" + if strategy not in ('accuracy', 'f_beta', 'max_tpr', + 'max_tnr'): + raise ValueError('Strategy can either be "accuracy", "f_beta" or ' + '"max_tpr" or "max_tnr". Got "{}" instead.' + .format(strategy)) + if strategy == 'max_tpr' or strategy == 'max_tnr': + if (min_rate is None or not isinstance(min_rate, (int, float)) or + not min_rate >= 0 or not min_rate <= 1): + raise ValueError('Parameter min_rate must be a number in' + '[0, 1]. ' + 'Got {} instead.'.format(min_rate)) + if strategy == 'f_beta': + if beta is None or not isinstance(beta, (int, float)): + raise ValueError('Parameter beta must be a real number. ' + 'Got {} instead.'.format(type(beta))) + class _QuadrupletsClassifierMixin(BaseMetricLearner): @@ -393,6 +589,7 @@ def predict(self, quadruplets): prediction : `numpy.ndarray` of floats, shape=(n_constraints,) Predictions of the ordering of pairs, for each quadruplet. """ + check_is_fitted(self, 'transformer_') quadruplets = check_input(quadruplets, type_of_inputs='tuples', preprocessor=self.preprocessor_, estimator=self, tuple_size=self._tuple_size) @@ -401,8 +598,12 @@ def predict(self, quadruplets): def decision_function(self, quadruplets): """Predicts differences between sample distances in input quadruplets. - For each quadruplet of samples, computes the difference between the learned - metric of the first pair minus the learned metric of the second pair. + For each quadruplet in the samples, computes the difference between the + learned metric of the second pair minus the learned metric of the first + pair. The higher it is, the more probable it is that the pairs in the + quadruplet are presented in the right order, i.e. that the label of the + quadruplet is 1. The lower it is, the more probable it is that the label of + the quadruplet is -1. Parameters ---------- @@ -417,10 +618,10 @@ def decision_function(self, quadruplets): decision_function : `numpy.ndarray` of floats, shape=(n_constraints,) Metric differences. """ - return (self.score_pairs(quadruplets[:, :2]) - - self.score_pairs(quadruplets[:, 2:])) + return (self.score_pairs(quadruplets[:, 2:]) - + self.score_pairs(quadruplets[:, :2])) - def score(self, quadruplets, y=None): + def score(self, quadruplets): """Computes score on input quadruplets Returns the accuracy score of the following classification task: a record @@ -435,11 +636,9 @@ def score(self, quadruplets, y=None): points, or 2D array of indices of quadruplets if the metric learner uses a preprocessor. - y : Ignored, for scikit-learn compatibility. - Returns ------- score : float The quadruplets score. """ - return -np.mean(self.predict(quadruplets)) + return - np.mean(self.predict(quadruplets)) diff --git a/metric_learn/itml.py b/metric_learn/itml.py index a0ff05f9..9b6dccb2 100644 --- a/metric_learn/itml.py +++ b/metric_learn/itml.py @@ -148,11 +148,19 @@ class ITML(_BaseITML, _PairsClassifierMixin): transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + + threshold_ : `float` + If the distance metric between two points is lower than this threshold, + points will be classified as similar, otherwise they will be + classified as dissimilar. """ - def fit(self, pairs, y, bounds=None): + def fit(self, pairs, y, bounds=None, calibration_params=None): """Learn the ITML model. + The threshold will be calibrated on the trainset using the parameters + `calibration_params`. + Parameters ---------- pairs: array-like, shape=(n_constraints, 2, n_features) or @@ -170,13 +178,22 @@ def fit(self, pairs, y, bounds=None): If not provided at initialization, bounds_[0] and bounds_[1] will be set to the 5th and 95th percentile of the pairwise distances among all points present in the input `pairs`. + calibration_params : `dict` or `None` + Dictionary of parameters to give to `calibrate_threshold` for the + threshold calibration step done at the end of `fit`. If `None` is + given, `calibrate_threshold` will use the default parameters. Returns ------- self : object Returns the instance. """ - return self._fit(pairs, y, bounds=bounds) + calibration_params = (calibration_params if calibration_params is not + None else dict()) + self._validate_calibration_params(**calibration_params) + self._fit(pairs, y) + self.calibrate_threshold(pairs, y, **calibration_params) + return self class ITML_Supervised(_BaseITML, TransformerMixin): diff --git a/metric_learn/lsml.py b/metric_learn/lsml.py index 312990ab..536719ba 100644 --- a/metric_learn/lsml.py +++ b/metric_learn/lsml.py @@ -45,7 +45,7 @@ def __init__(self, tol=1e-3, max_iter=1000, prior=None, verbose=False, self.verbose = verbose super(_BaseLSML, self).__init__(preprocessor) - def _fit(self, quadruplets, y=None, weights=None): + def _fit(self, quadruplets, weights=None): quadruplets = self._prepare_inputs(quadruplets, type_of_inputs='tuples') diff --git a/metric_learn/mmc.py b/metric_learn/mmc.py index f9d3690b..346db2f8 100644 --- a/metric_learn/mmc.py +++ b/metric_learn/mmc.py @@ -359,27 +359,43 @@ class MMC(_BaseMMC, _PairsClassifierMixin): transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + + threshold_ : `float` + If the distance metric between two points is lower than this threshold, + points will be classified as similar, otherwise they will be + classified as dissimilar. """ - def fit(self, pairs, y): + def fit(self, pairs, y, calibration_params=None): """Learn the MMC model. + The threshold will be calibrated on the trainset using the parameters + `calibration_params`. + Parameters ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) or + pairs : array-like, shape=(n_constraints, 2, n_features) or (n_constraints, 2) 3D Array of pairs with each row corresponding to two points, or 2D array of indices of pairs if the metric learner uses a preprocessor. - y: array-like, of shape (n_constraints,) + y : array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. - + calibration_params : `dict` or `None` + Dictionary of parameters to give to `calibrate_threshold` for the + threshold calibration step done at the end of `fit`. If `None` is + given, `calibrate_threshold` will use the default parameters. Returns ------- self : object Returns the instance. """ - return self._fit(pairs, y) + calibration_params = (calibration_params if calibration_params is not + None else dict()) + self._validate_calibration_params(**calibration_params) + self._fit(pairs, y) + self.calibrate_threshold(pairs, y, **calibration_params) + return self class MMC_Supervised(_BaseMMC, TransformerMixin): diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 6fd29d38..e9828d07 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -138,27 +138,44 @@ class SDML(_BaseSDML, _PairsClassifierMixin): transformer_ : `numpy.ndarray`, shape=(num_dims, n_features) The linear transformation ``L`` deduced from the learned Mahalanobis metric (See function `transformer_from_metric`.) + + threshold_ : `float` + If the distance metric between two points is lower than this threshold, + points will be classified as similar, otherwise they will be + classified as dissimilar. """ - def fit(self, pairs, y): + def fit(self, pairs, y, calibration_params=None): """Learn the SDML model. + The threshold will be calibrated on the trainset using the parameters + `calibration_params`. + Parameters ---------- - pairs: array-like, shape=(n_constraints, 2, n_features) or + pairs : array-like, shape=(n_constraints, 2, n_features) or (n_constraints, 2) 3D Array of pairs with each row corresponding to two points, or 2D array of indices of pairs if the metric learner uses a preprocessor. - y: array-like, of shape (n_constraints,) + y : array-like, of shape (n_constraints,) Labels of constraints. Should be -1 for dissimilar pair, 1 for similar. + calibration_params : `dict` or `None` + Dictionary of parameters to give to `calibrate_threshold` for the + threshold calibration step done at the end of `fit`. If `None` is + given, `calibrate_threshold` will use the default parameters. Returns ------- self : object Returns the instance. """ - return self._fit(pairs, y) + calibration_params = (calibration_params if calibration_params is not + None else dict()) + self._validate_calibration_params(**calibration_params) + self._fit(pairs, y) + self.calibrate_threshold(pairs, y, **calibration_params) + return self class SDML_Supervised(_BaseSDML, TransformerMixin): diff --git a/test/test_mahalanobis_mixin.py b/test/test_mahalanobis_mixin.py index a0bf3b9d..15bf1aed 100644 --- a/test/test_mahalanobis_mixin.py +++ b/test/test_mahalanobis_mixin.py @@ -13,7 +13,8 @@ from metric_learn.base_metric import (_QuadrupletsClassifierMixin, _PairsClassifierMixin) -from test.test_utils import ids_metric_learners, metric_learners +from test.test_utils import (ids_metric_learners, metric_learners, + remove_y_quadruplets) RNG = check_random_state(0) @@ -27,7 +28,7 @@ def test_score_pairs_pairwise(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) pairwise = model.score_pairs(np.array(list(product(X, X))))\ .reshape(n_samples, n_samples) @@ -51,7 +52,7 @@ def test_score_pairs_toy_example(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) pairs = np.stack([X[:10], X[10:20]], axis=1) embedded_pairs = pairs.dot(model.transformer_.T) distances = np.sqrt(np.sum((embedded_pairs[:, 1] - @@ -67,7 +68,7 @@ def test_score_pairs_finite(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) pairs = np.array(list(product(X, X))) assert np.isfinite(model.score_pairs(pairs)).all() @@ -81,7 +82,7 @@ def test_score_pairs_dim(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) tuples = np.array(list(product(X, X))) assert model.score_pairs(tuples).shape == (tuples.shape[0],) context = make_context(estimator) @@ -112,7 +113,7 @@ def test_embed_toy_example(estimator, build_dataset): X = X[:n_samples] model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) embedded_points = X.dot(model.transformer_.T) assert_array_almost_equal(model.transform(X), embedded_points) @@ -124,7 +125,7 @@ def test_embed_dim(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert model.transform(X).shape == X.shape # assert that ValueError is thrown if input shape is 1D @@ -137,8 +138,11 @@ def test_embed_dim(estimator, build_dataset): assert str(raised_error.value) == err_msg # we test that the shape is also OK when doing dimensionality reduction if type(model).__name__ in {'LFDA', 'MLKR', 'NCA', 'RCA'}: + # TODO: + # avoid this enumeration and rather test if hasattr n_components + # as soon as we have made the arguments names as such (issue #167) model.set_params(num_dims=2) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert model.transform(X).shape == (X.shape[0], 2) # assert that ValueError is thrown if input shape is 1D with pytest.raises(ValueError) as raised_error: @@ -153,7 +157,7 @@ def test_embed_finite(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert np.isfinite(model.transform(X)).all() @@ -164,7 +168,7 @@ def test_embed_is_linear(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert_array_almost_equal(model.transform(X[:10] + X[10:20]), model.transform(X[:10]) + model.transform(X[10:20])) @@ -183,7 +187,7 @@ def test_get_metric_equivalent_to_explicit_mahalanobis(estimator, input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] a, b = (rng.randn(n_features), rng.randn(n_features)) @@ -202,7 +206,7 @@ def test_get_metric_is_pseudo_metric(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] @@ -228,7 +232,7 @@ def test_metric_raises_deprecation_warning(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) with pytest.warns(DeprecationWarning) as raised_warning: model.metric() @@ -245,7 +249,7 @@ def test_get_metric_compatible_with_scikit_learn(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) clustering = DBSCAN(metric=model.get_metric()) clustering.fit(X) @@ -258,7 +262,7 @@ def test_get_squared_metric(estimator, build_dataset): input_data, labels, _, X = build_dataset() model = clone(estimator) set_random_state(model) - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) metric = model.get_metric() n_features = X.shape[1] @@ -278,7 +282,7 @@ def test_transformer_is_2D(estimator, build_dataset): model = clone(estimator) set_random_state(model) # test that it works for X.shape[1] features - model.fit(input_data, labels) + model.fit(*remove_y_quadruplets(estimator, input_data, labels)) assert model.transformer_.shape == (X.shape[1], X.shape[1]) # test that it works for 1 feature @@ -297,5 +301,5 @@ def test_transformer_is_2D(estimator, build_dataset): to_keep = np.where(np.abs(diffs.ravel()) > 1e-9) trunc_data = trunc_data[to_keep] labels = labels[to_keep] - model.fit(trunc_data, labels) + model.fit(*remove_y_quadruplets(estimator, trunc_data, labels)) assert model.transformer_.shape == (1, 1) # the transformer must be 2D diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py new file mode 100644 index 00000000..828181cb --- /dev/null +++ b/test/test_pairs_classifiers.py @@ -0,0 +1,491 @@ +from __future__ import division + +from functools import partial + +import pytest +from numpy.testing import assert_array_equal + +from metric_learn.base_metric import _PairsClassifierMixin, MahalanobisMixin +from sklearn.exceptions import NotFittedError +from sklearn.metrics import (f1_score, accuracy_score, fbeta_score, + precision_score) +from sklearn.model_selection import train_test_split + +from test.test_utils import pairs_learners, ids_pairs_learners +from sklearn.utils.testing import set_random_state +from sklearn import clone +import numpy as np +from itertools import product + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_predict_only_one_or_minus_one(estimator, build_dataset, + with_preprocessor): + """Test that all predicted values are either +1 or -1""" + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + pairs_train, pairs_test, y_train, y_test = train_test_split(input_data, + labels) + estimator.fit(pairs_train, y_train) + predictions = estimator.predict(pairs_test) + not_valid = [e for e in predictions if e not in [-1, 1]] + assert len(not_valid) == 0 + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_predict_monotonous(estimator, build_dataset, + with_preprocessor): + """Test that there is a threshold distance separating points labeled as + similar and points labeled as dissimilar """ + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + pairs_train, pairs_test, y_train, y_test = train_test_split(input_data, + labels) + estimator.fit(pairs_train, y_train) + distances = estimator.score_pairs(pairs_test) + predictions = estimator.predict(pairs_test) + min_dissimilar = np.min(distances[predictions == -1]) + max_similar = np.max(distances[predictions == 1]) + assert max_similar <= min_dissimilar + separator = np.mean([min_dissimilar, max_similar]) + assert (predictions[distances > separator] == -1).all() + assert (predictions[distances < separator] == 1).all() + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, + with_preprocessor): + """Test that a NotFittedError is raised if someone tries to predict and + the metric learner has not been fitted.""" + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + with pytest.raises(NotFittedError): + estimator.predict(input_data) + + +@pytest.mark.parametrize('calibration_params', + [None, {}, dict(), {'strategy': 'accuracy'}] + + [{'strategy': strategy, 'min_rate': min_rate} + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], [0., 0.2, 0.8, 1.])] + + [{'strategy': 'f_beta', 'beta': beta} + for beta in [0., 0.1, 0.2, 1., 5.]] + ) +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_fit_with_valid_threshold_params(estimator, build_dataset, + with_preprocessor, + calibration_params): + """Tests that fitting `calibration_params` with appropriate parameters works + as expected""" + pairs, y, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + estimator.fit(pairs, y, calibration_params=calibration_params) + estimator.predict(pairs) + + +@pytest.mark.parametrize('kwargs', + [{'strategy': 'accuracy'}] + + [{'strategy': strategy, 'min_rate': min_rate} + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], [0., 0.2, 0.8, 1.])] + + [{'strategy': 'f_beta', 'beta': beta} + for beta in [0., 0.1, 0.2, 1., 5.]] + ) +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_threshold_different_scores_is_finite(estimator, build_dataset, + with_preprocessor, kwargs): + # test that calibrating the threshold works for every metric learner + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + estimator.fit(input_data, labels) + with pytest.warns(None) as record: + estimator.calibrate_threshold(input_data, labels, **kwargs) + assert len(record) == 0 + + +class IdentityPairsClassifier(MahalanobisMixin, _PairsClassifierMixin): + """A simple pairs classifier for testing purposes, that will just have + identity as transformer_, and a string threshold so that it returns an + error if not explicitely set. + """ + def fit(self, pairs, y): + pairs, y = self._prepare_inputs(pairs, y, + type_of_inputs='tuples') + self.transformer_ = np.atleast_2d(np.identity(pairs.shape[2])) + self.threshold_ = 'I am not set.' + return self + + +def test_set_threshold(): + # test that set_threshold indeed sets the threshold + identity_pairs_classifier = IdentityPairsClassifier() + pairs = np.array([[[0.], [1.]], [[1.], [3.]], [[2.], [5.]], [[3.], [7.]]]) + y = np.array([1, 1, -1, -1]) + identity_pairs_classifier.fit(pairs, y) + identity_pairs_classifier.set_threshold(0.5) + assert identity_pairs_classifier.threshold_ == 0.5 + + +def test_f_beta_1_is_f_1(): + # test that putting beta to 1 indeed finds the best threshold to optimize + # the f1_score + rng = np.random.RandomState(42) + n_samples = 100 + pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) + pairs_learner = IdentityPairsClassifier() + pairs_learner.fit(pairs, y) + pairs_learner.calibrate_threshold(pairs, y, strategy='f_beta', beta=1) + best_f1_score = f1_score(y, pairs_learner.predict(pairs)) + for threshold in - pairs_learner.decision_function(pairs): + pairs_learner.set_threshold(threshold) + assert f1_score(y, pairs_learner.predict(pairs)) <= best_f1_score + + +def true_pos_true_neg_rates(y_true, y_pred): + """A function that returns the true positive rates and the true negatives + rate. For testing purposes (optimized for readability not performance).""" + assert y_pred.shape[0] == y_true.shape[0] + tp = np.sum((y_pred == 1) * (y_true == 1)) + tn = np.sum((y_pred == -1) * (y_true == -1)) + fn = np.sum((y_pred == -1) * (y_true == 1)) + fp = np.sum((y_pred == 1) * (y_true == -1)) + tpr = tp / (tp + fn) + tnr = tn / (tn + fp) + tpr = tpr if not np.isnan(tpr) else 0. + tnr = tnr if not np.isnan(tnr) else 0. + return tpr, tnr + + +def tpr_threshold(y_true, y_pred, tnr_threshold=0.): + """A function that returns the true positive rate if the true negative + rate is higher or equal than `threshold`, and -1 otherwise. For testing + purposes""" + tpr, tnr = true_pos_true_neg_rates(y_true, y_pred) + if tnr < tnr_threshold: + return -1 + else: + return tpr + + +def tnr_threshold(y_true, y_pred, tpr_threshold=0.): + """A function that returns the true negative rate if the true positive + rate is higher or equal than `threshold`, and -1 otherwise. For testing + purposes""" + tpr, tnr = true_pos_true_neg_rates(y_true, y_pred) + if tpr < tpr_threshold: + return -1 + else: + return tnr + + +@pytest.mark.parametrize('kwargs, scoring', + [({'strategy': 'accuracy'}, accuracy_score)] + + [({'strategy': 'f_beta', 'beta': b}, + partial(fbeta_score, beta=b)) + for b in [0.1, 0.5, 1.]] + + [({'strategy': 'f_beta', 'beta': 0}, + precision_score)] + + [({'strategy': 'max_tpr', 'min_rate': t}, + partial(tpr_threshold, tnr_threshold=t)) + for t in [0., 0.1, 0.5, 0.8, 1.]] + + [({'strategy': 'max_tnr', 'min_rate': t}, + partial(tnr_threshold, tpr_threshold=t)) + for t in [0., 0.1, 0.5, 0.8, 1.]], + ) +def test_found_score_is_best_score(kwargs, scoring): + # test that when we use calibrate threshold, it will indeed be the + # threshold that have the best score + rng = np.random.RandomState(42) + n_samples = 50 + pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) + pairs_learner = IdentityPairsClassifier() + pairs_learner.fit(pairs, y) + pairs_learner.calibrate_threshold(pairs, y, **kwargs) + best_score = scoring(y, pairs_learner.predict(pairs)) + scores = [] + predicted_scores = pairs_learner.decision_function(pairs) + predicted_scores = np.hstack([[np.min(predicted_scores) - 1], + predicted_scores, + [np.max(predicted_scores) + 1]]) + for threshold in - predicted_scores: + pairs_learner.set_threshold(threshold) + score = scoring(y, pairs_learner.predict(pairs)) + assert score <= best_score + scores.append(score) + assert len(set(scores)) > 1 # assert that we didn't always have the same + # value for the score (which could be a hint for some bug, but would still + # silently pass the test)) + + +@pytest.mark.parametrize('kwargs, scoring', + [({'strategy': 'accuracy'}, accuracy_score)] + + [({'strategy': 'f_beta', 'beta': b}, + partial(fbeta_score, beta=b)) + for b in [0.1, 0.5, 1.]] + + [({'strategy': 'f_beta', 'beta': 0}, + precision_score)] + + [({'strategy': 'max_tpr', 'min_rate': t}, + partial(tpr_threshold, tnr_threshold=t)) + for t in [0., 0.1, 0.5, 0.8, 1.]] + + [({'strategy': 'max_tnr', 'min_rate': t}, + partial(tnr_threshold, tpr_threshold=t)) + for t in [0., 0.1, 0.5, 0.8, 1.]] + ) +def test_found_score_is_best_score_duplicates(kwargs, scoring): + # test that when we use calibrate threshold, it will indeed be the + # threshold that have the best score. It's the same as the previous test + # except this time we test that the scores are coherent even if there are + # duplicates (i.e. points that have the same score returned by + # `decision_function`). + rng = np.random.RandomState(42) + n_samples = 50 + pairs, y = rng.randn(n_samples, 2, 5), rng.choice([-1, 1], size=n_samples) + # we create some duplicates points, which will also have the same score + # predicted + pairs[6:10] = pairs[10:14] + y[6:10] = y[10:14] + pairs_learner = IdentityPairsClassifier() + pairs_learner.fit(pairs, y) + pairs_learner.calibrate_threshold(pairs, y, **kwargs) + best_score = scoring(y, pairs_learner.predict(pairs)) + scores = [] + predicted_scores = pairs_learner.decision_function(pairs) + predicted_scores = np.hstack([[np.min(predicted_scores) - 1], + predicted_scores, + [np.max(predicted_scores) + 1]]) + for threshold in - predicted_scores: + pairs_learner.set_threshold(threshold) + score = scoring(y, pairs_learner.predict(pairs)) + assert score <= best_score + scores.append(score) + assert len(set(scores)) > 1 # assert that we didn't always have the same + # value for the score (which could be a hint for some bug, but would still + # silently pass the test)) + + +@pytest.mark.parametrize('invalid_args, expected_msg', + [({'strategy': 'weird'}, + ('Strategy can either be "accuracy", "f_beta" or ' + '"max_tpr" or "max_tnr". Got "weird" instead.'))] + + [({'strategy': strategy, 'min_rate': min_rate}, + 'Parameter min_rate must be a number in' + '[0, 1]. Got {} instead.'.format(min_rate)) + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], + [None, 'weird', -0.2, 1.2, 3 + 2j])] + + [({'strategy': 'f_beta', 'beta': beta}, + 'Parameter beta must be a real number. ' + 'Got {} instead.'.format(type(beta))) + for beta in [None, 'weird', 3 + 2j]] + ) +def test_calibrate_threshold_invalid_parameters_right_error(invalid_args, + expected_msg): + # test that the right error message is returned if invalid arguments are + # given to calibrate_threshold + rng = np.random.RandomState(42) + pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) + pairs_learner = IdentityPairsClassifier() + pairs_learner.fit(pairs, y) + with pytest.raises(ValueError) as raised_error: + pairs_learner.calibrate_threshold(pairs, y, **invalid_args) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('valid_args', + [{'strategy': 'accuracy'}] + + [{'strategy': strategy, 'min_rate': min_rate} + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], + [0., 0.2, 0.8, 1.])] + + [{'strategy': 'f_beta', 'beta': beta} + for beta in [-5., -1., 0., 0.1, 0.2, 1., 5.]] + # Note that we authorize beta < 0 (even if + # in fact it will be squared, so it would be useless + # to do that) + ) +def test_calibrate_threshold_valid_parameters(valid_args): + # test that no warning message is returned if valid arguments are given to + # calibrate threshold + rng = np.random.RandomState(42) + pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) + pairs_learner = IdentityPairsClassifier() + pairs_learner.fit(pairs, y) + with pytest.warns(None) as record: + pairs_learner.calibrate_threshold(pairs, y, **valid_args) + assert len(record) == 0 + + +def test_calibrate_threshold_extreme(): + """Test that in the (rare) case where we should accept all points or + reject all points, this is effectively what + is done""" + + class MockBadPairsClassifier(MahalanobisMixin, _PairsClassifierMixin): + """A pairs classifier that returns bad scores (i.e. in the inverse order + of what we would expect from a good pairs classifier + """ + + def fit(self, pairs, y, calibration_params=None): + self.transformer_ = 'not used' + self.calibrate_threshold(pairs, y, **(calibration_params if + calibration_params is not None else + dict())) + return self + + def decision_function(self, pairs): + return np.arange(pairs.shape[0], dtype=float) + + rng = np.random.RandomState(42) + pairs = rng.randn(7, 2, 5) # the info in X is not used, it's just for the + # API + + y = [1., 1., 1., -1., -1., -1., -1.] + mock_clf = MockBadPairsClassifier() + # case of bad scoring with more negative than positives. In + # this case, when: + # optimizing for accuracy we should reject all points + mock_clf.fit(pairs, y, calibration_params={'strategy': 'accuracy'}) + assert_array_equal(mock_clf.predict(pairs), - np.ones(7)) + + # optimizing for max_tpr we should accept all points if min_rate == 0. ( + # because by convention then tnr=0/0=0) + mock_clf.fit(pairs, y, calibration_params={'strategy': 'max_tpr', + 'min_rate': 0.}) + assert_array_equal(mock_clf.predict(pairs), np.ones(7)) + # optimizing for max_tnr we should reject all points if min_rate = 0. ( + # because by convention then tpr=0/0=0) + mock_clf.fit(pairs, y, calibration_params={'strategy': 'max_tnr', + 'min_rate': 0.}) + assert_array_equal(mock_clf.predict(pairs), - np.ones(7)) + + y = [1., 1., 1., 1., -1., -1., -1.] + # case of bad scoring with more positives than negatives. In + # this case, when: + # optimizing for accuracy we should accept all points + mock_clf.fit(pairs, y, calibration_params={'strategy': 'accuracy'}) + assert_array_equal(mock_clf.predict(pairs), np.ones(7)) + # optimizing for max_tpr we should accept all points if min_rate == 0. ( + # because by convention then tnr=0/0=0) + mock_clf.fit(pairs, y, calibration_params={'strategy': 'max_tpr', + 'min_rate': 0.}) + assert_array_equal(mock_clf.predict(pairs), np.ones(7)) + # optimizing for max_tnr we should reject all points if min_rate = 0. ( + # because by convention then tpr=0/0=0) + mock_clf.fit(pairs, y, calibration_params={'strategy': 'max_tnr', + 'min_rate': 0.}) + assert_array_equal(mock_clf.predict(pairs), - np.ones(7)) + + # Note: we'll never find a case where we would reject all points for + # maximizing tpr (we can always accept more points), and accept all + # points for maximizing tnr (we can always reject more points) + + # case of alternated scores: for optimizing the f_1 score we should accept + # all points (because this way we have max recall (1) and max precision ( + # here: 0.5)) + y = [1., -1., 1., -1., 1., -1.] + mock_clf.fit(pairs[:6], y, calibration_params={'strategy': 'f_beta', + 'beta': 1.}) + assert_array_equal(mock_clf.predict(pairs[:6]), np.ones(6)) + + # Note: for optimizing f_1 score, we will never find an optimal case where we + # reject all points because in this case we would have 0 precision (by + # convention, because it's 0/0), and 0 recall (and we could always decrease + # the threshold to increase the recall, and we couldn't do worse for + # precision so it would be better) + + +@pytest.mark.parametrize('estimator, _', + pairs_learners + [(IdentityPairsClassifier(), None), + (_PairsClassifierMixin, None)], + ids=ids_pairs_learners + ['mock', 'class']) +@pytest.mark.parametrize('invalid_args, expected_msg', + [({'strategy': 'weird'}, + ('Strategy can either be "accuracy", "f_beta" or ' + '"max_tpr" or "max_tnr". Got "weird" instead.'))] + + [({'strategy': strategy, 'min_rate': min_rate}, + 'Parameter min_rate must be a number in' + '[0, 1]. Got {} instead.'.format(min_rate)) + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], + [None, 'weird', -0.2, 1.2, 3 + 2j])] + + [({'strategy': 'f_beta', 'beta': beta}, + 'Parameter beta must be a real number. ' + 'Got {} instead.'.format(type(beta))) + for beta in [None, 'weird', 3 + 2j]] + ) +def test_validate_calibration_params_invalid_parameters_right_error( + estimator, _, invalid_args, expected_msg): + # test that the right error message is returned if invalid arguments are + # given to _validate_calibration_params, for all pairs metric learners as + # well as a mocking general identity pairs classifier and the class itself + with pytest.raises(ValueError) as raised_error: + estimator._validate_calibration_params(**invalid_args) + assert str(raised_error.value) == expected_msg + + +@pytest.mark.parametrize('estimator, _', + pairs_learners + [(IdentityPairsClassifier(), None), + (_PairsClassifierMixin, None)], + ids=ids_pairs_learners + ['mock', 'class']) +@pytest.mark.parametrize('valid_args', + [{}, {'strategy': 'accuracy'}] + + [{'strategy': strategy, 'min_rate': min_rate} + for (strategy, min_rate) in product( + ['max_tpr', 'max_tnr'], + [0., 0.2, 0.8, 1.])] + + [{'strategy': 'f_beta', 'beta': beta} + for beta in [-5., -1., 0., 0.1, 0.2, 1., 5.]] + # Note that we authorize beta < 0 (even if + # in fact it will be squared, so it would be useless + # to do that) + ) +def test_validate_calibration_params_valid_parameters( + estimator, _, valid_args): + # test that no warning message is returned if valid arguments are given to + # _validate_calibration_params for all pairs metric learners, as well as + # a mocking example, and the class itself + with pytest.warns(None) as record: + estimator._validate_calibration_params(**valid_args) + assert len(record) == 0 + + +@pytest.mark.parametrize('estimator, build_dataset', + pairs_learners, + ids=ids_pairs_learners) +def test_validate_calibration_params_invalid_parameters_error_before__fit( + estimator, build_dataset): + """For all pairs metric learners (which currently all have a _fit method), + make sure that calibration parameters are validated before fitting""" + estimator = clone(estimator) + input_data, labels, _, _ = build_dataset() + + def breaking_fun(**args): # a function that fails so that we will miss + # the calibration at the end and therefore the right error message from + # validating params should be thrown before + raise RuntimeError('Game over.') + estimator._fit = breaking_fun + expected_msg = ('Strategy can either be "accuracy", "f_beta" or ' + '"max_tpr" or "max_tnr". Got "weird" instead.') + with pytest.raises(ValueError) as raised_error: + estimator.fit(input_data, labels, calibration_params={'strategy': 'weird'}) + assert str(raised_error.value) == expected_msg diff --git a/test/test_quadruplets_classifiers.py b/test/test_quadruplets_classifiers.py new file mode 100644 index 00000000..2bf36b3f --- /dev/null +++ b/test/test_quadruplets_classifiers.py @@ -0,0 +1,42 @@ +import pytest +from sklearn.exceptions import NotFittedError +from sklearn.model_selection import train_test_split + +from test.test_utils import quadruplets_learners, ids_quadruplets_learners +from sklearn.utils.testing import set_random_state +from sklearn import clone +import numpy as np + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, + ids=ids_quadruplets_learners) +def test_predict_only_one_or_minus_one(estimator, build_dataset, + with_preprocessor): + """Test that all predicted values are either +1 or -1""" + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + (quadruplets_train, + quadruplets_test, y_train, y_test) = train_test_split(input_data, labels) + estimator.fit(quadruplets_train) + predictions = estimator.predict(quadruplets_test) + not_valid = [e for e in predictions if e not in [-1, 1]] + assert len(not_valid) == 0 + + +@pytest.mark.parametrize('with_preprocessor', [True, False]) +@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, + ids=ids_quadruplets_learners) +def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, + with_preprocessor): + """Test that a NotFittedError is raised if someone tries to predict and + the metric learner has not been fitted.""" + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + with pytest.raises(NotFittedError): + estimator.predict(input_data) + diff --git a/test/test_sklearn_compat.py b/test/test_sklearn_compat.py index f1248c9a..5d6c5d77 100644 --- a/test/test_sklearn_compat.py +++ b/test/test_sklearn_compat.py @@ -15,9 +15,13 @@ import numpy as np from sklearn.model_selection import (cross_val_score, cross_val_predict, train_test_split, KFold) +from sklearn.metrics.scorer import get_scorer from sklearn.utils.testing import _get_args from test.test_utils import (metric_learners, ids_metric_learners, - mock_preprocessor) + mock_preprocessor, tuples_learners, + ids_tuples_learners, pairs_learners, + ids_pairs_learners, remove_y_quadruplets, + quadruplets_learners) # Wrap the _Supervised methods with a deterministic wrapper for testing. @@ -97,22 +101,62 @@ def stable_init(self, sparsity_param=0.01, num_labeled='deprecated', @pytest.mark.parametrize('with_preprocessor', [True, False]) -@pytest.mark.parametrize('estimator, build_dataset', metric_learners, - ids=ids_metric_learners) -def test_cross_validation_is_finite(estimator, build_dataset, - with_preprocessor): +@pytest.mark.parametrize('estimator, build_dataset', pairs_learners, + ids=ids_pairs_learners) +def test_various_scoring_on_tuples_learners(estimator, build_dataset, + with_preprocessor): + """Tests that scikit-learn's scoring returns something finite, + for other scoring than default scoring. (List of scikit-learn's scores can be + found in sklearn.metrics.scorer). For each type of output (predict, + predict_proba, decision_function), we test a bunch of scores. + We only test on pairs learners because quadruplets don't have a y argument. + """ + input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + + # scores that need a predict function: every tuples learner should have a + # predict function (whether the pair is of positive samples or negative + # samples) + for scoring in ['accuracy', 'f1']: + check_score_is_finite(scoring, estimator, input_data, labels) + # scores that need a predict_proba: + if hasattr(estimator, "predict_proba"): + for scoring in ['neg_log_loss', 'brier_score']: + check_score_is_finite(scoring, estimator, input_data, labels) + # scores that need a decision_function: every tuples learner should have a + # decision function (the metric between points) + for scoring in ['roc_auc', 'average_precision', 'precision', 'recall']: + check_score_is_finite(scoring, estimator, input_data, labels) + + +def check_score_is_finite(scoring, estimator, input_data, labels): + estimator = clone(estimator) + assert np.isfinite(cross_val_score(estimator, input_data, labels, + scoring=scoring)).all() + estimator.fit(input_data, labels) + assert np.isfinite(get_scorer(scoring)(estimator, input_data, labels)) + + +@pytest.mark.parametrize('estimator, build_dataset', tuples_learners, + ids=ids_tuples_learners) +def test_cross_validation_is_finite(estimator, build_dataset): """Tests that validation on metric-learn estimators returns something finite """ - if any(hasattr(estimator, method) for method in ["predict", "score"]): - input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) - estimator = clone(estimator) - estimator.set_params(preprocessor=preprocessor) - set_random_state(estimator) - if hasattr(estimator, "score"): - assert np.isfinite(cross_val_score(estimator, input_data, labels)).all() - if hasattr(estimator, "predict"): - assert np.isfinite(cross_val_predict(estimator, - input_data, labels)).all() + input_data, labels, preprocessor, _ = build_dataset() + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + set_random_state(estimator) + assert np.isfinite(cross_val_score(estimator, + *remove_y_quadruplets(estimator, + input_data, + labels))).all() + assert np.isfinite(cross_val_predict(estimator, + *remove_y_quadruplets(estimator, + input_data, + labels) + )).all() @pytest.mark.parametrize('with_preprocessor', [True, False]) @@ -143,23 +187,28 @@ def test_cross_validation_manual_vs_scikit(estimator, build_dataset, train_mask = np.ones(input_data.shape[0], bool) train_mask[test_slice] = False y_train, y_test = labels[train_mask], labels[test_slice] - estimator.fit(input_data[train_mask], y_train) + estimator.fit(*remove_y_quadruplets(estimator, + input_data[train_mask], + y_train)) if hasattr(estimator, "score"): - scores.append(estimator.score(input_data[test_slice], y_test)) + scores.append(estimator.score(*remove_y_quadruplets( + estimator, input_data[test_slice], y_test))) if hasattr(estimator, "predict"): predictions[test_slice] = estimator.predict(input_data[test_slice]) if hasattr(estimator, "score"): - assert all(scores == cross_val_score(estimator, input_data, labels, - cv=kfold)) + assert all(scores == cross_val_score( + estimator, *remove_y_quadruplets(estimator, input_data, labels), + cv=kfold)) if hasattr(estimator, "predict"): - assert all(predictions == cross_val_predict(estimator, input_data, - labels, - cv=kfold)) + assert all(predictions == cross_val_predict( + estimator, + *remove_y_quadruplets(estimator, input_data, labels), + cv=kfold)) def check_score(estimator, tuples, y): if hasattr(estimator, "score"): - score = estimator.score(tuples, y) + score = estimator.score(*remove_y_quadruplets(estimator, tuples, y)) assert np.isfinite(score) @@ -183,7 +232,7 @@ def test_simple_estimator(estimator, build_dataset, with_preprocessor): estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) - estimator.fit(tuples_train, y_train) + estimator.fit(*remove_y_quadruplets(estimator, tuples_train, y_train)) check_score(estimator, tuples_test, y_test) check_predict(estimator, tuples_test) @@ -230,7 +279,9 @@ def test_estimators_fit_returns_self(estimator, build_dataset, input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) estimator = clone(estimator) estimator.set_params(preprocessor=preprocessor) - assert estimator.fit(input_data, labels) is estimator + assert estimator.fit(*remove_y_quadruplets(estimator, + input_data, + labels)) is estimator @pytest.mark.parametrize('with_preprocessor', [True, False]) @@ -240,42 +291,53 @@ def test_pipeline_consistency(estimator, build_dataset, with_preprocessor): # Adapted from scikit learn # check that make_pipeline(est) gives same score as est - input_data, y, preprocessor, _ = build_dataset(with_preprocessor) - - def make_random_state(estimator, in_pipeline): - rs = {} - name_estimator = estimator.__class__.__name__ - if name_estimator[-11:] == '_Supervised': - name_param = 'random_state' - if in_pipeline: - name_param = name_estimator.lower() + '__' + name_param - rs[name_param] = check_random_state(0) - return rs + # we do this test on all except quadruplets (since they don't have a y + # in fit): + if estimator.__class__.__name__ not in [e.__class__.__name__ + for (e, _) in + quadruplets_learners]: + input_data, y, preprocessor, _ = build_dataset(with_preprocessor) + + def make_random_state(estimator, in_pipeline): + rs = {} + name_estimator = estimator.__class__.__name__ + if name_estimator[-11:] == '_Supervised': + name_param = 'random_state' + if in_pipeline: + name_param = name_estimator.lower() + '__' + name_param + rs[name_param] = check_random_state(0) + return rs - estimator = clone(estimator) - estimator.set_params(preprocessor=preprocessor) - pipeline = make_pipeline(estimator) - estimator.fit(input_data, y, **make_random_state(estimator, False)) - pipeline.fit(input_data, y, **make_random_state(estimator, True)) - - if hasattr(estimator, 'score'): - result = estimator.score(input_data, y) - result_pipe = pipeline.score(input_data, y) - assert_allclose_dense_sparse(result, result_pipe) - - if hasattr(estimator, 'predict'): - result = estimator.predict(input_data) - result_pipe = pipeline.predict(input_data) - assert_allclose_dense_sparse(result, result_pipe) - - if issubclass(estimator.__class__, TransformerMixin): - if hasattr(estimator, 'transform'): - result = estimator.transform(input_data) - result_pipe = pipeline.transform(input_data) + estimator = clone(estimator) + estimator.set_params(preprocessor=preprocessor) + pipeline = make_pipeline(estimator) + estimator.fit(*remove_y_quadruplets(estimator, input_data, y), + **make_random_state(estimator, False)) + pipeline.fit(*remove_y_quadruplets(estimator, input_data, y), + **make_random_state(estimator, True)) + + if hasattr(estimator, 'score'): + result = estimator.score(*remove_y_quadruplets(estimator, + input_data, + y)) + result_pipe = pipeline.score(*remove_y_quadruplets(estimator, + input_data, + y)) assert_allclose_dense_sparse(result, result_pipe) + if hasattr(estimator, 'predict'): + result = estimator.predict(input_data) + result_pipe = pipeline.predict(input_data) + assert_allclose_dense_sparse(result, result_pipe) + + if issubclass(estimator.__class__, TransformerMixin): + if hasattr(estimator, 'transform'): + result = estimator.transform(input_data) + result_pipe = pipeline.transform(input_data) + assert_allclose_dense_sparse(result, result_pipe) -@pytest.mark.parametrize('with_preprocessor',[True, False]) + +@pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) def test_dict_unchanged(estimator, build_dataset, with_preprocessor): @@ -286,7 +348,7 @@ def test_dict_unchanged(estimator, build_dataset, with_preprocessor): estimator.set_params(preprocessor=preprocessor) if hasattr(estimator, "num_dims"): estimator.num_dims = 1 - estimator.fit(input_data, labels) + estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) def check_dict(): assert estimator.__dict__ == dict_before, ( @@ -303,7 +365,7 @@ def check_dict(): check_dict() -@pytest.mark.parametrize('with_preprocessor',[True, False]) +@pytest.mark.parametrize('with_preprocessor', [True, False]) @pytest.mark.parametrize('estimator, build_dataset', metric_learners, ids=ids_metric_learners) def test_dont_overwrite_parameters(estimator, build_dataset, @@ -317,7 +379,7 @@ def test_dont_overwrite_parameters(estimator, build_dataset, estimator.num_dims = 1 dict_before_fit = estimator.__dict__.copy() - estimator.fit(input_data, labels) + estimator.fit(*remove_y_quadruplets(estimator, input_data, labels)) dict_after_fit = estimator.__dict__ public_keys_after_fit = [key for key in dict_after_fit.keys() diff --git a/test/test_utils.py b/test/test_utils.py index f1df4098..cfadfd32 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -100,8 +100,11 @@ def build_quadruplets(with_preprocessor=False): [learner for (learner, _) in quadruplets_learners])) -pairs_learners = [(ITML(), build_pairs), - (MMC(max_iter=2), build_pairs), # max_iter=2 for faster +pairs_learners = [(ITML(max_iter=2), build_pairs), # max_iter=2 to be + # faster, also make tests pass while waiting for #175 to + # be solved + # TODO: remove this comment when #175 is solved + (MMC(max_iter=2), build_pairs), # max_iter=2 to be faster (SDML(use_cov=False, balance_param=1e-5), build_pairs)] ids_pairs_learners = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in @@ -117,7 +120,7 @@ def build_quadruplets(with_preprocessor=False): (MMC_Supervised(max_iter=5), build_classification), (RCA_Supervised(num_chunks=10), build_classification), (SDML_Supervised(use_cov=False, balance_param=1e-5), - build_classification)] + build_classification)] ids_classifiers = list(map(lambda x: x.__class__.__name__, [learner for (learner, _) in classifiers])) @@ -139,6 +142,18 @@ def build_quadruplets(with_preprocessor=False): ids_metric_learners = ids_tuples_learners + ids_supervised_learners +def remove_y_quadruplets(estimator, X, y): + """Quadruplets learners have no y in fit, but to write test for all + estimators, it is convenient to have this function, that will return X and y + if the estimator needs a y to fit on, and just X otherwise.""" + if estimator.__class__.__name__ in [e.__class__.__name__ + for (e, _) in + quadruplets_learners]: + return (X,) + else: + return (X, y) + + def mock_preprocessor(indices): """A preprocessor for testing purposes that returns an all ones 3D array """ @@ -840,7 +855,7 @@ def test_error_message_tuple_size(estimator, _): [[1.9, 5.3], [1., 7.8], [3.2, 1.2]]]) y = [1, 1] with pytest.raises(ValueError) as raised_err: - estimator.fit(invalid_pairs, y) + estimator.fit(*remove_y_quadruplets(estimator, invalid_pairs, y)) expected_msg = ("Tuples of {} element(s) expected{}. Got tuples of 3 " "element(s) instead (shape=(2, 3, 2)):\ninput={}.\n" .format(estimator._tuple_size, make_context(estimator), @@ -925,19 +940,25 @@ def make_random_state(estimator): estimator_with_preprocessor = clone(estimator) set_random_state(estimator_with_preprocessor) estimator_with_preprocessor.set_params(preprocessor=X) - estimator_with_preprocessor.fit(indices_train, y_train, + estimator_with_preprocessor.fit(*remove_y_quadruplets(estimator, + indices_train, + y_train), **make_random_state(estimator)) estimator_without_preprocessor = clone(estimator) set_random_state(estimator_without_preprocessor) estimator_without_preprocessor.set_params(preprocessor=None) - estimator_without_preprocessor.fit(formed_train, y_train, + estimator_without_preprocessor.fit(*remove_y_quadruplets(estimator, + formed_train, + y_train), **make_random_state(estimator)) estimator_with_prep_formed = clone(estimator) set_random_state(estimator_with_prep_formed) estimator_with_prep_formed.set_params(preprocessor=X) - estimator_with_prep_formed.fit(indices_train, y_train, + estimator_with_prep_formed.fit(*remove_y_quadruplets(estimator, + indices_train, + y_train), **make_random_state(estimator)) # test prediction methods