Skip to content

Commit

Permalink
Hide JAX's internal tracing state and update libraries to use limited…
Browse files Browse the repository at this point in the history
… trace-state-querying APIs as needed. This is prep work for stackless which will change those internals while preserving the API.

PiperOrigin-RevId: 677843398
  • Loading branch information
dougalm authored and copybara-github committed Sep 24, 2024
1 parent 4773949 commit 2f4bf3c
Showing 1 changed file with 2 additions and 6 deletions.
8 changes: 2 additions & 6 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,8 @@ class JaxTraceLevel(NamedTuple):

@classmethod
def current(cls):
# TODO(tomhennigan): Remove once a version of JAX is released incl PR#9423.
trace_stack = jax_core.thread_local_state.trace_state.trace_stack.stack
top_type = trace_stack[0].trace_type
level = trace_stack[-1].level
sublevel = jax_core.cur_sublevel()
return JaxTraceLevel(opaque=(top_type, level, sublevel))
ts = jax_core.get_opaque_trace_state(convention="haiku")
return JaxTraceLevel(opaque=ts)

frame_ids = it.count()

Expand Down

0 comments on commit 2f4bf3c

Please sign in to comment.