Skip to content

Commit

Permalink
feat!(test): remove call of data_process
Browse files Browse the repository at this point in the history
  • Loading branch information
Caceresenzo committed May 4, 2023
1 parent 0bcdc87 commit 78bd750
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 31 deletions.
2 changes: 1 addition & 1 deletion crunch/__version__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
__title__ = 'crunch-cli'
__description__ = 'crunch-cli - CLI of the CrunchDAO Platform'
__version__ = '0.16.0'
__version__ = '0.17.0'
__author__ = 'Enzo CACERES'
__author_email__ = '[email protected]'
__url__ = 'https://github.com/crunchdao/crunch-cli'
21 changes: 0 additions & 21 deletions crunch/ensure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,12 @@ def is_function(module, name: str):
return getattr(module, name)


def is_tuple_3(input):
if not isinstance(input, tuple):
logging.error("result is not a tuple")
raise click.Abort()

if len(input) != 3:
logging.error("result tuple must be of length 3")
raise click.Abort()


def is_dataframe(input, name: str):
if not isinstance(input, pandas.DataFrame):
logging.error(f"`%s` must be a dataframe", name)
raise click.Abort()


def return_data_process(result) -> typing.Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]:
is_tuple_3(result)

x_train, y_train, x_test = result
is_dataframe(x_train, "x_train")
is_dataframe(y_train, "y_train")
is_dataframe(x_test, "x_test")

return result


def return_infer(result) -> pandas.DataFrame:
is_dataframe(result, "prediction")

Expand Down
12 changes: 3 additions & 9 deletions crunch/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def run(
memory_before = _get_process_memory()
start = time.time()

data_process_handler = ensure.is_function(module, "data_process")
train_handler = ensure.is_function(module, "train")
infer_handler = ensure.is_function(module, "infer")

Expand Down Expand Up @@ -77,19 +76,14 @@ def run(
logging.warn('---')
logging.warn('loop: moon=%s train=%s (%s/%s)', moon, train, index + 1, len(moons))

x_train_loop = x_train[x_train.index < moon - embargo].reset_index()
y_train_loop = y_train[y_train.index < moon - embargo].reset_index()
x_test_loop = x_test[x_test.index == moon].reset_index()

logging.warn('handler: data_process(%s, %s, %s)', x_train_path, y_train_path, x_test_path)
result = data_process_handler(x_train_loop, y_train_loop, x_test_loop)
x_train_loop, y_train_loop, x_test_loop = ensure.return_data_process(result)

if train:
logging.warn('handler: train(%s, %s, %s)', x_train_path, y_train_path, model_directory_path)
x_train_loop = x_train[x_train.index < moon - embargo].reset_index()
y_train_loop = y_train[y_train.index < moon - embargo].reset_index()
train_handler(x_train_loop, y_train_loop, model_directory_path)

logging.warn('handler: infer(%s, %s)', model_directory_path, x_test_path)
x_test_loop = x_test[x_test.index == moon].reset_index()
prediction = infer_handler(model_directory_path, x_test_loop)
prediction = ensure.return_infer(prediction)

Expand Down

0 comments on commit 78bd750

Please sign in to comment.