From cbac0f92585e971469938bd162ab7a09385f2e29 Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Fri, 1 Dec 2023 10:14:44 -0800 Subject: [PATCH] Use only 'jax_disable_jit' in `fake_jit`. PiperOrigin-RevId: 587044786 --- chex/_src/fake.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/chex/_src/fake.py b/chex/_src/fake.py index b9c3f9ba..a6fdc74e 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -88,11 +88,6 @@ def convert_to_varargs(sig, *args, **kwargs): return bound_args.args -@functools.wraps(jax.jit) -def _fake_jit(fn, *unused_args, **unused_kwargs): - return fn - - def _ignore_axis_index_groups(fn): """Wrapper that forces axis_index_groups to be None. @@ -256,23 +251,7 @@ def foo(x): such as `jax.lax.scan`, etc. """ stack = FakeContext() - if enable_patching: - stack.enter_context(mock.patch('jax.jit', _fake_jit)) - - # Some functions like jax.lax.scan also internally use jit. Most respect - # the config setting `jax_disable_jit` and replace its implementation - # with a dummy, jit-free one if the setting is one. Use this mechanism too. - @contextlib.contextmanager - def _jax_disable_jit(): - original_value = jax.config.jax_disable_jit - jax.config.update('jax_disable_jit', True) - try: - yield - finally: - jax.config.update('jax_disable_jit', original_value) - - stack.enter_context(_jax_disable_jit()) - + stack.enter_context(jax.disable_jit(disable=enable_patching)) return stack