Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update libraries to use JAX's limited (and ill-advised) trace-state-querying APIs rather than depending on JAX's deeper internals, which are about to change. #4219

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions flax/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@
from .tracers import (
check_trace_level as check_trace_level,
current_trace as current_trace,
trace_level as trace_level,
)

from flax.typing import (
Array as Array,
)
)
2 changes: 1 addition & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def __init__(
self.flags = freeze({} if flags is None else flags)

self._root = parent.root if parent else None
self.trace_level = tracers.trace_level(tracers.current_trace())
self.trace_level = tracers.current_trace()

self.rng_counters = {key: 0 for key in self.rngs}
self.reservations = collections.defaultdict(set)
Expand Down
19 changes: 9 additions & 10 deletions flax/core/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@


def current_trace():
"""Returns the innermost Jax tracer."""
return jax.core.find_top_trace(())


def trace_level(main):
"""Returns the level of the trace of -infinity if it is None."""
if main:
return main.level
return float('-inf')
"""Returns the current JAX state tracer."""
if jax.__version_info__ <= (0, 4, 33):
top = jax.core.find_top_trace(())
if top:
return top.level
else:
return float('-inf')

return jax.core.get_opaque_trace_state(convention="flax")

def check_trace_level(base_level):
level = trace_level(current_trace())
level = current_trace()
if level != base_level:
raise errors.JaxTransformError()
20 changes: 14 additions & 6 deletions flax/nnx/tracers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,16 @@
# Taken from flax/core/tracer.py 🏴‍☠️


from jax.core import MainTrace, thread_local_state
from jax.core import get_opaque_trace_state, OpaqueTraceState

from flax.nnx import reprlib


def current_jax_trace() -> MainTrace:
"""Returns the innermost Jax tracer."""
return thread_local_state.trace_state.trace_stack.dynamic
def current_jax_trace() -> OpaqueTraceState:
"""Returns the Jax tracing state."""
if jax.__version_info__ <= (0, 4, 33):
return thread_local_state.trace_state.trace_stack.dynamic
return get_opaque_trace_state(convention="nnx")


class TraceState(reprlib.Representable):
Expand All @@ -36,7 +38,10 @@ def jax_trace(self):
return self._jax_trace

def is_valid(self) -> bool:
return self._jax_trace is current_jax_trace()
if jax.__version_info__ <= (0, 4, 33):
return self._jax_trace is current_jax_trace()

return self._jax_trace == current_jax_trace()

def __nnx_repr__(self):
yield reprlib.Object(f'{type(self).__name__}')
Expand All @@ -52,4 +57,7 @@ def __treescope_repr__(self, path, subtree_renderer):
)

def __eq__(self, other):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace
if jax.__version_info__ <= (0, 4, 33):
return isinstance(other, TraceState) and self._jax_trace is other._jax_trace

return isinstance(other, TraceState) and self._jax_trace == other._jax_trace
Loading