-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG+1] Threshold for pairs learners #168
Changes from 1 commit
676ab86
cc1c3e6
f95c456
9ffe8f7
3354fb1
12cb5f1
dd8113e
1c8cd29
d12729a
dc9e21d
402729f
aaac3de
e5b1e47
a0cb3ca
8d5fc50
0f14b25
a6458a2
fada5cc
32a4889
5cf71b9
c2bc693
e96ee00
3ed3430
69c6945
bc39392
facc546
f0ca65e
a6ec283
49fbbd7
960b174
c91acf7
a742186
9ec1ead
986fed3
3f5d6d1
7b5e4dd
a3ec02c
ccc66eb
6dff15b
719d018
551d161
594c485
14713c6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -374,8 +374,10 @@ def score(self, pairs, y): | |
def set_default_threshold(self, pairs, y): | ||
"""Returns a threshold that is the mean between the similar metrics | ||
mean, and the dissimilar metrics mean""" | ||
similar_threshold = np.mean(self.decision_function(pairs[y==1])) | ||
dissimilar_threshold = np.mean(self.decision_function(pairs[y==1])) | ||
similar_threshold = np.mean(self.decision_function( | ||
pairs[(y == 1).ravel()])) | ||
dissimilar_threshold = np.mean(self.decision_function( | ||
pairs[(y == -1).ravel()])) | ||
self.threshold_ = np.mean([similar_threshold, dissimilar_threshold]) | ||
|
||
|
||
|
@@ -458,9 +460,14 @@ def score(self, quadruplets, y=None): | |
score : float | ||
The quadruplets score. | ||
""" | ||
quadruplets = check_input(quadruplets, y, type_of_inputs='tuples', | ||
preprocessor=self.preprocessor_, | ||
estimator=self, tuple_size=self._tuple_size) | ||
checked_input = check_input(quadruplets, y, type_of_inputs='tuples', | ||
preprocessor=self.preprocessor_, | ||
estimator=self, tuple_size=self._tuple_size) | ||
# checked_input will be of the form `(checked_quadruplets, checked_y)` if | ||
# `y` is not None, or just `checked_quadruplets` if `y` is None | ||
quadruplets = checked_input if y is None else checked_input[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find this a bit ugly, but I couldn't find another way to do it. Maybe refactor |
||
if y is None: | ||
y = np.ones(quadruplets.shape[0]) | ||
else: | ||
y = checked_input[1] | ||
return accuracy_score(y, self.predict(quadruplets)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,35 +25,11 @@ def test_predict_only_one_or_minus_one(estimator, build_dataset, | |
assert np.isin(predictions, [-1, 1]).all() | ||
|
||
|
||
@pytest.mark.parametrize('with_preprocessor', [True, False]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I removed this test because it makes no sense for quadruplets since the threshold should be always 0 |
||
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, | ||
ids=ids_quadruplets_learners) | ||
def test_predict_monotonous(estimator, build_dataset, | ||
with_preprocessor): | ||
"""Test that there is a threshold distance separating points labeled as | ||
similar and points labeled as dissimilar """ | ||
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) | ||
estimator = clone(estimator) | ||
estimator.set_params(preprocessor=preprocessor) | ||
set_random_state(estimator) | ||
(quadruplets_train, | ||
quadruplets_test, y_train, y_test) = train_test_split(input_data, labels) | ||
estimator.fit(quadruplets_train, y_train) | ||
distances = estimator.score_quadruplets(quadruplets_test) | ||
predictions = estimator.predict(quadruplets_test) | ||
min_dissimilar = np.min(distances[predictions == -1]) | ||
max_similar = np.max(distances[predictions == 1]) | ||
assert max_similar <= min_dissimilar | ||
separator = np.mean([min_dissimilar, max_similar]) | ||
assert (predictions[distances > separator] == -1).all() | ||
assert (predictions[distances < separator] == 1).all() | ||
|
||
|
||
@pytest.mark.parametrize('with_preprocessor', [True, False]) | ||
@pytest.mark.parametrize('estimator, build_dataset', quadruplets_learners, | ||
ids=ids_quadruplets_learners) | ||
def test_raise_not_fitted_error_if_not_fitted(estimator, build_dataset, | ||
with_preprocessor): | ||
with_preprocessor): | ||
"""Test that a NotFittedError is raised if someone tries to predict and | ||
the metric learner has not been fitted.""" | ||
input_data, labels, preprocessor, _ = build_dataset(with_preprocessor) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to remove if indeed we choose the accuracy-calibrated threshold as default