From bcd313cd4a18ea7f15792908d6bde27bbe5ec09c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=A9my=20Morel?= Date: Mon, 27 Jan 2020 15:08:00 +0100 Subject: [PATCH] remove train predictions --- docs/api.md | 10 ++++------ substratools/algo.py | 32 ++++++++++++-------------------- tests/test_algo.py | 23 +++++------------------ tests/test_compositealgo.py | 19 +++++-------------- tests/test_workflow.py | 12 ++++-------- 5 files changed, 30 insertions(+), 66 deletions(-) diff --git a/docs/api.md b/docs/api.md index f220a02..29bea7f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -27,9 +27,8 @@ import substratools as tools class DummyAlgo(tools.Algo): def train(self, X, y, models, rank): - predictions = 0 new_model = None - return predictions, new_model + return new_model def predict(self, X, model): predictions = 0 @@ -84,7 +83,7 @@ __Arguments__ __Returns__ -`tuple`: (predictions, model). +`model`: model object. ## predict ```python @@ -171,10 +170,9 @@ import substratools as tools class DummyCompositeAlgo(tools.CompositeAlgo): def train(self, X, y, head_model, trunk_model, rank): - predictions = 0 new_head_model = None new_trunk_model = None - return predictions, new_head_model, new_trunk_model + return new_head_model, new_trunk_model def predict(self, X, head_model, trunk_model): predictions = 0 @@ -218,7 +216,7 @@ __Arguments__ __Returns__ -`tuple`: (predictions, head_model, trunk_model). +`tuple`: (head_model, trunk_model). ## predict ```python diff --git a/substratools/algo.py b/substratools/algo.py index 1314308..dcfa723 100644 --- a/substratools/algo.py +++ b/substratools/algo.py @@ -40,9 +40,8 @@ class Algo(abc.ABC): class DummyAlgo(tools.Algo): def train(self, X, y, models, rank): - predictions = 0 new_model = None - return predictions, new_model + return new_model def predict(self, X, model): predictions = 0 @@ -93,7 +92,7 @@ def train(self, X, y, models, rank): # Returns - tuple: (predictions, model). + model: model object. """ raise NotImplementedError @@ -123,7 +122,7 @@ def _train_fake_data(self, *args, **kwargs): replaced by the opener fake data. By default, it only calls directly `Algo.train()` method. Override this - method if you want to implement a different behaviour. + method if you want to implement a different behavior. """ return self.train(*args, **kwargs) @@ -135,7 +134,7 @@ def _predict_fake_data(self, *args, **kwargs): the opener fake data. By default, it only calls directly `Algo.predict()` method. Override - this method if you want to implement a different behaviour. + this method if you want to implement a different behavior. """ return self.predict(*args, **kwargs) @@ -208,17 +207,14 @@ def train(self, model_names, rank=0, fake_data=False): logger.info("launching training task") method = (self._interface.train if not fake_data else self._interface._train_fake_data) - pred, model = method(X, y, models, rank) + model = method(X, y, models, rank) # serialize output model and save it to workspace logger.info("saving output model to '{}'".format( self._workspace.output_model_path)) self._interface.save_model(model, self._workspace.output_model_path) - # save predictions - self._opener_wrapper.save_predictions(pred) - - return pred, model + return model def predict(self, model_name, fake_data=False): """Predict method wrapper.""" @@ -362,10 +358,9 @@ class CompositeAlgo(abc.ABC): class DummyCompositeAlgo(tools.CompositeAlgo): def train(self, X, y, head_model, trunk_model, rank): - predictions = 0 new_head_model = None new_trunk_model = None - return predictions, new_head_model, new_trunk_model + return new_head_model, new_trunk_model def predict(self, X, head_model, trunk_model): predictions = 0 @@ -406,7 +401,7 @@ def train(self, X, y, head_model, trunk_model, rank): # Returns - tuple: (predictions, head_model, trunk_model). + tuple: (head_model, trunk_model). """ raise NotImplementedError @@ -437,7 +432,7 @@ def _train_fake_data(self, *args, **kwargs): replaced by the opener fake data. By default, it only calls directly `Algo.train()` method. Override this - method if you want to implement a different behaviour. + method if you want to implement a different behavior. """ return self.train(*args, **kwargs) @@ -449,7 +444,7 @@ def _predict_fake_data(self, *args, **kwargs): the opener fake data. By default, it only calls directly `Algo.predict()` method. Override - this method if you want to implement a different behaviour. + this method if you want to implement a different behavior. """ return self.predict(*args, **kwargs) @@ -550,7 +545,7 @@ def train(self, input_head_model_filename=None, input_trunk_model_filename=None, logger.info("launching training task") method = (self._interface.train if not fake_data else self._interface._train_fake_data) - pred, head_model, trunk_model = method(X, y, head_model, trunk_model, rank) + head_model, trunk_model = method(X, y, head_model, trunk_model, rank) # serialize output head and trunk models and save them to workspace output_head_model_path = self._workspace.output_head_model_path @@ -561,10 +556,7 @@ def train(self, input_head_model_filename=None, input_trunk_model_filename=None, logger.info("saving output trunk model to '{}'".format(output_trunk_model_path)) self._interface.save_trunk_model(trunk_model, output_trunk_model_path) - # save predictions - self._opener_wrapper.save_predictions(pred) - - return pred, head_model, trunk_model + return head_model, trunk_model def predict(self, input_head_model_filename, input_trunk_model_filename, fake_data=False): diff --git a/tests/test_algo.py b/tests/test_algo.py index 815d2bc..bea813c 100644 --- a/tests/test_algo.py +++ b/tests/test_algo.py @@ -19,9 +19,7 @@ def train(self, X, y, models, rank): assert isinstance(m, dict) assert 'value' in m new_model['value'] += m['value'] - new_value = new_model['value'] - pred = list(range(new_value, new_value + 3)) - return pred, new_model + return new_model def predict(self, X, model): pred = model['value'] @@ -65,8 +63,7 @@ def test_create(): def test_train_no_model(): a = DummyAlgo() wp = algo.AlgoWrapper(a) - pred, model = wp.train([]) - assert pred == [0, 1, 2] + model = wp.train([]) assert model['value'] == 0 @@ -76,16 +73,14 @@ def test_train_multiple_models(workdir, create_models): a = DummyAlgo() wp = algo.AlgoWrapper(a) - pred, model = wp.train(model_filenames) - assert pred == [3, 4, 5] + model = wp.train(model_filenames) assert model['value'] == 3 def test_train_fake_data(): a = DummyAlgo() wp = algo.AlgoWrapper(a) - pred, model = wp.train([], fake_data=True) - assert pred == [0, 1, 2] + model = wp.train([], fake_data=True) assert model['value'] == 0 @@ -134,10 +129,7 @@ def test_execute_train_multiple_models(workdir, create_models): model = json.load(f) assert model['value'] == 3 - assert pred_path.exists() - with open(pred_path, 'r') as f: - pred = json.load(f) - assert pred == [3, 4, 5] + assert not pred_path.exists() def test_execute_predict(workdir, create_models): @@ -152,11 +144,6 @@ def test_execute_predict(workdir, create_models): command.extend(model_filenames) algo.execute(DummyAlgo(), sysargs=command) assert output_model_path.exists() - assert pred_path.exists() - with open(pred_path, 'r') as f: - pred = json.load(f) - assert pred == [3, 4, 5] - pred_path.unlink() # do predict on output model assert not pred_path.exists() diff --git a/tests/test_compositealgo.py b/tests/test_compositealgo.py index 3bae273..fef9c04 100644 --- a/tests/test_compositealgo.py +++ b/tests/test_compositealgo.py @@ -27,10 +27,7 @@ def train(self, X, y, head_model, trunk_model, rank): new_head_model['value'] += 1 new_trunk_model['value'] -= 1 - # get predictions - pred = list(range(new_head_model['value'], new_trunk_model['value'])) - - return pred, new_head_model, new_trunk_model + return new_head_model, new_trunk_model def predict(self, X, head_model, trunk_model): pred = list(range(head_model['value'], trunk_model['value'])) @@ -97,8 +94,7 @@ def test_create(): def test_train_no_model(dummy_wrapper): - pred, head_model, trunk_model = dummy_wrapper.train() - assert pred == [] + head_model, trunk_model = dummy_wrapper.train() assert head_model['value'] == 1 assert trunk_model['value'] == -1 @@ -106,15 +102,13 @@ def test_train_no_model(dummy_wrapper): def test_train_input_head_trunk_models(create_models, dummy_wrapper): _, _, head_filename, trunk_filename = create_models - pred, head_model, trunk_model = dummy_wrapper.train(head_filename, trunk_filename) - assert pred == [] + head_model, trunk_model = dummy_wrapper.train(head_filename, trunk_filename) assert head_model['value'] == 2 assert trunk_model['value'] == -2 def test_train_fake_data(dummy_wrapper): - pred, head_model, trunk_model = dummy_wrapper.train(fake_data=True) - assert pred == [] + head_model, trunk_model = dummy_wrapper.train(fake_data=True) assert head_model['value'] == 1 assert trunk_model['value'] == -1 @@ -188,10 +182,7 @@ def test_execute_train_multiple_models(workdir, create_models): trunk_model = json.load(f) assert trunk_model['value'] == -2 - assert pred_path.exists() - with open(pred_path, 'r') as f: - pred = json.load(f) - assert pred == [] + assert not pred_path.exists() def test_execute_predict(workdir, create_models): diff --git a/tests/test_workflow.py b/tests/test_workflow.py index a8ed862..0680338 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -43,9 +43,8 @@ def save_predictions(self, pred, path): class DummyAlgo(Algo): def train(self, X, y, models, rank): total = sum([m['i'] for m in models]) - pred = {'sum': len(models)} new_model = {'i': len(models) + 1, 'total': total} - return pred, new_model + return new_model def predict(self, X, model): return {'sum': model['i']} @@ -70,8 +69,7 @@ def test_workflow(workdir, dummy_opener): models_path = algo_wp._workspace.input_models_folder_path # loop 1 (no input) - pred, model = algo_wp.train([]) - assert pred == {'sum': 0} + model = algo_wp.train([]) assert model == {'i': 1, 'total': 0} output_model_path = os.path.join(models_path, 'model') assert os.path.exists(output_model_path) @@ -79,8 +77,7 @@ def test_workflow(workdir, dummy_opener): # loop 2 (one model as input) model_1_name = 'model1' shutil.move(output_model_path, os.path.join(models_path, model_1_name)) - pred, model = algo_wp.train([model_1_name]) - assert pred == {'sum': 1} + model = algo_wp.train([model_1_name]) assert model == {'i': 2, 'total': 1} output_model_path = os.path.join(models_path, 'model') assert os.path.exists(output_model_path) @@ -88,8 +85,7 @@ def test_workflow(workdir, dummy_opener): # loop 3 (two models as input) model_2_name = 'model2' shutil.move(output_model_path, os.path.join(models_path, model_2_name)) - pred, model = algo_wp.train([model_1_name, model_2_name]) - assert pred == {'sum': 2} + model = algo_wp.train([model_1_name, model_2_name]) assert model == {'i': 3, 'total': 3} output_model_path = os.path.join(models_path, 'model') assert os.path.exists(output_model_path)