Skip to content

Commit

Permalink
Make code that parses callback errors more tolerant of message format.
Browse files Browse the repository at this point in the history
JAX is switching from pybind11 to nanobind, and nanobind formats this
error differently, with the traceback preceding the EqxRuntimeError
message. If we change the "At:" split to accept zero or more "At:"
pieces, things work as intended with both versions.
  • Loading branch information
hawkinsp authored and patrick-kidger committed Mar 1, 2024
1 parent 3425298 commit 1e60167
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __call__(self, /, *args, **kwargs):
(msg,) = e.args
if "EqxRuntimeError: " in msg:
_, msg = msg.split("EqxRuntimeError: ", 1)
msg, _ = msg.rsplit("\n\nAt:\n", 1)
msg, *_ = msg.rsplit("\n\nAt:\n", 1)
msg = msg + _eqx_on_error_msg
e.args = (msg,)
if jax.config.jax_traceback_filtering in ( # pyright: ignore
Expand Down

0 comments on commit 1e60167

Please sign in to comment.