diff --git a/prompttools/prompttest/prompttest.py b/prompttools/prompttest/prompttest.py index b551e1cd..332f06de 100644 --- a/prompttools/prompttest/prompttest.py +++ b/prompttools/prompttest/prompttest.py @@ -45,13 +45,12 @@ def prompttest( This enables developers to create a prompt test suite from their evaluations. """ - model_arguments["experiment_classname"] = experiment_classname - def prompttest_decorator(eval_fn: Callable): @wraps(eval_fn) def runs_test(): if prompt_template_file and user_input_file: return run_prompt_template_test_from_files( + experiment_classname, model_name, metric_name, eval_fn, @@ -65,6 +64,7 @@ def runs_test(): ) elif prompt_template and user_input: return run_prompt_template_test( + experiment_classname, model_name, metric_name, eval_fn, @@ -78,6 +78,7 @@ def runs_test(): ) elif system_prompt_file and human_messages_file: return run_system_prompt_test_from_files( + experiment_classname, model_name, metric_name, eval_fn, @@ -91,6 +92,7 @@ def runs_test(): ) elif system_prompt and human_messages: return run_system_prompt_test( + experiment_classname, model_name, metric_name, eval_fn, diff --git a/prompttools/prompttest/runner/prompt_template_runner.py b/prompttools/prompttest/runner/prompt_template_runner.py index 371b8a70..37926ea0 100644 --- a/prompttools/prompttest/runner/prompt_template_runner.py +++ b/prompttools/prompttest/runner/prompt_template_runner.py @@ -51,13 +51,14 @@ def read( @staticmethod def _get_harness( + experiment_classname, model_name: str, prompt_template: str, user_inputs: List[Dict[str, str]], model_args: Dict[str, object], ) -> PromptTemplateExperimentationHarness: return PromptTemplateExperimentationHarness( - model_name, [prompt_template], user_inputs, model_arguments=model_args + experiment_classname, model_name, [prompt_template], user_inputs, model_arguments=model_args ) @@ -65,6 +66,7 @@ def _get_harness( def run_prompt_template_test( + experiment_classname, model_name: str, metric_name: str, eval_fn: Callable, @@ -115,6 +117,7 @@ def run_prompt_template_test( def run_prompt_template_test_from_files( + experiment_classname, model_name: str, metric_name: str, eval_fn: Callable, @@ -133,6 +136,7 @@ def run_prompt_template_test_from_files( prompt_template_file, user_input_file ) return run_prompt_template_test( + experiment_classname, model_name, metric_name, eval_fn, diff --git a/prompttools/prompttest/runner/runner.py b/prompttools/prompttest/runner/runner.py index 8ec0a466..df1eb44f 100644 --- a/prompttools/prompttest/runner/runner.py +++ b/prompttools/prompttest/runner/runner.py @@ -51,6 +51,7 @@ def rank(self, key: str, metric_name: str, is_average: bool) -> Dict[str, float] @staticmethod def _get_harness( + experiment_classname, model_name: str, prompt_template: str, user_inputs: List[Dict[str, str]], diff --git a/prompttools/prompttest/runner/system_prompt_runner.py b/prompttools/prompttest/runner/system_prompt_runner.py index 262e9362..bdf79ce1 100644 --- a/prompttools/prompttest/runner/system_prompt_runner.py +++ b/prompttools/prompttest/runner/system_prompt_runner.py @@ -49,13 +49,14 @@ def read( @staticmethod def _get_harness( + experiment_classname, model_name: str, system_prompt: str, human_messages: List[str], model_args: Dict[str, object], ) -> SystemPromptExperimentationHarness: return SystemPromptExperimentationHarness( - model_name, [system_prompt], human_messages, model_arguments=model_args + experiment_classname, model_name, [system_prompt], human_messages, model_arguments=model_args ) @@ -113,6 +114,7 @@ def run_system_prompt_test( def run_system_prompt_test_from_files( + experiment_classname, model_name: str, metric_name: str, eval_fn: Callable, @@ -131,6 +133,7 @@ def run_system_prompt_test_from_files( system_prompt_file, human_messages_file ) return run_system_prompt_test( + experiment_classname, model_name, metric_name, eval_fn,