From 035c5b73d18d69010c343b7e631cd2eb8f7fab0f Mon Sep 17 00:00:00 2001 From: Caceresenzo Date: Sun, 26 Mar 2023 19:34:42 +0200 Subject: [PATCH] feat(test): add command Closes #1 --- crunch_cli/__version__.py | 2 +- crunch_cli/command/__init__.py | 1 + crunch_cli/command/push.py | 3 - crunch_cli/command/test.py | 116 +++++++++++++++++++++++++++++++++ crunch_cli/main.py | 10 +++ requirements.txt | 1 + 6 files changed, 129 insertions(+), 4 deletions(-) create mode 100644 crunch_cli/command/test.py diff --git a/crunch_cli/__version__.py b/crunch_cli/__version__.py index 95fb198..e2d2147 100644 --- a/crunch_cli/__version__.py +++ b/crunch_cli/__version__.py @@ -1,6 +1,6 @@ __title__ = 'crunch-cli' __description__ = 'crunch-cli - CLI of the CrunchDAO Platform' -__version__ = '0.1.1' +__version__ = '0.2.0' __author__ = 'Enzo CACERES' __author_email__ = 'enzo.caceres@crunchdao.com' __url__ = 'https://github.com/crunchdao/crunch-cli' diff --git a/crunch_cli/command/__init__.py b/crunch_cli/command/__init__.py index 16caf79..3ceba8e 100644 --- a/crunch_cli/command/__init__.py +++ b/crunch_cli/command/__init__.py @@ -1,2 +1,3 @@ from .clone import clone from .push import push +from .test import test diff --git a/crunch_cli/command/push.py b/crunch_cli/command/push.py index 5d98395..271a902 100644 --- a/crunch_cli/command/push.py +++ b/crunch_cli/command/push.py @@ -7,9 +7,6 @@ from .. import utils from .. import constants -session = None -debug = False - def push( session: requests.Session, diff --git a/crunch_cli/command/test.py b/crunch_cli/command/test.py new file mode 100644 index 0000000..af7033a --- /dev/null +++ b/crunch_cli/command/test.py @@ -0,0 +1,116 @@ +import os +import tempfile +import importlib +import sys +import logging +import pandas +import coloredlogs +import click + +from .. import utils + + +def ensure_function(module, name: str): + if not hasattr(module, name): + logging.error("no `%s` function found", name) + raise click.Abort() + + return getattr(module, name) + + +def ensure_tuple(input): + if not isinstance(input, tuple): + logging.error("result is not a tuple") + raise click.Abort() + + if len(input) != 3: + logging.error("result tuple must be of length 3") + raise click.Abort() + + +def ensure_dataframe(input, name: str): + if not isinstance(input, pandas.DataFrame): + logging.error(f"`%s` must be a dataframe", name) + raise click.Abort() + + +def read(path: str) -> pandas.DataFrame: + if path.endswith(".parquet"): + return pandas.read_parquet(path) + return pandas.read_csv(path) + + +def write(dataframe: pandas.DataFrame, path: str) -> None: + if path.endswith(".parquet"): + dataframe.to_parquet(path) + else: + dataframe.to_csv(path) + + +def test( + main_file: str +): + coloredlogs.install( + level=logging.DEBUG, + fmt='%(asctime)s %(message)s', + datefmt='%H:%M:%S', + ) + + utils.change_root() + + logging.info('running local test') + logging.warn("internet access isn't restricted, no check will be done") + logging.info("") + + tmp = tempfile.TemporaryDirectory(prefix="test-") + logging.info('tmp=%s', tmp.name) + + x_train_path = os.path.join(tmp.name, "x_train.csv") + y_train_path = os.path.join(tmp.name, "y_train.csv") + x_test_path = os.path.join(tmp.name, "x_test.csv") + model_path = os.path.join(tmp.name, "model.csv") + prediction_path = os.path.join(tmp.name, "prediction.csv") + + spec = importlib.util.spec_from_file_location("user_code", main_file) + module = importlib.util.module_from_spec(spec) + + sys.path.insert(0, os.getcwd()) + spec.loader.exec_module(module) + + data_process_handler = ensure_function(module, "data_process") + train_handler = ensure_function(module, "train") + infer_handler = ensure_function(module, "infer") + + if True: + dummy = pandas.DataFrame() + for path in [x_train_path, y_train_path, x_test_path]: + dummy.to_csv(path) + + x_train = read(x_train_path) + y_train = read(y_train_path) + x_test = read(x_test_path) + + logging.warn('handler: data_process(%s, %s, %s)', x_train, y_train, x_test) + result = data_process_handler(x_train, y_train, x_test) + ensure_tuple(result) + + x_train, y_train, x_test = result + ensure_dataframe(x_train, "x_train") + ensure_dataframe(y_train, "y_train") + ensure_dataframe(x_test, "x_test") + + logging.warn('handler: train(%s, %s)', x_train, y_train) + model = train_handler(x_train, y_train) + ensure_dataframe(model, "model") + write(model, model_path) + + logging.warn('model_path=%s', model_path) + logging.warn('model=%s', model) + + logging.warn('handler: infer(%s, %s)', model_path, x_test) + prediction = infer_handler(model, x_test) + ensure_dataframe(prediction, "prediction") + write(prediction, prediction_path) + + logging.warn('prediction_path=%s', prediction_path) + logging.warn('prediction=%s', prediction) diff --git a/crunch_cli/main.py b/crunch_cli/main.py index fc2e656..cb4ce63 100644 --- a/crunch_cli/main.py +++ b/crunch_cli/main.py @@ -52,5 +52,15 @@ def push( ) +@cli.command() +@click.option("-m", "--main-file", default="main.py") +def test( + main_file: str +): + command.test( + main_file=main_file + ) + + if __name__ == '__main__': cli() diff --git a/requirements.txt b/requirements.txt index d9e53de..13b2176 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ click requests gitignorefile +coloredlogs