From 54a1ebaa1c8387cd34114f4e0901e628bec92504 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 1 Aug 2024 23:59:53 -0700 Subject: [PATCH] Remove uses of deprecated xb, xc, and xe abbreviation from jax.interpreters.xla PiperOrigin-RevId: 658688048 --- haiku/_src/transform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/haiku/_src/transform.py b/haiku/_src/transform.py index 2365159b8..3c48da325 100644 --- a/haiku/_src/transform.py +++ b/haiku/_src/transform.py @@ -338,8 +338,8 @@ def transform(f, *, apply_rng=True) -> Transformed: return without_state(transform_with_state(f)) -COMPILED_FN_TYPES = (jax.interpreters.xla.xe.PjitFunction, - jax.interpreters.xla.xe.PmapFunction) # pytype: disable=name-error +COMPILED_FN_TYPES = (jax.lib.xla_extension.PjitFunction, + jax.lib.xla_extension.PmapFunction) # pytype: disable=name-error def check_not_jax_transformed(f):