Skip to content

Commit

Permalink
Use only 'jax_disable_jit' in fake_jit.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 587044786
  • Loading branch information
hbq1 authored and ChexDev committed Dec 1, 2023
1 parent c781a9d commit cbac0f9
Showing 1 changed file with 1 addition and 22 deletions.
23 changes: 1 addition & 22 deletions chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,6 @@ def convert_to_varargs(sig, *args, **kwargs):
return bound_args.args


@functools.wraps(jax.jit)
def _fake_jit(fn, *unused_args, **unused_kwargs):
return fn


def _ignore_axis_index_groups(fn):
"""Wrapper that forces axis_index_groups to be None.
Expand Down Expand Up @@ -256,23 +251,7 @@ def foo(x):
such as `jax.lax.scan`, etc.
"""
stack = FakeContext()
if enable_patching:
stack.enter_context(mock.patch('jax.jit', _fake_jit))

# Some functions like jax.lax.scan also internally use jit. Most respect
# the config setting `jax_disable_jit` and replace its implementation
# with a dummy, jit-free one if the setting is one. Use this mechanism too.
@contextlib.contextmanager
def _jax_disable_jit():
original_value = jax.config.jax_disable_jit
jax.config.update('jax_disable_jit', True)
try:
yield
finally:
jax.config.update('jax_disable_jit', original_value)

stack.enter_context(_jax_disable_jit())

stack.enter_context(jax.disable_jit(disable=enable_patching))
return stack


Expand Down

0 comments on commit cbac0f9

Please sign in to comment.