Skip to content

Commit

Permalink
🔀 Merge pull request #28 from alvarobartt/objax-variable-fix
Browse files Browse the repository at this point in the history
🐛 Fix issue when restoring a `VarCollection`
  • Loading branch information
alvarobartt authored Jan 5, 2023
2 parents ecf6a3e + a12c3d4 commit 67e77de
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 69 deletions.
35 changes: 13 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,25 @@ pip install safejax --upgrade
* Convert `params` to `bytes` in memory

```python
from safejax import serialize, deserialize
from safejax.flax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes, freeze_dict=True)
decoded_params = deserialize(encoded_bytes)

model.apply(decoded_params, ...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize
from safejax.flax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", freeze_dict=True)
decoded_params = deserialize("./params.safetensors")

model.apply(decoded_params, ...)
```
Expand All @@ -54,7 +54,7 @@ pip install safejax --upgrade
* Just contains `params`

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params = model.init(...)

Expand All @@ -67,7 +67,7 @@ pip install safejax --upgrade
* If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params, state = model.init(...)
params_state = {"params": params, "state": state}
Expand All @@ -81,7 +81,7 @@ pip install safejax --upgrade
* If it contains `params` and `state`, but we want to serialize those individually

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params, state = model.init(...)

Expand All @@ -101,35 +101,33 @@ pip install safejax --upgrade
* Convert `params` to `bytes` in memory, and convert back to `VarCollection`

```python
from safejax import serialize, deserialize
from safejax.objax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params)
decoded_params = deserialize(
encoded_bytes, requires_unflattening=False, to_var_collection=True
)
decoded_params = deserialize(encoded_bytes)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)
model.vars()[key].assign(value.value)

model(...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize
from safejax.objax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", requires_unflattening=False)
decoded_params = deserialize("./params.safetensors")

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)
model.vars()[key].assign(value.value)

model(...)
```
Expand All @@ -149,13 +147,6 @@ pip install safejax --upgrade

---

📌 As you may have seen in the examples above, most of those codeblocks are imporing both
`serialize` and `deserialize` from `safejax`, but as some of those expect params with respect
to the JAX framework that we're using, we can just import those from their files to avoid
defining the params over and over e.g. instead of `from safejax import deserialize, serialize`,
we can just import `from safejax.flax import deserialize, serialize`, and skip the function
params, so that the only input param that we need to provide are the params themselves.

More in-detail examples can be found at [`examples/`](./examples) for `flax`, `dm-haiku`, and `objax`.

## 🤔 Why `safejax`?
Expand Down
35 changes: 13 additions & 22 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,25 @@
* Convert `params` to `bytes` in memory

```python
from safejax import serialize, deserialize
from safejax.flax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params)
decoded_params = deserialize(encoded_bytes, freeze_dict=True)
decoded_params = deserialize(encoded_bytes)

model.apply(decoded_params, ...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize
from safejax.flax import serialize, deserialize

params = model.init(...)

encoded_bytes = serialize(params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", freeze_dict=True)
decoded_params = deserialize("./params.safetensors")

model.apply(decoded_params, ...)
```
Expand All @@ -35,7 +35,7 @@
* Just contains `params`

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params = model.init(...)

Expand All @@ -48,7 +48,7 @@
* If it contains `params` and `state` e.g. ExponentialMovingAverage in BatchNorm

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params, state = model.init(...)
params_state = {"params": params, "state": state}
Expand All @@ -62,7 +62,7 @@
* If it contains `params` and `state`, but we want to serialize those individually

```python
from safejax import serialize, deserialize
from safejax.haiku import serialize, deserialize

params, state = model.init(...)

Expand All @@ -82,35 +82,33 @@
* Convert `params` to `bytes` in memory, and convert back to `VarCollection`

```python
from safejax import serialize, deserialize
from safejax.objax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params)
decoded_params = deserialize(
encoded_bytes, requires_unflattening=False, to_var_collection=True
)
decoded_params = deserialize(encoded_bytes)

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)
model.vars()[key].assign(value.value)

model(...)
```

* Convert `params` to `bytes` in `params.safetensors` file

```python
from safejax import serialize, deserialize
from safejax.objax import serialize, deserialize

params = model.vars()

encoded_bytes = serialize(params=params, filename="./params.safetensors")
decoded_params = deserialize("./params.safetensors", requires_unflattening=False)
decoded_params = deserialize("./params.safetensors")

for key, value in decoded_params.items():
if key in model.vars():
model.vars()[key].assign(value)
model.vars()[key].assign(value.value)

model(...)
```
Expand All @@ -130,12 +128,5 @@

---

📌 As you may have seen in the examples above, most of those codeblocks are imporing both
`serialize` and `deserialize` from `safejax`, but as some of those expect params with respect
to the JAX framework that we're using, we can just import those from their files to avoid
defining the params over and over e.g. instead of `from safejax import deserialize, serialize`,
we can just import `from safejax.flax import deserialize, serialize`, and skip the function
params, so that the only input param that we need to provide are the params themselves.

More in-detail examples can be found at [`examples/`](https://github.com/alvarobartt/safejax/examples)
for `flax`, `dm-haiku`, and `objax`.
11 changes: 8 additions & 3 deletions src/safejax/core/load.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from pathlib import Path
from typing import Union

Expand All @@ -7,7 +8,7 @@
from safetensors.flax import load, load_file

from safejax.typing import ParamsDictLike, PathLike
from safejax.utils import unflatten_dict
from safejax.utils import cast_objax_variables, unflatten_dict


def deserialize(
Expand Down Expand Up @@ -71,10 +72,14 @@ def deserialize(
"`path_or_buf` must be a `bytes` object or a file path (`str` or"
f" `pathlib.Path` object), not {type(path_or_buf)}."
)
if to_var_collection:
try:
return VarCollection(cast_objax_variables(params=decoded_params))
except ValueError as e:
warnings.warn(e)
return decoded_params
if requires_unflattening:
decoded_params = unflatten_dict(params=decoded_params)
if freeze_dict:
return freeze(decoded_params)
if to_var_collection:
return VarCollection(decoded_params)
return decoded_params
6 changes: 5 additions & 1 deletion src/safejax/core/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

def serialize(
params: ParamsDictLike,
include_objax_variables: bool = False,
filename: Union[PathLike, None] = None,
fs: Union[AbstractFileSystem, None] = None,
) -> Union[bytes, PathLike]:
Expand All @@ -23,13 +24,16 @@ def serialize(
Args:
params: A `FrozenDict`, a `Dict` or a `VarCollection` containing the model params.
include_objax_variables: Whether to include `objax.Variable` objects in the serialized model params.
filename: The path to the file where the model params will be saved.
fs: The filesystem to use to save the model params. Defaults to `None`.
Returns:
The serialized model params as a `bytes` object or the path to the file where the model params were saved.
"""
params = flatten_dict(params=params)
params = flatten_dict(
params=params, include_objax_variables=include_objax_variables
)

if filename:
if not isinstance(filename, (str, Path)):
Expand Down
34 changes: 28 additions & 6 deletions src/safejax/objax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,21 @@
from safejax.core.load import deserialize
from safejax.core.save import serialize # noqa: F401
from safejax.typing import PathLike
from safejax.utils import OBJAX_VARIABLE_SEPARATOR

# `objax` params are usually defined as a `VarCollection`, and that's basically a dictionary with
# key-value pairs where the value is either a `BaseVar` or a `StateVar`. The problem is that when
# serializing those variables by default we just keep the value which is a `jnp.DeviceArray`, so we
# need to provide `include_objax_variables=True` to store the variable type names as part of the key
# using `::` as the separator. This is useful when deserializing the params, as we can restore a
# `VarCollection` object instead of a `Dict[str, jnp.DeviceArray]`.
serialize = partial(serialize, include_objax_variables=True)

# `objax` expects either a `Dict[str, jnp.DeviceArray]` or a `VarCollection` as model params
# which means any other type of `Dict` will not work. This is why we need to set `requires_unflattening`
# to `False` and `to_var_collection` to `True` to restore a `VarCollection`, but the later could be skipped.
# which means any other type of `Dict` will not work. The only difference is that `VarCollection` can
# be assigned directly to `.vars()` while `Dict[str, jnp.DeviceArray]` needs to be manually assigned
# when looping over `.vars()`. Ideally, we want to restore the params from a `VarCollection`, that's why
# we've set the `to_var_collection` parameter to `True` by default.
deserialize = partial(deserialize, requires_unflattening=False, to_var_collection=True)


Expand All @@ -22,6 +33,12 @@
def deserialize_with_assignment(filename: PathLike, model_vars: VarCollection) -> None:
"""Deserialize a `VarCollection` from a file and assign it to a `VarCollection` object.
Note:
This function avoid some known issues related to the variable deserialization with `objax`,
since the params are stored in a `VarCollection` object, which contains some `objax.variable`
variables instead of key-tensor pais. So this way we avoid having to restore the `objax.variable`
type per each value.
Args:
filename: Path to the file containing the serialized `VarCollection` as a `Dict[str, jnp.DeviceArray]`.
model_vars: `VarCollection` object to which the deserialized tensors will be assigned.
Expand All @@ -40,7 +57,12 @@ def deserialize_with_assignment(filename: PathLike, model_vars: VarCollection) -
if not filename.exists or not filename.is_file:
raise ValueError(f"`filename` must be a valid file path, not {filename}.")
with safe_open(filename.as_posix(), framework="jax") as f:
for k in f.keys():
if k not in model_vars.keys():
raise ValueError(f"Variable with name {k} not found in model_vars.")
model_vars[k].assign(f.get_tensor(k))
for key in f.keys():
just_key = (
key.split(OBJAX_VARIABLE_SEPARATOR)[0]
if OBJAX_VARIABLE_SEPARATOR in key
else key
)
if just_key not in model_vars.keys():
raise ValueError(f"Variable with name {key} not found in model_vars.")
model_vars[just_key].assign(f.get_tensor(key))
3 changes: 2 additions & 1 deletion src/safejax/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

JaxDeviceArrayDict = Dict[str, jnp.DeviceArray]
HaikuParams = Dict[str, JaxDeviceArrayDict]
ObjaxParams = Union[VarCollection, Dict[str, Union[BaseVar, StateVar]]]
ObjaxDict = Dict[str, Union[BaseVar, StateVar]]
ObjaxParams = Union[VarCollection, ObjaxDict]
FlaxParams = Union[Dict[str, Union[Dict, JaxDeviceArrayDict]], FrozenDict]

ParamsDictLike = Union[JaxDeviceArrayDict, HaikuParams, ObjaxParams, FlaxParams]
Loading

0 comments on commit 67e77de

Please sign in to comment.