diff --git a/examples/notebooks/LangChainRouterChainExperiment.ipynb b/examples/notebooks/LangChainRouterChainExperiment.ipynb index 1d0b7bb2..56105687 100644 --- a/examples/notebooks/LangChainRouterChainExperiment.ipynb +++ b/examples/notebooks/LangChainRouterChainExperiment.ipynb @@ -56,7 +56,7 @@ " \"name\": \"restaurant\",\n", " \"description\": \"Good for building a restaurant\",\n", " \"prompt_template\": restaurant_template,\n", - " }\n", + " },\n", " ],\n", "]\n", "\n", @@ -77,7 +77,7 @@ "\n", "experiment.evaluate(\"similar_to_expected\", semantic_similarity, expected=[expected] * 2)\n", "\n", - "experiment.visualize()\n" + "experiment.visualize()" ] } ], diff --git a/examples/notebooks/PaLM2Experiment.ipynb b/examples/notebooks/PaLM2Experiment.ipynb index 2900c3a8..6480d096 100644 --- a/examples/notebooks/PaLM2Experiment.ipynb +++ b/examples/notebooks/PaLM2Experiment.ipynb @@ -80,7 +80,7 @@ "\n", "\n", "palm.configure(api_key=os.environ[\"GOOGLE_PALM_API_KEY\"])\n", - "[m.name for m in palm.list_models() if 'generateText' in m.supported_generation_methods]" + "[m.name for m in palm.list_models() if \"generateText\" in m.supported_generation_methods]" ] }, { @@ -208,7 +208,7 @@ "metadata": {}, "outputs": [], "source": [ - "from prompttools.utils import semantic_similarity\n" + "from prompttools.utils import semantic_similarity" ] }, { diff --git a/prompttools/experiment/experiments/google_palm_experiment.py b/prompttools/experiment/experiments/google_palm_experiment.py index c9c1f1b8..bac292cb 100644 --- a/prompttools/experiment/experiments/google_palm_experiment.py +++ b/prompttools/experiment/experiments/google_palm_experiment.py @@ -9,6 +9,7 @@ except ImportError: palm = None +from prompttools.selector.prompt_selector import PromptSelector from prompttools.mock.mock import mock_palm_completion_fn from .experiment import Experiment from typing import Optional, Union, Iterable @@ -77,6 +78,14 @@ def __init__( else: self.completion_fn = self.palm_completion_fn palm.configure(api_key=os.environ["GOOGLE_PALM_API_KEY"]) + + # If we are using a prompt selector, we need to + # render the prompts from the selector + if isinstance(prompt[0], PromptSelector): + prompt = [selector.for_palm() for selector in prompt] + else: + prompt = prompt + self.all_args = dict( model=model, prompt=prompt, @@ -96,7 +105,7 @@ def palm_completion_fn(self, **input_args): @staticmethod def _extract_responses(completion_response: "palm.text.text_types.Completion") -> list[str]: # `# completion_response.result` will return the top response - return [candidate["output"] for candidate in completion_response.candidates] + return [candidate["output"] for candidate in completion_response.candidates][0] def _get_model_names(self): return [combo["model"] for combo in self.argument_combos] diff --git a/prompttools/playground/data_loader.py b/prompttools/playground/data_loader.py index 3467b4c1..4a96eeb6 100644 --- a/prompttools/playground/data_loader.py +++ b/prompttools/playground/data_loader.py @@ -57,6 +57,9 @@ def load_data( experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature]) elif model_type == "Anthropic": experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature]) + elif model_type == "Google PaLM": + experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature]) + return experiment.to_pandas_df() diff --git a/prompttools/playground/requirements.txt b/prompttools/playground/requirements.txt index e0a46cfd..ffbcaf82 100644 --- a/prompttools/playground/requirements.txt +++ b/prompttools/playground/requirements.txt @@ -3,4 +3,5 @@ jinja2 huggingface_hub llama-cpp-python anthropic -pyperclip \ No newline at end of file +pyperclip +google-generativeai \ No newline at end of file diff --git a/prompttools/selector/prompt_selector.py b/prompttools/selector/prompt_selector.py index a617e348..6e6b3f5c 100644 --- a/prompttools/selector/prompt_selector.py +++ b/prompttools/selector/prompt_selector.py @@ -17,6 +17,11 @@ RESPONSE: """ +PALM_TEMPLATE = """{instruction} + +{user_input} +""" + LLAMA_TEMPLATE = """[INST] <> {instruction} < @@ -58,3 +63,6 @@ def for_anthropic(self): return ANTHROPIC_TEMPLATE.format( HUMAN_PROMPT=HUMAN_PROMPT, instruction=self.instruction, user_input=self.user_input, AI_PROMPT=AI_PROMPT ) + + def for_palm(self): + return PALM_TEMPLATE.format(instruction=self.instruction, user_input=self.user_input)