diff --git a/examples/notebooks/LanceDBExperiment.ipynb b/examples/notebooks/LanceDBExperiment.ipynb index 8d06b936..e6f598ea 100644 --- a/examples/notebooks/LanceDBExperiment.ipynb +++ b/examples/notebooks/LanceDBExperiment.ipynb @@ -46,7 +46,7 @@ }, "outputs": [], "source": [ - "from prompttools.experiment import LanceDBExperiment\n" + "from prompttools.experiment import LanceDBExperiment" ] }, { @@ -69,10 +69,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "821bbb21-292c-44e5-bdf0-ab05350acb36", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:4: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", + " if not hasattr(tensorboard, \"__version__\") or LooseVersion(\n", + "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/torch/utils/tensorboard/__init__.py:6: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.\n", + " ) < LooseVersion(\"1.15\"):\n", + "/Users/kevin/miniconda3/envs/prompttools/lib/python3.11/site-packages/tensorflow/python/debug/cli/debugger_cli_common.py:19: DeprecationWarning: module 'sre_constants' is deprecated\n", + " import sre_constants\n" + ] + } + ], "source": [ "from sentence_transformers import SentenceTransformer\n", "\n", @@ -100,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "9114cfbf", "metadata": { "ExecuteTime": { @@ -171,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "83b33130", "metadata": { "ExecuteTime": { @@ -184,21 +197,9 @@ "name": "stdout", "output_type": "stream", "text": [ + "WARNING: rate limit only support up to 3.10, proceeding without rate limiter\n", "WARNING: rate limit only support up to 3.10, proceeding without rate limiter\n" ] - }, - { - "ename": "TypeError", - "evalue": "query_builder() got an unexpected keyword argument 'emb_fn'", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mTypeError\u001B[0m Traceback (most recent call last)", - "Cell \u001B[0;32mIn[5], line 1\u001B[0m\n\u001B[0;32m----> 1\u001B[0m \u001B[43mexperiment\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun\u001B[49m\u001B[43m(\u001B[49m\u001B[43m)\u001B[49m\n", - "File \u001B[0;32m~/hegel/prompttools/prompttools/experiment/experiments/lancedb_experiment.py:125\u001B[0m, in \u001B[0;36mLanceDBExperiment.run\u001B[0;34m(self, runs)\u001B[0m\n\u001B[1;32m 123\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m _ \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(runs):\n\u001B[1;32m 124\u001B[0m input_args\u001B[38;5;241m.\u001B[39mappend(query_args)\n\u001B[0;32m--> 125\u001B[0m results\u001B[38;5;241m.\u001B[39mappend(\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mlancedb_completion_fn\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mtable\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membedding_fn\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43memb_fn\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mquery_args\u001B[49m\u001B[43m)\u001B[49m)\n\u001B[1;32m 127\u001B[0m \u001B[38;5;66;03m# Clean up (remove table)\u001B[39;00m\n\u001B[1;32m 128\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mclean_up:\n", - "File \u001B[0;32m~/hegel/prompttools/prompttools/experiment/experiments/lancedb_experiment.py:133\u001B[0m, in \u001B[0;36mLanceDBExperiment.lancedb_completion_fn\u001B[0;34m(self, table, embedding_fn, **kwargs)\u001B[0m\n\u001B[1;32m 132\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mlancedb_completion_fn\u001B[39m(\u001B[38;5;28mself\u001B[39m, table, embedding_fn, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m--> 133\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[43mquery_builder\u001B[49m\u001B[43m(\u001B[49m\u001B[43mtable\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43membedding_fn\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[38;5;241;43m*\u001B[39;49m\u001B[43mkwargs\u001B[49m\u001B[43m)\u001B[49m\n", - "\u001B[0;31mTypeError\u001B[0m: query_builder() got an unexpected keyword argument 'emb_fn'" - ] } ], "source": [ @@ -219,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "id": "01c7e682", "metadata": {}, "outputs": [ @@ -246,7 +247,7 @@ " \n", " text\n", " metric\n", - " embed_fn\n", + " emb_fn\n", " top doc ids\n", " distances\n", " documents\n", @@ -265,15 +266,6 @@ " \n", " 1\n", " This is a query document\n", - " cosine\n", - " default\n", - " [id1, id3, id2]\n", - " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304]\n", - " [This is a document, This is the document., This is another document]\n", - " \n", - " \n", - " 2\n", - " This is a query document\n", " l2\n", " openai-ada-002\n", " [id3, id1, id2]\n", @@ -281,16 +273,7 @@ " [This is the document., This is a document, This is another document]\n", " \n", " \n", - " 3\n", - " This is a query document\n", - " l2\n", - " default\n", - " [id1, id3, id2]\n", - " [1.619940996170044, 1.6578971147537231, 1.6617801189422607]\n", - " [This is a document, This is the document., This is another document]\n", - " \n", - " \n", - " 4\n", + " 2\n", " This is a another query document\n", " cosine\n", " openai-ada-002\n", @@ -299,8 +282,17 @@ " [This is another document, This is the document., This is a document]\n", " \n", " \n", - " 5\n", + " 3\n", " This is a another query document\n", + " l2\n", + " openai-ada-002\n", + " [id3, id1, id2]\n", + " [45.84406280517578, 49.12738037109375, 49.839256286621094]\n", + " [This is the document., This is a document, This is another document]\n", + " \n", + " \n", + " 4\n", + " This is a query document\n", " cosine\n", " default\n", " [id1, id3, id2]\n", @@ -308,13 +300,22 @@ " [This is a document, This is the document., This is another document]\n", " \n", " \n", + " 5\n", + " This is a query document\n", + " l2\n", + " default\n", + " [id1, id3, id2]\n", + " [1.619940996170044, 1.6578971147537231, 1.6617801189422607]\n", + " [This is a document, This is the document., This is another document]\n", + " \n", + " \n", " 6\n", " This is a another query document\n", - " l2\n", - " openai-ada-002\n", - " [id3, id1, id2]\n", - " [45.84406280517578, 49.12738037109375, 49.839256286621094]\n", - " [This is the document., This is a document, This is another document]\n", + " cosine\n", + " default\n", + " [id1, id3, id2]\n", + " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304]\n", + " [This is a document, This is the document., This is another document]\n", " \n", " \n", " 7\n", @@ -330,34 +331,34 @@ "" ], "text/plain": [ - " text metric embed_fn top doc ids \\\n", + " text metric emb_fn top doc ids \\\n", "0 This is a query document cosine openai-ada-002 [id2, id3, id1] \n", - "1 This is a query document cosine default [id1, id3, id2] \n", - "2 This is a query document l2 openai-ada-002 [id3, id1, id2] \n", - "3 This is a query document l2 default [id1, id3, id2] \n", - "4 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n", - "5 This is a another query document cosine default [id1, id3, id2] \n", - "6 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n", + "1 This is a query document l2 openai-ada-002 [id3, id1, id2] \n", + "2 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n", + "3 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n", + "4 This is a query document cosine default [id1, id3, id2] \n", + "5 This is a query document l2 default [id1, id3, id2] \n", + "6 This is a another query document cosine default [id1, id3, id2] \n", "7 This is a another query document l2 default [id1, id3, id2] \n", "\n", " distances \\\n", "0 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", - "1 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", - "2 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", - "3 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", - "4 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", - "5 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", - "6 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "1 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "2 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", + "3 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "4 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", + "5 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", + "6 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", "7 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", "\n", " documents \n", "0 [This is another document, This is the document., This is a document] \n", - "1 [This is a document, This is the document., This is another document] \n", - "2 [This is the document., This is a document, This is another document] \n", - "3 [This is a document, This is the document., This is another document] \n", - "4 [This is another document, This is the document., This is a document] \n", + "1 [This is the document., This is a document, This is another document] \n", + "2 [This is another document, This is the document., This is a document] \n", + "3 [This is the document., This is a document, This is another document] \n", + "4 [This is a document, This is the document., This is another document] \n", "5 [This is a document, This is the document., This is another document] \n", - "6 [This is the document., This is a document, This is another document] \n", + "6 [This is a document, This is the document., This is another document] \n", "7 [This is a document, This is the document., This is another document] " ] }, @@ -387,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 16, "id": "8ddbb951", "metadata": {}, "outputs": [], @@ -401,12 +402,13 @@ "}\n", "\n", "\n", - "def measure_correlation(input_query: str, results: dict, metadata: dict) -> float:\n", - " \"\"\"\n", + "def measure_correlation(row: \"pandas.core.series.Series\", ranking_column_name: str = \"top doc ids\") -> float:\n", + " r\"\"\"\n", " A simple test that compares the expected ranking for a given query with the actual ranking produced\n", " by the embedding function being tested.\n", " \"\"\"\n", - " correlation, _ = stats.spearmanr(results[\"ids\"], EXPECTED_RANKING[input_query])\n", + " input_query = row[\"text\"]\n", + " correlation, _ = stats.spearmanr(row[ranking_column_name], EXPECTED_RANKING[input_query])\n", " return correlation" ] }, @@ -420,19 +422,19 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 17, "id": "e80dfeec", "metadata": { "scrolled": true }, "outputs": [], "source": [ - "experiment.evaluate(\"ranking_correlation\", measure_correlation, input_key=\"text\")" + "experiment.evaluate(\"ranking_correlation\", measure_correlation)" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 18, "id": "4d09c18e", "metadata": { "scrolled": true @@ -459,140 +461,140 @@ " \n", " \n", " \n", - " ranking_correlation\n", " text\n", " metric\n", - " embed_fn\n", + " emb_fn\n", " top doc ids\n", " distances\n", " documents\n", + " ranking_correlation\n", " \n", " \n", " \n", " \n", " 0\n", - " 0.5\n", " This is a query document\n", " cosine\n", " openai-ada-002\n", " [id2, id3, id1]\n", " [0.7633732557296753, 0.773878812789917, 0.7882261872291565]\n", " [This is another document, This is the document., This is a document]\n", + " 0.5\n", " \n", " \n", " 1\n", - " 1.0\n", " This is a query document\n", - " cosine\n", - " default\n", - " [id1, id3, id2]\n", - " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304]\n", - " [This is a document, This is the document., This is another document]\n", + " l2\n", + " openai-ada-002\n", + " [id3, id1, id2]\n", + " [45.84406280517578, 49.12738037109375, 49.839256286621094]\n", + " [This is the document., This is a document, This is another document]\n", + " -1.0\n", " \n", " \n", " 2\n", - " -1.0\n", - " This is a query document\n", + " This is a another query document\n", + " cosine\n", + " openai-ada-002\n", + " [id2, id3, id1]\n", + " [0.7633732557296753, 0.773878812789917, 0.7882261872291565]\n", + " [This is another document, This is the document., This is a document]\n", + " 1.0\n", + " \n", + " \n", + " 3\n", + " This is a another query document\n", " l2\n", " openai-ada-002\n", " [id3, id1, id2]\n", " [45.84406280517578, 49.12738037109375, 49.839256286621094]\n", " [This is the document., This is a document, This is another document]\n", + " -0.5\n", " \n", " \n", - " 3\n", + " 4\n", + " This is a query document\n", + " cosine\n", + " default\n", + " [id1, id3, id2]\n", + " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304]\n", + " [This is a document, This is the document., This is another document]\n", " 1.0\n", + " \n", + " \n", + " 5\n", " This is a query document\n", " l2\n", " default\n", " [id1, id3, id2]\n", " [1.619940996170044, 1.6578971147537231, 1.6617801189422607]\n", " [This is a document, This is the document., This is another document]\n", - " \n", - " \n", - " 4\n", " 1.0\n", - " This is a another query document\n", - " cosine\n", - " openai-ada-002\n", - " [id2, id3, id1]\n", - " [0.7633732557296753, 0.773878812789917, 0.7882261872291565]\n", - " [This is another document, This is the document., This is a document]\n", " \n", " \n", - " 5\n", - " 0.5\n", + " 6\n", " This is a another query document\n", " cosine\n", " default\n", " [id1, id3, id2]\n", " [0.8099705576896667, 0.8289484977722168, 0.8308900594711304]\n", " [This is a document, This is the document., This is another document]\n", - " \n", - " \n", - " 6\n", - " -0.5\n", - " This is a another query document\n", - " l2\n", - " openai-ada-002\n", - " [id3, id1, id2]\n", - " [45.84406280517578, 49.12738037109375, 49.839256286621094]\n", - " [This is the document., This is a document, This is another document]\n", + " 0.5\n", " \n", " \n", " 7\n", - " 0.5\n", " This is a another query document\n", " l2\n", " default\n", " [id1, id3, id2]\n", " [1.619940996170044, 1.6578971147537231, 1.6617801189422607]\n", " [This is a document, This is the document., This is another document]\n", + " 0.5\n", " \n", " \n", "\n", "" ], "text/plain": [ - " ranking_correlation text metric \\\n", - "0 0.5 This is a query document cosine \n", - "1 1.0 This is a query document cosine \n", - "2 -1.0 This is a query document l2 \n", - "3 1.0 This is a query document l2 \n", - "4 1.0 This is a another query document cosine \n", - "5 0.5 This is a another query document cosine \n", - "6 -0.5 This is a another query document l2 \n", - "7 0.5 This is a another query document l2 \n", - "\n", - " embed_fn top doc ids \\\n", - "0 openai-ada-002 [id2, id3, id1] \n", - "1 default [id1, id3, id2] \n", - "2 openai-ada-002 [id3, id1, id2] \n", - "3 default [id1, id3, id2] \n", - "4 openai-ada-002 [id2, id3, id1] \n", - "5 default [id1, id3, id2] \n", - "6 openai-ada-002 [id3, id1, id2] \n", - "7 default [id1, id3, id2] \n", + " text metric emb_fn top doc ids \\\n", + "0 This is a query document cosine openai-ada-002 [id2, id3, id1] \n", + "1 This is a query document l2 openai-ada-002 [id3, id1, id2] \n", + "2 This is a another query document cosine openai-ada-002 [id2, id3, id1] \n", + "3 This is a another query document l2 openai-ada-002 [id3, id1, id2] \n", + "4 This is a query document cosine default [id1, id3, id2] \n", + "5 This is a query document l2 default [id1, id3, id2] \n", + "6 This is a another query document cosine default [id1, id3, id2] \n", + "7 This is a another query document l2 default [id1, id3, id2] \n", "\n", " distances \\\n", "0 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", - "1 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", - "2 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", - "3 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", - "4 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", - "5 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", - "6 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "1 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "2 [0.7633732557296753, 0.773878812789917, 0.7882261872291565] \n", + "3 [45.84406280517578, 49.12738037109375, 49.839256286621094] \n", + "4 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", + "5 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", + "6 [0.8099705576896667, 0.8289484977722168, 0.8308900594711304] \n", "7 [1.619940996170044, 1.6578971147537231, 1.6617801189422607] \n", "\n", - " documents \n", - "0 [This is another document, This is the document., This is a document] \n", - "1 [This is a document, This is the document., This is another document] \n", - "2 [This is the document., This is a document, This is another document] \n", - "3 [This is a document, This is the document., This is another document] \n", - "4 [This is another document, This is the document., This is a document] \n", - "5 [This is a document, This is the document., This is another document] \n", - "6 [This is the document., This is a document, This is another document] \n", - "7 [This is a document, This is the document., This is another document] " + " documents \\\n", + "0 [This is another document, This is the document., This is a document] \n", + "1 [This is the document., This is a document, This is another document] \n", + "2 [This is another document, This is the document., This is a document] \n", + "3 [This is the document., This is a document, This is another document] \n", + "4 [This is a document, This is the document., This is another document] \n", + "5 [This is a document, This is the document., This is another document] \n", + "6 [This is a document, This is the document., This is another document] \n", + "7 [This is a document, This is the document., This is another document] \n", + "\n", + " ranking_correlation \n", + "0 0.5 \n", + "1 -1.0 \n", + "2 1.0 \n", + "3 -0.5 \n", + "4 1.0 \n", + "5 1.0 \n", + "6 0.5 \n", + "7 0.5 " ] }, "metadata": {},