Skip to content

Commit

Permalink
Merge pull request #439 from Routhleck/docs
Browse files Browse the repository at this point in the history
[docs] Update custom saving and loading
  • Loading branch information
chaoming0625 authored Aug 21, 2023
2 parents 25502e8 + a903ddd commit e9c0298
Showing 1 changed file with 41 additions and 2 deletions.
43 changes: 41 additions & 2 deletions docs/tutorial_toolbox/saving_and_loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,48 @@
}
},
"source": [
"You can make your own saving and loading functions easily. Beacause all variables in the model can be easily collected through ``.vars()``. Therefore, saving variables is just transforming these variables to numpy.ndarray and then storing them into the disk. Similarly, to load variables, you just need read the numpy arrays from the disk and then transform these arrays as instances of [Variables](../tutorial_math/variables.ipynb). \n",
"You can make your own saving and loading functions easily.\n",
"\n",
"The only gotcha to pay attention to is to avoid saving duplicated variables. "
"For customizing the saving and loading, users can overwrite ``__save_state__`` and ``__load_state__`` functions.\n",
"\n",
"Here is an example to customize:\n",
"```python\n",
"class YourClass(bp.DynamicSystem):\n",
" def __init__(self):\n",
" self.a = 1\n",
" self.b = bm.random.rand(10)\n",
" self.c = bm.Variable(bm.random.rand(3))\n",
" self.d = bm.var_list([bm.Variable(bm.random.rand(3)),\n",
" bm.Variable(bm.random.rand(3))])\n",
"\n",
" def __save_state__(self) -> dict:\n",
" state_dict = {'a': self.a,\n",
" 'b': self.b,\n",
" 'c': self.c}\n",
" for i, elem in enumerate(self.d):\n",
" state_dict[f'd{i}'] = elem.value\n",
"\n",
" return state_dict\n",
"\n",
" def __load_state__(self, state_dict):\n",
" self.a = state_dict['a']\n",
" self.b = bm.asarray(state_dict['b'])\n",
" self.c = bm.asarray(state_dict['c'])\n",
"\n",
" for i in range(len(self.d)):\n",
" self.d[i].value = bm.asarray(state_dict[f'd{i}'])\n",
"```\n",
"\n",
"\n",
"- ``__save_state__(self)`` function saves the state of the object's variables and returns a dictionary where the keys are the names of the variables and the values are the variables' contents.\n",
"\n",
"- ``__load_state__(self, state_dict: Dict)`` function loads the state of the object's variables from a provided dictionary (``state_dict``). \n",
"At firstly it gets the current variables of the object.\n",
"Then, it determines the intersection of keys from the provided state_dict and the object's variables.\n",
"For each intersecting key, it updates the value of the object's variable with the value from state_dict.\n",
"Finally, returns A tuple containing two lists:\n",
" - ``unexpected_keys``: Keys in state_dict that were not found in the object's variables.\n",
" - ``missing_keys``: Keys that are in the object's variables but were not found in state_dict."
]
}
],
Expand Down

0 comments on commit e9c0298

Please sign in to comment.