Skip to content

Commit

Permalink
[nnx] add checkify
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Nov 14, 2024
1 parent 480a196 commit d51a0ff
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 5 deletions.
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@
from .transforms.transforms import eval_shape as eval_shape
from .transforms.transforms import cond as cond
from .transforms.transforms import switch as switch
from .transforms.transforms import checkify as checkify
from .transforms.iteration import while_loop as while_loop
from .transforms.iteration import fori_loop as fori_loop
from .transforms.iteration import StateAxes as StateAxes
Expand Down
80 changes: 76 additions & 4 deletions flax/nnx/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@
from __future__ import annotations

from abc import abstractmethod
import dataclasses
import functools
import inspect
import typing as tp

import jax.experimental
from jax._src import checkify as checkify_lib

from flax.nnx import (
extract,
graph,
)
from flax.nnx.module import Module
from flax.nnx.proxy_caller import (
Expand Down Expand Up @@ -119,7 +124,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs):


# -------------------------------
# eval_shape
# simple transforms
# -------------------------------


Expand All @@ -140,9 +145,76 @@ def _eval_shape_fn(*args, **kwargs):
return extract.from_tree(out)


# -------------------------------
# cond and switch
# -------------------------------
@dataclasses.dataclass(eq=False)
class CheckifyFn:
f: tp.Callable[..., tp.Any]

def __post_init__(self):
functools.update_wrapper(self, self.f)

def __call__(self, *pure_args, **pure_kwargs):
args, kwargs = extract.from_tree(
(pure_args, pure_kwargs), ctxtag='checkify'
)
out = self.f(*args, **kwargs)

args_out, kwargs_out = extract.clear_non_graph_nodes((args, kwargs))
pure_args_out, pure_kwargs_out, pure_out = extract.to_tree(
(args, kwargs, out), ctxtag='checkify'
)
return pure_args_out, pure_kwargs_out, pure_out

def checkify(
f: tp.Callable[..., checkify_lib.Out],
errors: frozenset[checkify_lib.JaxException] = checkify_lib.user_checks, # type: ignore
) -> tp.Callable[..., tuple[checkify_lib.Error, checkify_lib.Out]]:
"""Reference-aware version of `jax.experimental.checkify
<https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-functional-api>`_.
Example::
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> import dataclasses
>>> from flax import nnx
...
>>> @dataclasses.dataclass
... class Foo(nnx.Module):
... a: nnx.Param
...
>>> @nnx.jit
... def f(m):
... y = jnp.sin(m.a.value) # error
... return m.a + y
...
>>> m = Foo(a=nnx.Param(jnp.inf))
>>> err, out = nnx.checkify(f, errors=checkify.float_checks)(m)
>>> # err.throw()
>>> print(err)
Error(nan generated by primitive: sin.)
"""
checkify_fn = checkify_lib.checkify(CheckifyFn(f), errors)

@functools.wraps(f)
@graph.update_context('checkify')
def jit_wrapper(*args, **kwargs):
pure_args, pure_kwargs = extract.to_tree(
(args, kwargs),
ctxtag='checkify',
)
error, (pure_args_out, pure_kwargs_out, pure_out) = checkify_fn(
*pure_args, **pure_kwargs
)

args_out, kwargs_out, out = extract.from_tree(
(pure_args_out, pure_kwargs_out, pure_out),
ctxtag='checkify',
)

return error, out

return jit_wrapper # type: ignore


@general.split_inputs(ctxtag='cond')
Expand Down
19 changes: 18 additions & 1 deletion tests/nnx/transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from flax import nnx
from flax.nnx.transforms import general
import jax
from jax.experimental import mesh_utils
from jax.experimental import mesh_utils, checkify
import jax.numpy as jnp
import numpy as np



class List(nnx.Module):
def __init__(self, items):
vars(self).update({str(i): item for i, item in enumerate(items)})
Expand Down Expand Up @@ -3024,6 +3025,22 @@ def no_nothing(env: Env):
env.step.value, np.array([1, 0, 1, 0, 1, 0, 1, 0], np.uint32)
)

class TestCheckify(absltest.TestCase):
def test_basic(self):
@dataclasses.dataclass
class Foo(nnx.Module):
a: nnx.Param

@nnx.jit
def f(m):
y = jnp.sin(m.a.value) # error
return m.a + y

m = Foo(a=nnx.Param(jnp.inf))
err, out = nnx.checkify(f, errors=checkify.float_checks)(m)

with self.assertRaisesRegex(ValueError, 'nan generated by primitive: sin'):
err.throw()

if __name__ == '__main__':
absltest.main()

0 comments on commit d51a0ff

Please sign in to comment.