From e22f71b8fbc817a4e95ad88b061937bfe6b5a0ce Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Mon, 28 Oct 2024 18:09:33 -0700 Subject: [PATCH] remove test that leaked jax tracers PiperOrigin-RevId: 690815980 --- optax/_src/utils_test.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/optax/_src/utils_test.py b/optax/_src/utils_test.py index 658dc9a2..619f21a2 100644 --- a/optax/_src/utils_test.py +++ b/optax/_src/utils_test.py @@ -14,8 +14,6 @@ # ============================================================================== """Tests for `utils.py`.""" -from unittest import mock - from absl.testing import absltest from absl.testing import parameterized import chex @@ -38,23 +36,6 @@ def _shape_to_tuple(shape): class ScaleGradientTest(parameterized.TestCase): - @parameterized.product( - inputs=[-1.0, 0.0, 1.0], scale=[-0.5, 0.0, 0.5, 1.0, 2.0] - ) - @mock.patch.object(jax.lax, 'stop_gradient', wraps=jax.lax.stop_gradient) - def test_scale_gradient(self, mock_sg, inputs, scale): - def fn(inputs): - outputs = utils.scale_gradient(inputs, scale) - return outputs**2 - - grad = jax.grad(fn) - self.assertEqual(grad(inputs), 2 * inputs * scale) - if scale == 0.0: - mock_sg.assert_called_once_with(inputs) - else: - self.assertFalse(mock_sg.called) - self.assertEqual(fn(inputs), inputs**2) - @parameterized.product(scale=[-0.5, 0.0, 0.5, 1.0, 2.0]) def test_scale_gradient_pytree(self, scale): def fn(inputs):