-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add triplet margin for distance functions in TripletEvaluator (#2862)
* Add triplet margin distance metrics * Adjust triplet_margin logic to improved TripletEvaluator implementation * Check if triplet margins are a dictinary * Use similarity instead of distance; use 'margin' instead of 'triplet_margins' + tests * Rename main_disistance_function to main_similarity_function * Avoid backwards incompatibility due to main_distance_function --------- Co-authored-by: Milos Zivic <[email protected]> Co-authored-by: Tom Aarsen <[email protected]>
- Loading branch information
1 parent
af39619
commit b055b5d
Showing
2 changed files
with
92 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
Tests the correct computation of evaluation scores from TripletEvaluator | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
from sentence_transformers import SentenceTransformer | ||
from sentence_transformers.evaluation import TripletEvaluator | ||
|
||
|
||
def test_TripletEvaluator(stsb_bert_tiny_model_reused: SentenceTransformer) -> None: | ||
"""Tests that the TripletEvaluator can be loaded & used""" | ||
model = stsb_bert_tiny_model_reused | ||
anchors = [ | ||
"A person on a horse jumps over a broken down airplane.", | ||
"Children smiling and waving at camera", | ||
"A boy is jumping on skateboard in the middle of a red bridge.", | ||
] | ||
positives = [ | ||
"A person is outdoors, on a horse.", | ||
"There are children looking at the camera.", | ||
"The boy does a skateboarding trick.", | ||
] | ||
negatives = [ | ||
"A person is at a diner, ordering an omelette.", | ||
"The kids are frowning", | ||
"The boy skates down the sidewalk.", | ||
] | ||
evaluator = TripletEvaluator(anchors, positives, negatives, name="all_nli_dev") | ||
metrics = evaluator(model) | ||
assert evaluator.primary_metric == "all_nli_dev_cosine_accuracy" | ||
assert metrics[evaluator.primary_metric] == 1.0 | ||
|
||
evaluator_with_margin = TripletEvaluator(anchors, positives, negatives, margin=0.7, name="all_nli_dev") | ||
metrics = evaluator_with_margin(model) | ||
assert metrics[evaluator.primary_metric] == 0.0 |