From db03a5f790ad9cb32e59d165b720ff1e1bf5e913 Mon Sep 17 00:00:00 2001 From: Ferran Llamas Date: Tue, 16 Apr 2024 17:50:50 +0200 Subject: [PATCH] Allow chat on resource by slug (#2067) * Add tests * Add tests * parallelize workflow runs more * fix test * fix test * fix endpoint name --- .github/workflows/nucliadb_search.yml | 4 +- .../nucliadb_search/templates/search.vs.yaml | 2 +- .../nucliadb/search/api/v1/resource/chat.py | 85 ++++++++++++++++++- nucliadb/nucliadb/search/api/v1/router.py | 1 + .../tests/unit/api/v1/resource/test_chat.py | 3 +- nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py | 1 + .../nucliadb_sdk/tests/test_sdk_async.py | 1 + nucliadb_sdk/nucliadb_sdk/v2/sdk.py | 9 ++ 8 files changed, 98 insertions(+), 8 deletions(-) diff --git a/.github/workflows/nucliadb_search.yml b/.github/workflows/nucliadb_search.yml index d367e9118f..d8ed0fb1c9 100644 --- a/.github/workflows/nucliadb_search.yml +++ b/.github/workflows/nucliadb_search.yml @@ -40,7 +40,7 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true + cancel-in-progress: false permissions: id-token: write # This is required for requesting the JWT @@ -81,7 +81,7 @@ jobs: runs-on: ubuntu-latest strategy: - max-parallel: 1 + max-parallel: 3 matrix: include: - maindb_driver: "tikv" diff --git a/charts/nucliadb_search/templates/search.vs.yaml b/charts/nucliadb_search/templates/search.vs.yaml index 92906570f2..8413d9469d 100644 --- a/charts/nucliadb_search/templates/search.vs.yaml +++ b/charts/nucliadb_search/templates/search.vs.yaml @@ -39,7 +39,7 @@ spec: method: regex: "GET|POST|OPTIONS" - uri: - regex: '^/api/v\d+/kb/[^/]+/resource/[^/]+/(chat|find|search|ask)$' + regex: '^/api/v\d+/kb/[^/]+/(resource|slug)/[^/]+/(chat|find|search|ask)$' method: regex: "GET|POST|OPTIONS" - uri: diff --git a/nucliadb/nucliadb/search/api/v1/resource/chat.py b/nucliadb/nucliadb/search/api/v1/resource/chat.py index df904474ca..0146bbf2d7 100644 --- a/nucliadb/nucliadb/search/api/v1/resource/chat.py +++ b/nucliadb/nucliadb/search/api/v1/resource/chat.py @@ -17,15 +17,16 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . # -from typing import Union +from typing import Optional, Union from fastapi import Header, Request, Response from fastapi_versioning import version from starlette.responses import StreamingResponse +from nucliadb.common import datamanagers from nucliadb.models.responses import HTTPClientError from nucliadb.search import predict -from nucliadb.search.api.v1.router import KB_PREFIX, api +from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_SLUG_PREFIX, api from nucliadb.search.search.exceptions import ( IncompleteFindResultsError, InvalidQueryError, @@ -49,7 +50,7 @@ ) @requires(NucliaDBRoles.READER) @version(1) -async def resource_chat_endpoint( +async def resource_chat_endpoint_by_uuid( request: Request, kbid: str, rid: str, @@ -63,6 +64,75 @@ async def resource_chat_endpoint( "This is slower and requires waiting for entire answer to be ready.", ), ) -> Union[StreamingResponse, HTTPClientError, Response]: + return await resource_chat_endpoint( + request, + kbid, + item, + x_ndb_client, + x_nucliadb_user, + x_forwarded_for, + x_synchronous, + resource_id=rid, + ) + + +@api.post( + f"/{KB_PREFIX}/{{kbid}}/{RESOURCE_SLUG_PREFIX}/{{slug}}/chat", + status_code=200, + name="Chat with a Resource (by slug)", + summary="Chat with a resource", + description="Chat with a resource", + tags=["Search"], + response_model=None, +) +@requires(NucliaDBRoles.READER) +@version(1) +async def resource_chat_endpoint_by_slug( + request: Request, + kbid: str, + slug: str, + item: ChatRequest, + x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API), + x_nucliadb_user: str = Header(""), + x_forwarded_for: str = Header(""), + x_synchronous: bool = Header( + False, + description="When set to true, outputs response as JSON in a non-streaming way. " + "This is slower and requires waiting for entire answer to be ready.", + ), +) -> Union[StreamingResponse, HTTPClientError, Response]: + return await resource_chat_endpoint( + request, + kbid, + item, + x_ndb_client, + x_nucliadb_user, + x_forwarded_for, + x_synchronous, + resource_slug=slug, + ) + + +async def resource_chat_endpoint( + request: Request, + kbid: str, + item: ChatRequest, + x_ndb_client: NucliaDBClientType, + x_nucliadb_user: str, + x_forwarded_for: str, + x_synchronous: bool, + resource_id: Optional[str] = None, + resource_slug: Optional[str] = None, +) -> Union[StreamingResponse, HTTPClientError, Response]: + + if resource_id is None: + if resource_slug is None: + raise ValueError("Either resource_id or resource_slug must be provided") + + resource_id = await get_resource_uuid_by_slug(kbid, resource_slug) + if resource_id is None: + return HTTPClientError(status_code=404, detail="Resource not found") + try: return await create_chat_response( kbid, @@ -71,7 +141,7 @@ async def resource_chat_endpoint( x_ndb_client, x_forwarded_for, x_synchronous, - resource=rid, + resource=resource_id, ) except LimitsExceededError as exc: return HTTPClientError(status_code=exc.status_code, detail=exc.detail) @@ -97,3 +167,10 @@ async def resource_chat_endpoint( ) except InvalidQueryError as exc: return HTTPClientError(status_code=412, detail=str(exc)) + + +async def get_resource_uuid_by_slug(kbid: str, slug: str) -> Optional[str]: + async with datamanagers.with_transaction() as txn: + return await datamanagers.resources.get_resource_uuid_from_slug( + txn, kbid=kbid, slug=slug + ) diff --git a/nucliadb/nucliadb/search/api/v1/router.py b/nucliadb/nucliadb/search/api/v1/router.py index 6afbc55acf..909246d7c4 100644 --- a/nucliadb/nucliadb/search/api/v1/router.py +++ b/nucliadb/nucliadb/search/api/v1/router.py @@ -24,3 +24,4 @@ KB_PREFIX = "kb" KBS_PREFIX = "kbs" RESOURCE_PREFIX = "resource" +RESOURCE_SLUG_PREFIX = "slug" diff --git a/nucliadb/nucliadb/search/tests/unit/api/v1/resource/test_chat.py b/nucliadb/nucliadb/search/tests/unit/api/v1/resource/test_chat.py index 8be82971d3..160caa5e6a 100644 --- a/nucliadb/nucliadb/search/tests/unit/api/v1/resource/test_chat.py +++ b/nucliadb/nucliadb/search/tests/unit/api/v1/resource/test_chat.py @@ -87,11 +87,12 @@ async def test_resource_chat_endpoint_handles_errors( response = await resource_chat_endpoint( request=request, kbid="kbid", - rid="rid", item=Mock(), x_ndb_client=None, x_nucliadb_user="", x_forwarded_for="", + x_synchronous=True, + resource_id="rid", ) assert response.status_code == http_error_response.status_code assert response.body == http_error_response.body diff --git a/nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py b/nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py index 0e31d3e43c..7cbe2e2382 100644 --- a/nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py +++ b/nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py @@ -95,6 +95,7 @@ def test_search_endpoints(sdk: nucliadb_sdk.NucliaDB, kb): resource = sdk.create_resource(kbid=kb.uuid, title="Resource", slug="resource") sdk.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") + sdk.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") sdk.feedback(kbid=kb.uuid, ident="bar", good=True, feedback="baz", task="CHAT") with pytest.raises(nucliadb_sdk.v2.exceptions.UnknownError) as err: sdk.summarize(kbid=kb.uuid, resources=["foobar"]) diff --git a/nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py b/nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py index e74c6546f5..9ea5d272b3 100644 --- a/nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py +++ b/nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py @@ -87,6 +87,7 @@ async def test_search_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb): kbid=kb.uuid, title="Resource", slug="resource" ) await sdk_async.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo") + await sdk_async.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo") await sdk_async.feedback( kbid=kb.uuid, ident="bar", good=True, feedback="baz", task=FeedbackTasks.CHAT ) diff --git a/nucliadb_sdk/nucliadb_sdk/v2/sdk.py b/nucliadb_sdk/nucliadb_sdk/v2/sdk.py index 6c80a45d8f..16eb9cb853 100644 --- a/nucliadb_sdk/nucliadb_sdk/v2/sdk.py +++ b/nucliadb_sdk/nucliadb_sdk/v2/sdk.py @@ -642,6 +642,15 @@ def _check_response(self, response: httpx.Response): response_type=chat_response_parser, docstring=docstrings.RESOURCE_CHAT, ) + chat_on_resource_by_slug = _request_builder( + name="chat_on_resource", + path_template="/v1/kb/{kbid}/slug/{slug}/chat", + method="POST", + path_params=("kbid", "slug"), + request_type=ChatRequest, + response_type=chat_response_parser, + docstring=docstrings.RESOURCE_CHAT, + ) summarize = _request_builder( name="summarize", path_template="/v1/kb/{kbid}/summarize",