Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgjs committed Dec 5, 2023
1 parent fcc744d commit 4f8d900
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 2 deletions.
6 changes: 5 additions & 1 deletion src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from timeit import default_timer as timer
from typing import List

import jsonlines
import mlflow
Expand All @@ -18,7 +19,9 @@


@app.command()
def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts) -> None:
def predict(
data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id: InstructionPrompts
) -> List[str]:
run = mlflow.active_run()

sys_prompt = SYS[sys_id]
Expand Down Expand Up @@ -51,6 +54,7 @@ def predict(data_file: str, model_path: str, sys_id: SystemPrompts, instruc_id:
total_tokens = sum([len(token) for token in tokens])
mlflow.log_metric("total_tokens", total_tokens)
mlflow.log_metric("tokens/sec", total_tokens / pred_time)
return predictions


@app.command()
Expand Down
1 change: 0 additions & 1 deletion src/autora/doc/runtime/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ class SystemPrompts(Enum):


class InstructionPrompts(Enum):
SYS_1 = "SYS_1"
INSTR_SWEETP_1 = "INSTR_SWEETP_1"


Expand Down
15 changes: 15 additions & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from pathlib import Path

from autora.doc.pipelines.main import predict
from autora.doc.runtime.prompts import InstructionPrompts, SystemPrompts

# dummy HF model for testing
TEST_HF_MODEL = "hf-internal-testing/tiny-random-FalconForCausalLM"


def test_predict() -> None:
data = Path(__file__).parent.joinpath("../data/data.jsonl").resolve()
outputs = predict(str(data), TEST_HF_MODEL, SystemPrompts.SYS_1, InstructionPrompts.INSTR_SWEETP_1)
assert len(outputs) == 3, "Expected 3 outputs"
for output in outputs:
assert len(output) > 0, "Expected non-empty output"
File renamed without changes.

0 comments on commit 4f8d900

Please sign in to comment.