Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] Update custom saving and loading #439

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions docs/tutorial_toolbox/saving_and_loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,48 @@
}
},
"source": [
"You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through ``.vars()``. Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of [Variables](../tutorial_math/variables.ipynb). \n",
"You can make your own saving and loading functions easily.\n",
"\n",
"The only gotcha to pay attention to is to avoid saving duplicated variables. "
"For customizing the saving and loading, users can overwrite ``__save_state__`` and ``__load_state__`` functions.\n",
"\n",
"Here is an example to customize:\n",
"```python\n",
"class YourClass(bp.DynamicSystem):\n",
" def __init__(self):\n",
" self.a = 1\n",
" self.b = bm.random.rand(10)\n",
" self.c = bm.Variable(bm.random.rand(3))\n",
" self.d = bm.var_list([bm.Variable(bm.random.rand(3)),\n",
" bm.Variable(bm.random.rand(3))])\n",
"\n",
" def __save_state__(self) -> dict:\n",
" state_dict = {'a': self.a,\n",
" 'b': self.b,\n",
" 'c': self.c}\n",
" for i, elem in enumerate(self.d):\n",
" state_dict[f'd{i}'] = elem.value\n",
"\n",
" return state_dict\n",
"\n",
" def __load_state__(self, state_dict):\n",
" self.a = state_dict['a']\n",
" self.b = bm.asarray(state_dict['b'])\n",
" self.c = bm.asarray(state_dict['c'])\n",
"\n",
" for i in range(len(self.d)):\n",
" self.d[i].value = bm.asarray(state_dict[f'd{i}'])\n",
"```\n",
"\n",
"\n",
"- ``__save_state__(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n",
"\n",
"- ``__load_state__(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n",
"At firstly it gets the current variables of the object.\n",
"Then, it determines the intersection of keys from the provided state_dict and the object's variables.\n",
"For each intersecting key, it updates the value of the object's variable with the value from state_dict.\n",
"Finally, returns A tuple containing two lists:\n",
" - ``unexpected_keys``: Keys in state_dict that were not found in the object's variables.\n",
" - ``missing_keys``: Keys that are in the object's variables but were not found in state_dict."
]
}
],
Expand Down