diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index c76de99b..c4c427b9 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -6,7 +6,13 @@ import numpy as np from sklearn.base import TransformerMixin from scipy.linalg import pinvh -from sklearn.covariance import graphical_lasso +try: + from sklearn.covariance._graph_lasso import ( + _graphical_lasso as graphical_lasso + ) +except ImportError: + from sklearn.covariance import graphical_lasso + from sklearn.exceptions import ConvergenceWarning from .base_metric import MahalanobisMixin, _PairsClassifierMixin @@ -79,9 +85,9 @@ def _fit(self, pairs, y): msg=self.verbose, Theta0=theta0, Sigma0=sigma0) else: - _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, - verbose=self.verbose, - cov_init=sigma0) + _, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param, + verbose=self.verbose, + cov_init=sigma0) raised_error = None w_mahalanobis, _ = np.linalg.eigh(M) not_spd = any(w_mahalanobis < 0.) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index f109a667..d457b52d 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1,3 +1,4 @@ +import warnings import unittest import re import pytest @@ -734,12 +735,12 @@ def test_raises_no_warning_installed_skggm(self): pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y_pairs = [1, -1] X, y = make_classification(random_state=42) - with pytest.warns(None) as records: + with warnings.catch_warnings(record=True) as records: sdml = SDML(prior='covariance') sdml.fit(pairs, y_pairs) for record in records: assert record.category is not ConvergenceWarning - with pytest.warns(None) as records: + with warnings.catch_warnings(record=True) as records: sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5) sdml_supervised.fit(X, y) for record in records: @@ -999,7 +1000,7 @@ def test_rank_deficient_returns_warning(self): 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') - with pytest.warns(None) as raised_warnings: + with warnings.catch_warnings(record=True) as raised_warnings: rca.fit(X, y) assert any(str(w.message) == msg for w in raised_warnings) @@ -1034,7 +1035,7 @@ def test_bad_parameters(self): 'Increase the number or size of the chunks to correct ' 'this problem.' ) - with pytest.warns(None) as raised_warning: + with warnings.catch_warnings(record=True) as raised_warning: rca.fit(X, y) assert any(str(w.message) == msg for w in raised_warning) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index fa641526..b1e71020 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -1,4 +1,5 @@ from numpy.core.numeric import array_equal +import warnings import pytest import re import unittest @@ -226,7 +227,7 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): (X[0][None], X[1][None])] for u, v in list_test_get_metric_doesnt_raise: - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: metric(u, v) assert len(record) == 0 @@ -234,7 +235,7 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): model.components_ = np.array([3.1]) metric = model.get_metric() for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]: - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: metric(u, v) assert len(record) == 0 diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 6a725f23..bfedefea 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -1,5 +1,6 @@ from functools import partial +import warnings import pytest from numpy.testing import assert_array_equal from scipy.spatial.distance import euclidean @@ -136,7 +137,7 @@ def test_threshold_different_scores_is_finite(estimator, build_dataset, estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) estimator.fit(input_data, labels) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: estimator.calibrate_threshold(input_data, labels, **kwargs) assert len(record) == 0 @@ -383,7 +384,7 @@ def test_calibrate_threshold_valid_parameters(valid_args): pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) pairs_learner = IdentityPairsClassifier() pairs_learner.fit(pairs, y) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: pairs_learner.calibrate_threshold(pairs, y, **valid_args) assert len(record) == 0 @@ -518,7 +519,7 @@ def test_validate_calibration_params_valid_parameters( # test that no warning message is returned if valid arguments are given to # _validate_calibration_params for all pairs metric learners, as well as # a mocking example, and the class itself - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: estimator._validate_calibration_params(**valid_args) assert len(record) == 0 diff --git a/test/test_utils.py b/test/test_utils.py index 43d67111..c0383792 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import warnings import pytest from scipy.linalg import eigh, pinvh from collections import namedtuple @@ -353,7 +354,7 @@ def test_check_tuples_valid_tuple_size(tuple_size): checks that checking the number of tuples (pairs, quadruplets, etc) raises no warning if there is the right number of points in a tuple. """ - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_prep(), type_of_inputs='tuples', preprocessor=mock_preprocessor, tuple_size=tuple_size) check_input(tuples_no_prep(), type_of_inputs='tuples', preprocessor=None, @@ -378,7 +379,7 @@ def test_check_tuples_valid_tuple_size(tuple_size): [[2.6, 2.3], [3.4, 5.0]]])]) def test_check_tuples_valid_with_preprocessor(tuples): """Test that valid inputs when using a preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples, type_of_inputs='tuples', preprocessor=mock_preprocessor) assert len(record) == 0 @@ -399,7 +400,7 @@ def test_check_tuples_valid_with_preprocessor(tuples): ((3, 1), (4, 4), (29, 4)))]) def test_check_tuples_valid_without_preprocessor(tuples): """Test that valid inputs when using no preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples, type_of_inputs='tuples', preprocessor=None) assert len(record) == 0 @@ -408,12 +409,12 @@ def test_check_tuples_behaviour_auto_dtype(): """Checks that check_tuples allows by default every type if using a preprocessor, and numeric types if using no preprocessor""" tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']] - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_prep, type_of_inputs='tuples', preprocessor=mock_preprocessor) assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_no_prep(), type_of_inputs='tuples') # numeric type assert len(record) == 0 @@ -549,7 +550,7 @@ def test_check_classic_invalid_dtype_not_convertible(preprocessor, points): [2.6, 2.3]])]) def test_check_classic_valid_with_preprocessor(points): """Test that valid inputs when using a preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points, type_of_inputs='classic', preprocessor=mock_preprocessor) assert len(record) == 0 @@ -570,7 +571,7 @@ def test_check_classic_valid_with_preprocessor(points): (3, 1, 4, 4, 29, 4))]) def test_check_classic_valid_without_preprocessor(points): """Test that valid inputs when using no preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points, type_of_inputs='classic', preprocessor=None) assert len(record) == 0 @@ -585,12 +586,12 @@ def test_check_classic_behaviour_auto_dtype(): """Checks that check_input (for points) allows by default every type if using a preprocessor, and numeric types if using no preprocessor""" points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png'] - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points_prep, type_of_inputs='classic', preprocessor=mock_preprocessor) assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points_no_prep(), type_of_inputs='classic') # numeric type assert len(record) == 0