Skip to content

Commit

Permalink
Errors can now be used before init_google if they only use literal Py…
Browse files Browse the repository at this point in the history
…thon scalars. Added `eqxi.assert_dce`.
  • Loading branch information
patrick-kidger committed Jul 26, 2023
1 parent 4e706e0 commit 31fbf9f
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 9 deletions.
41 changes: 34 additions & 7 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import warnings
from collections.abc import Sequence
from typing import cast, Literal, Union
from typing import Literal, Union

import jax
import jax._src.traceback_util as traceback_util
Expand Down Expand Up @@ -86,8 +86,9 @@ def to_nan(_index):
return jtu.tree_map(_nan_like, struct)

def handle_error():
index_struct = jax.eval_shape(lambda: index)
_index = jax.pure_callback( # pyright: ignore
display_msg, index, index, vectorized=True
display_msg, index_struct, index, vectorized=True
)
# Support JAX with and without DCE behaviour on breakpoints.
if "token" in inspect.signature(jax.debug.breakpoint).parameters.keys():
Expand Down Expand Up @@ -155,6 +156,7 @@ def error_if(
x: PyTree,
pred: Bool[ArrayLike, "..."],
msg: str,
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
) -> PyTree:
"""Throws an error based on runtime values. Works even under JIT.
Expand Down Expand Up @@ -199,7 +201,7 @@ def f(x):
f(jax.numpy.array(-1))
```
"""
return branched_error_if(x, pred, 0, [msg], on_error)
return branched_error_if(x, pred, 0, [msg], on_error=on_error)


@doc_remove_args("on_error")
Expand All @@ -208,6 +210,7 @@ def branched_error_if(
pred: Bool[ArrayLike, "..."],
index: Int[ArrayLike, "..."],
msgs: Sequence[str],
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
) -> PyTree:
"""As [`equinox.error_if`][], but will raise one of
Expand All @@ -222,14 +225,22 @@ def branched_error_if(
if on_error not in ("raise", "breakpoint", "nan"):
raise RuntimeError("Unrecognised value for `on_error`.")
with jax.ensure_compile_time_eval():
pred = unvmap_any(pred)
index = unvmap_max(index)
# This carefully does not perform any JAX operations if `pred` and `index` are
# a bool and an int.
# This ensures we can use `error_if` before init_google.
if not isinstance(pred, bool):
pred = unvmap_any(pred)
if not isinstance(index, int):
index = unvmap_max(index)
if not isinstance(pred, jax.core.Tracer):
if pred.item():
if isinstance(pred, Array):
pred = pred.item()
assert type(pred) is bool
if pred:
if not isinstance(index, jax.core.Tracer):
if isinstance(index, Array):
index = index.item()
index = cast(int, index)
assert type(index) is int
warnings.warn(
"`Error can be resolved statically. Handling at trace-time "
"rather than waiting until runtime."
Expand All @@ -253,3 +264,19 @@ def branched_error_if(
raise ValueError("No arrays to thread error on to.")
dynamic_x = _error(dynamic_x, pred, index, msgs=msgs, on_error=on_error)
return combine(dynamic_x, static_x)


def assert_dce(
x: PyTree,
msg: str,
*,
on_error: Literal["default", "raise", "breakpoint", "nan"] = "default",
) -> PyTree:
"""Asserts that a particular array (or PyTree of arrays) is DCE'd."""

if _currently_jitting():
pred = jnp.invert(False) # Prevent the trace-time error-raising from running.
return error_if(x, pred, msg, on_error=on_error)
else:
# Don't run if not JIT'ing, as without the compiler nothing will be DCE'd.
return x
3 changes: 1 addition & 2 deletions equinox/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Backward compatibility: expose via `equinox.internal`. Now available under `equinox`.
from .._errors import (
assert_dce as assert_dce,
branched_error_if as branched_error_if,
error_if as error_if,
)
Expand All @@ -39,8 +40,6 @@
inspect_dce as inspect_dce,
store_dce as store_dce,
)

# Note that `announce_jaxpr_p` should be exposed here regardless.
from ..debug._announce_transform import announce_jaxpr_p as announce_jaxpr_p
from ._finalise_jaxpr import (
finalise_eval_jaxpr as finalise_eval_jaxpr,
Expand Down
19 changes: 19 additions & 0 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,22 @@ def f(x, pred):

y = f(1.0, True)
assert jnp.isnan(y)


def test_assert_dce():
@jax.jit
def f(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x

f(1.0)

@jax.jit
def g(x):
x = x + 1
eqxi.assert_dce(x, msg="oh no")
return x

with jax.disable_jit():
g(1.0)

0 comments on commit 31fbf9f

Please sign in to comment.