-
Notifications
You must be signed in to change notification settings - Fork 195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Triplet Margin Loss] Issue 1118 #1120
base: main
Are you sure you want to change the base?
Changes from all commits
e9811d0
c19ceb3
c8f6937
34a8610
1778d05
5ab74eb
510d1c6
f4f93c3
52498f5
ad40005
ce16c2a
85efeb8
a15a2c7
fc3c32a
b9f35a5
1d18f1c
b235710
f36416e
490d941
7d1b43b
6a3a9a1
80c95fb
7f877ef
77f6b33
6b92b6c
7df60cb
31a0ece
894dfa2
6bf8aa5
63a91f6
9df7799
e586719
a7dd576
cc8377d
afec78f
4d8f50f
3125b82
8706919
954d27e
fa6e7a7
35b3e50
a56c0db
b46f945
12a1efa
d91ecc7
61fa085
ccd3ce5
2f9a138
b4ed379
75183f1
6a5e2f2
8d78cf1
9a641cc
ec34333
af1bdb2
4b4a36e
4ddb7b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,6 +115,72 @@ def ntxent( | |
denom = jnp.sum(jnp.exp(xcs_shift_diffs), axis=1, keepdims=True) | ||
denom += numer_exp | ||
log_softm = numer - jnp.log(denom) | ||
loss = -jnp.where(matches == 1, log_softm, 0.0).sum() / matches.sum() | ||
loss = -jnp.where(matches == 1, log_softm, 0.0).sum()/matches.sum() | ||
|
||
return loss | ||
|
||
|
||
def triplet_loss( | ||
anchors: chex.Array, | ||
positives: chex.Array, | ||
negatives: chex.Array, | ||
axis: chex.Numeric = -1, | ||
p: chex.Numeric = 2, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
not
You may want to include the case |
||
margin: chex.Numeric = 1.0, | ||
eps: chex.Numeric = 1e-6, | ||
reduction: str = 'none', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the reduction option. No losses in optax reduce the losses after computation. |
||
) -> chex.Array: | ||
"""Computes the triplet loss for a batch of embeddings. | ||
|
||
Examples: | ||
>>> import jax.numpy as jnp | ||
>>> import optax | ||
>>> import chex | ||
>>> anchors = jnp.array([[0.0, 0.0], [1.0, 1.0]]) | ||
>>> positives = jnp.array([[0.1, 0.1], [1.1, 1.1]]) | ||
>>> negatives = jnp.array([[1.0, 0.0], [0.0, 1.0]]) | ||
>>> output =optax.triplet_loss(anchors, positives, negatives, margin=1.0) | ||
>>> print(output) | ||
>>> Array([0.14142442, 0.14142442], dtype=float32) | ||
|
||
Args: | ||
anchors: An array of anchor embeddings, with shape [batch, feature_dim]. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add indents appropriately like that:
|
||
positives: An array of positive embeddings | ||
(similar to anchors), with shape [batch, feature_dim]. | ||
negatives: An array of negative embeddings | ||
(dissimilar to anchors), with shape [batch, feature_dim]. | ||
axis: The axis along which to compute the distances | ||
(default is -1). | ||
p: The norm degree for distance calculation | ||
(default is 2 for Euclidean distance). | ||
margin: The minimum margin by which the positive distance | ||
should be smaller than the negative distance. | ||
eps: A small epsilon value to ensure numerical stability | ||
in the distance calculation. | ||
reduction: Specifies the reduction to apply to the | ||
output: 'none' | 'mean' | 'sum'. | ||
|
||
Returns: | ||
The computed triplet loss as an array or scalar | ||
depending on the reduction parameter. | ||
If reduction is 'mean' or 'sum', returns a scalar. | ||
|
||
References: | ||
Learning shallow convolutional feature descriptors with triplet losses | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use the following formatting for references
|
||
by V. Balntas, E. Riba et al. | ||
<https://bmva-archive.org.uk/bmvc/2016/papers/paper119/abstract119.pdf> | ||
""" | ||
chex.assert_type([anchors], float) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove the three |
||
chex.assert_type([positives], float) | ||
chex.assert_type([negatives], float) | ||
positive_distance = jnp.sqrt(jnp.power(anchors - positives, p).sum(axis) + eps | ||
) | ||
negative_distance = jnp.sqrt(jnp.power(anchors - negatives, p).sum(axis) + eps | ||
) | ||
loss = jnp.maximum(positive_distance - negative_distance + margin, 0) | ||
if reduction == 'mean': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As said above, remove the reduction options. |
||
return loss.mean() | ||
elif reduction == 'sum': | ||
return loss.sum() | ||
else: | ||
return loss |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,12 +14,13 @@ | |
# ============================================================================== | ||
"""Tests for self-supervised losses in `optax.losses._self_supervised.py`.""" | ||
|
||
from absl.testing import absltest | ||
from absl.testing import absltest, parameterized | ||
import chex | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from optax.losses import _self_supervised | ||
|
||
from optax.losses import _self_supervised | ||
|
||
class NtxentTest(chex.TestCase): | ||
|
||
|
@@ -46,7 +47,6 @@ def setUp(self): | |
|
||
@chex.all_variants | ||
def test_batched(self): | ||
"""Tests for a full batch.""" | ||
np.testing.assert_allclose( | ||
self.variant(_self_supervised.ntxent)(self.ys, self.ts_1), | ||
self.exp_1, | ||
|
@@ -65,6 +65,67 @@ def test_batched(self): | |
atol=1e-4, | ||
) | ||
|
||
class TripletMarginLossTest(chex.TestCase, parameterized.TestCase): | ||
|
||
def setUp(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid using numerical values as expected returns. You may also add a test for some specific behaviors (like using swap here). Also you should test this function under jit/vmap etc... (see the chex.all_variant utility in some other tests). |
||
super().setUp() | ||
self.a1 = jnp.ones((2, 2)) | ||
self.p1 = jnp.zeros((2, 2)) | ||
self.n1 = jnp.ones((2, 2)) * 2 | ||
self.a2 = jnp.zeros((2, 2)) | ||
self.p2 = jnp.ones((2, 2)) | ||
self.n2 = jnp.ones((2, 2)) * 2 | ||
|
||
@chex.all_variants | ||
@parameterized.parameters([ | ||
{ | ||
'anchor': jnp.ones((2, 2)), | ||
'positive': jnp.zeros((2, 2)), | ||
'negative': jnp.ones((2, 2)) * 2, | ||
'margin': 1.0, | ||
}, | ||
{ | ||
'anchor': jnp.zeros((2, 2)), | ||
'positive': jnp.ones((2, 2)), | ||
'negative': jnp.ones((2, 2)) * 2, | ||
'margin': 1.0, | ||
} | ||
]) | ||
def test_batched(self, anchor, positive, negative, margin): | ||
def testing_triplet_loss(a, p, n, margin=1.0, p_norm=2, eps=1e-6): | ||
ap_distance = jnp.sqrt(jnp.sum(jnp.power(a - p, p_norm)) + eps) | ||
an_distance = jnp.sqrt(jnp.sum(jnp.power(a - n, p_norm)) + eps) | ||
return jnp.maximum(ap_distance - an_distance + margin, 0) | ||
|
||
handmade_result = testing_triplet_loss( | ||
a=anchor, p=positive, n=negative, margin=margin | ||
) | ||
result = self.variant(_self_supervised.triplet_loss)( | ||
anchor, positive, negative | ||
) | ||
np.testing.assert_allclose(result, handmade_result, atol=1e-4) | ||
|
||
@chex.all_variants | ||
@parameterized.parameters([ | ||
{ | ||
'anchor': jnp.ones((2, 2)), | ||
'positive': jnp.zeros((2, 2)), | ||
'negative': jnp.ones((2, 2)) * 2, | ||
}, | ||
]) | ||
def test_vmap(self, anchor, positive, negative): | ||
original_loss = _self_supervised.triplet_loss(anchor, positive, | ||
negative, reduction='none') | ||
anchor_batched = anchor.reshape(1, *anchor.shape) | ||
positive_batched = positive.reshape(1, *positive.shape) | ||
negative_batched = negative.reshape(1, *negative.shape) | ||
vmap_loss = self.variant(jax.vmap(_self_supervised.triplet_loss, | ||
in_axes=(0, 0, 0)))(anchor_batched, | ||
positive_batched, | ||
negative_batched) | ||
np.testing.assert_allclose(vmap_loss.flatten(), original_loss.flatten() | ||
, atol=1e-4) | ||
|
||
|
||
if __name__ == '__main__': | ||
absltest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis: int = -1