diff --git a/haiku/_src/base.py b/haiku/_src/base.py index f44c9d5b5..66b86ef6f 100644 --- a/haiku/_src/base.py +++ b/haiku/_src/base.py @@ -70,12 +70,7 @@ class JaxTraceLevel(NamedTuple): @classmethod def current(cls): - # TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423. - trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack - top_type = trace_stack[0].trace_type - level = trace_stack[-1].level - sublevel = jax_core.cur_sublevel() - return JaxTraceLevel(opaque=(top_type, level, sublevel)) + return JaxTraceLevel(opaque=jax_core.get_opaque_trace_state()) frame_ids = it.count()