Skip to content

Commit

Permalink
Fix prompttest
Browse files Browse the repository at this point in the history
  • Loading branch information
steventkrawczyk committed Jul 17, 2023
1 parent 8523246 commit 51f1165
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 4 deletions.
6 changes: 4 additions & 2 deletions prompttools/prompttest/prompttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ def prompttest(
This enables developers to create a prompt test suite from their evaluations.
"""

model_arguments["experiment_classname"] = experiment_classname

def prompttest_decorator(eval_fn: Callable):
@wraps(eval_fn)
def runs_test():
if prompt_template_file and user_input_file:
return run_prompt_template_test_from_files(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand All @@ -65,6 +64,7 @@ def runs_test():
)
elif prompt_template and user_input:
return run_prompt_template_test(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand All @@ -78,6 +78,7 @@ def runs_test():
)
elif system_prompt_file and human_messages_file:
return run_system_prompt_test_from_files(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand All @@ -91,6 +92,7 @@ def runs_test():
)
elif system_prompt and human_messages:
return run_system_prompt_test(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand Down
6 changes: 5 additions & 1 deletion prompttools/prompttest/runner/prompt_template_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,22 @@ def read(

@staticmethod
def _get_harness(
experiment_classname,
model_name: str,
prompt_template: str,
user_inputs: List[Dict[str, str]],
model_args: Dict[str, object],
) -> PromptTemplateExperimentationHarness:
return PromptTemplateExperimentationHarness(
model_name, [prompt_template], user_inputs, model_arguments=model_args
experiment_classname, model_name, [prompt_template], user_inputs, model_arguments=model_args
)


prompt_template_test_runner = PromptTemplateTestRunner()


def run_prompt_template_test(
experiment_classname,
model_name: str,
metric_name: str,
eval_fn: Callable,
Expand Down Expand Up @@ -115,6 +117,7 @@ def run_prompt_template_test(


def run_prompt_template_test_from_files(
experiment_classname,
model_name: str,
metric_name: str,
eval_fn: Callable,
Expand All @@ -133,6 +136,7 @@ def run_prompt_template_test_from_files(
prompt_template_file, user_input_file
)
return run_prompt_template_test(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand Down
1 change: 1 addition & 0 deletions prompttools/prompttest/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def rank(self, key: str, metric_name: str, is_average: bool) -> Dict[str, float]

@staticmethod
def _get_harness(
experiment_classname,
model_name: str,
prompt_template: str,
user_inputs: List[Dict[str, str]],
Expand Down
5 changes: 4 additions & 1 deletion prompttools/prompttest/runner/system_prompt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def read(

@staticmethod
def _get_harness(
experiment_classname,
model_name: str,
system_prompt: str,
human_messages: List[str],
model_args: Dict[str, object],
) -> SystemPromptExperimentationHarness:
return SystemPromptExperimentationHarness(
model_name, [system_prompt], human_messages, model_arguments=model_args
experiment_classname, model_name, [system_prompt], human_messages, model_arguments=model_args
)


Expand Down Expand Up @@ -113,6 +114,7 @@ def run_system_prompt_test(


def run_system_prompt_test_from_files(
experiment_classname,
model_name: str,
metric_name: str,
eval_fn: Callable,
Expand All @@ -131,6 +133,7 @@ def run_system_prompt_test_from_files(
system_prompt_file, human_messages_file
)
return run_system_prompt_test(
experiment_classname,
model_name,
metric_name,
eval_fn,
Expand Down

0 comments on commit 51f1165

Please sign in to comment.