Skip to content

Commit

Permalink
feat: created eval_on_prompts_file() to run and compare multiple prom…
Browse files Browse the repository at this point in the history
…pts on single data file input
  • Loading branch information
anujsinha3 committed Jan 30, 2024
1 parent d20c6d9 commit 905ae5f
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion src/autora/doc/pipelines/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from nltk.translate.meteor_score import single_meteor_score

from autora.doc.runtime.predict_hf import Predictor
from autora.doc.runtime.prompts import PROMPTS, PromptIds
from autora.doc.runtime.prompts import PROMPTS, PromptBuilder, PromptIds
from autora.doc.util import get_eval_result_from_prediction, get_prompts_from_file, load_file

app = typer.Typer()
logging.basicConfig(
Expand Down Expand Up @@ -47,6 +48,44 @@ def evaluate_documentation(predictions: List[str], references: List[str]) -> Tup
return (bleu, meteor)


@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file")
def eval_on_prompts_file(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"),
prompts_file: str = typer.Argument(..., help="JSON file with a list of dictionary of prompts"),
param: List[str] = typer.Option(
[], help="Additional float parameters to pass to the model as name=float pairs"
),
) -> List[Dict[str, str]]:
import mlflow

results_list = []

mlflow.autolog()
param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]}
run = mlflow.active_run()

prompts_list = get_prompts_from_file(prompts_file)

if run is None:
run = mlflow.start_run()
with run:
logger.info(f"Active run_id: {run.info.run_id}")
logger.info(f"running predict with {data_file}")
logger.info(f"model path: {model_path}")
mlflow.log_params(param_dict)
mlflow.log_param("model_path", model_path)
mlflow.log_param("data_file", data_file)
predictor = Predictor(model_path)
for i in range(len(prompts_list)):
logger.info(f"Starting to run model on prompt {i}: {prompts_list[i]}")
prediction_with_scores = eval_prompt(data_file, predictor, prompts_list[i], param_dict)
logger.info(f"Model run completed on prompt {i}: {prompts_list[i]}")
eval_result = get_eval_result_from_prediction(prediction_with_scores, prompts_list[i])
results_list.append(eval_result)
return results_list


@app.command(help="Evaluate model on a data file")
def eval(
data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),
Expand Down

0 comments on commit 905ae5f

Please sign in to comment.