diff --git a/optax/losses/_self_supervised.py b/optax/losses/_self_supervised.py index aa935045..30d5cb2b 100644 --- a/optax/losses/_self_supervised.py +++ b/optax/losses/_self_supervised.py @@ -89,70 +89,69 @@ def ntxent( def _pairwise_distance(x: chex.Array, y: chex.Array, p: int = 2, eps: float = 1e-6) -> chex.Array: - diff = x - y - dist = jnp.sum(jnp.abs(diff) ** p + eps, axis=-1) ** (1.0 / p) - return dist - + diff = x - y + dist = jnp.sum(jnp.abs(diff) ** p + eps, axis=-1) ** (1.0 / p) + return dist + def triplet_margin_loss( - anchor: chex.Array, - positive: chex.Array, - negative: chex.Array, - *, - margin: float = 1.0, - p: int = 2, - eps: float = 1e-6, - swap: bool = False, - reduction: str = 'mean', + anchor: chex.Array, + positive: chex.Array, + negative: chex.Array, + *, + margin: float = 1.0, + p: int = 2, + eps: float = 1e-6, + swap: bool = False, + reduction: str = 'mean', ) -> chex.Array: - """Triplet margin loss function. - - Measures the relative similarity between an anchor point, a positive point, and - a negative point using the distance metric specified by p-norm. The loss encourages - the distance between the anchor and positive points to be smaller than the distance - between the anchor and negative points by at least the margin amount. - - Args: - anchor: The anchor embeddings. Shape: [batch_size, feature_dim]. - positive: The positive embeddings. Shape: [batch_size, feature_dim]. - negative: The negative embeddings. Shape: [batch_size, feature_dim]. - margin: The margin value. Default: 1.0. - p: The norm degree for pairwise distance. Default: 2. - eps: Small epsilon value to avoid numerical issues. Default: 1e-6. - swap: Use the distance swap optimization from "Learning shallow convolutional - feature descriptors with triplet losses" by V. Balntas et al. Default: False. - reduction: Specifies the reduction to apply to the output: - 'none' | 'mean' | 'sum'. Default: 'mean'. - - Returns: - The triplet margin loss value. - If reduction is 'none': tensor of shape [batch_size] - If reduction is 'mean' or 'sum': scalar tensor. - """ - chex.assert_equal_shape([anchor, positive, negative]) - - if not(anchor.ndim == 2 and positive.ndim == 2 and negative.ndim == 2): - raise ValueError("Inputs must be 2D tensors") - - # Calculate distances between pairs - dist_pos = _pairwise_distance(anchor, positive, p, eps) - dist_neg = _pairwise_distance(anchor, negative, p, eps) - - # Implement distance swap if enabled - if swap: - dist_swap = _pairwise_distance(positive, negative) - dist_neg = jnp.minimum(dist_neg, dist_swap) - - # Calculate loss with margin - losses = jnp.maximum(margin + dist_pos - dist_neg, 0.0) - - # Apply reduction - if reduction == 'none': - return losses - elif reduction == 'mean': - return jnp.mean(losses) - elif reduction == 'sum': - return jnp.sum(losses) - else: - raise ValueError(f"Invalid reduction mode: {reduction}") - \ No newline at end of file + """Triplet margin loss function. + + Measures the relative similarity between an anchor point, a positive point, and + a negative point using the distance metric specified by p-norm. The loss encourages + the distance between the anchor and positive points to be smaller than the distance + between the anchor and negative points by at least the margin amount. + + Args: + anchor: The anchor embeddings. Shape: [batch_size, feature_dim]. + positive: The positive embeddings. Shape: [batch_size, feature_dim]. + negative: The negative embeddings. Shape: [batch_size, feature_dim]. + margin: The margin value. Default: 1.0. + p: The norm degree for pairwise distance. Default: 2. + eps: Small epsilon value to avoid numerical issues. Default: 1e-6. + swap: Use the distance swap optimization from "Learning shallow convolutional + feature descriptors with triplet losses" by V. Balntas et al. Default: False. + reduction: Specifies the reduction to apply to the output: + 'none' | 'mean' | 'sum'. Default: 'mean'. + + Returns: + The triplet margin loss value. + If reduction is 'none': tensor of shape [batch_size] + If reduction is 'mean' or 'sum': scalar tensor. + """ + chex.assert_equal_shape([anchor, positive, negative]) + + if not(anchor.ndim == 2 and positive.ndim == 2 and negative.ndim == 2): + raise ValueError("Inputs must be 2D tensors") + + # Calculate distances between pairs + dist_pos = _pairwise_distance(anchor, positive, p, eps) + dist_neg = _pairwise_distance(anchor, negative, p, eps) + + # Implement distance swap if enabled + if swap: + dist_swap = _pairwise_distance(positive, negative) + dist_neg = jnp.minimum(dist_neg, dist_swap) + + # Calculate loss with margin + losses = jnp.maximum(margin + dist_pos - dist_neg, 0.0) + + # Apply reduction + if reduction == 'none': + return losses + elif reduction == 'mean': + return jnp.mean(losses) + elif reduction == 'sum': + return jnp.sum(losses) + else: + raise ValueError(f"Invalid reduction mode: {reduction}")