Skip to content

Commit

Permalink
feat(test): add command
Browse files Browse the repository at this point in the history
Closes #1
  • Loading branch information
Caceresenzo committed Mar 26, 2023
1 parent 024dda3 commit 035c5b7
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 4 deletions.
2 changes: 1 addition & 1 deletion crunch_cli/__version__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'crunch-cli'
__description__ = 'crunch-cli - CLI of the CrunchDAO Platform'
__version__ = '0.1.1'
__version__ = '0.2.0'
__author__ = 'Enzo CACERES'
__author_email__ = '[email protected]'
__url__ = 'https://github.com/crunchdao/crunch-cli'
1 change: 1 addition & 0 deletions crunch_cli/command/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .clone import clone
from .push import push
from .test import test
3 changes: 0 additions & 3 deletions crunch_cli/command/push.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@
from .. import utils
from .. import constants

session = None
debug = False


def push(
session: requests.Session,
Expand Down
116 changes: 116 additions & 0 deletions crunch_cli/command/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import tempfile
import importlib
import sys
import logging
import pandas
import coloredlogs
import click

from .. import utils


def ensure_function(module, name: str):
if not hasattr(module, name):
logging.error("no `%s` function found", name)
raise click.Abort()

return getattr(module, name)


def ensure_tuple(input):
if not isinstance(input, tuple):
logging.error("result is not a tuple")
raise click.Abort()

if len(input) != 3:
logging.error("result tuple must be of length 3")
raise click.Abort()


def ensure_dataframe(input, name: str):
if not isinstance(input, pandas.DataFrame):
logging.error(f"`%s` must be a dataframe", name)
raise click.Abort()


def read(path: str) -> pandas.DataFrame:
if path.endswith(".parquet"):
return pandas.read_parquet(path)
return pandas.read_csv(path)


def write(dataframe: pandas.DataFrame, path: str) -> None:
if path.endswith(".parquet"):
dataframe.to_parquet(path)
else:
dataframe.to_csv(path)


def test(
main_file: str
):
coloredlogs.install(
level=logging.DEBUG,
fmt='%(asctime)s %(message)s',
datefmt='%H:%M:%S',
)

utils.change_root()

logging.info('running local test')
logging.warn("internet access isn't restricted, no check will be done")
logging.info("")

tmp = tempfile.TemporaryDirectory(prefix="test-")
logging.info('tmp=%s', tmp.name)

x_train_path = os.path.join(tmp.name, "x_train.csv")
y_train_path = os.path.join(tmp.name, "y_train.csv")
x_test_path = os.path.join(tmp.name, "x_test.csv")
model_path = os.path.join(tmp.name, "model.csv")
prediction_path = os.path.join(tmp.name, "prediction.csv")

spec = importlib.util.spec_from_file_location("user_code", main_file)
module = importlib.util.module_from_spec(spec)

sys.path.insert(0, os.getcwd())
spec.loader.exec_module(module)

data_process_handler = ensure_function(module, "data_process")
train_handler = ensure_function(module, "train")
infer_handler = ensure_function(module, "infer")

if True:
dummy = pandas.DataFrame()
for path in [x_train_path, y_train_path, x_test_path]:
dummy.to_csv(path)

x_train = read(x_train_path)
y_train = read(y_train_path)
x_test = read(x_test_path)

logging.warn('handler: data_process(%s, %s, %s)', x_train, y_train, x_test)
result = data_process_handler(x_train, y_train, x_test)
ensure_tuple(result)

x_train, y_train, x_test = result
ensure_dataframe(x_train, "x_train")
ensure_dataframe(y_train, "y_train")
ensure_dataframe(x_test, "x_test")

logging.warn('handler: train(%s, %s)', x_train, y_train)
model = train_handler(x_train, y_train)
ensure_dataframe(model, "model")
write(model, model_path)

logging.warn('model_path=%s', model_path)
logging.warn('model=%s', model)

logging.warn('handler: infer(%s, %s)', model_path, x_test)
prediction = infer_handler(model, x_test)
ensure_dataframe(prediction, "prediction")
write(prediction, prediction_path)

logging.warn('prediction_path=%s', prediction_path)
logging.warn('prediction=%s', prediction)
10 changes: 10 additions & 0 deletions crunch_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,15 @@ def push(
)


@cli.command()
@click.option("-m", "--main-file", default="main.py")
def test(
main_file: str
):
command.test(
main_file=main_file
)


if __name__ == '__main__':
cli()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
click
requests
gitignorefile
coloredlogs

0 comments on commit 035c5b7

Please sign in to comment.