Skip to content

Commit

Permalink
[nnx] support pure dicts
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 1, 2024
1 parent e4dad9c commit 843a3dc
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 63 deletions.
1 change: 1 addition & 0 deletions examples/nnx_toy_examples/02_lifted_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_step(model: MLP, batch):
total_steps = 10_000
for step, batch in enumerate(dataset(32)):
train_step(model, optimizer, batch)
print(nnx.graph.GRAPH_CONTEXT)

if step % 1000 == 0:
logs = test_step(model, (X, Y))
Expand Down
11 changes: 10 additions & 1 deletion flax/nnx/filterlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,21 @@ def filters_to_predicates(filters: tuple[Filter, ...]) -> tuple[Predicate, ...]:
)
return tuple(map(to_predicate, filters))


class HasTag(tp.Protocol):
tag: str


def _has_tag(x: tp.Any) -> tp.TypeGuard[HasTag]:
return hasattr(x, 'tag')


@dataclasses.dataclass(frozen=True)
class WithTag:
tag: str

def __call__(self, path: PathParts, x: tp.Any):
return hasattr(x, 'tag') and x.tag == self.tag
return _has_tag(x) and x.tag == self.tag

def __repr__(self):
return f'WithTag({self.tag!r})'
Expand Down
142 changes: 90 additions & 52 deletions flax/nnx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from flax.nnx.statelib import FlatState, State
from flax.nnx import variablelib
from flax.nnx.variablelib import Variable, VariableState
from flax.typing import Key, PathParts
from flax.typing import Key, PathParts, is_key_like

A = tp.TypeVar('A')
B = tp.TypeVar('B')
Expand All @@ -43,6 +43,7 @@

HA = tp.TypeVar('HA', bound=tp.Hashable)
HB = tp.TypeVar('HB', bound=tp.Hashable)
KeyT = tp.TypeVar('KeyT', bound=Key)

Index = int
Names = tp.Sequence[int]
Expand Down Expand Up @@ -241,6 +242,35 @@ def __treescope_repr__(self, path, subtree_renderer):

jax.tree_util.register_static(NodeRef)

@dataclasses.dataclass(frozen=True, repr=False)
class VariableDef(reprlib.Representable):
type: type[Variable]
index: int
metadata: FrozenDict[str, tp.Any]

def __nnx_repr__(self):
yield reprlib.Object(type=type(self))
yield reprlib.Attr('type', self.type.__name__)
yield reprlib.Attr('index', self.index)
yield reprlib.Attr('metadata', reprlib.PrettyMapping(self.metadata))

def __treescope_repr__(self, path, subtree_renderer):
import treescope # type: ignore[import-not-found,import-untyped]

return treescope.repr_lib.render_object_constructor(
object_type=type(self),
attributes={
'type': self.type,
'index': self.index,
'metadata': self.metadata,
},
path=path,
subtree_renderer=subtree_renderer,
)


jax.tree_util.register_static(VariableDef)


@dataclasses.dataclass(frozen=True, repr=False)
class NodeDef(GraphDef[Node], reprlib.Representable):
Expand All @@ -253,7 +283,7 @@ class NodeDef(GraphDef[Node], reprlib.Representable):
attributes: tuple[Key, ...]
subgraphs: _HashableMapping[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]
static_fields: _HashableMapping[Key, tp.Any]
leaves: _HashableMapping[Key, NodeRef[tp.Any] | None]
leaves: _HashableMapping[Key, VariableDef | NodeRef[tp.Any]]
metadata: tp.Any
index_mapping: FrozenDict[Index, Index] | None

Expand All @@ -265,7 +295,7 @@ def create(
attributes: tuple[Key, ...],
subgraphs: tp.Iterable[tuple[Key, NodeDef[tp.Any] | NodeRef[tp.Any]]],
static_fields: tp.Iterable[tuple[Key, tp.Any]],
leaves: tp.Iterable[tuple[Key, NodeRef[tp.Any] | None]],
leaves: tp.Iterable[tuple[Key, VariableDef | NodeRef[tp.Any]]],
metadata: tp.Any,
index_mapping: tp.Mapping[Index, Index] | None,
):
Expand Down Expand Up @@ -380,7 +410,7 @@ def _graph_flatten(

subgraphs: list[tuple[Key, NodeDef[Node] | NodeRef]] = []
static_fields: list[tuple[Key, tp.Any]] = []
leaves: list[tuple[Key, NodeRef | None]] = []
leaves: list[tuple[Key, VariableDef | NodeRef]] = []

values, metadata = node_impl.flatten(node)
for key, value in values:
Expand All @@ -393,10 +423,10 @@ def _graph_flatten(
else:
flat_state[(*path, key)] = value.to_state()
variable_index = ref_index[value] = len(ref_index)
leaves.append((key, NodeRef(type(value), variable_index)))
elif is_state_leaf(value):
flat_state[(*path, key)] = value
leaves.append((key, None))
variabledef = VariableDef(
type(value), variable_index, FrozenDict(value.get_metadata())
)
leaves.append((key, variabledef))
else:
if isinstance(value, (jax.Array, np.ndarray)):
path_str = '/'.join(map(str, (*path, key)))
Expand All @@ -420,7 +450,7 @@ def _graph_flatten(

def unflatten(
graphdef: GraphDef[Node],
state: GraphState,
state: tp.Mapping[KeyT, StateLeaf | tp.Mapping[Key, tp.Any]],
/,
*,
index_ref: dict[Index, tp.Any] | None = None,
Expand All @@ -441,12 +471,12 @@ def unflatten(
existing graph nodes are mutated to have the new content/topology
specified by the graphdef.
"""
if isinstance(state, State):
state = state.raw_mapping
if index_ref is None:
index_ref = {}
assert isinstance(graphdef, (NodeDef, NodeRef))
node = _graph_unflatten(
graphdef, state.raw_mapping, index_ref, index_ref_cache
)
node = _graph_unflatten(graphdef, state, index_ref, index_ref_cache)
return node

def _graph_unflatten(
Expand Down Expand Up @@ -480,7 +510,7 @@ def _graph_unflatten(
node_impl = get_node_impl_for_type(nodedef.type)

def _get_children():
children: dict[Key, StateLeaf | Node] = {}
children: dict[Key, NodeLeaf | Node] = {}

# NOTE: we could allw adding new StateLeafs here
if unkown_keys := set(state) - set(nodedef.attributes):
Expand All @@ -491,13 +521,13 @@ def _get_children():
# - (3) the key can be a subgraph, a leaf, or a static attribute
for key in nodedef.attributes:
if key not in state:
# TODO(cgarcia): maybe we shouldn't support unflattening with missing keys?
# if key is not present create an empty types
if key in nodedef.static_fields:
children[key] = nodedef.static_fields[key]
elif key in nodedef.subgraphs:
# if the key is a subgraph we create an empty node
subgraphdef = nodedef.subgraphs[key]
assert not isinstance(subgraphdef, VariableDef)
if isinstance(subgraphdef, NodeRef):
# subgraph exists, take it from the cache
children[key] = index_ref[subgraphdef.index]
Expand All @@ -511,10 +541,10 @@ def _get_children():
subgraphdef, substate, index_ref, index_ref_cache
)
elif key in nodedef.leaves:
noderef = nodedef.leaves[key]
if noderef is not None and noderef.index in index_ref:
variabledef = nodedef.leaves[key]
if variabledef.index in index_ref:
# variable exists, take it from the cache
children[key] = index_ref[noderef.index]
children[key] = index_ref[variabledef.index]
else:
# key for a variable is missing, raise an error
raise ValueError(
Expand Down Expand Up @@ -546,41 +576,47 @@ def _get_children():
)

elif key in nodedef.leaves:
if not is_state_leaf(value):
raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')

noderef = nodedef.leaves[key]

if noderef is None:
# if the leaf is None, it means that the value was originally
# a non-VariableState leaf, however we allow providing a
# VariableState presumbly created by modifying the State
if isinstance(value, VariableState):
value = value.to_variable()
children[key] = value
elif noderef.index in index_ref:
# if not is_state_leaf(value):
# raise ValueError(f'Expected a leaf for {key!r}, but got {value!r}')

variabledef = nodedef.leaves[key]

if variabledef.index in index_ref:
# add an existing variable
children[key] = index_ref[noderef.index]
assert isinstance(variabledef, NodeRef)
children[key] = index_ref[variabledef.index]
else:
# its a unseen variable, create a new one
if not isinstance(value, VariableState):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(value)}.'
)
assert isinstance(variabledef, VariableDef)
# if not isinstance(value, VariableState):
# raise ValueError(
# f'Expected a Variable type for {key!r}, but got {type(value)}.'
# )
# when idxmap is present, check if the Varable exists there
# and update existing variables if it does
if index_ref_cache is not None and noderef.index in index_ref_cache:
variable = index_ref_cache[noderef.index]
if (
index_ref_cache is not None
and variabledef.index in index_ref_cache
):
# if variable exists, update it
variable = index_ref_cache[variabledef.index]
if not isinstance(variable, Variable):
raise ValueError(
f'Expected a Variable type for {key!r}, but got {type(variable)}.'
)
variable.update_from_state(value)
if isinstance(value, VariableState):
variable.update_from_state(value)
else:
variable.raw_value = value
else: # if it doesn't, create a new variable
assert isinstance(value, VariableState)
variable = value.to_variable()
if isinstance(value, VariableState):
variable = value.to_variable()
else:
variable = variabledef.type.from_metadata(
value, variabledef.metadata
)
children[key] = variable
index_ref[noderef.index] = variable
index_ref[variabledef.index] = variable
else:
raise RuntimeError(f'Unknown key: {key!r}, this is a bug.')

Expand Down Expand Up @@ -676,7 +712,7 @@ def _graph_pop(
pass


def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[Key, tp.Any]):
def _graph_update_dynamic(node: tp.Any, state: tp.Mapping[KeyT, tp.Any]):
if not is_node(node):
raise RuntimeError(f'Unsupported type: {type(node)}')

Expand Down Expand Up @@ -1251,12 +1287,11 @@ def split(
states = _split_state(state, filters)
return graphdef, *states


def merge(
graphdef: GraphDef[A],
state: GraphState,
state: tp.Mapping[KeyT, tp.Any],
/,
*states: GraphState,
*states: tp.Mapping[KeyT, tp.Any],
) -> A:
"""The inverse of :func:`split`.
Expand Down Expand Up @@ -1293,13 +1328,15 @@ def merge(
Returns:
The merged :class:`Module`.
"""
state = GraphState.merge(state, *states)
state = State.merge(state, *states)
node = unflatten(graphdef, state)
return node


def update(node, state: State, /, *states: State) -> None:
"""Update the given graph node with a new :class:`State` in-place.
def update(
node, state: tp.Mapping[KeyT, tp.Any], /, *states: tp.Mapping[KeyT, tp.Any]
) -> None:
"""Update the given graph node with a new state(s) in-place.
Example usage::
Expand All @@ -1325,9 +1362,10 @@ def update(node, state: State, /, *states: State) -> None:
*states: Additional :class:`State` objects.
"""
if states:
state = GraphState.merge(state, *states)

_graph_update_dynamic(node, state.raw_mapping)
state = State.merge(state, *states)
if isinstance(state, State):
state = state.raw_mapping
_graph_update_dynamic(node, state)

def _variables_generator(node) -> tp.Iterable[tuple[PathParts, Variable]]:
for path, value in iter_graph(node):
Expand Down Expand Up @@ -1722,7 +1760,7 @@ def _key_path_to_key(key: tp.Any) -> Key:
elif isinstance(
key, (jax.tree_util.DictKey, jax.tree_util.FlattenedIndexKey)
):
if not isinstance(key.key, Key):
if not is_key_like(key.key):
raise ValueError(
f'Invalid key: {key.key}. May be due to its type not being hashable or comparable.'
)
Expand Down
11 changes: 9 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,26 @@
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
PARTITION_NAME = 'partition_name'

class HasSharding(tp.Protocol):
sharding: tuple[str | None, ...] | None

def add_axis(tree: A, index: int, params: tp.Mapping[tp.Any, tp.Any]) -> A:

def _has_sharding(x: tp.Any) -> tp.TypeGuard[HasSharding]:
return hasattr(x, 'sharding') and x.sharding is not None

def add_axis(tree: A, index: int, params: tp.Mapping) -> A:
axis_name = _get_partition_name(params)

def _add_axis(x: tp.Any):
if isinstance(x, variablelib.VariableState):
if hasattr(x, 'sharding') and x.sharding is not None:
if _has_sharding(x) and x.sharding is not None:
sharding: list[str | None] = list(x.sharding)
while len(sharding) < index:
sharding.append(None)
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

assert isinstance(x, variablelib.VariableState)
x.add_axis(index, axis_name)
return x

Expand Down
10 changes: 7 additions & 3 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,9 @@ def filter(
return states # type: ignore[bad-return-type]

@staticmethod
def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]:
def merge(
state: tp.Mapping[K, V], /, *states: tp.Mapping[K, V]
) -> State[K, V]:
"""The inverse of :meth:`split() <flax.nnx.State.state.split>`.
``merge`` takes one or more ``State``'s and creates
Expand Down Expand Up @@ -353,14 +355,16 @@ def merge(state: State[K, V], /, *states: State[K, V]) -> State[K, V]:
The merged ``State``.
"""
if not states:
return state
if isinstance(state, State):
return state
return State(state)

states = (state, *states)

new_state: FlatState[V] = {}

for state in states:
new_state.update(state.flat_state()) # type: ignore[attribute-error] # pytype is wrong here
new_state.update(traversals.flatten_mapping(state)) # type: ignore[attribute-error] # pytype is wrong here

return State.from_flat_path(new_state)

Expand Down
Loading

0 comments on commit 843a3dc

Please sign in to comment.