diff --git a/crunch/__version__.py b/crunch/__version__.py index 6a1e045..3fd8277 100644 --- a/crunch/__version__.py +++ b/crunch/__version__.py @@ -1,6 +1,6 @@ __title__ = 'crunch-cli' __description__ = 'crunch-cli - CLI of the CrunchDAO Platform' -__version__ = '0.16.0' +__version__ = '0.17.0' __author__ = 'Enzo CACERES' __author_email__ = 'enzo.caceres@crunchdao.com' __url__ = 'https://github.com/crunchdao/crunch-cli' diff --git a/crunch/ensure.py b/crunch/ensure.py index 1ff405e..d56f5f3 100644 --- a/crunch/ensure.py +++ b/crunch/ensure.py @@ -12,33 +12,12 @@ def is_function(module, name: str): return getattr(module, name) -def is_tuple_3(input): - if not isinstance(input, tuple): - logging.error("result is not a tuple") - raise click.Abort() - - if len(input) != 3: - logging.error("result tuple must be of length 3") - raise click.Abort() - - def is_dataframe(input, name: str): if not isinstance(input, pandas.DataFrame): logging.error(f"`%s` must be a dataframe", name) raise click.Abort() -def return_data_process(result) -> typing.Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]: - is_tuple_3(result) - - x_train, y_train, x_test = result - is_dataframe(x_train, "x_train") - is_dataframe(y_train, "y_train") - is_dataframe(x_test, "x_test") - - return result - - def return_infer(result) -> pandas.DataFrame: is_dataframe(result, "prediction") diff --git a/crunch/tester.py b/crunch/tester.py index f0404dc..c4d9646 100644 --- a/crunch/tester.py +++ b/crunch/tester.py @@ -45,7 +45,6 @@ def run( memory_before = _get_process_memory() start = time.time() - data_process_handler = ensure.is_function(module, "data_process") train_handler = ensure.is_function(module, "train") infer_handler = ensure.is_function(module, "infer") @@ -77,19 +76,14 @@ def run( logging.warn('---') logging.warn('loop: moon=%s train=%s (%s/%s)', moon, train, index + 1, len(moons)) - x_train_loop = x_train[x_train.index < moon - embargo].reset_index() - y_train_loop = y_train[y_train.index < moon - embargo].reset_index() - x_test_loop = x_test[x_test.index == moon].reset_index() - - logging.warn('handler: data_process(%s, %s, %s)', x_train_path, y_train_path, x_test_path) - result = data_process_handler(x_train_loop, y_train_loop, x_test_loop) - x_train_loop, y_train_loop, x_test_loop = ensure.return_data_process(result) - if train: logging.warn('handler: train(%s, %s, %s)', x_train_path, y_train_path, model_directory_path) + x_train_loop = x_train[x_train.index < moon - embargo].reset_index() + y_train_loop = y_train[y_train.index < moon - embargo].reset_index() train_handler(x_train_loop, y_train_loop, model_directory_path) logging.warn('handler: infer(%s, %s)', model_directory_path, x_test_path) + x_test_loop = x_test[x_test.index == moon].reset_index() prediction = infer_handler(model_directory_path, x_test_loop) prediction = ensure.return_infer(prediction)