-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modules got silently "reused" with hk.vmap
#740
Comments
Hi @jjyyxx , this is indeed confusing behaviour. Changing this would be backwards incompatible with all existing usages of
If you need to use def vmap_with_reuse(f, *, name: str | None = None):
f = hk.vmap(f, split_rng=(not hk.running_init()))
f = hk.to_module(f)
return lambda *a, **k: f(name=name)(*a, **k)
def f3(x):
def g(x):
return hk.Linear(2)(x)
x = vmap_with_reuse(g)(x)
x = vmap_with_reuse(g)(x)
return x
# w3: dict_keys(['g/linear', 'g_1/linear']) |
Thanks for your suggestion! Indeed, I found that However, you mentioned that
So, if only |
I have to admit that I do not fully understand the necessity of
hk.vmap
instead ofjax.vmap
. Nevertheless, when I need to vmap something, I would usehk.vmap
whenever the inner function contains calls to haiku modules. This works OK, until I debug the bad performance of a transformer model. Things boils down to the following snippetIt turns out that when
g
is vmapped, modules created insideg
would reuse a previously created module. In some cases, errors would happen immediately due to incompatible shape, but in other cases (for me, transformer layers have quite consistent shapes), things went wrong silently.My question: Is this behavior intended? Could the documentation be improved on this topic? Or am I missing something?
The text was updated successfully, but these errors were encountered: