diff --git a/src/autora/doc/pipelines/main.py b/src/autora/doc/pipelines/main.py index 2c122c7..f9ae7c8 100644 --- a/src/autora/doc/pipelines/main.py +++ b/src/autora/doc/pipelines/main.py @@ -10,7 +10,8 @@ from nltk.translate.meteor_score import single_meteor_score from autora.doc.runtime.predict_hf import Predictor -from autora.doc.runtime.prompts import PROMPTS, PromptIds +from autora.doc.runtime.prompts import PROMPTS, PromptBuilder, PromptIds +from autora.doc.util import get_eval_result_from_prediction, get_prompts_from_file, load_file app = typer.Typer() logging.basicConfig( @@ -47,6 +48,44 @@ def evaluate_documentation(predictions: List[str], references: List[str]) -> Tup return (bleu, meteor) +@app.command(help="Evaluate a model for code-to-documentation generation for all prompts in the prompts_file") +def eval_on_prompts_file( + data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"), + model_path: str = typer.Option("meta-llama/Llama-2-7b-chat-hf", help="Path to HF model"), + prompts_file: str = typer.Argument(..., help="JSON file with a list of dictionary of prompts"), + param: List[str] = typer.Option( + [], help="Additional float parameters to pass to the model as name=float pairs" + ), +) -> List[Dict[str, str]]: + import mlflow + + results_list = [] + + mlflow.autolog() + param_dict = {pair[0]: float(pair[1]) for pair in [pair.split("=") for pair in param]} + run = mlflow.active_run() + + prompts_list = get_prompts_from_file(prompts_file) + + if run is None: + run = mlflow.start_run() + with run: + logger.info(f"Active run_id: {run.info.run_id}") + logger.info(f"running predict with {data_file}") + logger.info(f"model path: {model_path}") + mlflow.log_params(param_dict) + mlflow.log_param("model_path", model_path) + mlflow.log_param("data_file", data_file) + predictor = Predictor(model_path) + for i in range(len(prompts_list)): + logger.info(f"Starting to run model on prompt {i}: {prompts_list[i]}") + prediction_with_scores = eval_prompt(data_file, predictor, prompts_list[i], param_dict) + logger.info(f"Model run completed on prompt {i}: {prompts_list[i]}") + eval_result = get_eval_result_from_prediction(prediction_with_scores, prompts_list[i]) + results_list.append(eval_result) + return results_list + + @app.command(help="Evaluate model on a data file") def eval( data_file: str = typer.Argument(..., help="JSONL Data file to evaluate on"),