Skip to content

Commit

Permalink
Fix Replicate kwargs and example
Browse files Browse the repository at this point in the history
  • Loading branch information
NivekT committed Aug 29, 2023
1 parent 2964081 commit 4b2e0f2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 26 deletions.
44 changes: 23 additions & 21 deletions examples/notebooks/ReplicateExperiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
"source": [
"import os\n",
"\n",
"os.environ[\"DEBUG\"] = \"\" # Set this to \"\" to call Replicate's API, otherwise a mock function is used\n",
"os.environ[\"REPLICATE_API_TOKEN\"] = \"r8_AlGz9ofN7OAwmiIcCqhxa6b4Duqgk4T3IyaaI\""
"os.environ[\"DEBUG\"] = \"\" # Set this to \"\" to call Replicate's API, \"1\" to call the mock function\n",
"os.environ[\"REPLICATE_API_TOKEN\"] = \"\" # Set your API token here"
]
},
{
Expand Down Expand Up @@ -86,29 +86,39 @@
"id": "3babfe5a",
"metadata": {},
"source": [
"Next, we create our test inputs. We can iterate over models, inputs, and configurations like temperature."
"Next, we create our test inputs. We can iterate over models, inputs, and configurations."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"id": "347590cf",
"metadata": {},
"outputs": [],
"source": [
"sd1 = \"stability-ai/stable-diffusion:27b93a2413e7f36cd83da926f3656280b2931564ff050bf9575f1fdf9bcd7478\"\n",
"sd1 = \"stability-ai/stable-diffusion:ac732df83cea7fff18b8472768c88ad041fa750ff7682a21affe81863cbe77e4\"\n",
"models = [sd1] # You can specify multiple models here\n",
"input_kwargs = {\"prompt\": [\"a 19th century portrait of a wombat gentleman\", \"a 22nd century portrait of a robotic dog\"]}\n",
"input_kwargs = {\"prompt\": [\"a 19th century portrait of a wombat gentleman\", \"a 22nd century portrait of a wombat gentleman\"]}\n",
"model_specific_kwargs = {sd1: {}}\n",
"\n",
"experiment = ReplicateExperiment(models, input_kwargs, model_specific_kwargs)"
]
},
{
"cell_type": "markdown",
"id": "e3bbccfa",
"metadata": {},
"source": [
"We can then run the experiment to get results."
]
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "ca01ff10",
"metadata": {},
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
Expand All @@ -126,14 +136,14 @@
" <tr>\n",
" <th>0</th>\n",
" <td>{'prompt': 'a 19th century portrait of a wombat gentleman'}</td>\n",
" <td>5.001268</td>\n",
" <td><img src=\"https://pbxt.replicate.delivery/3xcGYj0qhCZyD5TiJ9W1zcjbCh4OkjggUYBf9M6c2zyY9avIA/out-0.png\" width=\"100\"/></td>\n",
" <td>4.154633</td>\n",
" <td><img src=\"https://pbxt.replicate.delivery/WkGqWEOKhlKVN5tnfn1Dk5X7GIspINVTCiXSifFODTrtQ2eiA/out-0.png\" width=\"300\"/></td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>{'prompt': 'a 22nd century portrait of a robotic dog'}</td>\n",
" <td>5.024777</td>\n",
" <td><img src=\"https://pbxt.replicate.delivery/gmUW2zllT0baHVnQ0cwsCyAEuDgETzm9MvTRId6IiXpteavIA/out-0.png\" width=\"100\"/></td>\n",
" <td>{'prompt': 'a 22nd century portrait of a wombat gentleman'}</td>\n",
" <td>6.877061</td>\n",
" <td><img src=\"https://pbxt.replicate.delivery/eiqUWyZ7qp3mUKI1QEac7ZXcxSwBKN9YNGh1DElB8ZOaIbvIA/out-0.png\" width=\"300\"/></td>\n",
" </tr>\n",
" </tbody>\n",
"</table>"
Expand All @@ -151,14 +161,6 @@
"experiment.visualize()"
]
},
{
"cell_type": "markdown",
"id": "f3fa5450",
"metadata": {},
"source": [
"We can then run the experiment to get results."
]
},
{
"cell_type": "markdown",
"id": "266c13eb",
Expand Down
13 changes: 8 additions & 5 deletions prompttools/experiment/experiments/replicate_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import logging
import itertools
from functools import partial

from prompttools.mock.mock import mock_replicate_stable_diffusion_completion_fn
from IPython.display import display, HTML
Expand Down Expand Up @@ -76,23 +77,25 @@ def prepare(self):

@staticmethod
def replicate_completion_fn(model_version: str, **kwargs):
return replicate.run(model_version, input=kwargs)
return replicate.run(model_version, **kwargs)

@staticmethod
def _extract_responses(output: dict) -> list[str]:
return output[0]

@staticmethod
def _image_tag(url):
return f'<img src="{url}" width="100"/>'
def _image_tag(url, image_width):
return f'<img src="{url}" width="{image_width}"/>'

def visualize(self, get_all_cols: bool = False, pivot: bool = False, pivot_columns: list = []) -> None:
def visualize(
self, get_all_cols: bool = False, pivot: bool = False, pivot_columns: list = [], image_width=300
) -> None:
if pivot:
table = self.pivot_table(pivot_columns, get_all_cols=get_all_cols)
else:
table = self.get_table(get_all_cols)

images = table["response"].apply(self._image_tag)
images = table["response"].apply(partial(self._image_tag, image_width=image_width))
table["images"] = images

if is_interactive():
Expand Down

0 comments on commit 4b2e0f2

Please sign in to comment.