Skip to content

Commit

Permalink
remove test that leaked jax tracers
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 690815980
  • Loading branch information
mattjj authored and OptaxDev committed Oct 29, 2024
1 parent b8c2e13 commit e22f71b
Showing 1 changed file with 0 additions and 19 deletions.
19 changes: 0 additions & 19 deletions optax/_src/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# ==============================================================================
"""Tests for `utils.py`."""

from unittest import mock

from absl.testing import absltest
from absl.testing import parameterized
import chex
Expand All @@ -38,23 +36,6 @@ def _shape_to_tuple(shape):

class ScaleGradientTest(parameterized.TestCase):

@parameterized.product(
inputs=[-1.0, 0.0, 1.0], scale=[-0.5, 0.0, 0.5, 1.0, 2.0]
)
@mock.patch.object(jax.lax, 'stop_gradient', wraps=jax.lax.stop_gradient)
def test_scale_gradient(self, mock_sg, inputs, scale):
def fn(inputs):
outputs = utils.scale_gradient(inputs, scale)
return outputs**2

grad = jax.grad(fn)
self.assertEqual(grad(inputs), 2 * inputs * scale)
if scale == 0.0:
mock_sg.assert_called_once_with(inputs)
else:
self.assertFalse(mock_sg.called)
self.assertEqual(fn(inputs), inputs**2)

@parameterized.product(scale=[-0.5, 0.0, 0.5, 1.0, 2.0])
def test_scale_gradient_pytree(self, scale):
def fn(inputs):
Expand Down

0 comments on commit e22f71b

Please sign in to comment.