Skip to content

Commit

Permalink
[MRG] Fix test failures due to updated packages: deprecated pytest.wa…
Browse files Browse the repository at this point in the history
…rns(None) syntax + GLasso update in sklearn (#357)

* Fix GLasso import for SDML for newer sklearn versions

* fix import and argument issue

* also fix deprecated pytest.warns(None) syntex

* fix flake8
  • Loading branch information
bellet authored Aug 3, 2024
1 parent 8fb6872 commit dc7e449
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
14 changes: 10 additions & 4 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import numpy as np
from sklearn.base import TransformerMixin
from scipy.linalg import pinvh
from sklearn.covariance import graphical_lasso
try:
from sklearn.covariance._graph_lasso import (
_graphical_lasso as graphical_lasso
)
except ImportError:
from sklearn.covariance import graphical_lasso

from sklearn.exceptions import ConvergenceWarning

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
Expand Down Expand Up @@ -79,9 +85,9 @@ def _fit(self, pairs, y):
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
_, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
Expand Down
9 changes: 5 additions & 4 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import unittest
import re
import pytest
Expand Down Expand Up @@ -734,12 +735,12 @@ def test_raises_no_warning_installed_skggm(self):
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
X, y = make_classification(random_state=42)
with pytest.warns(None) as records:
with warnings.catch_warnings(record=True) as records:
sdml = SDML(prior='covariance')
sdml.fit(pairs, y_pairs)
for record in records:
assert record.category is not ConvergenceWarning
with pytest.warns(None) as records:
with warnings.catch_warnings(record=True) as records:
sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5)
sdml_supervised.fit(X, y)
for record in records:
Expand Down Expand Up @@ -999,7 +1000,7 @@ def test_rank_deficient_returns_warning(self):
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')

with pytest.warns(None) as raised_warnings:
with warnings.catch_warnings(record=True) as raised_warnings:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warnings)

Expand Down Expand Up @@ -1034,7 +1035,7 @@ def test_bad_parameters(self):
'Increase the number or size of the chunks to correct '
'this problem.'
)
with pytest.warns(None) as raised_warning:
with warnings.catch_warnings(record=True) as raised_warning:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warning)

Expand Down
5 changes: 3 additions & 2 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from numpy.core.numeric import array_equal
import warnings
import pytest
import re
import unittest
Expand Down Expand Up @@ -226,15 +227,15 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset):
(X[0][None], X[1][None])]

for u, v in list_test_get_metric_doesnt_raise:
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
metric(u, v)
assert len(record) == 0

# Test that the scalar case works
model.components_ = np.array([3.1])
metric = model.get_metric()
for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]:
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
metric(u, v)
assert len(record) == 0

Expand Down
7 changes: 4 additions & 3 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import warnings
import pytest
from numpy.testing import assert_array_equal
from scipy.spatial.distance import euclidean
Expand Down Expand Up @@ -136,7 +137,7 @@ def test_threshold_different_scores_is_finite(estimator, build_dataset,
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
estimator.fit(input_data, labels)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
estimator.calibrate_threshold(input_data, labels, **kwargs)
assert len(record) == 0

Expand Down Expand Up @@ -383,7 +384,7 @@ def test_calibrate_threshold_valid_parameters(valid_args):
pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20)
pairs_learner = IdentityPairsClassifier()
pairs_learner.fit(pairs, y)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
pairs_learner.calibrate_threshold(pairs, y, **valid_args)
assert len(record) == 0

Expand Down Expand Up @@ -518,7 +519,7 @@ def test_validate_calibration_params_valid_parameters(
# test that no warning message is returned if valid arguments are given to
# _validate_calibration_params for all pairs metric learners, as well as
# a mocking example, and the class itself
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
estimator._validate_calibration_params(**valid_args)
assert len(record) == 0

Expand Down
19 changes: 10 additions & 9 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import pytest
from scipy.linalg import eigh, pinvh
from collections import namedtuple
Expand Down Expand Up @@ -353,7 +354,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
checks that checking the number of tuples (pairs, quadruplets, etc) raises
no warning if there is the right number of points in a tuple.
"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_prep(), type_of_inputs='tuples',
preprocessor=mock_preprocessor, tuple_size=tuple_size)
check_input(tuples_no_prep(), type_of_inputs='tuples', preprocessor=None,
Expand All @@ -378,7 +379,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
[[2.6, 2.3], [3.4, 5.0]]])])
def test_check_tuples_valid_with_preprocessor(tuples):
"""Test that valid inputs when using a preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples, type_of_inputs='tuples',
preprocessor=mock_preprocessor)
assert len(record) == 0
Expand All @@ -399,7 +400,7 @@ def test_check_tuples_valid_with_preprocessor(tuples):
((3, 1), (4, 4), (29, 4)))])
def test_check_tuples_valid_without_preprocessor(tuples):
"""Test that valid inputs when using no preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples, type_of_inputs='tuples', preprocessor=None)
assert len(record) == 0

Expand All @@ -408,12 +409,12 @@ def test_check_tuples_behaviour_auto_dtype():
"""Checks that check_tuples allows by default every type if using a
preprocessor, and numeric types if using no preprocessor"""
tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']]
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_prep, type_of_inputs='tuples',
preprocessor=mock_preprocessor)
assert len(record) == 0

with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_no_prep(), type_of_inputs='tuples') # numeric type
assert len(record) == 0

Expand Down Expand Up @@ -549,7 +550,7 @@ def test_check_classic_invalid_dtype_not_convertible(preprocessor, points):
[2.6, 2.3]])])
def test_check_classic_valid_with_preprocessor(points):
"""Test that valid inputs when using a preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points, type_of_inputs='classic',
preprocessor=mock_preprocessor)
assert len(record) == 0
Expand All @@ -570,7 +571,7 @@ def test_check_classic_valid_with_preprocessor(points):
(3, 1, 4, 4, 29, 4))])
def test_check_classic_valid_without_preprocessor(points):
"""Test that valid inputs when using no preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points, type_of_inputs='classic', preprocessor=None)
assert len(record) == 0

Expand All @@ -585,12 +586,12 @@ def test_check_classic_behaviour_auto_dtype():
"""Checks that check_input (for points) allows by default every type if
using a preprocessor, and numeric types if using no preprocessor"""
points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png']
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points_prep, type_of_inputs='classic',
preprocessor=mock_preprocessor)
assert len(record) == 0

with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points_no_prep(), type_of_inputs='classic') # numeric type
assert len(record) == 0

Expand Down

0 comments on commit dc7e449

Please sign in to comment.