From 4f8d9005e47caca493ee839ad6d93f87e7d7959a Mon Sep 17 00:00:00 2001 From: Carlos Garcia Jurado Suarez Date: Tue, 5 Dec 2023 11:19:51 -0800 Subject: [PATCH] add tests --- src/autora/doc/pipelines/main.py | 6 +++++- src/autora/doc/runtime/prompts.py | 1 - tests/test_main.py | 15 +++++++++++++++ tests/{test.py => test_predict_hf.py} | 0 4 files changed, 20 insertions(+), 2 deletions(-) create mode 100644 tests/test_main.py rename tests/{test.py => test_predict_hf.py} (100%) diff --git a/src/autora/doc/pipelines/main.py b/src/autora/doc/pipelines/main.py index b74bf4b..aacb809 100644 --- a/src/autora/doc/pipelines/main.py +++ b/src/autora/doc/pipelines/main.py @@ -1,5 +1,6 @@ import logging from timeit import default_timer as timer +from typing import List import jsonlines import mlflow @@ -18,7 +19,9 @@ @app.command() -def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> None: +def predict( + data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts +) -> List[str]: run = mlflow.active_run() sys_prompt = SYS[sys_id] @@ -51,6 +54,7 @@ def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: total_tokens = sum([len(token) for token in tokens]) mlflow.log_metric("total_tokens", total_tokens) mlflow.log_metric("tokens/sec", total_tokens / pred_time) + return predictions @app.command() diff --git a/src/autora/doc/runtime/prompts.py b/src/autora/doc/runtime/prompts.py index 5875127..19f905a 100644 --- a/src/autora/doc/runtime/prompts.py +++ b/src/autora/doc/runtime/prompts.py @@ -29,7 +29,6 @@ class SystemPrompts(Enum): class InstructionPrompts(Enum): - SYS_1 = "SYS_1" INSTR_SWEETP_1 = "INSTR_SWEETP_1" diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..e2d12d3 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,15 @@ +from pathlib import Path + +from autora.doc.pipelines.main import predict +from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts + +# dummy HF model for testing +TEST_HF_MODEL = "hf-internal-testing/tiny-random-FalconForCausalLM" + + +def test_predict() -> None: + data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve() + outputs = predict(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1) + assert len(outputs) == 3, "Expected 3 outputs" + for output in outputs: + assert len(output) > 0, "Expected non-empty output" diff --git a/tests/test.py b/tests/test_predict_hf.py similarity index 100% rename from tests/test.py rename to tests/test_predict_hf.py