Skip to content

Commit

Permalink
Merge pull request #24 from SubstraFoundation/remove_train_predictions
Browse files Browse the repository at this point in the history
Remove train predictions
  • Loading branch information
jmorel authored Jan 28, 2020
2 parents 1696431 + bcd313c commit f04e88e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 66 deletions.
10 changes: 4 additions & 6 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,8 @@ import substratools as tools

class DummyAlgo(tools.Algo):
def train(self, X, y, models, rank):
predictions = 0
new_model = None
return predictions, new_model
return new_model

def predict(self, X, model):
predictions = 0
Expand Down Expand Up @@ -84,7 +83,7 @@ __Arguments__
__Returns__


`tuple`: (predictions, model).
`model`: model object.

## predict
```python
Expand Down Expand Up @@ -171,10 +170,9 @@ import substratools as tools

class DummyCompositeAlgo(tools.CompositeAlgo):
def train(self, X, y, head_model, trunk_model, rank):
predictions = 0
new_head_model = None
new_trunk_model = None
return predictions, new_head_model, new_trunk_model
return new_head_model, new_trunk_model

def predict(self, X, head_model, trunk_model):
predictions = 0
Expand Down Expand Up @@ -218,7 +216,7 @@ __Arguments__
__Returns__


`tuple`: (predictions, head_model, trunk_model).
`tuple`: (head_model, trunk_model).

## predict
```python
Expand Down
32 changes: 12 additions & 20 deletions substratools/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ class Algo(abc.ABC):
class DummyAlgo(tools.Algo):
def train(self, X, y, models, rank):
predictions = 0
new_model = None
return predictions, new_model
return new_model
def predict(self, X, model):
predictions = 0
Expand Down Expand Up @@ -93,7 +92,7 @@ def train(self, X, y, models, rank):
# Returns
tuple: (predictions, model).
model: model object.
"""
raise NotImplementedError

Expand Down Expand Up @@ -123,7 +122,7 @@ def _train_fake_data(self, *args, **kwargs):
replaced by the opener fake data.
By default, it only calls directly `Algo.train()` method. Override this
method if you want to implement a different behaviour.
method if you want to implement a different behavior.
"""
return self.train(*args, **kwargs)

Expand All @@ -135,7 +134,7 @@ def _predict_fake_data(self, *args, **kwargs):
the opener fake data.
By default, it only calls directly `Algo.predict()` method. Override
this method if you want to implement a different behaviour.
this method if you want to implement a different behavior.
"""
return self.predict(*args, **kwargs)

Expand Down Expand Up @@ -208,17 +207,14 @@ def train(self, model_names, rank=0, fake_data=False):
logger.info("launching training task")
method = (self._interface.train if not fake_data else
self._interface._train_fake_data)
pred, model = method(X, y, models, rank)
model = method(X, y, models, rank)

# serialize output model and save it to workspace
logger.info("saving output model to '{}'".format(
self._workspace.output_model_path))
self._interface.save_model(model, self._workspace.output_model_path)

# save predictions
self._opener_wrapper.save_predictions(pred)

return pred, model
return model

def predict(self, model_name, fake_data=False):
"""Predict method wrapper."""
Expand Down Expand Up @@ -362,10 +358,9 @@ class CompositeAlgo(abc.ABC):
class DummyCompositeAlgo(tools.CompositeAlgo):
def train(self, X, y, head_model, trunk_model, rank):
predictions = 0
new_head_model = None
new_trunk_model = None
return predictions, new_head_model, new_trunk_model
return new_head_model, new_trunk_model
def predict(self, X, head_model, trunk_model):
predictions = 0
Expand Down Expand Up @@ -406,7 +401,7 @@ def train(self, X, y, head_model, trunk_model, rank):
# Returns
tuple: (predictions, head_model, trunk_model).
tuple: (head_model, trunk_model).
"""
raise NotImplementedError

Expand Down Expand Up @@ -437,7 +432,7 @@ def _train_fake_data(self, *args, **kwargs):
replaced by the opener fake data.
By default, it only calls directly `Algo.train()` method. Override this
method if you want to implement a different behaviour.
method if you want to implement a different behavior.
"""
return self.train(*args, **kwargs)

Expand All @@ -449,7 +444,7 @@ def _predict_fake_data(self, *args, **kwargs):
the opener fake data.
By default, it only calls directly `Algo.predict()` method. Override
this method if you want to implement a different behaviour.
this method if you want to implement a different behavior.
"""
return self.predict(*args, **kwargs)

Expand Down Expand Up @@ -550,7 +545,7 @@ def train(self, input_head_model_filename=None, input_trunk_model_filename=None,
logger.info("launching training task")
method = (self._interface.train if not fake_data else
self._interface._train_fake_data)
pred, head_model, trunk_model = method(X, y, head_model, trunk_model, rank)
head_model, trunk_model = method(X, y, head_model, trunk_model, rank)

# serialize output head and trunk models and save them to workspace
output_head_model_path = self._workspace.output_head_model_path
Expand All @@ -561,10 +556,7 @@ def train(self, input_head_model_filename=None, input_trunk_model_filename=None,
logger.info("saving output trunk model to '{}'".format(output_trunk_model_path))
self._interface.save_trunk_model(trunk_model, output_trunk_model_path)

# save predictions
self._opener_wrapper.save_predictions(pred)

return pred, head_model, trunk_model
return head_model, trunk_model

def predict(self, input_head_model_filename, input_trunk_model_filename,
fake_data=False):
Expand Down
23 changes: 5 additions & 18 deletions tests/test_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def train(self, X, y, models, rank):
assert isinstance(m, dict)
assert 'value' in m
new_model['value'] += m['value']
new_value = new_model['value']
pred = list(range(new_value, new_value + 3))
return pred, new_model
return new_model

def predict(self, X, model):
pred = model['value']
Expand Down Expand Up @@ -65,8 +63,7 @@ def test_create():
def test_train_no_model():
a = DummyAlgo()
wp = algo.AlgoWrapper(a)
pred, model = wp.train([])
assert pred == [0, 1, 2]
model = wp.train([])
assert model['value'] == 0


Expand All @@ -76,16 +73,14 @@ def test_train_multiple_models(workdir, create_models):
a = DummyAlgo()
wp = algo.AlgoWrapper(a)

pred, model = wp.train(model_filenames)
assert pred == [3, 4, 5]
model = wp.train(model_filenames)
assert model['value'] == 3


def test_train_fake_data():
a = DummyAlgo()
wp = algo.AlgoWrapper(a)
pred, model = wp.train([], fake_data=True)
assert pred == [0, 1, 2]
model = wp.train([], fake_data=True)
assert model['value'] == 0


Expand Down Expand Up @@ -134,10 +129,7 @@ def test_execute_train_multiple_models(workdir, create_models):
model = json.load(f)
assert model['value'] == 3

assert pred_path.exists()
with open(pred_path, 'r') as f:
pred = json.load(f)
assert pred == [3, 4, 5]
assert not pred_path.exists()


def test_execute_predict(workdir, create_models):
Expand All @@ -152,11 +144,6 @@ def test_execute_predict(workdir, create_models):
command.extend(model_filenames)
algo.execute(DummyAlgo(), sysargs=command)
assert output_model_path.exists()
assert pred_path.exists()
with open(pred_path, 'r') as f:
pred = json.load(f)
assert pred == [3, 4, 5]
pred_path.unlink()

# do predict on output model
assert not pred_path.exists()
Expand Down
19 changes: 5 additions & 14 deletions tests/test_compositealgo.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def train(self, X, y, head_model, trunk_model, rank):
new_head_model['value'] += 1
new_trunk_model['value'] -= 1

# get predictions
pred = list(range(new_head_model['value'], new_trunk_model['value']))

return pred, new_head_model, new_trunk_model
return new_head_model, new_trunk_model

def predict(self, X, head_model, trunk_model):
pred = list(range(head_model['value'], trunk_model['value']))
Expand Down Expand Up @@ -97,24 +94,21 @@ def test_create():


def test_train_no_model(dummy_wrapper):
pred, head_model, trunk_model = dummy_wrapper.train()
assert pred == []
head_model, trunk_model = dummy_wrapper.train()
assert head_model['value'] == 1
assert trunk_model['value'] == -1


def test_train_input_head_trunk_models(create_models, dummy_wrapper):
_, _, head_filename, trunk_filename = create_models

pred, head_model, trunk_model = dummy_wrapper.train(head_filename, trunk_filename)
assert pred == []
head_model, trunk_model = dummy_wrapper.train(head_filename, trunk_filename)
assert head_model['value'] == 2
assert trunk_model['value'] == -2


def test_train_fake_data(dummy_wrapper):
pred, head_model, trunk_model = dummy_wrapper.train(fake_data=True)
assert pred == []
head_model, trunk_model = dummy_wrapper.train(fake_data=True)
assert head_model['value'] == 1
assert trunk_model['value'] == -1

Expand Down Expand Up @@ -188,10 +182,7 @@ def test_execute_train_multiple_models(workdir, create_models):
trunk_model = json.load(f)
assert trunk_model['value'] == -2

assert pred_path.exists()
with open(pred_path, 'r') as f:
pred = json.load(f)
assert pred == []
assert not pred_path.exists()


def test_execute_predict(workdir, create_models):
Expand Down
12 changes: 4 additions & 8 deletions tests/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ def save_predictions(self, pred, path):
class DummyAlgo(Algo):
def train(self, X, y, models, rank):
total = sum([m['i'] for m in models])
pred = {'sum': len(models)}
new_model = {'i': len(models) + 1, 'total': total}
return pred, new_model
return new_model

def predict(self, X, model):
return {'sum': model['i']}
Expand All @@ -70,26 +69,23 @@ def test_workflow(workdir, dummy_opener):
models_path = algo_wp._workspace.input_models_folder_path

# loop 1 (no input)
pred, model = algo_wp.train([])
assert pred == {'sum': 0}
model = algo_wp.train([])
assert model == {'i': 1, 'total': 0}
output_model_path = os.path.join(models_path, 'model')
assert os.path.exists(output_model_path)

# loop 2 (one model as input)
model_1_name = 'model1'
shutil.move(output_model_path, os.path.join(models_path, model_1_name))
pred, model = algo_wp.train([model_1_name])
assert pred == {'sum': 1}
model = algo_wp.train([model_1_name])
assert model == {'i': 2, 'total': 1}
output_model_path = os.path.join(models_path, 'model')
assert os.path.exists(output_model_path)

# loop 3 (two models as input)
model_2_name = 'model2'
shutil.move(output_model_path, os.path.join(models_path, model_2_name))
pred, model = algo_wp.train([model_1_name, model_2_name])
assert pred == {'sum': 2}
model = algo_wp.train([model_1_name, model_2_name])
assert model == {'i': 3, 'total': 3}
output_model_path = os.path.join(models_path, 'model')
assert os.path.exists(output_model_path)
Expand Down

0 comments on commit f04e88e

Please sign in to comment.