Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch Hooks in Equinox? #864

Open
samuelstevens opened this issue Sep 26, 2024 · 2 comments
Open

PyTorch Hooks in Equinox? #864

samuelstevens opened this issue Sep 26, 2024 · 2 comments
Labels
question User queries

Comments

@samuelstevens
Copy link

I would like to record some model activations in an architecture-invariant way.
In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).

Is there an equivalent strategy in Equinox?

One idea is to create a class Wrapper(eqx.Module) that simply wraps a module and calls some callback in __call__ with the underlying module's activations, then somehow replace modules in an equinox module.

class Wrapper(eqx.Module):
    wrapped: eqx.Module
    callback: ...
    def __init__(self, module, callback):
        self.wrapped = module
        self.callback = callback
   
   def __calll__(self, *args, **kwargs):
       outs = self.wrapped(*args, **kwargs)
       self.callback(outs)  # this would save to disk or something

Then in the main script, I could do something like:

model = MyViT()
for i in range(n_layers):
    model = eqx.tree_at(lambda m: m.layers[i].mlp, replace_fn=lambda m: Wrapper(m, my_callback))

Is there a better/more obvious way to do this?

@nasyxx
Copy link
Contributor

nasyxx commented Sep 26, 2024

You can use jax.tree.leaves to get all Modules you want.

For example, if you need linear:

is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_linear = lambda m: [x
                        for x in jax.tree.leaves(m, is_leaf=is_linear)
                        if is_linear(x)]
linears = get_linear(model)
wrapped = [Wrapper(m, callback) for m in linears]
eqx.tree_at(get_linear, model, wrapped)

However, I'm not sure if your callback could run in the jitted module.

@samuelstevens
Copy link
Author

Wow that's really neat, I can try that. I think I can use jax.debug.callback or jax.experimental.io_callback--not sure which will be better.

@patrick-kidger patrick-kidger added the question User queries label Sep 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants