From 7de34a2069bcda5d3e1a2c6d78a83d920466d3e3 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 11 May 2024 21:02:43 +0200 Subject: [PATCH] Fixes for compatibility with JAX version 0.4.28. - callback inputs are now `Array`s rather than `array`s. - `emit_python_callback` now uses a keyword argument for what was once a positional argument. Other nice-to-haves whilst I'm here: - Removed `filter_jit` passing function default arguments across the JIT boundary. In particular this is incompatible with `filter_jit(filter_vmap(eqx.nn.MLP(...)))`, as then it tries to pass the `MLP.__call__(..., *, key=...)` default argument through, but `filter_vmap` does not allow keyword arguments. - Added some `stacklevel`s to some warnings. --- equinox/_jit.py | 1 - equinox/_module.py | 6 ++++-- equinox/internal/_noinline.py | 2 +- tests/test_debug.py | 2 +- tests/test_noinline.py | 4 ++-- 5 files changed, 8 insertions(+), 7 deletions(-) diff --git a/equinox/_jit.py b/equinox/_jit.py index c2f8d409..9df51d72 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -58,7 +58,6 @@ def fun_wrapped(dynamic_donate, dynamic_nodonate, static): def _bind(signature, args, kwargs): bound = signature.bind(*args, **kwargs) - bound.apply_defaults() args = bound.args kwargs = bound.kwargs return args, kwargs diff --git a/equinox/_module.py b/equinox/_module.py index f2d0200c..4d420338 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -277,7 +277,8 @@ def __new__( "hold, then consider using `__check_init__` instead. This is an " "Equinox-specific extension that is always ran. See here for more " "details: " - "https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants" # noqa: E501 + "https://docs.kidger.site/equinox/api/module/advanced_fields/#checking-invariants", # noqa: E501 + stacklevel=2, ) # Add support for `eqx.field(converter=...)` when using `__post_init__`. @@ -747,7 +748,8 @@ def __init__(self, ...): def __call__(self, ...): ... = self.vmap_linear(...) ``` -""" +""", + stacklevel=3, ) break diff --git a/equinox/internal/_noinline.py b/equinox/internal/_noinline.py index 74d112eb..848d9819 100644 --- a/equinox/internal/_noinline.py +++ b/equinox/internal/_noinline.py @@ -362,7 +362,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs): vals_in, avals_in, ctx.avals_out, - False, + has_side_effect=False, sharding=None, ) ctx.module_context.add_keepalive(keepalive) diff --git a/tests/test_debug.py b/tests/test_debug.py index b8c338a1..0778a5b6 100644 --- a/tests/test_debug.py +++ b/tests/test_debug.py @@ -32,7 +32,7 @@ def f(x, terminate): text, _ = capfd.readouterr() assert ( text - == "foo:\n primals=array(1., dtype=float32)\ncotangents=array(nan, dtype=float32)\n" # noqa: E501 + == "foo:\n primals=Array(1., dtype=float32)\ncotangents=Array(nan, dtype=float32)\n" # noqa: E501 ) with pytest.raises(Exception): diff --git a/tests/test_noinline.py b/tests/test_noinline.py index f454fda5..74e19bd7 100644 --- a/tests/test_noinline.py +++ b/tests/test_noinline.py @@ -34,9 +34,9 @@ def test_mlp(getkey): def test_vmap(getkey): mlp = eqx.nn.MLP(2, 2, 512, 2, key=getkey()) mlp_noinline = eqx.internal.noinline(mlp) - mlp_vmap = jax.vmap(mlp) + mlp_vmap = eqx.filter_vmap(mlp) mlp_jit_vmap = eqx.filter_jit(mlp_vmap, donate="none") - mlp_vmap_noinline = jax.vmap(mlp_noinline) + mlp_vmap_noinline = eqx.filter_vmap(mlp_noinline) mlp_jit_vmap_noinline = eqx.filter_jit(mlp_vmap_noinline, donate="none") x = jr.normal(getkey(), (5, 2)) o1 = mlp_vmap(x)