Skip to content

Commit

Permalink
indentation changes
Browse files Browse the repository at this point in the history
  • Loading branch information
cvnad1 committed Oct 26, 2024
1 parent 34a8610 commit 1778d05
Showing 1 changed file with 63 additions and 64 deletions.
127 changes: 63 additions & 64 deletions optax/losses/_self_supervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,70 +89,69 @@ def ntxent(


def _pairwise_distance(x: chex.Array, y: chex.Array, p: int = 2, eps: float = 1e-6) -> chex.Array:
diff = x - y
dist = jnp.sum(jnp.abs(diff) ** p + eps, axis=-1) ** (1.0 / p)
return dist
diff = x - y
dist = jnp.sum(jnp.abs(diff) ** p + eps, axis=-1) ** (1.0 / p)
return dist


def triplet_margin_loss(
anchor: chex.Array,
positive: chex.Array,
negative: chex.Array,
*,
margin: float = 1.0,
p: int = 2,
eps: float = 1e-6,
swap: bool = False,
reduction: str = 'mean',
anchor: chex.Array,
positive: chex.Array,
negative: chex.Array,
*,
margin: float = 1.0,
p: int = 2,
eps: float = 1e-6,
swap: bool = False,
reduction: str = 'mean',
) -> chex.Array:
"""Triplet margin loss function.
Measures the relative similarity between an anchor point, a positive point, and
a negative point using the distance metric specified by p-norm. The loss encourages
the distance between the anchor and positive points to be smaller than the distance
between the anchor and negative points by at least the margin amount.
Args:
anchor: The anchor embeddings. Shape: [batch_size, feature_dim].
positive: The positive embeddings. Shape: [batch_size, feature_dim].
negative: The negative embeddings. Shape: [batch_size, feature_dim].
margin: The margin value. Default: 1.0.
p: The norm degree for pairwise distance. Default: 2.
eps: Small epsilon value to avoid numerical issues. Default: 1e-6.
swap: Use the distance swap optimization from "Learning shallow convolutional
feature descriptors with triplet losses" by V. Balntas et al. Default: False.
reduction: Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. Default: 'mean'.
Returns:
The triplet margin loss value.
If reduction is 'none': tensor of shape [batch_size]
If reduction is 'mean' or 'sum': scalar tensor.
"""
chex.assert_equal_shape([anchor, positive, negative])

if not(anchor.ndim == 2 and positive.ndim == 2 and negative.ndim == 2):
raise ValueError("Inputs must be 2D tensors")

# Calculate distances between pairs
dist_pos = _pairwise_distance(anchor, positive, p, eps)
dist_neg = _pairwise_distance(anchor, negative, p, eps)

# Implement distance swap if enabled
if swap:
dist_swap = _pairwise_distance(positive, negative)
dist_neg = jnp.minimum(dist_neg, dist_swap)

# Calculate loss with margin
losses = jnp.maximum(margin + dist_pos - dist_neg, 0.0)

# Apply reduction
if reduction == 'none':
return losses
elif reduction == 'mean':
return jnp.mean(losses)
elif reduction == 'sum':
return jnp.sum(losses)
else:
raise ValueError(f"Invalid reduction mode: {reduction}")

"""Triplet margin loss function.
Measures the relative similarity between an anchor point, a positive point, and
a negative point using the distance metric specified by p-norm. The loss encourages
the distance between the anchor and positive points to be smaller than the distance
between the anchor and negative points by at least the margin amount.
Args:
anchor: The anchor embeddings. Shape: [batch_size, feature_dim].
positive: The positive embeddings. Shape: [batch_size, feature_dim].
negative: The negative embeddings. Shape: [batch_size, feature_dim].
margin: The margin value. Default: 1.0.
p: The norm degree for pairwise distance. Default: 2.
eps: Small epsilon value to avoid numerical issues. Default: 1e-6.
swap: Use the distance swap optimization from "Learning shallow convolutional
feature descriptors with triplet losses" by V. Balntas et al. Default: False.
reduction: Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum'. Default: 'mean'.
Returns:
The triplet margin loss value.
If reduction is 'none': tensor of shape [batch_size]
If reduction is 'mean' or 'sum': scalar tensor.
"""
chex.assert_equal_shape([anchor, positive, negative])

if not(anchor.ndim == 2 and positive.ndim == 2 and negative.ndim == 2):
raise ValueError("Inputs must be 2D tensors")

# Calculate distances between pairs
dist_pos = _pairwise_distance(anchor, positive, p, eps)
dist_neg = _pairwise_distance(anchor, negative, p, eps)

# Implement distance swap if enabled
if swap:
dist_swap = _pairwise_distance(positive, negative)
dist_neg = jnp.minimum(dist_neg, dist_swap)

# Calculate loss with margin
losses = jnp.maximum(margin + dist_pos - dist_neg, 0.0)

# Apply reduction
if reduction == 'none':
return losses
elif reduction == 'mean':
return jnp.mean(losses)
elif reduction == 'sum':
return jnp.sum(losses)
else:
raise ValueError(f"Invalid reduction mode: {reduction}")

0 comments on commit 1778d05

Please sign in to comment.