Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replicating jax's export feature to cache filter_jitted functions #879

Open
aeftimia opened this issue Oct 12, 2024 · 4 comments
Open

Replicating jax's export feature to cache filter_jitted functions #879

aeftimia opened this issue Oct 12, 2024 · 4 comments
Labels
feature New feature

Comments

@aeftimia
Copy link

Jax allows you to serialize and deserialize a jitted function as described here.

https://jax.readthedocs.io/en/latest/_autosummary/jax.export.export.html#jax.export.export

I tried this for a filter_jitted function, but received this error.

Function to be exported must be the result of `jit` but is: _JitWrapper(
  fn='simulate',
  filter_warning=False,
  donate_first=False,
  donate_rest=False
)%

Is it possible to replicate serialization like this for filter_jit? I wasn't sure if that was even theoretically possible, but figured it was at least worth asking.

@patrick-kidger
Copy link
Owner

This should definitely be possible, we'd just have to write an API that wraps the existing jax.export.{export, deserialize}.

Under-the-hood eqx.filter_jit is basically just wrapping jax.jit with nicer behaviour. We could arrange to unwrap that, run jax.export.export, and then save any additional metadata. And conversely an analogous filter_deserialize function could then read the result, and package things back up.

I'd be happy to take a PR on this.

@patrick-kidger patrick-kidger added the feature New feature label Oct 12, 2024
@aeftimia
Copy link
Author

aeftimia commented Oct 12, 2024

When I try

from jax import export

class MyClass(eqx.Module):
  @eqx.filter_jit
  def fn(self, ...):
       ....

model = MyClass(...)
exported = export.export(model.fn._cached)
exported(*args, **kwargs).serialize()

I get this error.

Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jaxlib.xla_extension.ArrayImpl'> for function simulate is non-hashable.%

I'm guessing all the extra stuff that filter_jit turns into static args wouldn't be supported for serialization? Do you think there is any way around that?

@aeftimia
Copy link
Author

aeftimia commented Oct 12, 2024

Just following up with the minimal example to replicate the problem.

class Test(eqx.Module):

    @eqx.filter_jit
    def fn(self, data):
        return data

obj = Test()
fn = obj.fn
x = jnp.array(3.0)
fn(x)
exported = export.export(fn.func._cached)
exported(*fn.args, obj, x, **fn.keywords).serialize()

@patrick-kidger
Copy link
Owner

Right, so this is because the internal JIT'd function (fn.func._cached) does not have the same signature as the original function.

This is kind of the whole point of filter_jit: it automatically looks at your arguments, then splits them into three groups of 'arguments that should be traced and donated', 'arguments that should be traced and not donated', and 'arguments that are static', and then passes them across the JIT boundary in those groups.

This is the function that's actually JIT'd:

def fun_wrapped(dynamic_donate, dynamic_nodonate, static):

and here is where they are split up in this way:

equinox/equinox/_jit.py

Lines 220 to 222 in d9b3ffd

dynamic_donate, dynamic_nodonate, static = _preprocess( # pyright: ignore
info, args, kwargs, return_static=True
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants