diff --git a/haiku/_src/transform.py b/haiku/_src/transform.py index 2365159b8..3c48da325 100644 --- a/haiku/_src/transform.py +++ b/haiku/_src/transform.py @@ -338,8 +338,8 @@ def transform(f, *, apply_rng=True) -> Transformed: return without_state(transform_with_state(f)) -COMPILED_FN_TYPES = (jax.interpreters.xla.xe.PjitFunction, - jax.interpreters.xla.xe.PmapFunction) # pytype: disable=name-error +COMPILED_FN_TYPES = (jax.lib.xla_extension.PjitFunction, + jax.lib.xla_extension.PmapFunction) # pytype: disable=name-error def check_not_jax_transformed(f):