-
Notifications
You must be signed in to change notification settings - Fork 3
LIP0005
LIP | 5 |
---|---|
Title | Parameter Value Loading |
Author | A. Pronobis |
Status | Draft |
Type | Standard |
Discussion | Issue #29 |
PR | |
Created | Sep 1, 2017 |
LibSPN offers model saving and loading. It is designed to save a model once and then instantiate a whole new SPN graph based on the saved data. However, often, one wants to save the model multiple times (e.g. every n-th epoch). Similarly, it would be useful to allow for model parameter loading without the need to instantiate a new graph or add new TF operations. Here, we propose to make the necessary changes to achieve that.
This proposal depends on LIP3 which introduces interface for parameter assignments and LIP2 which introduces parameter snapshots.
We propose a modification to the Saver
class. The interface should remain:
__init__(path, ...)
save(root, ...)
however, path
should now accept a replacement field {}
, which if present, will be replaced with a number incremented for each new file created. If the replacement field is missing, any existing file will be overwritten. The extension in the path should be optional. If extension is missing, extension .spn
should be used. Relative paths and home folder expansion should be supported.
We propose a modification to the Loader
class. The interface should be:
__init__(path, ...)
load(...)
-
load_params(root)
- Load only the parameter values into an existing graph without adding any new TF operations. Assumes that the structure of the saved and instantiated graph are the same (the names of nodes do not have to be the same, just the structure).
Furthermore, path
should again accept a replacement field {}
, which if present, will be replaced with the largest number associated with a matching, existing file.
Create and initialize an SPN graph:
model = spn.Poon11NaiveMixtureModel()
root = model1.build()
init = spn.initialize_weights(root1)
Save the graph multiple times (e.g. every epoch):
saver = spn.Saver("~/my_model{}")
for i in range(num_epochs):
...
saver.save(root)
Load a new graph instance:
loader = spn.Loader("~/my_model{}")
new_root = loader.load()
Load parameters only:
loader.load_params(new_root)