Skip to content

Commit

Permalink
Merge pull request #18 from SubstraFoundation/composite-separate-methods
Browse files Browse the repository at this point in the history
Algo: update serializers method
  • Loading branch information
samlesu authored Dec 12, 2019
2 parents fa50eb4 + c026d17 commit 1c7bec7
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 24 deletions.
92 changes: 84 additions & 8 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,10 @@ following abstract methods:

- `CompositeAlgo.train()`
- `CompositeAlgo.predict()`
- `CompositeAlgo.load_model()`
- `CompositeAlgo.save_model()`
- `CompositeAlgo.load_head_model()`
- `CompositeAlgo.save_head_model()`
- `CompositeAlgo.load_trunk_model()`
- `CompositeAlgo.save_trunk_model()`

To add a composite algo to the Substra Platform, the line
`tools.algo.execute(<CompositeAlgoClass>())` must be added to the main of the algo
Expand All @@ -178,10 +180,16 @@ class DummyCompositeAlgo(tools.CompositeAlgo):
predictions = 0
return predictions

def load_model(self, path):
def load_head_model(self, path):
return json.load(path)

def save_model(self, model, path):
def save_head_model(self, model, path):
json.dump(model, path)

def load_trunk_model(self, path):
return json.load(path)

def save_trunk_model(self, model, path):
json.dump(model, path)


Expand All @@ -203,8 +211,8 @@ __Arguments__

- __X__: training data samples loaded with `Opener.get_X()`.
- __y__: training data samples labels loaded with `Opener.get_y()`.
- __head_model__: head model loaded with `CompositeAlgo.load_model()` (may be None).
- __trunk_model__: trunk model loaded with `CompositeAlgo.load_model()` (may be None).
- __head_model__: head model loaded with `CompositeAlgo.load_head_model()` (may be None).
- __trunk_model__: trunk model loaded with `CompositeAlgo.load_trunk_model()` (may be None).
- __rank__: rank of the training task.

__Returns__
Expand All @@ -225,14 +233,82 @@ __Arguments__


- __X__: testing data samples loaded with `Opener.get_X()`.
- __head_model__: head model loaded with `CompositeAlgo.load_model()`.
- __trunk_model__: trunk model loaded with `CompositeAlgo.load_model()`.
- __head_model__: head model loaded with `CompositeAlgo.load_head_model()`.
- __trunk_model__: trunk model loaded with `CompositeAlgo.load_trunk_model()`.

__Returns__


`predictions`: predictions object.

## load_head_model
```python
CompositeAlgo.load_head_model(self, path)
```
Deserialize head model from file.

This method will be executed before the call to the methods
`Algo.train()` and `Algo.predict()` to deserialize the model objects.

__Arguments__


- __path__: path of the model to load.

__Returns__


`model`: the deserialized model object.

## save_head_model
```python
CompositeAlgo.save_head_model(self, model, path)
```
Serialize head model in file.

This method will be executed after the call to the methods
`Algo.train()` and `Algo.predict()` to save the model objects.

__Arguments__


- __path__: path of file to write.
- __model__: the model to serialize.

## load_trunk_model
```python
CompositeAlgo.load_trunk_model(self, path)
```
Deserialize trunk model from file.

This method will be executed before the call to the methods
`Algo.train()` and `Algo.predict()` to deserialize the model objects.

__Arguments__


- __path__: path of the model to load.

__Returns__


`model`: the deserialized model object.

## save_trunk_model
```python
CompositeAlgo.save_trunk_model(self, model, path)
```
Serialize trunk model in file.

This method will be executed after the call to the methods
`Algo.train()` and `Algo.predict()` to save the model objects.

__Arguments__


- __path__: path of file to write.
- __model__: the model to serialize.

# AggregateAlgo
```python
AggregateAlgo(self, /, *args, **kwargs)
Expand Down
124 changes: 110 additions & 14 deletions substratools/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ def save_model(self, model, path):

class AlgoWrapper(object):
"""Algo wrapper to execute an algo instance on the platform."""
_INTERFACE_CLASS = Algo
_DEFAULT_WORKSPACE_CLASS = AlgoWorkspace

def __init__(self, interface, workspace=None, opener_wrapper=None):
assert isinstance(interface, Algo)
assert isinstance(interface, self._INTERFACE_CLASS)
self._workspace = workspace or self._DEFAULT_WORKSPACE_CLASS()
self._opener_wrapper = opener_wrapper or \
opener.load_from_module(workspace=self._workspace)
Expand Down Expand Up @@ -334,16 +335,18 @@ def _predict(args):
return parser


class CompositeAlgo(Algo):
class CompositeAlgo(abc.ABC):
"""Abstract base class for defining a composite algo to run on the platform.
To define a new composite algo script, subclass this class and implement the
following abstract methods:
- #CompositeAlgo.train()
- #CompositeAlgo.predict()
- #CompositeAlgo.load_model()
- #CompositeAlgo.save_model()
- #CompositeAlgo.load_head_model()
- #CompositeAlgo.save_head_model()
- #CompositeAlgo.load_trunk_model()
- #CompositeAlgo.save_trunk_model()
To add a composite algo to the Substra Platform, the line
`tools.algo.execute(<CompositeAlgoClass>())` must be added to the main of the algo
Expand All @@ -368,10 +371,16 @@ def predict(self, X, head_model, trunk_model):
predictions = 0
return predictions
def load_model(self, path):
def load_head_model(self, path):
return json.load(path)
def save_model(self, model, path):
def save_head_model(self, model, path):
json.dump(model, path)
def load_trunk_model(self, path):
return json.load(path)
def save_trunk_model(self, model, path):
json.dump(model, path)
Expand All @@ -391,8 +400,8 @@ def train(self, X, y, head_model, trunk_model, rank):
X: training data samples loaded with `Opener.get_X()`.
y: training data samples labels loaded with `Opener.get_y()`.
head_model: head model loaded with `CompositeAlgo.load_model()` (may be None).
trunk_model: trunk model loaded with `CompositeAlgo.load_model()` (may be None).
head_model: head model loaded with `CompositeAlgo.load_head_model()` (may be None).
trunk_model: trunk model loaded with `CompositeAlgo.load_trunk_model()` (may be None).
rank: rank of the training task.
# Returns
Expand All @@ -411,18 +420,105 @@ def predict(self, X, head_model, trunk_model):
# Arguments
X: testing data samples loaded with `Opener.get_X()`.
head_model: head model loaded with `CompositeAlgo.load_model()`.
trunk_model: trunk model loaded with `CompositeAlgo.load_model()`.
head_model: head model loaded with `CompositeAlgo.load_head_model()`.
trunk_model: trunk model loaded with `CompositeAlgo.load_trunk_model()`.
# Returns
predictions: predictions object.
"""
raise NotImplementedError

def _train_fake_data(self, *args, **kwargs):
"""Train model fake data mode.
This method is called by the algorithm wrapper when the fake data mode
is enabled. In fake data mode, `X` and `y` input args have been
replaced by the opener fake data.
By default, it only calls directly `Algo.train()` method. Override this
method if you want to implement a different behaviour.
"""
return self.train(*args, **kwargs)

def _predict_fake_data(self, *args, **kwargs):
"""Predict model fake data mode.
This method is called by the algorithm wrapper when the fake data mode
is enabled. In fake data mode, `X` input arg has been replaced by
the opener fake data.
By default, it only calls directly `Algo.predict()` method. Override
this method if you want to implement a different behaviour.
"""
return self.predict(*args, **kwargs)

@abc.abstractmethod
def load_head_model(self, path):
"""Deserialize head model from file.
This method will be executed before the call to the methods
`Algo.train()` and `Algo.predict()` to deserialize the model objects.
# Arguments
path: path of the model to load.
# Returns
model: the deserialized model object.
"""
raise NotImplementedError

@abc.abstractmethod
def save_head_model(self, model, path):
"""Serialize head model in file.
This method will be executed after the call to the methods
`Algo.train()` and `Algo.predict()` to save the model objects.
# Arguments
path: path of file to write.
model: the model to serialize.
"""
raise NotImplementedError

@abc.abstractmethod
def load_trunk_model(self, path):
"""Deserialize trunk model from file.
This method will be executed before the call to the methods
`Algo.train()` and `Algo.predict()` to deserialize the model objects.
# Arguments
path: path of the model to load.
# Returns
model: the deserialized model object.
"""
raise NotImplementedError

@abc.abstractmethod
def save_trunk_model(self, model, path):
"""Serialize trunk model in file.
This method will be executed after the call to the methods
`Algo.train()` and `Algo.predict()` to save the model objects.
# Arguments
path: path of file to write.
model: the model to serialize.
"""
raise NotImplementedError


class CompositeAlgoWrapper(AlgoWrapper):
"""Algo wrapper to execute an algo instance on the platform."""
_INTERFACE_CLASS = CompositeAlgo
_DEFAULT_WORKSPACE_CLASS = CompositeAlgoWorkspace

def _load_head_trunk_models(self, head_filename, trunk_filename):
Expand All @@ -431,12 +527,12 @@ def _load_head_trunk_models(self, head_filename, trunk_filename):
if head_filename:
head_model_path = os.path.join(self._workspace.input_models_folder_path,
head_filename)
head_model = self._interface.load_model(head_model_path)
head_model = self._interface.load_head_model(head_model_path)
trunk_model = None
if trunk_filename:
trunk_model_path = os.path.join(self._workspace.input_models_folder_path,
trunk_filename)
trunk_model = self._interface.load_model(trunk_model_path)
trunk_model = self._interface.load_trunk_model(trunk_model_path)
return head_model, trunk_model

def train(self, input_head_model_filename=None, input_trunk_model_filename=None,
Expand All @@ -459,11 +555,11 @@ def train(self, input_head_model_filename=None, input_trunk_model_filename=None,
# serialize output head and trunk models and save them to workspace
output_head_model_path = self._workspace.output_head_model_path
logger.info("saving output head model to '{}'".format(output_head_model_path))
self._interface.save_model(head_model, output_head_model_path)
self._interface.save_head_model(head_model, output_head_model_path)

output_trunk_model_path = self._workspace.output_trunk_model_path
logger.info("saving output trunk model to '{}'".format(output_trunk_model_path))
self._interface.save_model(trunk_model, output_trunk_model_path)
self._interface.save_trunk_model(trunk_model, output_trunk_model_path)

# save predictions
self._opener_wrapper.save_predictions(pred)
Expand Down
16 changes: 14 additions & 2 deletions tests/test_compositealgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,23 @@ def predict(self, X, head_model, trunk_model):
pred = list(range(head_model['value'], trunk_model['value']))
return pred

def load_model(self, path):
def load_head_model(self, path):
return self._load_model(path)

def save_head_model(self, model, path):
return self._save_model(model, path)

def load_trunk_model(self, path):
return self._load_model(path)

def save_trunk_model(self, model, path):
return self._save_model(model, path)

def _load_model(self, path):
with open(path, 'r') as f:
return json.load(f)

def save_model(self, model, path):
def _save_model(self, model, path):
with open(path, 'w') as f:
json.dump(model, f)

Expand Down

0 comments on commit 1c7bec7

Please sign in to comment.