From d9ea833b462e1300d5f377c70eefa816e35a6322 Mon Sep 17 00:00:00 2001 From: Aurelien Bellet Date: Fri, 2 Aug 2024 09:50:44 +0200 Subject: [PATCH 1/4] Fix GLasso import for SDML for newer sklearn versions --- metric_learn/sdml.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index c76de99b..45fefcf9 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -6,7 +6,11 @@ import numpy as np from sklearn.base import TransformerMixin from scipy.linalg import pinvh -from sklearn.covariance import graphical_lasso +try: + from sklearn.covariance import _graphical_lasso as graphical_lasso +except ImportError: + from sklearn.covariance import graphical_lasso + from sklearn.exceptions import ConvergenceWarning from .base_metric import MahalanobisMixin, _PairsClassifierMixin From af94810924a47ea86cd3b010476c2f42360c5b7a Mon Sep 17 00:00:00 2001 From: Aurelien Bellet Date: Fri, 2 Aug 2024 11:42:19 +0200 Subject: [PATCH 2/4] fix import and argument issue --- metric_learn/sdml.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 45fefcf9..83663e3d 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -7,7 +7,9 @@ from sklearn.base import TransformerMixin from scipy.linalg import pinvh try: - from sklearn.covariance import _graphical_lasso as graphical_lasso + from sklearn.covariance._graph_lasso import ( + _graphical_lasso as graphical_lasso + ) except ImportError: from sklearn.covariance import graphical_lasso @@ -83,7 +85,7 @@ def _fit(self, pairs, y): msg=self.verbose, Theta0=theta0, Sigma0=sigma0) else: - _, M = graphical_lasso(emp_cov, alpha=self.sparsity_param, + _, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param, verbose=self.verbose, cov_init=sigma0) raised_error = None From 60e74be186d86e8b308b473c3ca073a56d38abe9 Mon Sep 17 00:00:00 2001 From: Aurelien Bellet Date: Fri, 2 Aug 2024 12:20:35 +0200 Subject: [PATCH 3/4] also fix deprecated pytest.warns(None) syntex --- test/metric_learn_test.py | 9 +++++---- test/test_base_metric.py | 5 +++-- test/test_pairs_classifiers.py | 7 ++++--- test/test_utils.py | 19 ++++++++++--------- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/test/metric_learn_test.py b/test/metric_learn_test.py index f109a667..d457b52d 100644 --- a/test/metric_learn_test.py +++ b/test/metric_learn_test.py @@ -1,3 +1,4 @@ +import warnings import unittest import re import pytest @@ -734,12 +735,12 @@ def test_raises_no_warning_installed_skggm(self): pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]]) y_pairs = [1, -1] X, y = make_classification(random_state=42) - with pytest.warns(None) as records: + with warnings.catch_warnings(record=True) as records: sdml = SDML(prior='covariance') sdml.fit(pairs, y_pairs) for record in records: assert record.category is not ConvergenceWarning - with pytest.warns(None) as records: + with warnings.catch_warnings(record=True) as records: sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5) sdml_supervised.fit(X, y) for record in records: @@ -999,7 +1000,7 @@ def test_rank_deficient_returns_warning(self): 'for instance using `sklearn.decomposition.PCA` as a ' 'preprocessing step.') - with pytest.warns(None) as raised_warnings: + with warnings.catch_warnings(record=True) as raised_warnings: rca.fit(X, y) assert any(str(w.message) == msg for w in raised_warnings) @@ -1034,7 +1035,7 @@ def test_bad_parameters(self): 'Increase the number or size of the chunks to correct ' 'this problem.' ) - with pytest.warns(None) as raised_warning: + with warnings.catch_warnings(record=True) as raised_warning: rca.fit(X, y) assert any(str(w.message) == msg for w in raised_warning) diff --git a/test/test_base_metric.py b/test/test_base_metric.py index fa641526..b1e71020 100644 --- a/test/test_base_metric.py +++ b/test/test_base_metric.py @@ -1,4 +1,5 @@ from numpy.core.numeric import array_equal +import warnings import pytest import re import unittest @@ -226,7 +227,7 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): (X[0][None], X[1][None])] for u, v in list_test_get_metric_doesnt_raise: - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: metric(u, v) assert len(record) == 0 @@ -234,7 +235,7 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset): model.components_ = np.array([3.1]) metric = model.get_metric() for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]: - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: metric(u, v) assert len(record) == 0 diff --git a/test/test_pairs_classifiers.py b/test/test_pairs_classifiers.py index 6a725f23..bfedefea 100644 --- a/test/test_pairs_classifiers.py +++ b/test/test_pairs_classifiers.py @@ -1,5 +1,6 @@ from functools import partial +import warnings import pytest from numpy.testing import assert_array_equal from scipy.spatial.distance import euclidean @@ -136,7 +137,7 @@ def test_threshold_different_scores_is_finite(estimator, build_dataset, estimator.set_params(preprocessor=preprocessor) set_random_state(estimator) estimator.fit(input_data, labels) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: estimator.calibrate_threshold(input_data, labels, **kwargs) assert len(record) == 0 @@ -383,7 +384,7 @@ def test_calibrate_threshold_valid_parameters(valid_args): pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20) pairs_learner = IdentityPairsClassifier() pairs_learner.fit(pairs, y) - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: pairs_learner.calibrate_threshold(pairs, y, **valid_args) assert len(record) == 0 @@ -518,7 +519,7 @@ def test_validate_calibration_params_valid_parameters( # test that no warning message is returned if valid arguments are given to # _validate_calibration_params for all pairs metric learners, as well as # a mocking example, and the class itself - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: estimator._validate_calibration_params(**valid_args) assert len(record) == 0 diff --git a/test/test_utils.py b/test/test_utils.py index 43d67111..c0383792 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,3 +1,4 @@ +import warnings import pytest from scipy.linalg import eigh, pinvh from collections import namedtuple @@ -353,7 +354,7 @@ def test_check_tuples_valid_tuple_size(tuple_size): checks that checking the number of tuples (pairs, quadruplets, etc) raises no warning if there is the right number of points in a tuple. """ - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_prep(), type_of_inputs='tuples', preprocessor=mock_preprocessor, tuple_size=tuple_size) check_input(tuples_no_prep(), type_of_inputs='tuples', preprocessor=None, @@ -378,7 +379,7 @@ def test_check_tuples_valid_tuple_size(tuple_size): [[2.6, 2.3], [3.4, 5.0]]])]) def test_check_tuples_valid_with_preprocessor(tuples): """Test that valid inputs when using a preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples, type_of_inputs='tuples', preprocessor=mock_preprocessor) assert len(record) == 0 @@ -399,7 +400,7 @@ def test_check_tuples_valid_with_preprocessor(tuples): ((3, 1), (4, 4), (29, 4)))]) def test_check_tuples_valid_without_preprocessor(tuples): """Test that valid inputs when using no preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples, type_of_inputs='tuples', preprocessor=None) assert len(record) == 0 @@ -408,12 +409,12 @@ def test_check_tuples_behaviour_auto_dtype(): """Checks that check_tuples allows by default every type if using a preprocessor, and numeric types if using no preprocessor""" tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']] - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_prep, type_of_inputs='tuples', preprocessor=mock_preprocessor) assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(tuples_no_prep(), type_of_inputs='tuples') # numeric type assert len(record) == 0 @@ -549,7 +550,7 @@ def test_check_classic_invalid_dtype_not_convertible(preprocessor, points): [2.6, 2.3]])]) def test_check_classic_valid_with_preprocessor(points): """Test that valid inputs when using a preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points, type_of_inputs='classic', preprocessor=mock_preprocessor) assert len(record) == 0 @@ -570,7 +571,7 @@ def test_check_classic_valid_with_preprocessor(points): (3, 1, 4, 4, 29, 4))]) def test_check_classic_valid_without_preprocessor(points): """Test that valid inputs when using no preprocessor raises no warning""" - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points, type_of_inputs='classic', preprocessor=None) assert len(record) == 0 @@ -585,12 +586,12 @@ def test_check_classic_behaviour_auto_dtype(): """Checks that check_input (for points) allows by default every type if using a preprocessor, and numeric types if using no preprocessor""" points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png'] - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points_prep, type_of_inputs='classic', preprocessor=mock_preprocessor) assert len(record) == 0 - with pytest.warns(None) as record: + with warnings.catch_warnings(record=True) as record: check_input(points_no_prep(), type_of_inputs='classic') # numeric type assert len(record) == 0 From cda9e6e183d73bf7b57414f88b3f3581bf9e3ee8 Mon Sep 17 00:00:00 2001 From: Aurelien Bellet Date: Fri, 2 Aug 2024 13:02:21 +0200 Subject: [PATCH 4/4] fix flake8 --- metric_learn/sdml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/metric_learn/sdml.py b/metric_learn/sdml.py index 83663e3d..c4c427b9 100644 --- a/metric_learn/sdml.py +++ b/metric_learn/sdml.py @@ -86,8 +86,8 @@ def _fit(self, pairs, y): Theta0=theta0, Sigma0=sigma0) else: _, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param, - verbose=self.verbose, - cov_init=sigma0) + verbose=self.verbose, + cov_init=sigma0) raised_error = None w_mahalanobis, _ = np.linalg.eigh(M) not_spd = any(w_mahalanobis < 0.)