Skip to content

Commit

Permalink
Move compiler APIs out of dispatch.py and xla_bridge.py into a new ja…
Browse files Browse the repository at this point in the history
…x._src.compiler module.

Refactoring only, no user-visible changes intended.

PiperOrigin-RevId: 554845555
  • Loading branch information
hawkinsp authored and ChexDev committed Aug 11, 2023
1 parent 4288c51 commit 5845d27
Showing 1 changed file with 12 additions and 5 deletions.
17 changes: 12 additions & 5 deletions chex/_src/restrict_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@
import functools
from typing import Optional, Sequence

import jax._src.dispatch as jax_dispatch
# pylint: disable=g-import-not-at-top
try:
from jax._src import compiler
except ImportError:
# TODO(phawkins): remove this path after jax>=0.4.15 is the minimum version
# required by chex.
from jax._src import dispatch as compiler # type: ignore
# pylint: enable=g-import-not-at-top


class RestrictedBackendError(RuntimeError):
Expand Down Expand Up @@ -73,7 +80,7 @@ def is_allowed(backend_platform):
return ((backend_platform in allowed) if allowed is not None else
(backend_platform not in forbidden))

inner_backend_compile = jax_dispatch.backend_compile
inner_backend_compile = compiler.backend_compile

@functools.wraps(inner_backend_compile)
def wrapper(backend, *args, **kwargs):
Expand All @@ -84,9 +91,9 @@ def wrapper(backend, *args, **kwargs):
return inner_backend_compile(backend, *args, **kwargs)

try:
jax_dispatch.backend_compile = wrapper
compiler.backend_compile = wrapper
yield
finally:
backend_compile = jax_dispatch.backend_compile
backend_compile = compiler.backend_compile
assert backend_compile is wrapper, backend_compile
jax_dispatch.backend_compile = inner_backend_compile
compiler.backend_compile = inner_backend_compile

0 comments on commit 5845d27

Please sign in to comment.