Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InitVar and pickle break PyTreeDef equality #859

Open
HGangloff opened this issue Sep 20, 2024 · 1 comment
Open

InitVar and pickle break PyTreeDef equality #859

HGangloff opened this issue Sep 20, 2024 · 1 comment

Comments

@HGangloff
Copy link

Hi,

This is a similar issue to #857, it looks like there is a bad interaction between InitVar and pickle. Better practices than resorting to pickle are given in the documentation and I have been able to solve the issue and find a better workaround. But I think I should open the issue for the record.

Pickling model parameters containing an InitVar breaks PyTreeDef equality:

import pickle
from dataclasses import InitVar
from copy import deepcopy

import jax
import equinox as eqx
from jaxtyping import Key

class MLP2(eqx.Module):

    key: InitVar[Key] = eqx.field(kw_only=True)
    layers: list = eqx.field(init=False)

    def __post_init__(self, key):
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t


key = jax.random.PRNGKey(0)
mlp = MLP2(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

with open("parameters.pkl", "wb") as f:
    pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
    reloaded_params = pickle.load(f)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return False!

The above works by removing the InitVar:

class MLP1(eqx.Module):

    layers: list = eqx.field(init=False)

    def __init__(self, key):        
        self.layers = [eqx.nn.Linear(1, 50, key=key), jax.nn.relu]

    def __call__(self, t):
        for layer in self.layers:
            t = layer(t)
        return t

key = jax.random.PRNGKey(0)
mlp = MLP1(key=key)
params, static = eqx.partition(mlp, eqx.is_inexact_array)

with open("parameters.pkl", "wb") as f:
    pickle.dump(params, f)
with open("parameters.pkl", "rb") as f:
    reloaded_params = pickle.load(f)

print(jax.tree.flatten(params)[1] == jax.tree.flatten(reloaded_params)[1]) # return True
HGangloff added a commit to mia-jinns/jinns that referenced this issue Sep 20, 2024
…ave the nn_params in the pickle anymore but only the eq_params. See patrick-kidger/equinox#859
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants
@HGangloff and others