Skip to content

Commit

Permalink
Fixes for compatibility with JAX version 0.4.28.
Browse files Browse the repository at this point in the history
- callback inputs are now `Array`s rather than `array`s.
- `emit_python_callback` now uses a keyword argument for what was once a positional argument.

Other nice-to-haves whilst I'm here:

- Removed `filter_jit` passing function default arguments across the JIT boundary. In particular this is incompatible with `filter_jit(filter_vmap(eqx.nn.MLP(...)))`, as then it tries to pass the `MLP.__call__(..., *, key=...)` default argument through, but `filter_vmap` does not allow keyword arguments.
- Added some `stacklevel`s to some warnings.
  • Loading branch information
patrick-kidger committed May 12, 2024
1 parent 35d7454 commit 7de34a2
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
1 change: 0 additions & 1 deletion equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def fun_wrapped(dynamic_donate, dynamic_nodonate, static):

def _bind(signature, args, kwargs):
bound = signature.bind(*args, **kwargs)
bound.apply_defaults()
args = bound.args
kwargs = bound.kwargs
return args, kwargs
Expand Down
6 changes: 4 additions & 2 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ def __new__(
"hold, then consider using `__check_init__` instead. This is an "
"Equinox-specific extension that is always ran. See here for more "
"details: "
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501
"https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants", # noqa: E501
stacklevel=2,
)

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
Expand Down Expand Up @@ -747,7 +748,8 @@ def __init__(self, ...):
def __call__(self, ...):
... = self.vmap_linear(...)
```
"""
""",
stacklevel=3,
)
break

Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
vals_in,
avals_in,
ctx.avals_out,
False,
has_side_effect=False,
sharding=None,
)
ctx.module_context.add_keepalive(keepalive)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def f(x, terminate):
text, _ = capfd.readouterr()
assert (
text
== "foo:\n primals=array(1., dtype=float32)\ncotangents=array(nan, dtype=float32)\n" # noqa: E501
== "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501
)

with pytest.raises(Exception):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_mlp(getkey):
def test_vmap(getkey):
mlp = eqx.nn.MLP(2, 2, 512, 2, key=getkey())
mlp_noinline = eqx.internal.noinline(mlp)
mlp_vmap = jax.vmap(mlp)
mlp_vmap = eqx.filter_vmap(mlp)
mlp_jit_vmap = eqx.filter_jit(mlp_vmap, donate="none")
mlp_vmap_noinline = jax.vmap(mlp_noinline)
mlp_vmap_noinline = eqx.filter_vmap(mlp_noinline)
mlp_jit_vmap_noinline = eqx.filter_jit(mlp_vmap_noinline, donate="none")
x = jr.normal(getkey(), (5, 2))
o1 = mlp_vmap(x)
Expand Down

0 comments on commit 7de34a2

Please sign in to comment.