Skip to content

Commit

Permalink
Add triplet margin for distance functions in TripletEvaluator (#2862)
Browse files Browse the repository at this point in the history
* Add triplet margin distance metrics

* Adjust triplet_margin logic to improved TripletEvaluator implementation

* Check if triplet margins are a dictinary

* Use similarity instead of distance; use 'margin' instead of 'triplet_margins' + tests

* Rename main_disistance_function to main_similarity_function

* Avoid backwards incompatibility due to main_distance_function

---------

Co-authored-by: Milos Zivic <[email protected]>
Co-authored-by: Tom Aarsen <[email protected]>
  • Loading branch information
3 people authored Nov 26, 2024
1 parent af39619 commit b055b5d
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 23 deletions.
79 changes: 56 additions & 23 deletions sentence_transformers/evaluation/TripletEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
from contextlib import nullcontext
from typing import TYPE_CHECKING, Literal

import numpy as np
from sklearn.metrics.pairwise import paired_cosine_distances, paired_euclidean_distances, paired_manhattan_distances

from sentence_transformers.evaluation.SentenceEvaluator import SentenceEvaluator
from sentence_transformers.readers import InputExample
from sentence_transformers.similarity_functions import SimilarityFunction
from sentence_transformers.util import (
pairwise_cos_sim,
pairwise_dot_score,
pairwise_euclidean_sim,
pairwise_manhattan_sim,
)

if TYPE_CHECKING:
from sentence_transformers.SentenceTransformer import SentenceTransformer
Expand All @@ -22,7 +25,7 @@
class TripletEvaluator(SentenceEvaluator):
"""
Evaluate a model based on a triplet: (sentence, positive_example, negative_example).
Checks if distance(sentence, positive_example) < distance(sentence, negative_example).
Checks if ``similarity(sentence, positive_example) < similarity(sentence, negative_example) + margin``.
Example:
::
Expand All @@ -47,7 +50,7 @@ class TripletEvaluator(SentenceEvaluator):
results = triplet_evaluator(model)
'''
TripletEvaluator: Evaluating the model on the all-nli-dev dataset:
Accuracy Cosine Distance: 95.60%
Accuracy Cosine Similarity: 95.60%
'''
print(triplet_evaluator.primary_metric)
# => "all_nli_dev_cosine_accuracy"
Expand All @@ -60,13 +63,15 @@ def __init__(
anchors: list[str],
positives: list[str],
negatives: list[str],
main_distance_function: str | SimilarityFunction | None = None,
main_similarity_function: str | SimilarityFunction | None = None,
margin: float | dict[str, float] | None = None,
name: str = "",
batch_size: int = 16,
show_progress_bar: bool = False,
write_csv: bool = True,
truncate_dim: int | None = None,
similarity_fn_names: list[Literal["cosine", "dot", "euclidean", "manhattan"]] | None = None,
main_distance_function: str | SimilarityFunction | None = "deprecated",
):
"""
Initializes a TripletEvaluator object.
Expand All @@ -75,17 +80,22 @@ def __init__(
anchors (List[str]): Sentences to check similarity to. (e.g. a query)
positives (List[str]): List of positive sentences
negatives (List[str]): List of negative sentences
main_distance_function (Union[str, SimilarityFunction], optional):
The distance function to use. If not specified, use cosine similarity,
dot product, Euclidean, and Manhattan. Defaults to None.
main_similarity_function (Union[str, SimilarityFunction], optional):
The similarity function to use. If not specified, use cosine similarity,
dot product, Euclidean, and Manhattan similarity. Defaults to None.
margin (Union[float, Dict[str, float]], optional): Margins for various similarity metrics.
If a float is provided, it will be used as the margin for all similarity metrics.
If a dictionary is provided, the keys should be 'cosine', 'dot', 'manhattan', and 'euclidean'.
The value specifies the minimum margin by which the negative sample should be further from
the anchor than the positive sample. Defaults to None.
name (str): Name for the output. Defaults to "".
batch_size (int): Batch size used to compute embeddings. Defaults to 16.
show_progress_bar (bool): If true, prints a progress bar. Defaults to False.
write_csv (bool): Write results to a CSV file. Defaults to True.
truncate_dim (int, optional): The dimension to truncate sentence embeddings to.
`None` uses the model's current truncation dimension. Defaults to None.
similarity_fn_names (List[str], optional): List of similarity function names to evaluate.
If not specified, evaluate using the ``similarity_fn_name`` .
If not specified, evaluate using the ``model.similarity_fn_name``.
Defaults to None.
"""
super().__init__()
Expand All @@ -98,9 +108,32 @@ def __init__(
assert len(self.anchors) == len(self.positives)
assert len(self.anchors) == len(self.negatives)

self.main_distance_function = SimilarityFunction(main_distance_function) if main_distance_function else None
if main_distance_function != "deprecated" and main_similarity_function is None:
main_similarity_function = main_distance_function
logger.warning(
"The 'main_distance_function' parameter is deprecated. Please use 'main_similarity_function' instead. "
"'main_distance_function' will be removed in a future release."
)

self.main_similarity_function = (
SimilarityFunction(main_similarity_function) if main_similarity_function else None
)
self.similarity_fn_names = similarity_fn_names or []

if margin is None:
self.margin = {"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0}
elif isinstance(margin, (float, int)):
self.margin = {"cosine": margin, "dot": margin, "manhattan": margin, "euclidean": margin}
elif isinstance(margin, dict):
self.margin = {
**{"cosine": 0, "dot": 0, "manhattan": 0, "euclidean": 0},
**margin,
}
else:
raise ValueError(
"`margin` should be a float or a dictionary with keys 'cosine', 'dot', 'manhattan', and 'euclidean'"
)

self.batch_size = batch_size
if show_progress_bar is None:
show_progress_bar = (
Expand Down Expand Up @@ -171,20 +204,20 @@ def __call__(

similarity_functions = {
"cosine": lambda anchors, positives, negatives: (
paired_cosine_distances(anchors, positives),
paired_cosine_distances(anchors, negatives),
pairwise_cos_sim(anchors, positives),
pairwise_cos_sim(anchors, negatives),
),
"dot": lambda anchors, positives, negatives: (
np.sum(anchors * positives, axis=-1),
np.sum(anchors * negatives, axis=-1),
pairwise_dot_score(anchors, positives),
pairwise_dot_score(anchors, negatives),
),
"manhattan": lambda anchors, positives, negatives: (
paired_manhattan_distances(anchors, positives),
paired_manhattan_distances(anchors, negatives),
pairwise_manhattan_sim(anchors, positives),
pairwise_manhattan_sim(anchors, negatives),
),
"euclidean": lambda anchors, positives, negatives: (
paired_euclidean_distances(anchors, positives),
paired_euclidean_distances(anchors, negatives),
pairwise_euclidean_sim(anchors, positives),
pairwise_euclidean_sim(anchors, negatives),
),
}

Expand All @@ -194,9 +227,9 @@ def __call__(
positive_scores, negative_scores = similarity_functions[fn_name](
embeddings_anchors, embeddings_positives, embeddings_negatives
)
accuracy = np.mean(positive_scores < negative_scores)
accuracy = (positive_scores > negative_scores + self.margin[fn_name]).float().mean().item()
metrics[f"{fn_name}_accuracy"] = accuracy
logger.info(f"Accuracy {fn_name.capitalize()} Distance:\t{accuracy:.2%}")
logger.info(f"Accuracy {fn_name.capitalize()} Similarity:\t{accuracy:.2%}")

if output_path is not None and self.write_csv:
csv_path = os.path.join(output_path, self.csv_file)
Expand All @@ -214,13 +247,13 @@ def __call__(
if len(self.similarity_fn_names) > 1:
metrics["max_accuracy"] = max(metrics.values())

if self.main_distance_function:
if self.main_similarity_function:
self.primary_metric = {
SimilarityFunction.COSINE: "cosine_accuracy",
SimilarityFunction.DOT_PRODUCT: "dot_accuracy",
SimilarityFunction.EUCLIDEAN: "euclidean_accuracy",
SimilarityFunction.MANHATTAN: "manhattan_accuracy",
}.get(self.main_distance_function)
}.get(self.main_similarity_function)
else:
if len(self.similarity_fn_names) > 1:
self.primary_metric = "max_accuracy"
Expand Down
36 changes: 36 additions & 0 deletions tests/evaluation/test_triplet_evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
Tests the correct computation of evaluation scores from TripletEvaluator
"""

from __future__ import annotations

from sentence_transformers import SentenceTransformer
from sentence_transformers.evaluation import TripletEvaluator


def test_TripletEvaluator(stsb_bert_tiny_model_reused: SentenceTransformer) -> None:
"""Tests that the TripletEvaluator can be loaded & used"""
model = stsb_bert_tiny_model_reused
anchors = [
"A person on a horse jumps over a broken down airplane.",
"Children smiling and waving at camera",
"A boy is jumping on skateboard in the middle of a red bridge.",
]
positives = [
"A person is outdoors, on a horse.",
"There are children looking at the camera.",
"The boy does a skateboarding trick.",
]
negatives = [
"A person is at a diner, ordering an omelette.",
"The kids are frowning",
"The boy skates down the sidewalk.",
]
evaluator = TripletEvaluator(anchors, positives, negatives, name="all_nli_dev")
metrics = evaluator(model)
assert evaluator.primary_metric == "all_nli_dev_cosine_accuracy"
assert metrics[evaluator.primary_metric] == 1.0

evaluator_with_margin = TripletEvaluator(anchors, positives, negatives, margin=0.7, name="all_nli_dev")
metrics = evaluator_with_margin(model)
assert metrics[evaluator.primary_metric] == 0.0

0 comments on commit b055b5d

Please sign in to comment.