Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Fix test failures due to updated packages: deprecated pytest.warns(None) syntax + GLasso update in sklearn #357

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions metric_learn/sdml.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
import numpy as np
from sklearn.base import TransformerMixin
from scipy.linalg import pinvh
from sklearn.covariance import graphical_lasso
try:
from sklearn.covariance._graph_lasso import (
_graphical_lasso as graphical_lasso
)
except ImportError:
from sklearn.covariance import graphical_lasso

from sklearn.exceptions import ConvergenceWarning

from .base_metric import MahalanobisMixin, _PairsClassifierMixin
Expand Down Expand Up @@ -79,9 +85,9 @@ def _fit(self, pairs, y):
msg=self.verbose,
Theta0=theta0, Sigma0=sigma0)
else:
_, M = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
_, M, *_ = graphical_lasso(emp_cov, alpha=self.sparsity_param,
verbose=self.verbose,
cov_init=sigma0)
raised_error = None
w_mahalanobis, _ = np.linalg.eigh(M)
not_spd = any(w_mahalanobis < 0.)
Expand Down
9 changes: 5 additions & 4 deletions test/metric_learn_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import unittest
import re
import pytest
Expand Down Expand Up @@ -734,12 +735,12 @@ def test_raises_no_warning_installed_skggm(self):
pairs = np.array([[[-10., 0.], [10., 0.]], [[0., -55.], [0., -60]]])
y_pairs = [1, -1]
X, y = make_classification(random_state=42)
with pytest.warns(None) as records:
with warnings.catch_warnings(record=True) as records:
sdml = SDML(prior='covariance')
sdml.fit(pairs, y_pairs)
for record in records:
assert record.category is not ConvergenceWarning
with pytest.warns(None) as records:
with warnings.catch_warnings(record=True) as records:
sdml_supervised = SDML_Supervised(prior='identity', balance_param=1e-5)
sdml_supervised.fit(X, y)
for record in records:
Expand Down Expand Up @@ -999,7 +1000,7 @@ def test_rank_deficient_returns_warning(self):
'for instance using `sklearn.decomposition.PCA` as a '
'preprocessing step.')

with pytest.warns(None) as raised_warnings:
with warnings.catch_warnings(record=True) as raised_warnings:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warnings)

Expand Down Expand Up @@ -1034,7 +1035,7 @@ def test_bad_parameters(self):
'Increase the number or size of the chunks to correct '
'this problem.'
)
with pytest.warns(None) as raised_warning:
with warnings.catch_warnings(record=True) as raised_warning:
rca.fit(X, y)
assert any(str(w.message) == msg for w in raised_warning)

Expand Down
5 changes: 3 additions & 2 deletions test/test_base_metric.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from numpy.core.numeric import array_equal
import warnings
import pytest
import re
import unittest
Expand Down Expand Up @@ -226,15 +227,15 @@ def test_get_metric_works_does_not_raise(estimator, build_dataset):
(X[0][None], X[1][None])]

for u, v in list_test_get_metric_doesnt_raise:
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
metric(u, v)
assert len(record) == 0

# Test that the scalar case works
model.components_ = np.array([3.1])
metric = model.get_metric()
for u, v in [(5, 6.7), ([5], [6.7]), ([[5]], [[6.7]])]:
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
metric(u, v)
assert len(record) == 0

Expand Down
7 changes: 4 additions & 3 deletions test/test_pairs_classifiers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from functools import partial

import warnings
import pytest
from numpy.testing import assert_array_equal
from scipy.spatial.distance import euclidean
Expand Down Expand Up @@ -136,7 +137,7 @@ def test_threshold_different_scores_is_finite(estimator, build_dataset,
estimator.set_params(preprocessor=preprocessor)
set_random_state(estimator)
estimator.fit(input_data, labels)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
estimator.calibrate_threshold(input_data, labels, **kwargs)
assert len(record) == 0

Expand Down Expand Up @@ -383,7 +384,7 @@ def test_calibrate_threshold_valid_parameters(valid_args):
pairs, y = rng.randn(20, 2, 5), rng.choice([-1, 1], size=20)
pairs_learner = IdentityPairsClassifier()
pairs_learner.fit(pairs, y)
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
pairs_learner.calibrate_threshold(pairs, y, **valid_args)
assert len(record) == 0

Expand Down Expand Up @@ -518,7 +519,7 @@ def test_validate_calibration_params_valid_parameters(
# test that no warning message is returned if valid arguments are given to
# _validate_calibration_params for all pairs metric learners, as well as
# a mocking example, and the class itself
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
estimator._validate_calibration_params(**valid_args)
assert len(record) == 0

Expand Down
19 changes: 10 additions & 9 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
import pytest
from scipy.linalg import eigh, pinvh
from collections import namedtuple
Expand Down Expand Up @@ -353,7 +354,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
checks that checking the number of tuples (pairs, quadruplets, etc) raises
no warning if there is the right number of points in a tuple.
"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_prep(), type_of_inputs='tuples',
preprocessor=mock_preprocessor, tuple_size=tuple_size)
check_input(tuples_no_prep(), type_of_inputs='tuples', preprocessor=None,
Expand All @@ -378,7 +379,7 @@ def test_check_tuples_valid_tuple_size(tuple_size):
[[2.6, 2.3], [3.4, 5.0]]])])
def test_check_tuples_valid_with_preprocessor(tuples):
"""Test that valid inputs when using a preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples, type_of_inputs='tuples',
preprocessor=mock_preprocessor)
assert len(record) == 0
Expand All @@ -399,7 +400,7 @@ def test_check_tuples_valid_with_preprocessor(tuples):
((3, 1), (4, 4), (29, 4)))])
def test_check_tuples_valid_without_preprocessor(tuples):
"""Test that valid inputs when using no preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples, type_of_inputs='tuples', preprocessor=None)
assert len(record) == 0

Expand All @@ -408,12 +409,12 @@ def test_check_tuples_behaviour_auto_dtype():
"""Checks that check_tuples allows by default every type if using a
preprocessor, and numeric types if using no preprocessor"""
tuples_prep = [['img1.png', 'img2.png'], ['img3.png', 'img5.png']]
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_prep, type_of_inputs='tuples',
preprocessor=mock_preprocessor)
assert len(record) == 0

with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(tuples_no_prep(), type_of_inputs='tuples') # numeric type
assert len(record) == 0

Expand Down Expand Up @@ -549,7 +550,7 @@ def test_check_classic_invalid_dtype_not_convertible(preprocessor, points):
[2.6, 2.3]])])
def test_check_classic_valid_with_preprocessor(points):
"""Test that valid inputs when using a preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points, type_of_inputs='classic',
preprocessor=mock_preprocessor)
assert len(record) == 0
Expand All @@ -570,7 +571,7 @@ def test_check_classic_valid_with_preprocessor(points):
(3, 1, 4, 4, 29, 4))])
def test_check_classic_valid_without_preprocessor(points):
"""Test that valid inputs when using no preprocessor raises no warning"""
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points, type_of_inputs='classic', preprocessor=None)
assert len(record) == 0

Expand All @@ -585,12 +586,12 @@ def test_check_classic_behaviour_auto_dtype():
"""Checks that check_input (for points) allows by default every type if
using a preprocessor, and numeric types if using no preprocessor"""
points_prep = ['img1.png', 'img2.png', 'img3.png', 'img5.png']
with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points_prep, type_of_inputs='classic',
preprocessor=mock_preprocessor)
assert len(record) == 0

with pytest.warns(None) as record:
with warnings.catch_warnings(record=True) as record:
check_input(points_no_prep(), type_of_inputs='classic') # numeric type
assert len(record) == 0

Expand Down
Loading