diff --git a/examples/notebooks/LanceDBExperiment.ipynb b/examples/notebooks/LanceDBExperiment.ipynb
index 8d06b936..e6f598ea 100644
--- a/examples/notebooks/LanceDBExperiment.ipynb
+++ b/examples/notebooks/LanceDBExperiment.ipynb
@@ -46,7 +46,7 @@
},
"outputs": [],
"source": [
- "from prompttools.experiment import LanceDBExperiment\n"
+ "from prompttools.experiment import LanceDBExperiment"
]
},
{
@@ -69,10 +69,23 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 2,
"id": "821bbb21-292c-44e5-bdf0-ab05350acb36",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
+ " if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n",
+ "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n",
+ " ) < LooseVersion(\"1.15\"):\n",
+ "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/tensorflow/python/debug/cli/debugger_cli_common.py:19: DeprecationWarning: module 'sre_constants' is deprecated\n",
+ " import sre_constants\n"
+ ]
+ }
+ ],
"source": [
"from sentence_transformers import SentenceTransformer\n",
"\n",
@@ -100,7 +113,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 3,
"id": "9114cfbf",
"metadata": {
"ExecuteTime": {
@@ -171,7 +184,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 4,
"id": "83b33130",
"metadata": {
"ExecuteTime": {
@@ -184,21 +197,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
+ "WARNING: rate limit only support up to 3.10, proceeding without rate limiter\n",
"WARNING: rate limit only support up to 3.10, proceeding without rate limiter\n"
]
- },
- {
- "ename": "TypeError",
- "evalue": "query_builder() got an unexpected keyword argument 'emb_fn'",
- "output_type": "error",
- "traceback": [
- "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
- "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)",
- "Cell \u001B[0;32mIn[5], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mexperiment\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n",
- "File \u001B[0;32m~/hegel/prompttools/prompttools/experiment/experiments/lancedb_experiment.py:125\u001B[0m, in \u001B[0;36mLanceDBExperiment.run\u001B[0;34m(self, runs)\u001B[0m\n\u001B[1;32m 123\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(runs):\n\u001B[1;32m 124\u001B[0m input_args\u001B[38;5;241m.\u001B[39mappend(query_args)\n\u001B[0;32m--> 125\u001B[0m results\u001B[38;5;241m.\u001B[39mappend(\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlancedb_completion_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtable\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membedding_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43memb_fn\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mquery_args\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[1;32m 127\u001B[0m \u001B[38;5;66;03m# Clean up (remove table)\u001B[39;00m\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mclean_up:\n",
- "File \u001B[0;32m~/hegel/prompttools/prompttools/experiment/experiments/lancedb_experiment.py:133\u001B[0m, in \u001B[0;36mLanceDBExperiment.lancedb_completion_fn\u001B[0;34m(self, table, embedding_fn, **kwargs)\u001B[0m\n\u001B[1;32m 132\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mlancedb_completion_fn\u001B[39m(\u001B[38;5;28mself\u001B[39m, table, embedding_fn, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 133\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mquery_builder\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtable\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membedding_fn\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n",
- "\u001B[0;31mTypeError\u001B[0m: query_builder() got an unexpected keyword argument 'emb_fn'"
- ]
}
],
"source": [
@@ -219,7 +220,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 5,
"id": "01c7e682",
"metadata": {},
"outputs": [
@@ -246,7 +247,7 @@
"
| \n",
" text | \n",
" metric | \n",
- " embed_fn | \n",
+ " emb_fn | \n",
" top doc ids | \n",
" distances | \n",
" documents | \n",
@@ -265,15 +266,6 @@
" \n",
" 1 | \n",
" This is a query document | \n",
- " cosine | \n",
- " default | \n",
- " [id1, id3, id2] | \n",
- " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] | \n",
- " [This is a document, This is the document., This is another document] | \n",
- "
\n",
- " \n",
- " 2 | \n",
- " This is a query document | \n",
" l2 | \n",
" openai-ada-002 | \n",
" [id3, id1, id2] | \n",
@@ -281,16 +273,7 @@
" [This is the document., This is a document, This is another document] | \n",
"
\n",
" \n",
- " 3 | \n",
- " This is a query document | \n",
- " l2 | \n",
- " default | \n",
- " [id1, id3, id2] | \n",
- " [1.619940996170044, 1.6578971147537231, 1.6617801189422607] | \n",
- " [This is a document, This is the document., This is another document] | \n",
- "
\n",
- " \n",
- " 4 | \n",
+ " 2 | \n",
" This is a another query document | \n",
" cosine | \n",
" openai-ada-002 | \n",
@@ -299,8 +282,17 @@
" [This is another document, This is the document., This is a document] | \n",
"
\n",
" \n",
- " 5 | \n",
+ " 3 | \n",
" This is a another query document | \n",
+ " l2 | \n",
+ " openai-ada-002 | \n",
+ " [id3, id1, id2] | \n",
+ " [45.84406280517578, 49.12738037109375, 49.839256286621094] | \n",
+ " [This is the document., This is a document, This is another document] | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " This is a query document | \n",
" cosine | \n",
" default | \n",
" [id1, id3, id2] | \n",
@@ -308,13 +300,22 @@
" [This is a document, This is the document., This is another document] | \n",
"
\n",
" \n",
+ " 5 | \n",
+ " This is a query document | \n",
+ " l2 | \n",
+ " default | \n",
+ " [id1, id3, id2] | \n",
+ " [1.619940996170044, 1.6578971147537231, 1.6617801189422607] | \n",
+ " [This is a document, This is the document., This is another document] | \n",
+ "
\n",
+ " \n",
" 6 | \n",
" This is a another query document | \n",
- " l2 | \n",
- " openai-ada-002 | \n",
- " [id3, id1, id2] | \n",
- " [45.84406280517578, 49.12738037109375, 49.839256286621094] | \n",
- " [This is the document., This is a document, This is another document] | \n",
+ " cosine | \n",
+ " default | \n",
+ " [id1, id3, id2] | \n",
+ " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] | \n",
+ " [This is a document, This is the document., This is another document] | \n",
"
\n",
" \n",
" 7 | \n",
@@ -330,34 +331,34 @@
""
],
"text/plain": [
- " text metric embed_fn top doc ids \\\n",
+ " text metric emb_fn top doc ids \\\n",
"0 This is a query document cosine openai-ada-002 [id2, id3, id1] \n",
- "1 This is a query document cosine default [id1, id3, id2] \n",
- "2 This is a query document l2 openai-ada-002 [id3, id1, id2] \n",
- "3 This is a query document l2 default [id1, id3, id2] \n",
- "4 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n",
- "5 This is a another query document cosine default [id1, id3, id2] \n",
- "6 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n",
+ "1 This is a query document l2 openai-ada-002 [id3, id1, id2] \n",
+ "2 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n",
+ "3 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n",
+ "4 This is a query document cosine default [id1, id3, id2] \n",
+ "5 This is a query document l2 default [id1, id3, id2] \n",
+ "6 This is a another query document cosine default [id1, id3, id2] \n",
"7 This is a another query document l2 default [id1, id3, id2] \n",
"\n",
" distances \\\n",
"0 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
- "1 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
- "2 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
- "3 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
- "4 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
- "5 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
- "6 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "1 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "2 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
+ "3 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "4 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
+ "5 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
+ "6 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
"7 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
"\n",
" documents \n",
"0 [This is another document, This is the document., This is a document] \n",
- "1 [This is a document, This is the document., This is another document] \n",
- "2 [This is the document., This is a document, This is another document] \n",
- "3 [This is a document, This is the document., This is another document] \n",
- "4 [This is another document, This is the document., This is a document] \n",
+ "1 [This is the document., This is a document, This is another document] \n",
+ "2 [This is another document, This is the document., This is a document] \n",
+ "3 [This is the document., This is a document, This is another document] \n",
+ "4 [This is a document, This is the document., This is another document] \n",
"5 [This is a document, This is the document., This is another document] \n",
- "6 [This is the document., This is a document, This is another document] \n",
+ "6 [This is a document, This is the document., This is another document] \n",
"7 [This is a document, This is the document., This is another document] "
]
},
@@ -387,7 +388,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 16,
"id": "8ddbb951",
"metadata": {},
"outputs": [],
@@ -401,12 +402,13 @@
"}\n",
"\n",
"\n",
- "def measure_correlation(input_query: str, results: dict, metadata: dict) -> float:\n",
- " \"\"\"\n",
+ "def measure_correlation(row: \"pandas.core.series.Series\", ranking_column_name: str = \"top doc ids\") -> float:\n",
+ " r\"\"\"\n",
" A simple test that compares the expected ranking for a given query with the actual ranking produced\n",
" by the embedding function being tested.\n",
" \"\"\"\n",
- " correlation, _ = stats.spearmanr(results[\"ids\"], EXPECTED_RANKING[input_query])\n",
+ " input_query = row[\"text\"]\n",
+ " correlation, _ = stats.spearmanr(row[ranking_column_name], EXPECTED_RANKING[input_query])\n",
" return correlation"
]
},
@@ -420,19 +422,19 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 17,
"id": "e80dfeec",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
- "experiment.evaluate(\"ranking_correlation\", measure_correlation, input_key=\"text\")"
+ "experiment.evaluate(\"ranking_correlation\", measure_correlation)"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 18,
"id": "4d09c18e",
"metadata": {
"scrolled": true
@@ -459,140 +461,140 @@
" \n",
" \n",
" | \n",
- " ranking_correlation | \n",
" text | \n",
" metric | \n",
- " embed_fn | \n",
+ " emb_fn | \n",
" top doc ids | \n",
" distances | \n",
" documents | \n",
+ " ranking_correlation | \n",
"
\n",
" \n",
"
\n",
" \n",
" 0 | \n",
- " 0.5 | \n",
" This is a query document | \n",
" cosine | \n",
" openai-ada-002 | \n",
" [id2, id3, id1] | \n",
" [0.7633732557296753, 0.773878812789917, 0.7882261872291565] | \n",
" [This is another document, This is the document., This is a document] | \n",
+ " 0.5 | \n",
"
\n",
" \n",
" 1 | \n",
- " 1.0 | \n",
" This is a query document | \n",
- " cosine | \n",
- " default | \n",
- " [id1, id3, id2] | \n",
- " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] | \n",
- " [This is a document, This is the document., This is another document] | \n",
+ " l2 | \n",
+ " openai-ada-002 | \n",
+ " [id3, id1, id2] | \n",
+ " [45.84406280517578, 49.12738037109375, 49.839256286621094] | \n",
+ " [This is the document., This is a document, This is another document] | \n",
+ " -1.0 | \n",
"
\n",
" \n",
" 2 | \n",
- " -1.0 | \n",
- " This is a query document | \n",
+ " This is a another query document | \n",
+ " cosine | \n",
+ " openai-ada-002 | \n",
+ " [id2, id3, id1] | \n",
+ " [0.7633732557296753, 0.773878812789917, 0.7882261872291565] | \n",
+ " [This is another document, This is the document., This is a document] | \n",
+ " 1.0 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " This is a another query document | \n",
" l2 | \n",
" openai-ada-002 | \n",
" [id3, id1, id2] | \n",
" [45.84406280517578, 49.12738037109375, 49.839256286621094] | \n",
" [This is the document., This is a document, This is another document] | \n",
+ " -0.5 | \n",
"
\n",
" \n",
- " 3 | \n",
+ " 4 | \n",
+ " This is a query document | \n",
+ " cosine | \n",
+ " default | \n",
+ " [id1, id3, id2] | \n",
+ " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] | \n",
+ " [This is a document, This is the document., This is another document] | \n",
" 1.0 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
" This is a query document | \n",
" l2 | \n",
" default | \n",
" [id1, id3, id2] | \n",
" [1.619940996170044, 1.6578971147537231, 1.6617801189422607] | \n",
" [This is a document, This is the document., This is another document] | \n",
- "
\n",
- " \n",
- " 4 | \n",
" 1.0 | \n",
- " This is a another query document | \n",
- " cosine | \n",
- " openai-ada-002 | \n",
- " [id2, id3, id1] | \n",
- " [0.7633732557296753, 0.773878812789917, 0.7882261872291565] | \n",
- " [This is another document, This is the document., This is a document] | \n",
"
\n",
" \n",
- " 5 | \n",
- " 0.5 | \n",
+ " 6 | \n",
" This is a another query document | \n",
" cosine | \n",
" default | \n",
" [id1, id3, id2] | \n",
" [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] | \n",
" [This is a document, This is the document., This is another document] | \n",
- "
\n",
- " \n",
- " 6 | \n",
- " -0.5 | \n",
- " This is a another query document | \n",
- " l2 | \n",
- " openai-ada-002 | \n",
- " [id3, id1, id2] | \n",
- " [45.84406280517578, 49.12738037109375, 49.839256286621094] | \n",
- " [This is the document., This is a document, This is another document] | \n",
+ " 0.5 | \n",
"
\n",
" \n",
" 7 | \n",
- " 0.5 | \n",
" This is a another query document | \n",
" l2 | \n",
" default | \n",
" [id1, id3, id2] | \n",
" [1.619940996170044, 1.6578971147537231, 1.6617801189422607] | \n",
" [This is a document, This is the document., This is another document] | \n",
+ " 0.5 | \n",
"
\n",
" \n",
"\n",
""
],
"text/plain": [
- " ranking_correlation text metric \\\n",
- "0 0.5 This is a query document cosine \n",
- "1 1.0 This is a query document cosine \n",
- "2 -1.0 This is a query document l2 \n",
- "3 1.0 This is a query document l2 \n",
- "4 1.0 This is a another query document cosine \n",
- "5 0.5 This is a another query document cosine \n",
- "6 -0.5 This is a another query document l2 \n",
- "7 0.5 This is a another query document l2 \n",
- "\n",
- " embed_fn top doc ids \\\n",
- "0 openai-ada-002 [id2, id3, id1] \n",
- "1 default [id1, id3, id2] \n",
- "2 openai-ada-002 [id3, id1, id2] \n",
- "3 default [id1, id3, id2] \n",
- "4 openai-ada-002 [id2, id3, id1] \n",
- "5 default [id1, id3, id2] \n",
- "6 openai-ada-002 [id3, id1, id2] \n",
- "7 default [id1, id3, id2] \n",
+ " text metric emb_fn top doc ids \\\n",
+ "0 This is a query document cosine openai-ada-002 [id2, id3, id1] \n",
+ "1 This is a query document l2 openai-ada-002 [id3, id1, id2] \n",
+ "2 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n",
+ "3 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n",
+ "4 This is a query document cosine default [id1, id3, id2] \n",
+ "5 This is a query document l2 default [id1, id3, id2] \n",
+ "6 This is a another query document cosine default [id1, id3, id2] \n",
+ "7 This is a another query document l2 default [id1, id3, id2] \n",
"\n",
" distances \\\n",
"0 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
- "1 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
- "2 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
- "3 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
- "4 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
- "5 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
- "6 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "1 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "2 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n",
+ "3 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n",
+ "4 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
+ "5 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
+ "6 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n",
"7 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n",
"\n",
- " documents \n",
- "0 [This is another document, This is the document., This is a document] \n",
- "1 [This is a document, This is the document., This is another document] \n",
- "2 [This is the document., This is a document, This is another document] \n",
- "3 [This is a document, This is the document., This is another document] \n",
- "4 [This is another document, This is the document., This is a document] \n",
- "5 [This is a document, This is the document., This is another document] \n",
- "6 [This is the document., This is a document, This is another document] \n",
- "7 [This is a document, This is the document., This is another document] "
+ " documents \\\n",
+ "0 [This is another document, This is the document., This is a document] \n",
+ "1 [This is the document., This is a document, This is another document] \n",
+ "2 [This is another document, This is the document., This is a document] \n",
+ "3 [This is the document., This is a document, This is another document] \n",
+ "4 [This is a document, This is the document., This is another document] \n",
+ "5 [This is a document, This is the document., This is another document] \n",
+ "6 [This is a document, This is the document., This is another document] \n",
+ "7 [This is a document, This is the document., This is another document] \n",
+ "\n",
+ " ranking_correlation \n",
+ "0 0.5 \n",
+ "1 -1.0 \n",
+ "2 1.0 \n",
+ "3 -0.5 \n",
+ "4 1.0 \n",
+ "5 1.0 \n",
+ "6 0.5 \n",
+ "7 0.5 "
]
},
"metadata": {},