Skip to content

Commit

Permalink
Async sdk support (#928)
Browse files Browse the repository at this point in the history
* add async sdk 2 support

* lint

* fix

* bump

* add timeout setting support
  • Loading branch information
vangheem authored May 23, 2023
1 parent 67ee9f3 commit 25dec72
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 35 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.12.0
2.13.0
3 changes: 2 additions & 1 deletion nucliadb_sdk/nucliadb_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,12 @@
get_or_create,
list_kbs,
)
from nucliadb_sdk.v2 import NucliaSDK, Region, exceptions
from nucliadb_sdk.v2 import NucliaSDK, NucliaSDKAsync, Region, exceptions
from nucliadb_sdk.vectors import Vector

__all__ = (
"NucliaSDK",
"NucliaSDKAsync",
"Region",
"exceptions",
# OLD support APIs
Expand Down
4 changes: 2 additions & 2 deletions nucliadb_sdk/nucliadb_sdk/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from .sdk import NucliaSDK, Region
from .sdk import NucliaSDK, NucliaSDKAsync, Region

__all__ = ("NucliaSDK", "Region")
__all__ = ("NucliaSDK", "NucliaSDKAsync", "Region")
174 changes: 143 additions & 31 deletions nucliadb_sdk/nucliadb_sdk/v2/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import base64
import enum
import io
from typing import Any, Callable, Optional, Type, Union

import httpx
import orjson
from pydantic import BaseModel

from nucliadb_models.conversation import InputMessage
from nucliadb_models.entities import (
CreateEntitiesGroupPayload,
EntitiesGroup,
Expand All @@ -49,6 +52,7 @@
from nucliadb_models.writer import (
CreateResourcePayload,
ResourceCreated,
ResourceFieldAdded,
ResourceUpdated,
UpdateResourcePayload,
)
Expand Down Expand Up @@ -89,11 +93,33 @@ def chat_response_parser(response: httpx.Response) -> ChatResponse:
)


def _parse_list_of_pydantic(
data: list[Any],
) -> str:
output = []
for item in data:
if isinstance(item, BaseModel):
output.append(item.dict())
else:
output.append(item)
return orjson.dumps(output).decode("utf-8")


def _parse_response(response_type, resp: httpx.Response) -> Any:
if response_type is not None:
if isinstance(response_type, type) and issubclass(response_type, BaseModel):
return response_type.parse_raw(resp.content) # type: ignore
else:
return response_type(resp) # type: ignore
else:
return resp.content


def _request_builder(
path_template: str,
method: str,
path_params: tuple[str, ...],
request_type: Optional[Type[BaseModel]],
request_type: Optional[Union[Type[BaseModel], list[Any]]],
response_type: Optional[
Union[Type[BaseModel], Callable[[httpx.Response], BaseModel]]
],
Expand All @@ -109,44 +135,43 @@ def _func(self: "NucliaSDK", content: Optional[Any] = None, **kwargs):
data = None
if request_type is not None:
if content is not None:
if not isinstance(content, request_type):
raise TypeError(f"Expected {request_type}, got {type(content)}")
else:
data = content.json()
try:
if not isinstance(content, request_type): # type: ignore
raise TypeError(f"Expected {request_type}, got {type(content)}")
else:
data = content.json()
except TypeError:
if not isinstance(content, list):
raise
data = _parse_list_of_pydantic(content)
else:
# pull properties out of kwargs now
content_data = {}
for key in list(kwargs.keys()):
if key in request_type.__fields__:
if key in request_type.__fields__: # type: ignore
content_data[key] = kwargs.pop(key)
data = request_type.parse_obj(content_data).json()
data = request_type.parse_obj(content_data).json() # type: ignore

query_params = kwargs.pop("query_params", None)
if len(kwargs) > 0:
raise TypeError(f"Invalid arguments provided: {kwargs}")

resp = self._request(path, method, data=data, query_params=query_params)

if response_type is not None:
if isinstance(response_type, type) and issubclass(response_type, BaseModel):
return response_type.parse_raw(resp.content) # type: ignore
else:
return response_type(resp) # type: ignore
else:
return resp.content
if asyncio.iscoroutine(resp):

return _func
async def _wrapped_resp():
real_resp = await resp
return _parse_response(response_type, real_resp)

return _wrapped_resp()
else:
return _parse_response(response_type, resp)

class NucliaSDK:
"""
Example usage:
return _func

from nucliadb_sdk.v2.sdk import *
sdk = NucliaSDK(region=Region.EUROPE1, api_key="api-key")
sdk.list_resources(kbid='70a2530a-5863-41ec-b42b-bfe795bef2eb')
"""

class _NucliaSDKBase:
def __init__(
self,
*,
Expand Down Expand Up @@ -175,7 +200,7 @@ def __init__(
self.base_url = url.rstrip("/")
headers["X-STF-SERVICEACCOUNT"] = f"Bearer {api_key}"

self.session = httpx.Client(headers=headers, base_url=self.base_url)
self.headers = headers

def _request(
self,
Expand All @@ -184,14 +209,9 @@ def _request(
data: Optional[Union[str, bytes]] = None,
query_params: Optional[dict[str, str]] = None,
):
url = f"{self.base_url}{path}"
opts: dict[str, Any] = {}
if data is not None:
opts["data"] = data
if query_params is not None:
opts["params"] = query_params
response: httpx.Response = getattr(self.session, method.lower())(url, **opts)
raise NotImplementedError

def _check_response(self, response: httpx.Response):
if response.status_code < 300:
return response
elif response.status_code in (401, 403):
Expand All @@ -208,7 +228,7 @@ def _request(
raise exceptions.ConflictError(response.text)
elif response.status_code == 404:
raise exceptions.NotFoundError(
f"Resource not found at url {url}: {response.text}"
f"Resource not found at url {response.url}: {response.text}"
)
else:
raise exceptions.UnknownError(
Expand Down Expand Up @@ -260,6 +280,15 @@ def _request(
"/v1/kb/{kbid}/resources", "GET", ("kbid",), None, ResourceList
)

# Conversation endpoints
add_conversation_message = _request_builder(
"/v1/kb/{kbid}/resource/{rid}/conversation/{field_id}/messages",
"PUT",
("kbid", "rid", "field_id"),
list[InputMessage], # type: ignore
ResourceFieldAdded,
)

# Labels
set_labelset = _request_builder(
"/v1/kb/{kbid}/labelset/{labelset}",
Expand Down Expand Up @@ -339,3 +368,86 @@ def _request(
chat = _request_builder(
"/v1/kb/{kbid}/chat", "POST", ("kbid",), ChatRequest, chat_response_parser
)


class NucliaSDK(_NucliaSDKBase):
"""
Example usage:
from nucliadb_sdk.v2.sdk import *
sdk = NucliaSDK(region=Region.EUROPE1, api_key="api-key")
sdk.list_resources(kbid='70a2530a-5863-41ec-b42b-bfe795bef2eb')
"""

def __init__(
self,
*,
region: Region = Region.EUROPE1,
api_key: Optional[str] = None,
url: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = 60.0,
):
super().__init__(region=region, api_key=api_key, url=url, headers=headers)
self.session = httpx.Client(
headers=self.headers, base_url=self.base_url, timeout=timeout
)

def _request(
self,
path,
method: str,
data: Optional[Union[str, bytes]] = None,
query_params: Optional[dict[str, str]] = None,
):
url = f"{self.base_url}{path}"
opts: dict[str, Any] = {}
if data is not None:
opts["data"] = data
if query_params is not None:
opts["params"] = query_params
response: httpx.Response = getattr(self.session, method.lower())(url, **opts)
return self._check_response(response)


class NucliaSDKAsync(_NucliaSDKBase):
"""
Example usage:
from nucliadb_sdk.v2.sdk import *
sdk = NucliaSDK(region=Region.EUROPE1, api_key="api-key")
sdk.list_resources(kbid='70a2530a-5863-41ec-b42b-bfe795bef2eb')
"""

def __init__(
self,
*,
region: Region = Region.EUROPE1,
api_key: Optional[str] = None,
url: Optional[str] = None,
headers: Optional[dict[str, str]] = None,
timeout: Optional[float] = 60.0,
):
super().__init__(region=region, api_key=api_key, url=url, headers=headers)
self.session = httpx.AsyncClient(
headers=self.headers, base_url=self.base_url, timeout=timeout
)

async def _request(
self,
path,
method: str,
data: Optional[Union[str, bytes]] = None,
query_params: Optional[dict[str, str]] = None,
):
url = f"{self.base_url}{path}"
opts: dict[str, Any] = {}
if data is not None:
opts["data"] = data
if query_params is not None:
opts["params"] = query_params
response: httpx.Response = await getattr(self.session, method.lower())(
url, **opts
)
self._check_response(response)
return response
1 change: 1 addition & 0 deletions nucliadb_sdk/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ httpx
requests
nucliadb-models
urllib3<1.27,>=1.21.1
orjson

1 comment on commit 25dec72

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 25dec72 Previous: d19d131 Ratio
nucliadb/tests/benchmarks/test_search.py::test_search_returns_labels 43.26942082600021 iter/sec (stddev: 0.0015731842547920796) 68.48176450054224 iter/sec (stddev: 0.00011698806646091496) 1.58
nucliadb/tests/benchmarks/test_search.py::test_search_relations 115.99527276548751 iter/sec (stddev: 0.000504999455476682) 128.13602179672588 iter/sec (stddev: 0.000027569798386079616) 1.10

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.