diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index 97e30d7..e481a92 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -766,7 +766,9 @@ def _keep_in_algo(k, v): masked_grads = grads else: masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()} - flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads) + flat_grads, treedef = jax.tree_util.tree_flatten( + masked_grads, is_leaf=lambda x: x is None + ) flat_opt_state = jax.tree_util.tree_map( lambda _, x: x # pylint:disable=g-long-lambda if isinstance(x, (np.ndarray, jax.Array))