-
-
Notifications
You must be signed in to change notification settings - Fork 139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch Hooks in Equinox? #864
Labels
question
User queries
Comments
You can use For example, if you need linear: is_linear = lambda x: isinstance(x, eqx.nn.Linear)
get_linear = lambda m: [x
for x in jax.tree.leaves(m, is_leaf=is_linear)
if is_linear(x)]
linears = get_linear(model)
wrapped = [Wrapper(m, callback) for m in linears]
eqx.tree_at(get_linear, model, wrapped) However, I'm not sure if your callback could run in the jitted module. |
Wow that's really neat, I can try that. I think I can use |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
I would like to record some model activations in an architecture-invariant way.
In PyTorch, we can use forward hooks to do this, by registering a hook on modules that match some criteria (maybe all modules that are an MLP class, for example).
Is there an equivalent strategy in Equinox?
One idea is to create a
class Wrapper(eqx.Module)
that simply wraps a module and calls some callback in__call__
with the underlying module's activations, then somehow replace modules in an equinox module.Then in the main script, I could do something like:
Is there a better/more obvious way to do this?
The text was updated successfully, but these errors were encountered: