Skip to content

Commit

Permalink
Add palm to streamlit
Browse files Browse the repository at this point in the history
  • Loading branch information
steventkrawczyk committed Aug 10, 2023
1 parent 75d9533 commit b7a2a8a
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 6 deletions.
4 changes: 2 additions & 2 deletions examples/notebooks/LangChainRouterChainExperiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
" \"name\": \"restaurant\",\n",
" \"description\": \"Good for building a restaurant\",\n",
" \"prompt_template\": restaurant_template,\n",
" }\n",
" },\n",
" ],\n",
"]\n",
"\n",
Expand All @@ -77,7 +77,7 @@
"\n",
"experiment.evaluate(\"similar_to_expected\", semantic_similarity, expected=[expected] * 2)\n",
"\n",
"experiment.visualize()\n"
"experiment.visualize()"
]
}
],
Expand Down
4 changes: 2 additions & 2 deletions examples/notebooks/PaLM2Experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
"\n",
"\n",
"palm.configure(api_key=os.environ[\"GOOGLE_PALM_API_KEY\"])\n",
"[m.name for m in palm.list_models() if 'generateText' in m.supported_generation_methods]"
"[m.name for m in palm.list_models() if \"generateText\" in m.supported_generation_methods]"
]
},
{
Expand Down Expand Up @@ -208,7 +208,7 @@
"metadata": {},
"outputs": [],
"source": [
"from prompttools.utils import semantic_similarity\n"
"from prompttools.utils import semantic_similarity"
]
},
{
Expand Down
11 changes: 10 additions & 1 deletion prompttools/experiment/experiments/google_palm_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
except ImportError:
palm = None

from prompttools.selector.prompt_selector import PromptSelector
from prompttools.mock.mock import mock_palm_completion_fn
from .experiment import Experiment
from typing import Optional, Union, Iterable
Expand Down Expand Up @@ -77,6 +78,14 @@ def __init__(
else:
self.completion_fn = self.palm_completion_fn
palm.configure(api_key=os.environ["GOOGLE_PALM_API_KEY"])

# If we are using a prompt selector, we need to
# render the prompts from the selector
if isinstance(prompt[0], PromptSelector):
prompt = [selector.for_palm() for selector in prompt]
else:
prompt = prompt

self.all_args = dict(
model=model,
prompt=prompt,
Expand All @@ -96,7 +105,7 @@ def palm_completion_fn(self, **input_args):
@staticmethod
def _extract_responses(completion_response: "palm.text.text_types.Completion") -> list[str]:
# `# completion_response.result` will return the top response
return [candidate["output"] for candidate in completion_response.candidates]
return [candidate["output"] for candidate in completion_response.candidates][0]

def _get_model_names(self):
return [combo["model"] for combo in self.argument_combos]
Expand Down
3 changes: 3 additions & 0 deletions prompttools/playground/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def load_data(
experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature])
elif model_type == "Anthropic":
experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature])
elif model_type == "Google PaLM":
experiment = EXPERIMENTS[model_type]([model], selectors, temperature=[temperature])

return experiment.to_pandas_df()


Expand Down
3 changes: 2 additions & 1 deletion prompttools/playground/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@ jinja2
huggingface_hub
llama-cpp-python
anthropic
pyperclip
pyperclip
google-generativeai
8 changes: 8 additions & 0 deletions prompttools/selector/prompt_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
RESPONSE:
"""

PALM_TEMPLATE = """{instruction}
{user_input}
"""

LLAMA_TEMPLATE = """<s>[INST] <<SYS>>
{instruction}
<</SYS>
Expand Down Expand Up @@ -58,3 +63,6 @@ def for_anthropic(self):
return ANTHROPIC_TEMPLATE.format(
HUMAN_PROMPT=HUMAN_PROMPT, instruction=self.instruction, user_input=self.user_input, AI_PROMPT=AI_PROMPT
)

def for_palm(self):
return PALM_TEMPLATE.format(instruction=self.instruction, user_input=self.user_input)

0 comments on commit b7a2a8a

Please sign in to comment.