Skip to content

Commit

Permalink
Allow chat on resource by slug (#2067)
Browse files Browse the repository at this point in the history
* Add tests

* Add tests

* parallelize workflow runs more

* fix test

* fix test

* fix endpoint name
  • Loading branch information
lferran authored Apr 16, 2024
1 parent 46bfd20 commit db03a5f
Show file tree
Hide file tree
Showing 8 changed files with 98 additions and 8 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/nucliadb_search.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ on:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
cancel-in-progress: false

permissions:
id-token: write # This is required for requesting the JWT
Expand Down Expand Up @@ -81,7 +81,7 @@ jobs:
runs-on: ubuntu-latest

strategy:
max-parallel: 1
max-parallel: 3
matrix:
include:
- maindb_driver: "tikv"
Expand Down
2 changes: 1 addition & 1 deletion charts/nucliadb_search/templates/search.vs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ spec:
method:
regex: "GET|POST|OPTIONS"
- uri:
regex: '^/api/v\d+/kb/[^/]+/resource/[^/]+/(chat|find|search|ask)$'
regex: '^/api/v\d+/kb/[^/]+/(resource|slug)/[^/]+/(chat|find|search|ask)$'
method:
regex: "GET|POST|OPTIONS"
- uri:
Expand Down
85 changes: 81 additions & 4 deletions nucliadb/nucliadb/search/api/v1/resource/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,16 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from typing import Union
from typing import Optional, Union

from fastapi import Header, Request, Response
from fastapi_versioning import version
from starlette.responses import StreamingResponse

from nucliadb.common import datamanagers
from nucliadb.models.responses import HTTPClientError
from nucliadb.search import predict
from nucliadb.search.api.v1.router import KB_PREFIX, api
from nucliadb.search.api.v1.router import KB_PREFIX, RESOURCE_SLUG_PREFIX, api
from nucliadb.search.search.exceptions import (
IncompleteFindResultsError,
InvalidQueryError,
Expand All @@ -49,7 +50,7 @@
)
@requires(NucliaDBRoles.READER)
@version(1)
async def resource_chat_endpoint(
async def resource_chat_endpoint_by_uuid(
request: Request,
kbid: str,
rid: str,
Expand All @@ -63,6 +64,75 @@ async def resource_chat_endpoint(
"This is slower and requires waiting for entire answer to be ready.",
),
) -> Union[StreamingResponse, HTTPClientError, Response]:
return await resource_chat_endpoint(
request,
kbid,
item,
x_ndb_client,
x_nucliadb_user,
x_forwarded_for,
x_synchronous,
resource_id=rid,
)


@api.post(
f"/{KB_PREFIX}/{{kbid}}/{RESOURCE_SLUG_PREFIX}/{{slug}}/chat",
status_code=200,
name="Chat with a Resource (by slug)",
summary="Chat with a resource",
description="Chat with a resource",
tags=["Search"],
response_model=None,
)
@requires(NucliaDBRoles.READER)
@version(1)
async def resource_chat_endpoint_by_slug(
request: Request,
kbid: str,
slug: str,
item: ChatRequest,
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
x_forwarded_for: str = Header(""),
x_synchronous: bool = Header(
False,
description="When set to true, outputs response as JSON in a non-streaming way. "
"This is slower and requires waiting for entire answer to be ready.",
),
) -> Union[StreamingResponse, HTTPClientError, Response]:
return await resource_chat_endpoint(
request,
kbid,
item,
x_ndb_client,
x_nucliadb_user,
x_forwarded_for,
x_synchronous,
resource_slug=slug,
)


async def resource_chat_endpoint(
request: Request,
kbid: str,
item: ChatRequest,
x_ndb_client: NucliaDBClientType,
x_nucliadb_user: str,
x_forwarded_for: str,
x_synchronous: bool,
resource_id: Optional[str] = None,
resource_slug: Optional[str] = None,
) -> Union[StreamingResponse, HTTPClientError, Response]:

if resource_id is None:
if resource_slug is None:
raise ValueError("Either resource_id or resource_slug must be provided")

resource_id = await get_resource_uuid_by_slug(kbid, resource_slug)
if resource_id is None:
return HTTPClientError(status_code=404, detail="Resource not found")

try:
return await create_chat_response(
kbid,
Expand All @@ -71,7 +141,7 @@ async def resource_chat_endpoint(
x_ndb_client,
x_forwarded_for,
x_synchronous,
resource=rid,
resource=resource_id,
)
except LimitsExceededError as exc:
return HTTPClientError(status_code=exc.status_code, detail=exc.detail)
Expand All @@ -97,3 +167,10 @@ async def resource_chat_endpoint(
)
except InvalidQueryError as exc:
return HTTPClientError(status_code=412, detail=str(exc))


async def get_resource_uuid_by_slug(kbid: str, slug: str) -> Optional[str]:
async with datamanagers.with_transaction() as txn:
return await datamanagers.resources.get_resource_uuid_from_slug(
txn, kbid=kbid, slug=slug
)
1 change: 1 addition & 0 deletions nucliadb/nucliadb/search/api/v1/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
KB_PREFIX = "kb"
KBS_PREFIX = "kbs"
RESOURCE_PREFIX = "resource"
RESOURCE_SLUG_PREFIX = "slug"
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ async def test_resource_chat_endpoint_handles_errors(
response = await resource_chat_endpoint(
request=request,
kbid="kbid",
rid="rid",
item=Mock(),
x_ndb_client=None,
x_nucliadb_user="",
x_forwarded_for="",
x_synchronous=True,
resource_id="rid",
)
assert response.status_code == http_error_response.status_code
assert response.body == http_error_response.body
1 change: 1 addition & 0 deletions nucliadb_sdk/nucliadb_sdk/tests/test_sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def test_search_endpoints(sdk: nucliadb_sdk.NucliaDB, kb):

resource = sdk.create_resource(kbid=kb.uuid, title="Resource", slug="resource")
sdk.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
sdk.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
sdk.feedback(kbid=kb.uuid, ident="bar", good=True, feedback="baz", task="CHAT")
with pytest.raises(nucliadb_sdk.v2.exceptions.UnknownError) as err:
sdk.summarize(kbid=kb.uuid, resources=["foobar"])
Expand Down
1 change: 1 addition & 0 deletions nucliadb_sdk/nucliadb_sdk/tests/test_sdk_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ async def test_search_endpoints(sdk_async: nucliadb_sdk.NucliaDBAsync, kb):
kbid=kb.uuid, title="Resource", slug="resource"
)
await sdk_async.chat_on_resource(kbid=kb.uuid, rid=resource.uuid, query="foo")
await sdk_async.chat_on_resource_by_slug(kbid=kb.uuid, slug="resource", query="foo")
await sdk_async.feedback(
kbid=kb.uuid, ident="bar", good=True, feedback="baz", task=FeedbackTasks.CHAT
)
Expand Down
9 changes: 9 additions & 0 deletions nucliadb_sdk/nucliadb_sdk/v2/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,15 @@ def _check_response(self, response: httpx.Response):
response_type=chat_response_parser,
docstring=docstrings.RESOURCE_CHAT,
)
chat_on_resource_by_slug = _request_builder(
name="chat_on_resource",
path_template="/v1/kb/{kbid}/slug/{slug}/chat",
method="POST",
path_params=("kbid", "slug"),
request_type=ChatRequest,
response_type=chat_response_parser,
docstring=docstrings.RESOURCE_CHAT,
)
summarize = _request_builder(
name="summarize",
path_template="/v1/kb/{kbid}/summarize",
Expand Down

3 comments on commit db03a5f

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db03a5f Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 13461.205250380894 iter/sec (stddev: 1.6438512134225906e-7) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db03a5f Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 13260.719669714508 iter/sec (stddev: 4.895815906507924e-7) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: db03a5f Previous: d4afd82 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12818.200950585913 iter/sec (stddev: 0.000001204651990772851) 13028.533525895236 iter/sec (stddev: 4.192637045977425e-7) 1.02

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.