From c2963c33ee7ac79afb18d9ef136adc22a90479ed Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 9 Oct 2023 09:37:37 +0200 Subject: [PATCH 1/3] refactor API (#56) --- .github/workflows/test.yml | 2 +- examples/rest_api/rest_api.ipynb | 254 ++++++++++++++--------- examples/s3_documents/s3_documents.ipynb | 83 ++++++-- pyproject.toml | 3 +- ragna/_api.py | 247 ---------------------- ragna/_api/__init__.py | 1 + ragna/_api/core.py | 150 +++++++++++++ ragna/_api/schemas.py | 119 +++++++++++ ragna/core/_document.py | 2 +- ragna/core/_rag.py | 6 +- scripts/add_chats.py | 25 ++- tests/test_importable.py | 6 +- 12 files changed, 512 insertions(+), 386 deletions(-) delete mode 100644 ragna/_api.py create mode 100644 ragna/_api/__init__.py create mode 100644 ragna/_api/core.py create mode 100644 ragna/_api/schemas.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index cde3353e..3625d045 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ on: - ".github/**" - "ragna/**" - "tests/**" - - "environment.yaml" + - "environment.yml" - "pyproject.toml" jobs: diff --git a/examples/rest_api/rest_api.ipynb b/examples/rest_api/rest_api.ipynb index b481d1f3..aacb3f16 100644 --- a/examples/rest_api/rest_api.ipynb +++ b/examples/rest_api/rest_api.ipynb @@ -44,14 +44,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"GET /health HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:42394 - \"GET /health HTTP/1.1\" 200 OK\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO: Started server process [19776]\n", + "INFO: Started server process [39690]\n", "INFO: Waiting for application startup.\n", "INFO: Application startup complete.\n", "INFO: Uvicorn running on http://127.0.0.1:31476 (Press CTRL+C to quit)\n" @@ -137,7 +137,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"GET /chats?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"GET /chats?user=Ragna HTTP/1.1\" 200 OK\n", "[]\n" ] } @@ -167,7 +167,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"GET /components?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"GET /components?user=Ragna HTTP/1.1\" 200 OK\n", "{'assistants': ['Ragna/DemoAssistant'],\n", " 'source_storages': ['Ragna/DemoSourceStorage']}\n" ] @@ -229,20 +229,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"GET /document/new?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", - "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6IjhiOWQ4MjVhLTQ3OTctNDQxMS1iODc2LWVjM2MxMTM0OThmYyIsImV4cCI6MTY5NjM2Nzk3OS42MjU3MTN9.xT-1YI0awv2hixcHJZaDGqvYP7r8nYFrKB9DygOyHdw'},\n", - " 'document': {'id': '8b9d825a-4797-4411-b876-ec3c113498fc',\n", + "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", + "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImRiZjdiOWE2LWEzMGItNGEyNS1iNGJhLTE3ODIwNjNhYzJhZiIsImV4cCI6MTY5NjQ1NTg2MS40MjkwMjE0fQ.gjMsP0qZ36hO9fAOFPxeKom-oYsyIyxggG82oP0svoc'},\n", + " 'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", " 'name': 'document0.txt'},\n", - " 'url': 'http://127.0.0.1:31476/document/upload'}\n" + " 'url': 'http://127.0.0.1:31476/document'}\n" ] } ], "source": [ "path = paths[0]\n", "\n", - "response = await client.get(\n", - " f\"{URL}/document/new\", params={\"user\": USER, \"name\": path.name}\n", - ")\n", + "response = await client.get(f\"{URL}/document\", params={\"user\": USER, \"name\": path.name})\n", "document_info = response.json()\n", "document = document_info[\"document\"]\n", "pprint(document_info)" @@ -266,7 +264,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"POST /document/upload HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n" ] } ], @@ -297,18 +295,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"GET /document/new?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:51472 - \"POST /document/upload HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:51472 - \"GET /document/new?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:51472 - \"POST /document/upload HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n" ] }, { "data": { "text/plain": [ - "[{'id': '8b9d825a-4797-4411-b876-ec3c113498fc', 'name': 'document0.txt'},\n", - " {'id': '2e0c644b-932b-431e-acd1-623fae885dcc', 'name': 'document1.txt'},\n", - " {'id': '63fd15fc-2818-4e10-bbb6-d2724e21afc1', 'name': 'document2.txt'}]" + "[{'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af', 'name': 'document0.txt'},\n", + " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3', 'name': 'document1.txt'},\n", + " {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04', 'name': 'document2.txt'}]" ] }, "execution_count": 9, @@ -321,9 +319,7 @@ "\n", "for path in paths[1:]:\n", " document_info = (\n", - " await client.get(\n", - " f\"{URL}/document/new\", params={\"user\": USER, \"name\": path.name}\n", - " )\n", + " await client.get(f\"{URL}/document\", params={\"user\": USER, \"name\": path.name})\n", " ).json()\n", " documents.append(document_info[\"document\"])\n", " await client.post(\n", @@ -353,31 +349,34 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"POST /chat/new?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:42394 - \"POST /chats?user=Ragna HTTP/1.1\" 200 OK\n", "{'closed': False,\n", - " 'id': 'c9acfda3-bfc5-422c-835d-09c17ba208aa',\n", + " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", " 'messages': [],\n", - " 'metadata': {'assistant': \"\",\n", - " 'document_ids': ['8b9d825a-4797-4411-b876-ec3c113498fc',\n", - " '2e0c644b-932b-431e-acd1-623fae885dcc',\n", - " '63fd15fc-2818-4e10-bbb6-d2724e21afc1'],\n", + " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", + " 'documents': [{'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'name': 'document0.txt'},\n", + " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " 'name': 'document1.txt'},\n", + " {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'name': 'document2.txt'}],\n", " 'name': 'Ragna REST API example',\n", " 'params': {},\n", - " 'source_storage': '\"},\n", + " 'source_storage': 'Ragna/DemoSourceStorage'},\n", " 'started': False}\n" ] } ], "source": [ "response = await client.post(\n", - " f\"{URL}/chat/new\",\n", + " f\"{URL}/chats\",\n", " params={\"user\": USER},\n", " json={\n", " \"name\": \"Ragna REST API example\",\n", " \"document_ids\": [d[\"id\"] for d in documents],\n", " \"source_storage\": SOURCE_STORAGE,\n", " \"assistant\": ASSISTANT,\n", + " \"params\": {},\n", " },\n", ")\n", "chat = response.json()\n", @@ -402,76 +401,32 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:51472 - \"POST /chat/c9acfda3-bfc5-422c-835d-09c17ba208aa/start?user=Ragna HTTP/1.1\" 500 Internal Server Error\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "ERROR: Exception in ASGI application\n", - "Traceback (most recent call last):\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/uvicorn/protocols/http/httptools_impl.py\", line 426, in run_asgi\n", - " result = await app( # type: ignore[func-returns-value]\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/uvicorn/middleware/proxy_headers.py\", line 84, in __call__\n", - " return await self.app(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/fastapi/applications.py\", line 290, in __call__\n", - " await super().__call__(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/applications.py\", line 122, in __call__\n", - " await self.middleware_stack(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/middleware/errors.py\", line 184, in __call__\n", - " raise exc\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/middleware/errors.py\", line 162, in __call__\n", - " await self.app(scope, receive, _send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/middleware/exceptions.py\", line 79, in __call__\n", - " raise exc\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/middleware/exceptions.py\", line 68, in __call__\n", - " await self.app(scope, receive, sender)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py\", line 20, in __call__\n", - " raise e\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/fastapi/middleware/asyncexitstack.py\", line 17, in __call__\n", - " await self.app(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/routing.py\", line 718, in __call__\n", - " await route.handle(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/routing.py\", line 276, in handle\n", - " await self.app(scope, receive, send)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/starlette/routing.py\", line 66, in app\n", - " response = await func(request)\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/fastapi/routing.py\", line 231, in app\n", - " solved_result = await solve_dependencies(\n", - " File \"/home/philip/.conda/envs/ragna-dev/lib/python3.9/site-packages/fastapi/dependencies/utils.py\", line 622, in solve_dependencies\n", - " solved = await call(**sub_values)\n", - " File \"/home/philip/git/ora/ragna/_api.py\", line 220, in _get_chat\n", - " return rag._get_chat(user=user, id=id)\n", - " File \"/home/philip/git/ora/ragna/core/_rag.py\", line 174, in _get_chat\n", - " self._get_chats(user=user)\n", - " File \"/home/philip/git/ora/ragna/core/_rag.py\", line 123, in _get_chats\n", - " chats = [\n", - " File \"/home/philip/git/ora/ragna/core/_rag.py\", line 137, in \n", - " source_storage=self._parse_component(\n", - "AttributeError: 'Rag' object has no attribute '_parse_component'\n" - ] - }, - { - "ename": "JSONDecodeError", - "evalue": "Expecting value: line 1 column 1 (char 0)", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mJSONDecodeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[11], line 4\u001b[0m\n\u001b[1;32m 1\u001b[0m CHAT_ID \u001b[38;5;241m=\u001b[39m chat[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mid\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n\u001b[1;32m 3\u001b[0m response \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m client\u001b[38;5;241m.\u001b[39mpost(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mURL\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/chat/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mCHAT_ID\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m/start\u001b[39m\u001b[38;5;124m\"\u001b[39m, params\u001b[38;5;241m=\u001b[39m{\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124muser\u001b[39m\u001b[38;5;124m\"\u001b[39m: USER})\n\u001b[0;32m----> 4\u001b[0m chat \u001b[38;5;241m=\u001b[39m \u001b[43mresponse\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjson\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m pprint(chat)\n", - "File \u001b[0;32m~/.conda/envs/ragna-dev/lib/python3.9/site-packages/httpx/_models.py:756\u001b[0m, in \u001b[0;36mResponse.json\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 754\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m encoding \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 755\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m jsonlib\u001b[38;5;241m.\u001b[39mloads(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontent\u001b[38;5;241m.\u001b[39mdecode(encoding), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 756\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mjsonlib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mloads\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtext\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/ragna-dev/lib/python3.9/json/__init__.py:346\u001b[0m, in \u001b[0;36mloads\u001b[0;34m(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)\u001b[0m\n\u001b[1;32m 341\u001b[0m s \u001b[38;5;241m=\u001b[39m s\u001b[38;5;241m.\u001b[39mdecode(detect_encoding(s), \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msurrogatepass\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 343\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 344\u001b[0m parse_int \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m parse_float \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m\n\u001b[1;32m 345\u001b[0m parse_constant \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m object_pairs_hook \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m kw):\n\u001b[0;32m--> 346\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_default_decoder\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdecode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 347\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 348\u001b[0m \u001b[38;5;28mcls\u001b[39m \u001b[38;5;241m=\u001b[39m JSONDecoder\n", - "File \u001b[0;32m~/.conda/envs/ragna-dev/lib/python3.9/json/decoder.py:337\u001b[0m, in \u001b[0;36mJSONDecoder.decode\u001b[0;34m(self, s, _w)\u001b[0m\n\u001b[1;32m 332\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdecode\u001b[39m(\u001b[38;5;28mself\u001b[39m, s, _w\u001b[38;5;241m=\u001b[39mWHITESPACE\u001b[38;5;241m.\u001b[39mmatch):\n\u001b[1;32m 333\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the Python representation of ``s`` (a ``str`` instance\u001b[39;00m\n\u001b[1;32m 334\u001b[0m \u001b[38;5;124;03m containing a JSON document).\u001b[39;00m\n\u001b[1;32m 335\u001b[0m \n\u001b[1;32m 336\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 337\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mraw_decode\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m_w\u001b[49m\u001b[43m(\u001b[49m\u001b[43ms\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mend\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 338\u001b[0m end \u001b[38;5;241m=\u001b[39m _w(s, end)\u001b[38;5;241m.\u001b[39mend()\n\u001b[1;32m 339\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m end \u001b[38;5;241m!=\u001b[39m \u001b[38;5;28mlen\u001b[39m(s):\n", - "File \u001b[0;32m~/.conda/envs/ragna-dev/lib/python3.9/json/decoder.py:355\u001b[0m, in \u001b[0;36mJSONDecoder.raw_decode\u001b[0;34m(self, s, idx)\u001b[0m\n\u001b[1;32m 353\u001b[0m obj, end \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mscan_once(s, idx)\n\u001b[1;32m 354\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mStopIteration\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n\u001b[0;32m--> 355\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m JSONDecodeError(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mExpecting value\u001b[39m\u001b[38;5;124m\"\u001b[39m, s, err\u001b[38;5;241m.\u001b[39mvalue) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 356\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m obj, end\n", - "\u001b[0;31mJSONDecodeError\u001b[0m: Expecting value: line 1 column 1 (char 0)" + "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/start?user=Ragna HTTP/1.1\" 200 OK\n", + "{'closed': False,\n", + " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", + " 'messages': [{'content': 'How can I help you with the documents?',\n", + " 'id': '09df1d1c-3c8f-4d42-ad68-fa8a4d833d85',\n", + " 'role': 'system',\n", + " 'sources': [],\n", + " 'timestamp': '2023-10-04T21:43:51.941124'}],\n", + " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", + " 'documents': [{'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'name': 'document2.txt'},\n", + " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " 'name': 'document1.txt'},\n", + " {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'name': 'document0.txt'}],\n", + " 'name': 'Ragna REST API example',\n", + " 'params': {},\n", + " 'source_storage': 'Ragna/DemoSourceStorage'},\n", + " 'started': True}\n" ] } ], "source": [ "CHAT_ID = chat[\"id\"]\n", "\n", - "response = await client.post(f\"{URL}/chat/{CHAT_ID}/start\", params={\"user\": USER})\n", + "response = await client.post(f\"{URL}/chats/{CHAT_ID}/start\", params={\"user\": USER})\n", "chat = response.json()\n", "pprint(chat)" ] @@ -486,13 +441,54 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "id": "2bf0c0fe-64f9-40d6-b1c7-bd9b17e859d1", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/answer?user=Ragna&prompt=What%20is%20Ragna%3F HTTP/1.1\" 200 OK\n", + "{'content': \"I can't really help you with your prompt:\\n\"\n", + " '\\n'\n", + " '> What is Ragna?\\n'\n", + " '\\n'\n", + " 'I can at least show you the sources that I was given:\\n'\n", + " '\\n'\n", + " '- document2.txt: This is content of document 2\\n'\n", + " '- document1.txt: This is content of document 1\\n'\n", + " '- document0.txt: This is content of document 0',\n", + " 'id': '822615fa-c05c-4b03-8d25-60b8687ba568',\n", + " 'role': 'assistant',\n", + " 'sources': [{'document': {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'name': 'document2.txt'},\n", + " 'id': '6bfba599-7fdd-4a36-8f43-c1cb552f83eb',\n", + " 'location': ''},\n", + " {'document': {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " 'name': 'document1.txt'},\n", + " 'id': 'e7ebeac9-891b-4867-b35b-62d571a50837',\n", + " 'location': ''},\n", + " {'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'name': 'document0.txt'},\n", + " 'id': '37b9a731-032c-4bd6-9286-3693b1149952',\n", + " 'location': ''}],\n", + " 'timestamp': '2023-10-04T21:43:51.953583'}\n", + "I can't really help you with your prompt:\n", + "\n", + "> What is Ragna?\n", + "\n", + "I can at least show you the sources that I was given:\n", + "\n", + "- document2.txt: This is content of document 2\n", + "- document1.txt: This is content of document 1\n", + "- document0.txt: This is content of document 0\n" + ] + } + ], "source": [ "response = await client.post(\n", - " f\"{URL}/chat/{CHAT_ID}/answer\", params={\"user\": USER, \"prompt\": \"What is Ragna?\"}\n", + " f\"{URL}/chats/{CHAT_ID}/answer\", params={\"user\": USER, \"prompt\": \"What is Ragna?\"}\n", ")\n", "answer = response.json()\n", "pprint(answer[\"message\"])\n", @@ -509,12 +505,68 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "f021fe9b-ae71-4101-aa18-22277ebab8d2", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/close?user=Ragna HTTP/1.1\" 200 OK\n", + "{'closed': True,\n", + " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", + " 'messages': [{'content': 'How can I help you with the documents?',\n", + " 'id': '09df1d1c-3c8f-4d42-ad68-fa8a4d833d85',\n", + " 'role': 'system',\n", + " 'sources': [],\n", + " 'timestamp': '2023-10-04T21:43:51.941124'},\n", + " {'content': 'What is Ragna?',\n", + " 'id': '362e63d9-16d6-4c0e-a3d7-26cab7ddd980',\n", + " 'role': 'user',\n", + " 'sources': [],\n", + " 'timestamp': '2023-10-04T21:43:51.951584'},\n", + " {'content': \"I can't really help you with your prompt:\\n\"\n", + " '\\n'\n", + " '> What is Ragna?\\n'\n", + " '\\n'\n", + " 'I can at least show you the sources that I was '\n", + " 'given:\\n'\n", + " '\\n'\n", + " '- document2.txt: This is content of document 2\\n'\n", + " '- document1.txt: This is content of document 1\\n'\n", + " '- document0.txt: This is content of document 0',\n", + " 'id': '822615fa-c05c-4b03-8d25-60b8687ba568',\n", + " 'role': 'assistant',\n", + " 'sources': [{'document': {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'name': 'document2.txt'},\n", + " 'id': '6bfba599-7fdd-4a36-8f43-c1cb552f83eb',\n", + " 'location': ''},\n", + " {'document': {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " 'name': 'document1.txt'},\n", + " 'id': 'e7ebeac9-891b-4867-b35b-62d571a50837',\n", + " 'location': ''},\n", + " {'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'name': 'document0.txt'},\n", + " 'id': '37b9a731-032c-4bd6-9286-3693b1149952',\n", + " 'location': ''}],\n", + " 'timestamp': '2023-10-04T21:43:51.953583'}],\n", + " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", + " 'documents': [{'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'name': 'document2.txt'},\n", + " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " 'name': 'document1.txt'},\n", + " {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'name': 'document0.txt'}],\n", + " 'name': 'Ragna REST API example',\n", + " 'params': {},\n", + " 'source_storage': 'Ragna/DemoSourceStorage'},\n", + " 'started': True}\n" + ] + } + ], "source": [ - "response = await client.post(f\"{URL}/chat/{CHAT_ID}/close\", params={\"user\": USER})\n", + "response = await client.post(f\"{URL}/chats/{CHAT_ID}/close\", params={\"user\": USER})\n", "chat = response.json()\n", "pprint(chat)" ] diff --git a/examples/s3_documents/s3_documents.ipynb b/examples/s3_documents/s3_documents.ipynb index 0342c9b6..f70df0ec 100644 --- a/examples/s3_documents/s3_documents.ipynb +++ b/examples/s3_documents/s3_documents.ipynb @@ -419,7 +419,25 @@ "execution_count": 5, "id": "ae31bfb0-01f6-47c8-9733-ba1204b153d8", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:56670 - \"GET /health HTTP/1.1\" 200 OK\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "INFO: Started server process [40064]\n", + "INFO: Waiting for application startup.\n", + "INFO: Application startup complete.\n", + "INFO: Uvicorn running on http://127.0.0.1:31476 (Press CTRL+C to quit)\n" + ] + } + ], "source": [ "import contextlib\n", "import subprocess\n", @@ -473,10 +491,11 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImQxYTgyYmM1LTFlODItNDllZi1hODI4LTNlMTIzZjE1OTIwZiIsImV4cCI6MTY5NTg0NDQ4NC43MTYyMDc1fQ._Q8sUpMZg0NeJ9MOJwZ6qZ5QfBm_jD1Ziwy7PK9Vmsw'},\n", - " 'document': {'id': 'd1a82bc5-1e82-49ef-a828-3e123f15920f',\n", + "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", + "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImViY2E1NTZhLTQxZTgtNDQ4Ni1iOGI3LThkMmM3OGFlOTc1OCIsImV4cCI6MTY5NjQ1NjA2NS41ODAyMDUyfQ.MwFbgwv2nTcetRl9osgjnsnplvNC8fIxHBSBadFX7tg'},\n", + " 'document': {'id': 'ebca556a-41e8-4486-b8b7-8d2c78ae9758',\n", " 'name': 'document0.txt'},\n", - " 'url': 'http://127.0.0.1:31476/document/upload'}\n" + " 'url': 'http://127.0.0.1:31476/document'}\n" ] } ], @@ -485,9 +504,7 @@ "\n", "path = paths[0]\n", "\n", - "response = await client.get(\n", - " f\"{URL}/document/new\", params={\"user\": USER, \"name\": path.name}\n", - ")\n", + "response = await client.get(f\"{URL}/document\", params={\"user\": USER, \"name\": path.name})\n", "document_info = response.json()\n", "document = document_info[\"document\"]\n", "pprint(document_info)" @@ -498,7 +515,15 @@ "execution_count": 8, "id": "fdc5c6f4-bcf8-4902-b441-768b6e3a66bc", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n" + ] + } + ], "source": [ "response = await client.post(\n", " document_info[\"url\"],\n", @@ -514,12 +539,22 @@ "id": "f0d769ff-d721-460a-8888-5b033eb4a909", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n" + ] + }, { "data": { "text/plain": [ - "[{'id': 'd1a82bc5-1e82-49ef-a828-3e123f15920f', 'name': 'document0.txt'},\n", - " {'id': '4258a48a-1e79-4f8f-a393-1d91bdf989aa', 'name': 'document1.txt'},\n", - " {'id': 'a30ef295-d420-4076-82a0-145b4315f2c4', 'name': 'document2.txt'}]" + "[{'id': 'ebca556a-41e8-4486-b8b7-8d2c78ae9758', 'name': 'document0.txt'},\n", + " {'id': 'a723700a-8401-4b09-b727-ead6797f9725', 'name': 'document1.txt'},\n", + " {'id': '9028a6c3-2a07-436e-a953-7b92a393113b', 'name': 'document2.txt'}]" ] }, "execution_count": 9, @@ -532,9 +567,7 @@ "\n", "for path in paths[1:]:\n", " document_info = (\n", - " await client.get(\n", - " f\"{URL}/document/new\", params={\"user\": USER, \"name\": path.name}\n", - " )\n", + " await client.get(f\"{URL}/document\", params={\"user\": USER, \"name\": path.name})\n", " ).json()\n", " documents.append(document_info[\"document\"])\n", " await client.post(\n", @@ -552,10 +585,17 @@ "id": "ae364aec-e63b-4f00-8b55-217148d6df24", "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO: 127.0.0.1:56670 - \"POST /chats?user=Ragna HTTP/1.1\" 200 OK\n" + ] + }, { "data": { "text/plain": [ - "'http://127.0.0.1:31476/chat/0352c50a-cec0-4708-bb7c-4eec76e0a5f0'" + "'http://127.0.0.1:31476/chats/98a82c99-a221-4ff5-8b24-cc42ced36477'" ] }, "execution_count": 10, @@ -566,18 +606,19 @@ "source": [ "chat = (\n", " await client.post(\n", - " f\"{URL}/chat/new\",\n", + " f\"{URL}/chats\",\n", " params={\"user\": USER},\n", " json={\n", " \"name\": \"Ragna REST API example\",\n", " \"document_ids\": [d[\"id\"] for d in documents],\n", " \"source_storage\": \"Ragna/DemoSourceStorage\",\n", " \"assistant\": \"Ragna/DemoAssistant\",\n", + " \"params\": {},\n", " },\n", " )\n", ").json()\n", "\n", - "CHAT_URL = f\"{URL}/chat/{chat['id']}\"\n", + "CHAT_URL = f\"{URL}/chats/{chat['id']}\"\n", "CHAT_URL" ] }, @@ -591,14 +632,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "I just pretend to be an LLM. I can't actually help with your prompt:\n", + "INFO: 127.0.0.1:56670 - \"POST /chats/98a82c99-a221-4ff5-8b24-cc42ced36477/start?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:56670 - \"POST /chats/98a82c99-a221-4ff5-8b24-cc42ced36477/answer?user=Ragna&prompt=Hello%20World%21 HTTP/1.1\" 200 OK\n", + "I can't really help you with your prompt:\n", "\n", "> Hello World!\n", "\n", - "I was given the following sources:\n", + "I can at least show you the sources that I was given:\n", "\n", - "- document1.txt: This is content of document 1 located on S3\n", "- document2.txt: This is content of document 2 located on S3\n", + "- document1.txt: This is content of document 1 located on S3\n", "- document0.txt: This is content of document 0 located on S3\n" ] } diff --git a/pyproject.toml b/pyproject.toml index 80855718..7415b4f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,6 @@ classifiers = [ ] requires-python = ">=3.9" dependencies = [ - "cloudpickle", "huey", "importlib_metadata >=4.6; python_version < '3.10'", "packaging", @@ -34,6 +33,7 @@ console = [ "rich", ] api = [ + "aiofiles", "fastapi", "PyJWT", "python-multipart", @@ -52,6 +52,7 @@ builtin-components = [ ] complete = [ "PyJWT", + "aiofiles", "chromadb>=0.4", "fastapi", "httpx", diff --git a/ragna/_api.py b/ragna/_api.py deleted file mode 100644 index 914c8cca..00000000 --- a/ragna/_api.py +++ /dev/null @@ -1,247 +0,0 @@ -import datetime -import functools -from typing import Annotated -from uuid import UUID - -import aiofiles -from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile - -from pydantic import BaseModel, Field, HttpUrl, validator - -import ragna - -from ragna.core import Chat, LocalDocument, MessageRole, RagnaException, RagnaId - - -class DocumentModel(BaseModel): - id: RagnaId - name: str - - -class DocumentUploadInfoModel(BaseModel): - url: HttpUrl - data: dict - document: DocumentModel - - -class SourceModel(BaseModel): - document_id: RagnaId - document_name: str - location: str - - @classmethod - def from_source(cls, source): - return cls( - id=source.id, - document_id=source.document_id, - document_name=source.document_name, - location=source.location, - ) - - -class MessageModel(BaseModel): - id: RagnaId - role: MessageRole - content: str - sources: list[SourceModel] - timestamp: datetime.datetime - - @classmethod - def from_message(cls, message): - return cls( - id=message.id, - role=message.role, - content=message.content, - sources=[SourceModel.from_source(s) for s in message.sources], - timestamp=message.timestamp, - ) - - -class ChatMetadataModel(BaseModel): - name: str - # For some reason list[RagnaId] does not work and will get parsed into list[UUID]. - # Thus, we use a validator below to do the conversion. - document_ids: list[UUID] - source_storage: str - assistant: str - params: dict = Field(default_factory=dict) - - @validator("document_ids") - def uuid_to_ragna_id(cls, document_ids: list[UUID]) -> list[RagnaId]: - return [RagnaId.from_uuid(u) for u in document_ids] - - @classmethod - def from_chat(cls, chat): - return cls( - name=chat.name, - document_ids=[d.id for d in chat.documents], - source_storage=str(chat.source_storage), - assistant=str(chat.assistant), - params=chat.params, - ) - - -class ChatModel(BaseModel): - id: RagnaId - metadata: ChatMetadataModel - messages: list[MessageModel] - started: bool - closed: bool - - @classmethod - def from_chat(cls, chat): - return cls( - id=chat.id, - metadata=ChatMetadataModel.from_chat(chat), - messages=[MessageModel.from_message(m) for m in chat.messages], - started=chat._started, - closed=chat._closed, - ) - - -class AnswerOutputModel(BaseModel): - message: MessageModel - chat: ChatModel - - -class ComponentsModel(BaseModel): - source_storages: list[str] - assistants: list[str] - - -def process_ragna_exception(afn): - @functools.wraps(afn) - async def wrapper(*args, **kwargs): - try: - return await afn(*args, **kwargs) - except RagnaException as exc: - if exc.http_detail is RagnaException.EVENT: - detail = exc.event - elif exc.http_detail is RagnaException.MESSAGE: - detail = str(exc) - else: - detail = exc.http_detail - raise HTTPException( - status_code=exc.http_status_code, detail=detail - ) from None - except Exception: - raise - - return wrapper - - -def api(rag): - app = FastAPI() - - @app.get("/health") - @process_ragna_exception - async def health() -> str: - return ragna.__version__ - - async def _authorize_user(user: str) -> str: - # FIXME: implement auth here - return user - - UserDependency = Annotated[str, Depends(_authorize_user)] - - @app.get("/chats") - @process_ragna_exception - async def get_chats(user: UserDependency) -> list[ChatModel]: - return [ChatModel.from_chat(chat) for chat in rag._get_chats(user=user)] - - @app.get("/components") - @process_ragna_exception - async def get_components(_: UserDependency) -> ComponentsModel: - return ComponentsModel( - source_storages=list(rag.config.registered_source_storage_classes), - assistants=list(rag.config.registered_assistant_classes), - ) - - @app.get("/document/new") - @process_ragna_exception - async def get_document_upload_info( - user: UserDependency, - name: str, - ) -> DocumentUploadInfoModel: - id = RagnaId.make() - url, data, metadata = await rag.config.document_class.get_upload_info( - config=rag.config, user=user, id=id, name=name - ) - rag._add_document(user=user, id=id, name=name, metadata=metadata) - return DocumentUploadInfoModel( - url=url, data=data, document=DocumentModel(id=id, name=name) - ) - - @app.post("/document/upload") - @process_ragna_exception - async def upload_document( - token: Annotated[str, Form()], file: UploadFile - ) -> DocumentModel: - if not issubclass(rag.config.document_class, LocalDocument): - raise HTTPException( - status_code=400, - detail="Ragna configuration does not support local upload", - ) - - user, id = rag.config.document_class._decode_upload_token( - token, secret=rag.config.upload_token_secret - ) - document = rag._get_document(user=user, id=id) - - document.path.parent.mkdir(parents=True, exist_ok=True) - async with aiofiles.open(document.path, "wb") as document_file: - while content := await file.read(1024): - await document_file.write(content) - - return DocumentModel(id=id, name=document.name) - - @app.post("/chat/new") - @process_ragna_exception - async def new_chat( - *, user: UserDependency, chat_metadata: ChatMetadataModel - ) -> ChatModel: - return ChatModel.from_chat( - await rag.new_chat( - user=user, - name=chat_metadata.name, - documents=chat_metadata.document_ids, - source_storage=chat_metadata.source_storage, - assistant=chat_metadata.assistant, - **chat_metadata.params, - ) - ) - - async def _get_id(id: UUID) -> RagnaId: - return RagnaId.from_uuid(id) - - IdDependency = Annotated[RagnaId, Depends(_get_id)] - - async def _get_chat(*, user: UserDependency, id: IdDependency) -> Chat: - return rag._get_chat(user=user, id=id) - - ChatDependency = Annotated[Chat, Depends(_get_chat, use_cache=False)] - - @app.get("/chat/{id}") - @process_ragna_exception - async def get_chat(chat: ChatDependency) -> ChatModel: - return ChatModel.from_chat(chat) - - @app.post("/chat/{id}/start") - @process_ragna_exception - async def start_chat(chat: ChatDependency) -> ChatModel: - return ChatModel.from_chat(await chat.start()) - - @app.post("/chat/{id}/close") - @process_ragna_exception - async def close_chat(chat: ChatDependency) -> ChatModel: - return ChatModel.from_chat(await chat.close()) - - @app.post("/chat/{id}/answer") - @process_ragna_exception - async def answer(chat: ChatDependency, prompt: str) -> AnswerOutputModel: - return AnswerOutputModel( - message=MessageModel.from_message(await chat.answer(prompt)), - chat=ChatModel.from_chat(chat), - ) - - return app diff --git a/ragna/_api/__init__.py b/ragna/_api/__init__.py new file mode 100644 index 00000000..07f53c76 --- /dev/null +++ b/ragna/_api/__init__.py @@ -0,0 +1 @@ +from .core import api diff --git a/ragna/_api/core.py b/ragna/_api/core.py new file mode 100644 index 00000000..18ba376c --- /dev/null +++ b/ragna/_api/core.py @@ -0,0 +1,150 @@ +import functools +from typing import Annotated +from uuid import UUID + +import aiofiles +from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile + +import ragna + +from ragna.core import Chat, LocalDocument, RagnaException, RagnaId + +from . import schemas + + +def process_ragna_exception(afn): + @functools.wraps(afn) + async def wrapper(*args, **kwargs): + try: + return await afn(*args, **kwargs) + except RagnaException as exc: + if exc.http_detail is RagnaException.EVENT: + detail = exc.event + elif exc.http_detail is RagnaException.MESSAGE: + detail = str(exc) + else: + detail = exc.http_detail + raise HTTPException( + status_code=exc.http_status_code, detail=detail + ) from None + except Exception: + raise + + return wrapper + + +def api(rag): + app = FastAPI() + + @app.get("/health") + @process_ragna_exception + async def health() -> str: + return ragna.__version__ + + async def _authorize_user(user: str) -> str: + # FIXME: implement auth here + return user + + UserDependency = Annotated[str, Depends(_authorize_user)] + + @app.get("/components") + @process_ragna_exception + async def get_components(_: UserDependency) -> schemas.Components: + return schemas.Components( + source_storages=list(rag.config.registered_source_storage_classes), + assistants=list(rag.config.registered_assistant_classes), + ) + + @app.get("/document") + @process_ragna_exception + async def get_document_upload_info( + user: UserDependency, + name: str, + ) -> schemas.DocumentUploadInfo: + id = RagnaId.make() + url, data, metadata = await rag.config.document_class.get_upload_info( + config=rag.config, user=user, id=id, name=name + ) + rag._add_document(user=user, id=id, name=name, metadata=metadata) + return schemas.DocumentUploadInfo( + url=url, data=data, document=schemas.Document(id=id, name=name) + ) + + @app.post("/document") + @process_ragna_exception + async def upload_document( + token: Annotated[str, Form()], file: UploadFile + ) -> schemas.Document: + if not issubclass(rag.config.document_class, LocalDocument): + raise HTTPException( + status_code=400, + detail="Ragna configuration does not support local upload", + ) + + user, id = rag.config.document_class._decode_upload_token( + token, secret=rag.config.upload_token_secret + ) + document = rag._get_document(user=user, id=id) + + document.path.parent.mkdir(parents=True, exist_ok=True) + async with aiofiles.open(document.path, "wb") as document_file: + while content := await file.read(1024): + await document_file.write(content) + + return schemas.Document(id=id, name=document.name) + + @app.post("/chats") + @process_ragna_exception + async def create_chat( + *, user: UserDependency, chat_metadata: schemas.ChatMetadataCreate + ) -> schemas.Chat: + return schemas.Chat.from_core_chat( + await rag.new_chat( + user=user, + name=chat_metadata.name, + documents=chat_metadata.document_ids, + source_storage=chat_metadata.source_storage, + assistant=chat_metadata.assistant, + **chat_metadata.params, + ) + ) + + @app.get("/chats") + @process_ragna_exception + async def get_chats(user: UserDependency) -> list[schemas.Chat]: + return [schemas.Chat.from_core_chat(chat) for chat in rag._get_chats(user=user)] + + async def _get_id(id: UUID) -> RagnaId: + return RagnaId.from_uuid(id) + + IdDependency = Annotated[RagnaId, Depends(_get_id)] + + async def _get_chat(*, user: UserDependency, id: IdDependency) -> Chat: + return rag._get_chat(user=user, id=id) + + ChatDependency = Annotated[Chat, Depends(_get_chat, use_cache=False)] + + @app.get("/chats/{id}") + @process_ragna_exception + async def get_chat(chat: ChatDependency) -> schemas.Chat: + return schemas.Chat.from_core_chat(chat) + + @app.post("/chats/{id}/start") + @process_ragna_exception + async def start_chat(chat: ChatDependency) -> schemas.Chat: + return schemas.Chat.from_core_chat(await chat.start()) + + @app.post("/chats/{id}/close") + @process_ragna_exception + async def close_chat(chat: ChatDependency) -> schemas.Chat: + return schemas.Chat.from_core_chat(await chat.close()) + + @app.post("/chats/{id}/answer") + @process_ragna_exception + async def answer(chat: ChatDependency, prompt: str) -> schemas.AnswerOutput: + return schemas.AnswerOutput( + message=schemas.Message.from_core_message(await chat.answer(prompt)), + chat=schemas.Chat.from_core_chat(chat), + ) + + return app diff --git a/ragna/_api/schemas.py b/ragna/_api/schemas.py new file mode 100644 index 00000000..a6a16beb --- /dev/null +++ b/ragna/_api/schemas.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +from uuid import UUID + +from pydantic import BaseModel, HttpUrl, validator + +import ragna + +import ragna.core + + +class Document(BaseModel): + id: ragna.core.RagnaId + name: str + + @classmethod + def from_core_document(cls, document: ragna.core.Document) -> Document: + return cls( + id=document.id, + name=document.name, + ) + + +class DocumentUploadInfo(BaseModel): + url: HttpUrl + data: dict + document: Document + + +class Source(BaseModel): + id: ragna.core.RagnaId + document: Document + location: str + + @classmethod + def from_core_source(cls, source: ragna.core.Source) -> Source: + return cls( + id=source.id, + document=Document(id=source.document_id, name=source.document_name), + location=source.location, + ) + + +class Message(BaseModel): + id: ragna.core.RagnaId + role: ragna.core.MessageRole + content: str + sources: list[Source] + timestamp: datetime.datetime + + @classmethod + def from_core_message(cls, message: ragna.core.Message) -> Message: + return cls( + id=message.id, + role=message.role, + content=message.content, + sources=[Source.from_core_source(s) for s in message.sources], + timestamp=message.timestamp, + ) + + +class ChatMetadataBase(BaseModel): + name: str + source_storage: str + assistant: str + params: dict + + +class ChatMetadataCreate(ChatMetadataBase): + # For some reason list[RagnaId] does not work and will get parsed into list[UUID]. + # Thus, we use a validator below to do the conversion. + document_ids: list[UUID] + + @validator("document_ids") + def uuid_to_ragna_id(cls, document_ids: list[UUID]) -> list[ragna.core.RagnaId]: + return [ragna.core.RagnaId.from_uuid(u) for u in document_ids] + + +class ChatMetadata(ChatMetadataBase): + documents: list[Document] + + @classmethod + def from_core_chat(cls, chat: ragna.core.Chat) -> ChatMetadata: + return cls( + name=chat.name, + documents=[Document.from_core_document(d) for d in chat.documents], + source_storage=str(chat.source_storage), + assistant=str(chat.assistant), + params=chat.params, + ) + + +class Chat(BaseModel): + id: ragna.core.RagnaId + metadata: ChatMetadata + messages: list[Message] + started: bool + closed: bool + + @classmethod + def from_core_chat(cls, chat: ragna.core.Chat) -> Chat: + return cls( + id=chat.id, + metadata=ChatMetadata.from_core_chat(chat), + messages=[Message.from_core_message(m) for m in chat.messages], + started=chat._started, + closed=chat._closed, + ) + + +class AnswerOutput(BaseModel): + message: Message + chat: Chat + + +class Components(BaseModel): + source_storages: list[str] + assistants: list[str] diff --git a/ragna/core/_document.py b/ragna/core/_document.py index b553056d..6b376ebf 100644 --- a/ragna/core/_document.py +++ b/ragna/core/_document.py @@ -91,7 +91,7 @@ async def get_upload_info( import jwt - url = f"{config.ragna_api_url}/document/upload" + url = f"{config.ragna_api_url}/document" data = { "token": jwt.encode( payload={ diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index c776fa3f..0d5b6da4 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -72,9 +72,11 @@ async def new_chat( return chat - def _parse_documents(self, document: Sequence[Any], *, user: str) -> list[Document]: + def _parse_documents( + self, documents: Sequence[Any], *, user: str + ) -> list[Document]: documents_ = [] - for document in document: + for document in documents: if isinstance(document, RagnaId): document = self._get_document(id=document, user=user) else: diff --git a/scripts/add_chats.py b/scripts/add_chats.py index 2f701144..29e358bd 100644 --- a/scripts/add_chats.py +++ b/scripts/add_chats.py @@ -18,7 +18,7 @@ def main(): for i in range(5): name = f"document{i}.txt" document_info = client.get( - f"{url}/document/new", params={"user": user, "name": name} + f"{url}/document", params={"user": user, "name": name} ).json() client.post( document_info["url"], @@ -30,68 +30,71 @@ def main(): ## chat 1 chat = client.post( - f"{url}/chat/new", + f"{url}/chats", params={"user": user}, json={ "name": "Test chat", "document_ids": document_ids[:2], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", + "params": {}, }, ).json() client.post( - f"{url}/chat/{chat['id']}/start", + f"{url}/chats/{chat['id']}/start", params={"user": user}, ) client.post( - f"{url}/chat/{chat['id']}/answer", + f"{url}/chats/{chat['id']}/answer", params={"user": user, "prompt": "Hello!"}, ) ## chat 2 chat = client.post( - f"{url}/chat/new", + f"{url}/chats", params={"user": user}, json={ "name": f"Chat {datetime.datetime.now():%x %X}", "document_ids": document_ids[2:4], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", + "params": {}, }, ).json() client.post( - f"{url}/chat/{chat['id']}/start", + f"{url}/chats/{chat['id']}/start", params={"user": user}, ) for _ in range(3): client.post( - f"{url}/chat/{chat['id']}/answer", + f"{url}/chats/{chat['id']}/answer", params={"user": user, "prompt": "What is Ragna? Please, I need to know!"}, ) ## chat 3 chat = client.post( - f"{url}/chat/new", + f"{url}/chats", params={"user": user}, json={ "name": "Really long chat name that likely needs to be truncated somehow. If you can read this, truncating failed :boom:", "document_ids": [document_ids[i] for i in [0, 2, 4]], "source_storage": "Ragna/DemoSourceStorage", "assistant": "Ragna/DemoAssistant", + "params": {}, }, ).json() client.post( - f"{url}/chat/{chat['id']}/start", + f"{url}/chats/{chat['id']}/start", params={"user": user}, ) client.post( - f"{url}/chat/{chat['id']}/answer", + f"{url}/chats/{chat['id']}/answer", params={"user": user, "prompt": "Hello!"}, ) client.post( - f"{url}/chat/{chat['id']}/answer", + f"{url}/chats/{chat['id']}/answer", params={ "user": user, "prompt": "Ok, in that case show me some pretty markdown!", diff --git a/tests/test_importable.py b/tests/test_importable.py index d3e41f88..de09a4bb 100644 --- a/tests/test_importable.py +++ b/tests/test_importable.py @@ -13,10 +13,12 @@ def main(): else: path = path.with_suffix("") - if path.name.startswith("_"): + path = path.relative_to(PROJECT_ROOT) + + if any(part.startswith("_") for part in path.parts): continue - name = path.relative_to(PROJECT_ROOT).as_posix().replace("/", ".") + name = path.as_posix().replace("/", ".") try: importlib.import_module(name) From dede905ecda5c30b09be3c553ef35aea4d05581a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 10 Oct 2023 09:46:43 +0200 Subject: [PATCH 2/3] add LanceDB as SourceStorage (#66) --- examples/python_api/python_api.ipynb | 22 +++-- pyproject.toml | 4 + ragna/source_storage/__init__.py | 1 + ragna/source_storage/_lancedb.py | 117 +++++++++++++++++++++++++++ 4 files changed, 137 insertions(+), 7 deletions(-) create mode 100644 ragna/source_storage/_lancedb.py diff --git a/examples/python_api/python_api.ipynb b/examples/python_api/python_api.ipynb index a13ab72a..1b2a3b1a 100644 --- a/examples/python_api/python_api.ipynb +++ b/examples/python_api/python_api.ipynb @@ -21,7 +21,7 @@ { "data": { "text/plain": [ - "'0.1.dev21+g7f8ec2d.d20230925071852'" + "'0.1.dev29+g03721c7.d20231009084220'" ] }, "execution_count": 2, @@ -44,7 +44,7 @@ { "data": { "text/plain": [ - "Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=, upload_token_secret='245be1b4c5656eefec1ac16d4a856f189e9e35d35aa014eb7426573e5b86c03f', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': }, registered_assistant_classes={'Ragna/DemoAssistant': })" + "Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=, upload_token_secret='d9d5c32fcb2d4f3a3a36cb5d95c8147ab2e9c664ad2a3976f3e2eb9ef80b53c9', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': }, registered_assistant_classes={'Ragna/DemoAssistant': })" ] }, "execution_count": 3, @@ -98,7 +98,11 @@ " OpenaiGpt35Turbo16kAssistant,\n", " OpenaiGpt4Assistant,\n", ")\n", - "from ragna.source_storage import ChromaSourceStorage, RagnaDemoSourceStorage\n", + "from ragna.source_storage import (\n", + " ChromaSourceStorage,\n", + " RagnaDemoSourceStorage,\n", + " LanceDBSourceStorage,\n", + ")\n", "\n", "rag = Rag(demo_config)\n", "\n", @@ -123,8 +127,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", - " ('Chroma', 'OpenAI/gpt-4'): }\n" + "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", + " ('Chroma', 'OpenAI/gpt-4'): ,\n", + " ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", + " ('LanceDB', 'OpenAI/gpt-4'): }\n" ] } ], @@ -133,7 +139,7 @@ "import asyncio\n", "from pprint import pprint\n", "\n", - "source_storages = [ChromaSourceStorage]\n", + "source_storages = [ChromaSourceStorage, LanceDBSourceStorage]\n", "assistants = [OpenaiGpt35Turbo16kAssistant, OpenaiGpt4Assistant]\n", "\n", "\n", @@ -168,7 +174,9 @@ "output_type": "stream", "text": [ "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source RAG (Response Analysis Graph) orchestration app. It is designed to help users create conversational AI applications by providing a framework for managing and orchestrating the flow of conversations. Ragna allows developers to define conversation flows, handle user inputs, and generate dynamic responses based on predefined rules and logic. It is built on top of the Rasa framework and provides additional features and functionalities to simplify the development process.,\n", - " ('Chroma', 'OpenAI/gpt-4'): Ragna is an open-source RAG orchestration app.}\n" + " ('Chroma', 'OpenAI/gpt-4'): Ragna is an open-source RAG orchestration app.,\n", + " ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source rag orchestration app. It is a software application that allows users to create and arrange musical compositions using ragtime music. It is designed to be accessible and customizable for musicians and composers.,\n", + " ('LanceDB', 'OpenAI/gpt-4'): Ragna is an open-source rag orchestration app.}\n" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 7415b4f1..133340d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,8 @@ ui = [ builtin-components = [ "chromadb>=0.4", "httpx", + "lancedb>=0.2", + "sentence-transformers", "tiktoken", ] complete = [ @@ -56,8 +58,10 @@ complete = [ "chromadb>=0.4", "fastapi", "httpx", + "lancedb>=0.2", "python-multipart", "rich", + "sentence-transformers", "tiktoken", "uvicorn", ] diff --git a/ragna/source_storage/__init__.py b/ragna/source_storage/__init__.py index c3a3b529..125207f6 100644 --- a/ragna/source_storage/__init__.py +++ b/ragna/source_storage/__init__.py @@ -1,2 +1,3 @@ from ._chroma import ChromaSourceStorage from ._demo import RagnaDemoSourceStorage +from ._lancedb import LanceDBSourceStorage diff --git a/ragna/source_storage/_lancedb.py b/ragna/source_storage/_lancedb.py new file mode 100644 index 00000000..7d5789e9 --- /dev/null +++ b/ragna/source_storage/_lancedb.py @@ -0,0 +1,117 @@ +from ragna.core import ( + Document, + PackageRequirement, + RagnaId, + Requirement, + Source, + SourceStorage, +) + +from ragna.utils import chunk_pages, page_numbers_to_str, take_sources_up_to_max_tokens + + +class LanceDBSourceStorage(SourceStorage): + @classmethod + def display_name(cls) -> str: + return "LanceDB" + + @classmethod + def requirements(cls) -> list[Requirement]: + return [ + PackageRequirement("lancedb>=0.2"), + # FIXME: re-add this after https://github.com/apache/arrow/issues/38167 is + # resolved. + # PackageRequirement("pyarrow"), + PackageRequirement("sentence-transformers"), + ] + + def __init__(self, config): + super().__init__(config) + + import lancedb + import pyarrow as pa + from sentence_transformers import SentenceTransformer + + self._db = lancedb.connect(config.local_cache_root / "lancedb") + self._model = SentenceTransformer("paraphrase-albert-small-v2") + self._schema = pa.schema( + [ + pa.field("document_id", pa.string()), + pa.field("document_name", pa.string()), + pa.field("page_numbers", pa.string()), + pa.field("text", pa.string()), + pa.field( + self._VECTOR_COLUMN_NAME, + pa.list_(pa.float32(), self._model[-1].word_embedding_dimension), + ), + pa.field("num_tokens", pa.int32()), + ] + ) + + def _embed(self, batch): + return [self._model.encode(sentence) for sentence in batch] + + _VECTOR_COLUMN_NAME = "embedded_text" + + def store( + self, + documents: list[Document], + *, + chat_id: RagnaId, + chunk_size: int = 500, + chunk_overlap: int = 250, + ) -> None: + table = self._db.create_table(name=str(chat_id), schema=self._schema) + + for document in documents: + for chunk in chunk_pages( + document.extract_pages(), + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + tokenizer=self._model.tokenizer, + ): + table.add( + [ + { + "document_id": str(document.id), + "document_name": document.name, + "page_numbers": page_numbers_to_str(chunk.page_numbers), + "text": chunk.text, + self._VECTOR_COLUMN_NAME: self._model.encode(chunk.text), + "num_tokens": chunk.num_tokens, + } + ] + ) + + def retrieve( + self, + prompt: str, + *, + chat_id: RagnaId, + chunk_size: int = 500, + num_tokens: int = 1024, + ) -> list[Source]: + table = self._db.open_table(str(chat_id)) + + # We cannot retrieve source by a maximum number of tokens. Thus, we estimate how + # many sources we have to query. We overestimate by a factor of two to avoid + # retrieving to few sources and needed to query again. + limit = int(num_tokens * 2 / chunk_size) + results = table.search().limit(limit).to_arrow() + + return list( + take_sources_up_to_max_tokens( + ( + Source( + id=RagnaId.make(), + document_id=RagnaId(result["document_id"]), + document_name=result["document_name"], + location=result["page_numbers"], + content=result["text"], + num_tokens=result["num_tokens"], + ) + for result in results.to_pylist() + ), + max_tokens=num_tokens, + ) + ) From 2970082c3540965c2d62a884fd6b49a353513485 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 10 Oct 2023 14:09:11 +0200 Subject: [PATCH 3/3] fix component loading (#64) --- examples/local_llm/local_llm.ipynb | 2 +- examples/python_api/python_api.ipynb | 22 +++-- examples/rest_api/rest_api.ipynb | 106 +++++++++++------------ examples/s3_documents/s3_documents.ipynb | 62 ++----------- ragna/_api/schemas.py | 4 +- ragna/core/_queue.py | 28 +++--- ragna/core/_rag.py | 8 +- ragna/source_storage/_chroma.py | 2 +- 8 files changed, 103 insertions(+), 131 deletions(-) diff --git a/examples/local_llm/local_llm.ipynb b/examples/local_llm/local_llm.ipynb index 9a675441..6ce93bd2 100644 --- a/examples/local_llm/local_llm.ipynb +++ b/examples/local_llm/local_llm.ipynb @@ -127,7 +127,7 @@ "output_type": "stream", "text": [ "User: What is Ragna?\n", - "Assistant: Ragna is an OSS app for RAG workflows that offers a Python and REST API as well as web UI.\n" + "Assistant: Ragna is an open-source application (OSS) for RAG workflows. It offers a Python and REST API as well as a web UI.\n" ] } ], diff --git a/examples/python_api/python_api.ipynb b/examples/python_api/python_api.ipynb index 1b2a3b1a..9380aad3 100644 --- a/examples/python_api/python_api.ipynb +++ b/examples/python_api/python_api.ipynb @@ -44,7 +44,7 @@ { "data": { "text/plain": [ - "Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=, upload_token_secret='d9d5c32fcb2d4f3a3a36cb5d95c8147ab2e9c664ad2a3976f3e2eb9ef80b53c9', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': }, registered_assistant_classes={'Ragna/DemoAssistant': })" + "Config(local_cache_root=PosixPath('/home/philip/.cache/ragna'), state_database_url='sqlite://', queue_database_url='memory', ragna_api_url='http://127.0.0.1:31476', ragna_ui_url='http://127.0.0.1:31477', document_class=, upload_token_secret='a9e79b896f90c136820e70ced28f30e6cbbaca6d19f52e41c35b154df21a657a', upload_token_ttl=30, registered_source_storage_classes={'Ragna/DemoSourceStorage': }, registered_assistant_classes={'Ragna/DemoAssistant': })" ] }, "execution_count": 3, @@ -127,10 +127,10 @@ "name": "stdout", "output_type": "stream", "text": [ - "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", - " ('Chroma', 'OpenAI/gpt-4'): ,\n", - " ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", - " ('LanceDB', 'OpenAI/gpt-4'): }\n" + "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", + " ('Chroma', 'OpenAI/gpt-4'): ,\n", + " ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): ,\n", + " ('LanceDB', 'OpenAI/gpt-4'): }\n" ] } ], @@ -173,6 +173,18 @@ "name": "stdout", "output_type": "stream", "text": [ + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", + "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\n", + "To disable this warning, you can either:\n", + "\t- Avoid using `tokenizers` before the fork if possible\n", + "\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n", "{('Chroma', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source RAG (Response Analysis Graph) orchestration app. It is designed to help users create conversational AI applications by providing a framework for managing and orchestrating the flow of conversations. Ragna allows developers to define conversation flows, handle user inputs, and generate dynamic responses based on predefined rules and logic. It is built on top of the Rasa framework and provides additional features and functionalities to simplify the development process.,\n", " ('Chroma', 'OpenAI/gpt-4'): Ragna is an open-source RAG orchestration app.,\n", " ('LanceDB', 'OpenAI/gpt-3.5-turbo-16k'): Ragna is an open-source rag orchestration app. It is a software application that allows users to create and arrange musical compositions using ragtime music. It is designed to be accessible and customizable for musicians and composers.,\n", diff --git a/examples/rest_api/rest_api.ipynb b/examples/rest_api/rest_api.ipynb index aacb3f16..7dff2d34 100644 --- a/examples/rest_api/rest_api.ipynb +++ b/examples/rest_api/rest_api.ipynb @@ -44,14 +44,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"GET /health HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:60998 - \"GET /health HTTP/1.1\" 200 OK\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "INFO: Started server process [39690]\n", + "INFO: Started server process [32153]\n", "INFO: Waiting for application startup.\n", "INFO: Application startup complete.\n", "INFO: Uvicorn running on http://127.0.0.1:31476 (Press CTRL+C to quit)\n" @@ -137,7 +137,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"GET /chats?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"GET /chats?user=Ragna HTTP/1.1\" 200 OK\n", "[]\n" ] } @@ -167,7 +167,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"GET /components?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"GET /components?user=Ragna HTTP/1.1\" 200 OK\n", "{'assistants': ['Ragna/DemoAssistant'],\n", " 'source_storages': ['Ragna/DemoSourceStorage']}\n" ] @@ -229,9 +229,9 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", - "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImRiZjdiOWE2LWEzMGItNGEyNS1iNGJhLTE3ODIwNjNhYzJhZiIsImV4cCI6MTY5NjQ1NTg2MS40MjkwMjE0fQ.gjMsP0qZ36hO9fAOFPxeKom-oYsyIyxggG82oP0svoc'},\n", - " 'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + "INFO: 127.0.0.1:60998 - \"GET /document?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", + "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImRlNWJiYjMxLTQ1MDEtNDFhNC05ZDRlLTFiN2EzNjk1Mjk5NSIsImV4cCI6MTY5Njg3Nzk1OS40NTE3NzcyfQ.XC7qKuCZXkbrI7isQ39IqYZMCL_O6oVP4GfI8wqHpdE'},\n", + " 'document': {'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'},\n", " 'url': 'http://127.0.0.1:31476/document'}\n" ] @@ -264,7 +264,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:60998 - \"POST /document HTTP/1.1\" 200 OK\n" ] } ], @@ -295,18 +295,18 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:42394 - \"GET /document?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:42394 - \"POST /document HTTP/1.1\" 200 OK\n" + "INFO: 127.0.0.1:60998 - \"GET /document?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /document HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"GET /document?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /document HTTP/1.1\" 200 OK\n" ] }, { "data": { "text/plain": [ - "[{'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af', 'name': 'document0.txt'},\n", - " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3', 'name': 'document1.txt'},\n", - " {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04', 'name': 'document2.txt'}]" + "[{'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995', 'name': 'document0.txt'},\n", + " {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb', 'name': 'document1.txt'},\n", + " {'id': '336436fa-1b6a-4a29-8821-791cd682a7cb', 'name': 'document2.txt'}]" ] }, "execution_count": 9, @@ -349,16 +349,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"POST /chats?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /chats?user=Ragna HTTP/1.1\" 200 OK\n", "{'closed': False,\n", - " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", + " 'id': '0777396d-9993-47ce-b807-2e9c48364c63',\n", " 'messages': [],\n", " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", - " 'documents': [{'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " 'documents': [{'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'},\n", - " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb',\n", " 'name': 'document1.txt'},\n", - " {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " {'id': '336436fa-1b6a-4a29-8821-791cd682a7cb',\n", " 'name': 'document2.txt'}],\n", " 'name': 'Ragna REST API example',\n", " 'params': {},\n", @@ -401,20 +401,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/start?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /chats/0777396d-9993-47ce-b807-2e9c48364c63/start?user=Ragna HTTP/1.1\" 200 OK\n", "{'closed': False,\n", - " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", + " 'id': '0777396d-9993-47ce-b807-2e9c48364c63',\n", " 'messages': [{'content': 'How can I help you with the documents?',\n", - " 'id': '09df1d1c-3c8f-4d42-ad68-fa8a4d833d85',\n", + " 'id': 'f9b43a27-67ac-4984-9deb-3844ff4a6ea1',\n", " 'role': 'system',\n", " 'sources': [],\n", - " 'timestamp': '2023-10-04T21:43:51.941124'}],\n", + " 'timestamp': '2023-10-09T18:58:50.356957'}],\n", " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", - " 'documents': [{'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'documents': [{'id': '336436fa-1b6a-4a29-8821-791cd682a7cb',\n", " 'name': 'document2.txt'},\n", - " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb',\n", " 'name': 'document1.txt'},\n", - " {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " {'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'}],\n", " 'name': 'Ragna REST API example',\n", " 'params': {},\n", @@ -449,7 +449,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/answer?user=Ragna&prompt=What%20is%20Ragna%3F HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /chats/0777396d-9993-47ce-b807-2e9c48364c63/answer?user=Ragna&prompt=What%20is%20Ragna%3F HTTP/1.1\" 200 OK\n", "{'content': \"I can't really help you with your prompt:\\n\"\n", " '\\n'\n", " '> What is Ragna?\\n'\n", @@ -459,21 +459,21 @@ " '- document2.txt: This is content of document 2\\n'\n", " '- document1.txt: This is content of document 1\\n'\n", " '- document0.txt: This is content of document 0',\n", - " 'id': '822615fa-c05c-4b03-8d25-60b8687ba568',\n", + " 'id': '00f34df9-5e74-4588-8808-e1e11987db7e',\n", " 'role': 'assistant',\n", - " 'sources': [{'document': {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'sources': [{'document': {'id': '336436fa-1b6a-4a29-8821-791cd682a7cb',\n", " 'name': 'document2.txt'},\n", - " 'id': '6bfba599-7fdd-4a36-8f43-c1cb552f83eb',\n", + " 'id': 'a07d9305-860b-45c7-9997-bdf47b7dfe58',\n", " 'location': ''},\n", - " {'document': {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " {'document': {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb',\n", " 'name': 'document1.txt'},\n", - " 'id': 'e7ebeac9-891b-4867-b35b-62d571a50837',\n", + " 'id': '1fac7064-f8c0-4a85-9e52-06ade4bfff86',\n", " 'location': ''},\n", - " {'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " {'document': {'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'},\n", - " 'id': '37b9a731-032c-4bd6-9286-3693b1149952',\n", + " 'id': '8e5e94d7-35ab-4426-b369-88e04ef19e80',\n", " 'location': ''}],\n", - " 'timestamp': '2023-10-04T21:43:51.953583'}\n", + " 'timestamp': '2023-10-09T18:58:50.368654'}\n", "I can't really help you with your prompt:\n", "\n", "> What is Ragna?\n", @@ -513,19 +513,19 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:42394 - \"POST /chats/f68e3999-444c-4a92-8d25-0f80dd3b2219/close?user=Ragna HTTP/1.1\" 200 OK\n", + "INFO: 127.0.0.1:60998 - \"POST /chats/0777396d-9993-47ce-b807-2e9c48364c63/close?user=Ragna HTTP/1.1\" 200 OK\n", "{'closed': True,\n", - " 'id': 'f68e3999-444c-4a92-8d25-0f80dd3b2219',\n", + " 'id': '0777396d-9993-47ce-b807-2e9c48364c63',\n", " 'messages': [{'content': 'How can I help you with the documents?',\n", - " 'id': '09df1d1c-3c8f-4d42-ad68-fa8a4d833d85',\n", + " 'id': 'f9b43a27-67ac-4984-9deb-3844ff4a6ea1',\n", " 'role': 'system',\n", " 'sources': [],\n", - " 'timestamp': '2023-10-04T21:43:51.941124'},\n", + " 'timestamp': '2023-10-09T18:58:50.356957'},\n", " {'content': 'What is Ragna?',\n", - " 'id': '362e63d9-16d6-4c0e-a3d7-26cab7ddd980',\n", + " 'id': '1d7f401a-556e-4c00-af11-8765c01d1709',\n", " 'role': 'user',\n", " 'sources': [],\n", - " 'timestamp': '2023-10-04T21:43:51.951584'},\n", + " 'timestamp': '2023-10-09T18:58:50.366988'},\n", " {'content': \"I can't really help you with your prompt:\\n\"\n", " '\\n'\n", " '> What is Ragna?\\n'\n", @@ -536,27 +536,27 @@ " '- document2.txt: This is content of document 2\\n'\n", " '- document1.txt: This is content of document 1\\n'\n", " '- document0.txt: This is content of document 0',\n", - " 'id': '822615fa-c05c-4b03-8d25-60b8687ba568',\n", + " 'id': '00f34df9-5e74-4588-8808-e1e11987db7e',\n", " 'role': 'assistant',\n", - " 'sources': [{'document': {'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'sources': [{'document': {'id': '336436fa-1b6a-4a29-8821-791cd682a7cb',\n", " 'name': 'document2.txt'},\n", - " 'id': '6bfba599-7fdd-4a36-8f43-c1cb552f83eb',\n", + " 'id': 'a07d9305-860b-45c7-9997-bdf47b7dfe58',\n", " 'location': ''},\n", - " {'document': {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " {'document': {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb',\n", " 'name': 'document1.txt'},\n", - " 'id': 'e7ebeac9-891b-4867-b35b-62d571a50837',\n", + " 'id': '1fac7064-f8c0-4a85-9e52-06ade4bfff86',\n", " 'location': ''},\n", - " {'document': {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " {'document': {'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'},\n", - " 'id': '37b9a731-032c-4bd6-9286-3693b1149952',\n", + " 'id': '8e5e94d7-35ab-4426-b369-88e04ef19e80',\n", " 'location': ''}],\n", - " 'timestamp': '2023-10-04T21:43:51.953583'}],\n", + " 'timestamp': '2023-10-09T18:58:50.368654'}],\n", " 'metadata': {'assistant': 'Ragna/DemoAssistant',\n", - " 'documents': [{'id': '1ae72430-bde4-42a3-b623-4aadd234bc04',\n", + " 'documents': [{'id': '336436fa-1b6a-4a29-8821-791cd682a7cb',\n", " 'name': 'document2.txt'},\n", - " {'id': '5c54ded3-67ae-4dd4-8829-190e09b2f3e3',\n", + " {'id': '8e5e17b4-e6e3-4707-806a-917f0f5b46fb',\n", " 'name': 'document1.txt'},\n", - " {'id': 'dbf7b9a6-a30b-4a25-b4ba-1782063ac2af',\n", + " {'id': 'de5bbb31-4501-41a4-9d4e-1b7a36952995',\n", " 'name': 'document0.txt'}],\n", " 'name': 'Ragna REST API example',\n", " 'params': {},\n", diff --git a/examples/s3_documents/s3_documents.ipynb b/examples/s3_documents/s3_documents.ipynb index f70df0ec..f1a33a5d 100644 --- a/examples/s3_documents/s3_documents.ipynb +++ b/examples/s3_documents/s3_documents.ipynb @@ -419,25 +419,7 @@ "execution_count": 5, "id": "ae31bfb0-01f6-47c8-9733-ba1204b153d8", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO: 127.0.0.1:56670 - \"GET /health HTTP/1.1\" 200 OK\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO: Started server process [40064]\n", - "INFO: Waiting for application startup.\n", - "INFO: Application startup complete.\n", - "INFO: Uvicorn running on http://127.0.0.1:31476 (Press CTRL+C to quit)\n" - ] - } - ], + "outputs": [], "source": [ "import contextlib\n", "import subprocess\n", @@ -491,9 +473,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document0.txt HTTP/1.1\" 200 OK\n", - "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImViY2E1NTZhLTQxZTgtNDQ4Ni1iOGI3LThkMmM3OGFlOTc1OCIsImV4cCI6MTY5NjQ1NjA2NS41ODAyMDUyfQ.MwFbgwv2nTcetRl9osgjnsnplvNC8fIxHBSBadFX7tg'},\n", - " 'document': {'id': 'ebca556a-41e8-4486-b8b7-8d2c78ae9758',\n", + "{'data': {'token': 'eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiUmFnbmEiLCJpZCI6ImIyZDRkNTgwLTA0OGQtNDBkMS1iMzljLTk1NGFmNDI5OTMwOSIsImV4cCI6MTY5Njg3Nzk4Mi4yOTA1NTk1fQ.0rVxqW0h5A3c8CJmp-roaRFUoJ7IAc4zUsQHJL2RuaU'},\n", + " 'document': {'id': 'b2d4d580-048d-40d1-b39c-954af4299309',\n", " 'name': 'document0.txt'},\n", " 'url': 'http://127.0.0.1:31476/document'}\n" ] @@ -515,15 +496,7 @@ "execution_count": 8, "id": "fdc5c6f4-bcf8-4902-b441-768b6e3a66bc", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n" - ] - } - ], + "outputs": [], "source": [ "response = await client.post(\n", " document_info[\"url\"],\n", @@ -539,22 +512,12 @@ "id": "f0d769ff-d721-460a-8888-5b033eb4a909", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document1.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:56670 - \"GET /document?user=Ragna&name=document2.txt HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:56670 - \"POST /document HTTP/1.1\" 200 OK\n" - ] - }, { "data": { "text/plain": [ - "[{'id': 'ebca556a-41e8-4486-b8b7-8d2c78ae9758', 'name': 'document0.txt'},\n", - " {'id': 'a723700a-8401-4b09-b727-ead6797f9725', 'name': 'document1.txt'},\n", - " {'id': '9028a6c3-2a07-436e-a953-7b92a393113b', 'name': 'document2.txt'}]" + "[{'id': 'b2d4d580-048d-40d1-b39c-954af4299309', 'name': 'document0.txt'},\n", + " {'id': '525728e0-d478-4bce-8fdc-9f8b92f0cfe2', 'name': 'document1.txt'},\n", + " {'id': '308a261b-47a8-43b6-80aa-4c3c12549f99', 'name': 'document2.txt'}]" ] }, "execution_count": 9, @@ -585,17 +548,10 @@ "id": "ae364aec-e63b-4f00-8b55-217148d6df24", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "INFO: 127.0.0.1:56670 - \"POST /chats?user=Ragna HTTP/1.1\" 200 OK\n" - ] - }, { "data": { "text/plain": [ - "'http://127.0.0.1:31476/chats/98a82c99-a221-4ff5-8b24-cc42ced36477'" + "'http://127.0.0.1:31476/chats/29e99963-08e9-4006-8b07-2f76e499a9c3'" ] }, "execution_count": 10, @@ -632,8 +588,6 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO: 127.0.0.1:56670 - \"POST /chats/98a82c99-a221-4ff5-8b24-cc42ced36477/start?user=Ragna HTTP/1.1\" 200 OK\n", - "INFO: 127.0.0.1:56670 - \"POST /chats/98a82c99-a221-4ff5-8b24-cc42ced36477/answer?user=Ragna&prompt=Hello%20World%21 HTTP/1.1\" 200 OK\n", "I can't really help you with your prompt:\n", "\n", "> Hello World!\n", diff --git a/ragna/_api/schemas.py b/ragna/_api/schemas.py index a6a16beb..d00b0fcd 100644 --- a/ragna/_api/schemas.py +++ b/ragna/_api/schemas.py @@ -85,8 +85,8 @@ def from_core_chat(cls, chat: ragna.core.Chat) -> ChatMetadata: return cls( name=chat.name, documents=[Document.from_core_document(d) for d in chat.documents], - source_storage=str(chat.source_storage), - assistant=str(chat.assistant), + source_storage=chat.source_storage.display_name(), + assistant=chat.assistant.display_name(), params=chat.params, ) diff --git a/ragna/core/_queue.py b/ragna/core/_queue.py index def37b95..74e07c96 100644 --- a/ragna/core/_queue.py +++ b/ragna/core/_queue.py @@ -21,11 +21,12 @@ def decorator(fn): return decorator -_COMPONENTS: dict[Type[RagComponent], RagComponent] = {} +_COMPONENTS: dict[Type[RagComponent], Optional[RagComponent]] = {} def execute(component, fn, args, kwargs): self = _COMPONENTS[component] + assert self is not None return fn(self, *args, **kwargs) @@ -41,12 +42,11 @@ def __init__(self, config: Config, *, load_components: Optional[bool]): if load_components is None: load_components = isinstance(self._huey, huey.MemoryHuey) - if load_components: - for component in itertools.chain( - config.registered_source_storage_classes.values(), - config.registered_assistant_classes.values(), - ): - self.load_component(component) + for component in itertools.chain( + config.registered_source_storage_classes.values(), + config.registered_assistant_classes.values(), + ): + self.parse_component(component, load=load_components) def _load_huey(self, url: Optional[str]): # FIXME: we need to store_none=True here. SourceStorage.store returns None and @@ -80,8 +80,11 @@ def _load_huey(self, url: Optional[str]): return _huey - def load_component( - self, component: Union[Type[RagComponent], RagComponent, str] + def parse_component( + self, + component: Union[Type[RagComponent], RagComponent, str], + *, + load: bool = False, ) -> Type[RagComponent]: if isinstance(component, type) and issubclass(component, RagComponent): cls = component @@ -98,10 +101,13 @@ def load_component( raise RagnaException("Unknown component", component=component) instance = None - if cls in _COMPONENTS: + if instance is None: + instance = _COMPONENTS.get(cls) + + if instance is not None: return cls - if instance is None: + if load: if not cls.is_available(): raise RagnaException("Component not available", name=cls.display_name()) diff --git a/ragna/core/_rag.py b/ragna/core/_rag.py index 0d5b6da4..2f50f284 100644 --- a/ragna/core/_rag.py +++ b/ragna/core/_rag.py @@ -46,8 +46,8 @@ async def new_chat( **params, ): documents = self._parse_documents(documents, user=user) - source_storage = self._queue.load_component(source_storage) - assistant = self._queue.load_component(assistant) + source_storage = self._queue.parse_component(source_storage, load=True) + assistant = self._queue.parse_component(assistant, load=True) chat = Chat( rag=self, @@ -135,8 +135,8 @@ def _get_chats(self, *, user: str): ) for document_state in chat_state.document_states ], - source_storage=self._queue.load_component(chat_state.source_storage), - assistant=self._queue.load_component(chat_state.assistant), + source_storage=self._queue.parse_component(chat_state.source_storage), + assistant=self._queue.parse_component(chat_state.assistant), messages=[ Message( id=message_state.id, diff --git a/ragna/source_storage/_chroma.py b/ragna/source_storage/_chroma.py index b64c9dad..3f674f97 100644 --- a/ragna/source_storage/_chroma.py +++ b/ragna/source_storage/_chroma.py @@ -49,7 +49,7 @@ def store( chunk_size: int = 500, chunk_overlap: int = 250, ) -> None: - collection = self._client.get_or_create_collection( + collection = self._client.create_collection( str(chat_id), embedding_function=self._embedding_function )