Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add unsafe_graph_cache decorator #4356

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from .graph import split_context as split_context
from .graph import MergeContext as MergeContext
from .graph import merge_context as merge_context
from .graph import unsafe_graph_cache as unsafe_graph_cache
from .graph import variables as variables
from .nn import initializers as initializers
from .nn.activations import celu as celu
Expand Down
173 changes: 145 additions & 28 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import functools
import threading
import typing as tp
import weakref

import jax
import numpy as np
Expand Down Expand Up @@ -727,6 +728,11 @@ def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
# --------------------------------------------------------
# UpdateContext
# --------------------------------------------------------
# TODO(cgarciae): UpdateContext.split/merge are only used in `deprecated.py`,
# there might be some opportunities for simplification when its finally removed,
# either by simply deleting these methods or moving the implementation of
# SplitContext.split/MergeContext.merge into UpdateContext.split/merge. Maybe
# even removing Split/MergeContext entirely.

@dataclasses.dataclass
class GraphContext(threading.local):
Expand All @@ -735,15 +741,31 @@ class GraphContext(threading.local):
)
ref_index_stack: list[SplitContext] = dataclasses.field(default_factory=list)
index_ref_stack: list[MergeContext] = dataclasses.field(default_factory=list)
unsafe_graph_cache_stack: list[UnsafeGraphCacheContext] = dataclasses.field(
default_factory=list
)


GRAPH_CONTEXT = GraphContext()

def _get_variables_at_paths(variables: dict, node, state: State):
node_impl = get_node_impl(node)
node_children = node_impl.node_dict(node)
for key, child_state in state.items():
child_node = node_children[key]
if isinstance(child_node, Variable):
variables[key] = child_node
else:
variables[key] = _get_variables_at_paths({}, node, child_state)

return variables


@dataclasses.dataclass
class SplitContext:
ctxtag: str | None
ref_index: RefMap[tp.Any, Index]
ref_count: RefMap[tp.Any, int]

@tp.overload
def split(self, graph_node: A, /) -> tuple[GraphDef[A], GraphState]: ...
Expand All @@ -766,67 +788,120 @@ def split(
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)
if ctx is not None:
if ctx.index_ref is not None and isinstance(graphdef, NodeDef):

def _split_node():
graphdef, state = flatten(node, self.ref_index)
states = _split_state(state, filters)

if (
ctx is not None
and ctx.index_ref is not None
and isinstance(graphdef, NodeDef)
):
index_to_index = compose_mapping(ctx.index_ref, self.ref_index)
graphdef = dataclasses.replace(
graphdef, index_mapping=FrozenDict(index_to_index)
)

return graphdef, states

if ctx is not None and ctx.graph_cache_context is not None:
current_ref_count = self.ref_count.setdefault(node, 0)
split_caches = ctx.graph_cache_context.split_caches.setdefault(node, [])
if len(split_caches) > current_ref_count:
graphdef, variable_states = split_caches[current_ref_count]
states = jax.tree.map(Variable.to_state, variable_states)
else:
graphdef, states = _split_node()
variable_states = tuple(
State(_get_variables_at_paths({}, node, state)) for state in states
)
split_caches[current_ref_count] = (graphdef, variable_states)
self.ref_count[node] = current_ref_count + 1
else:
graphdef, states = _split_node()

return graphdef, *states


@contextlib.contextmanager
def split_context(ctxtag: str | None = None):
index_ref: RefMap[tp.Any, Index] = RefMap()
flatten_ctx = SplitContext(ctxtag, index_ref)
GRAPH_CONTEXT.ref_index_stack.append(flatten_ctx)
split_ctx = SplitContext(ctxtag, index_ref, RefMap())
GRAPH_CONTEXT.ref_index_stack.append(split_ctx)

try:
yield flatten_ctx
yield split_ctx
finally:
GRAPH_CONTEXT.ref_index_stack.pop()
if ctxtag is not None:
ctx = current_update_context(ctxtag)
ctx.flatten_end(index_ref)
del flatten_ctx.ref_index
del flatten_ctx.ctxtag
del split_ctx.ref_index
del split_ctx.ctxtag


@dataclasses.dataclass
class MergeContext:
ctxtag: str | None
index_ref: dict[Index, tp.Any]
index_count: dict[Index, int]

def merge(
self, graphdef: GraphDef[A], state: GraphState, /, *states: GraphState
) -> A:
states = (state, *states)
ctx = (
current_update_context(self.ctxtag) if self.ctxtag is not None else None
)
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping

def _merge_node():
if (
ctx is not None
and isinstance(graphdef, NodeDef)
and graphdef.index_mapping is not None
):
# outer merge (4), create index_ref_cache
assert ctx.ref_index is not None
index_ref_cache = compose_mapping_reversed(
ctx.ref_index, graphdef.index_mapping
)
else:
# inner merge (2)
index_ref_cache = None

state = State.merge(*states)
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node

if ctx is not None and ctx.graph_cache_context is not None:
current_index_count = self.index_count.setdefault(graphdef.index, 0)
merge_index_ref = ctx.graph_cache_context.merge_index_ref
if graphdef.index in merge_index_ref:
node = merge_index_ref[graphdef.index]
else:
node = merge_index_ref[graphdef.index] = _merge_node()
merge_caches = ctx.graph_cache_context.merge_caches.setdefault(node, [])
if len(merge_caches) > current_index_count:
variable_states = merge_caches[current_index_count]
else:
variable_states = tuple(
State(_get_variables_at_paths({}, node, state)) for state in states
)

def _update_variable(variable: Variable, variable_state: VariableState):
variable.raw_value = variable_state.value

jax.tree.map(_update_variable, variable_states, states)
self.index_count[graphdef.index] = current_index_count + 1
else:
# inner merge (2)
index_ref_cache = None
node = _merge_node()

state = State.merge(state, *states)
node = unflatten(
graphdef,
state,
index_ref=self.index_ref,
index_ref_cache=index_ref_cache,
)
return node


Expand Down Expand Up @@ -855,6 +930,7 @@ class UpdateContext:
tag: str
ref_index: RefMap[tp.Any, Index] | None
index_ref: dict[Index, tp.Any] | None
graph_cache_context: UnsafeGraphCacheContext | None

# define hash and eq to make this an opaque object
def __hash__(self):
Expand Down Expand Up @@ -1012,7 +1088,17 @@ class UpdateContextManager:
tag: str

def __enter__(self):
ctx = UpdateContext(self.tag, None, None)
if GRAPH_CONTEXT.unsafe_graph_cache_stack:
if len(GRAPH_CONTEXT.unsafe_graph_cache_stack) > 1:
raise ValueError(
f'Found {len(GRAPH_CONTEXT.unsafe_graph_cache_stack)} unsafe_graph_cache contexts '
'but only expected 1.'
)
graph_cache_context = GRAPH_CONTEXT.unsafe_graph_cache_stack.pop()
else:
graph_cache_context = None

ctx = UpdateContext(self.tag, None, None, graph_cache_context)
if self.tag not in GRAPH_CONTEXT.update_context_stacks:
GRAPH_CONTEXT.update_context_stacks[self.tag] = [ctx]
else:
Expand Down Expand Up @@ -1145,6 +1231,37 @@ def current_update_context(tag: str) -> UpdateContext:
return GRAPH_CONTEXT.update_context_stacks[tag][-1]


@dataclasses.dataclass
class UnsafeGraphCacheContext:
split_caches: weakref.WeakKeyDictionary[
tp.Any, list[tuple[GraphDef, tuple[State, ...]]]
]
merge_index_ref: weakref.WeakValueDictionary[Index, tp.Any]
merge_caches: weakref.WeakKeyDictionary[tp.Any, list[tuple[State, ...]]]


def unsafe_graph_cache(f):
split_caches = weakref.WeakKeyDictionary()
merge_index_ref = weakref.WeakValueDictionary()
merge_caches = weakref.WeakKeyDictionary()

@functools.wraps(f)
def _unsafe_graph_wrapper(*args, **kwargs):
GRAPH_CONTEXT.unsafe_graph_cache_stack.append(
UnsafeGraphCacheContext(split_caches, merge_index_ref, merge_caches)
)
try:
return f(*args, **kwargs)
finally:
if GRAPH_CONTEXT.unsafe_graph_cache_stack:
raise ValueError(
"unsafe_graph_cache's context was not consumed by an underlying update_context. "
'This likely means no nnx transform is being wrapped'
)

return _unsafe_graph_wrapper


# --------------------------------------------------------
# Functional API
# --------------------------------------------------------
Expand Down
Loading