Skip to content

Commit

Permalink
Upgrade fastapi requirements to <=0.95.2 (#1364)
Browse files Browse the repository at this point in the history
* Relax and upgrade fastapi requirements to <=0.95.2

* Removing all hack

* Upgrade and pin fastapi to latest version: 0.103.1

* Use lifespan instead of deprecated on_startup and on_shutdown

* Upgrade min version of FastAPI for security reasons

* Update examples as openapi_examples

* Ignore mypy issues with grpc.aio channels

* Revert "Use lifespan instead of deprecated on_startup and on_shutdown"

This reverts commit 6de588f.
  • Loading branch information
jotare authored Sep 27, 2023
1 parent f403694 commit 2b31db7
Show file tree
Hide file tree
Showing 28 changed files with 127 additions and 97 deletions.
4 changes: 2 additions & 2 deletions nucliadb/nucliadb/common/cluster/discovery/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ async def _get_index_node_metadata(
else:
grpc_address = f"{address}:{settings.node_writer_port}"
channel = get_traced_grpc_channel(grpc_address, "discovery", variant="_writer")
stub = nodewriter_pb2_grpc.NodeWriterStub(channel)
stub = nodewriter_pb2_grpc.NodeWriterStub(channel) # type: ignore
metadata: nodewriter_pb2.NodeMetadata = await stub.GetMetadata(noderesources_pb2.EmptyQuery()) # type: ignore
return IndexNodeMetadata(
node_id=metadata.node_id,
Expand All @@ -124,7 +124,7 @@ async def _get_standalone_index_node_metadata(
else:
grpc_address = address
channel = get_traced_grpc_channel(grpc_address, "standalone_proxy")
stub = standalone_pb2_grpc.StandaloneClusterServiceStub(channel)
stub = standalone_pb2_grpc.StandaloneClusterServiceStub(channel) # type: ignore
resp: standalone_pb2.NodeInfoResponse = await stub.NodeInfo(standalone_pb2.NodeInfoRequest()) # type: ignore
return IndexNodeMetadata(
node_id=resp.id,
Expand Down
6 changes: 3 additions & 3 deletions nucliadb/nucliadb/common/cluster/index_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def sidecar(self) -> NodeSidecarStub:
channel = get_traced_grpc_channel(
grpc_address, SERVICE_NAME, variant="_sidecar"
)
SIDECAR_CONNECTIONS[self.address] = NodeSidecarStub(channel)
SIDECAR_CONNECTIONS[self.address] = NodeSidecarStub(channel) # type: ignore
else:
SIDECAR_CONNECTIONS[self.address] = DummySidecarStub()
self._sidecar = SIDECAR_CONNECTIONS[self.address]
Expand All @@ -83,7 +83,7 @@ def writer(self) -> NodeWriterStub:
channel = get_traced_grpc_channel(
grpc_address, SERVICE_NAME, variant="_writer"
)
WRITE_CONNECTIONS[self.address] = NodeWriterStub(channel)
WRITE_CONNECTIONS[self.address] = NodeWriterStub(channel) # type: ignore
else:
WRITE_CONNECTIONS[self.address] = DummyWriterStub()
self._writer = WRITE_CONNECTIONS[self.address]
Expand All @@ -99,7 +99,7 @@ def reader(self) -> NodeReaderStub:
channel = get_traced_grpc_channel(
grpc_address, SERVICE_NAME, variant="_reader"
)
READ_CONNECTIONS[self.address] = NodeReaderStub(channel)
READ_CONNECTIONS[self.address] = NodeReaderStub(channel) # type: ignore
else:
READ_CONNECTIONS[self.address] = DummyReaderStub()
self._reader = READ_CONNECTIONS[self.address]
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/common/cluster/standalone/index_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(self, address: str, type: str, original_type: Any):
else:
grpc_address = address
self._channel = get_traced_grpc_channel(grpc_address, "standalone_proxy")
self._stub = standalone_pb2_grpc.StandaloneClusterServiceStub(self._channel)
self._stub = standalone_pb2_grpc.StandaloneClusterServiceStub(self._channel) # type: ignore

def __getattr__(self, name):
async def call(request):
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async def finalizer():


async def start_grpc_health_service(port: int) -> Callable[[], Awaitable[None]]:
aio.init_grpc_aio()
aio.init_grpc_aio() # type: ignore

server = aio.server()
server.add_insecure_port(f"0.0.0.0:{port}")
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/ingest/service/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


async def start_grpc(service_name: Optional[str] = None):
aio.init_grpc_aio()
aio.init_grpc_aio() # type: ignore

await setup_telemetry(service_name or "ingest")
server = get_traced_grpc_server(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@pytest.mark.asyncio
async def test_clean_and_upgrade_kb_index(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

kb_id = str(uuid4())
pb = knowledgebox_pb2.KnowledgeBoxNew(slug="test", forceuuid=kb_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
async def test_create_entities_group(
grpc_servicer: IngestFixture, entities_manager_mock
):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

kb_id = str(uuid4())
pb = knowledgebox_pb2.KnowledgeBoxNew(slug="test", forceuuid=kb_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@

@pytest.mark.asyncio
async def test_export_resources(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

pb = knowledgebox_pb2.KnowledgeBoxNew(slug=f"test-{uuid4()}")
pb.config.title = "My Title"
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_export_resources(grpc_servicer: IngestFixture):

@pytest.mark.asyncio
async def test_upload_download(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

# Create a KB
pb = knowledgebox_pb2.KnowledgeBoxNew(slug=f"test-{uuid4()}")
Expand Down Expand Up @@ -183,7 +183,7 @@ async def upload_iterator():

@pytest.mark.asyncio
async def test_export_file(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

pb = knowledgebox_pb2.KnowledgeBoxNew(slug=f"test-{uuid4()}")
pb.config.title = "My Title"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ async def test_create_knowledgebox(grpc_servicer: IngestFixture, maindb_driver):
if isinstance(maindb_driver, LocalDriver):
pytest.skip("There is a bug in the local driver that needs to be fixed")

stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore
pb_prefix = knowledgebox_pb2.KnowledgeBoxPrefix(prefix="")

count = 0
Expand Down Expand Up @@ -79,8 +79,7 @@ async def get_kb_similarity(txn, kbid) -> utils_pb2.VectorSimilarity.ValueType:

@pytest.mark.asyncio
async def test_create_knowledgebox_with_similarity(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)

stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore
pb = knowledgebox_pb2.KnowledgeBoxNew(slug="test-dot")
pb.config.title = "My Title"
pb.similarity = utils_pb2.VectorSimilarity.DOT
Expand All @@ -98,7 +97,7 @@ async def test_create_knowledgebox_defaults_to_cosine_similarity(
grpc_servicer: IngestFixture,
txn,
):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore
pb = knowledgebox_pb2.KnowledgeBoxNew(slug="test-default")
pb.config.title = "My Title"
result = await stub.NewKnowledgeBox(pb) # type: ignore
Expand All @@ -113,7 +112,7 @@ async def test_create_knowledgebox_defaults_to_cosine_similarity(

@pytest.mark.asyncio
async def test_get_resource_id(grpc_servicer: IngestFixture) -> None:
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

pb = knowledgebox_pb2.KnowledgeBoxNew(slug="test")
pb.config.title = "My Title"
Expand All @@ -128,7 +127,7 @@ async def test_get_resource_id(grpc_servicer: IngestFixture) -> None:
async def test_delete_knowledgebox_handles_unexisting_kb(
grpc_servicer: IngestFixture,
) -> None:
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

pbid = knowledgebox_pb2.KnowledgeBoxID(slug="idonotexist")
result = await stub.DeleteKnowledgeBox(pbid) # type: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

@pytest.mark.asyncio
async def test_list_members(grpc_servicer: IngestFixture):
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel)
stub = writer_pb2_grpc.WriterStub(grpc_servicer.channel) # type: ignore

response = await stub.ListMembers(ListMembersRequest()) # type: ignore

Expand Down
3 changes: 2 additions & 1 deletion nucliadb/nucliadb/ingest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ async def start_ingest(service_name: Optional[str] = None):
nucliadb_settings.nucliadb_ingest, service_name or "ingest"
)
set_utility(Utility.CHANNEL, channel)
set_utility(Utility.INGEST, WriterStub(channel))
ingest = WriterStub(channel) # type: ignore
set_utility(Utility.INGEST, ingest)
else:
# Its not distributed create a ingest
from nucliadb.ingest.service.writer import WriterServicer
Expand Down
13 changes: 7 additions & 6 deletions nucliadb/nucliadb/search/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

import pydantic
from fastapi import Body, Header, Request, Response
from fastapi.openapi.models import Example
from fastapi_versioning import version
from starlette.responses import StreamingResponse

Expand Down Expand Up @@ -53,13 +54,13 @@ class SyncChatResponse(pydantic.BaseModel):


CHAT_EXAMPLES = {
"search_and_chat": {
"summary": "Ask who won the league final",
"description": "You can ask a question to your knowledge box", # noqa
"value": {
"search_and_chat": Example(
summary="Ask who won the league final",
description="You can ask a question to your knowledge box", # noqa
value={
"query": "Who won the league final?",
},
},
),
}


Expand All @@ -77,7 +78,7 @@ class SyncChatResponse(pydantic.BaseModel):
async def chat_knowledgebox_endpoint(
request: Request,
kbid: str,
item: ChatRequest = Body(examples=CHAT_EXAMPLES),
item: ChatRequest = Body(openapi_examples=CHAT_EXAMPLES),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
x_forwarded_for: str = Header(""),
Expand Down
13 changes: 7 additions & 6 deletions nucliadb/nucliadb/search/api/v1/find.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Optional, Union

from fastapi import Body, Header, Request, Response
from fastapi.openapi.models import Example
from fastapi_versioning import version
from pydantic.error_wrappers import ValidationError

Expand All @@ -44,14 +45,14 @@
from nucliadb_utils.exceptions import LimitsExceededError

FIND_EXAMPLES = {
"find_hybrid_search": {
"summary": "Do a hybrid search on a Knowledge Box",
"description": "Perform a hybrid search that will return text and semantic results matching the query",
"value": {
"find_hybrid_search": Example(
summary="Do a hybrid search on a Knowledge Box",
description="Perform a hybrid search that will return text and semantic results matching the query",
value={
"query": "How can I be an effective product manager?",
"features": [SearchOptions.PARAGRAPH, SearchOptions.VECTOR],
},
}
)
}


Expand Down Expand Up @@ -166,7 +167,7 @@ async def find_post_knowledgebox(
request: Request,
response: Response,
kbid: str,
item: FindRequest = Body(examples=FIND_EXAMPLES),
item: FindRequest = Body(openapi_examples=FIND_EXAMPLES),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
x_forwarded_for: str = Header(""),
Expand Down
13 changes: 7 additions & 6 deletions nucliadb/nucliadb/search/api/v1/resource/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Union

from fastapi import Body, Header, Request, Response
from fastapi.openapi.models import Example
from fastapi_versioning import version
from nucliadb_protos.resources_pb2 import FieldComputedMetadata
from nucliadb_protos.utils_pb2 import ExtractedText
Expand All @@ -40,13 +41,13 @@
from nucliadb_utils.utilities import get_storage, has_feature

ASK_EXAMPLES = {
"Ask a Resource": {
"summary": "Ask a question to the document",
"description": "Ask a question to the document. The whole document is sent as context to the generative AI",
"value": {
"Ask a Resource": Example(
summary="Ask a question to the document",
description="Ask a question to the document. The whole document is sent as context to the generative AI",
value={
"question": "Does this document contain personal information?",
},
}
)
}


Expand All @@ -69,7 +70,7 @@ async def resource_ask_endpoint(
kbid: str,
rid: str,
item: AskRequest = Body(
examples=ASK_EXAMPLES, description="Ask a question payload"
openapi_examples=ASK_EXAMPLES, description="Ask a question payload"
),
x_nucliadb_user: str = Header("", description="User Id", include_in_schema=False),
) -> Union[AskResponse, HTTPClientError]:
Expand Down
23 changes: 12 additions & 11 deletions nucliadb/nucliadb/search/api/v1/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import List, Optional, Tuple, Union

from fastapi import Body, Header, Request, Response
from fastapi.openapi.models import Example
from fastapi_versioning import version
from pydantic.error_wrappers import ValidationError

Expand Down Expand Up @@ -57,24 +58,24 @@
from nucliadb_utils.utilities import get_audit

SEARCH_EXAMPLES = {
"filtering_by_icon": {
"summary": "Search for pdf documents where the text 'Noam Chomsky' appears",
"description": "For a complete list of filters, visit: https://github.com/nuclia/nucliadb/blob/main/docs/internal/SEARCH.md#filters-and-facets", # noqa
"value": {
"filtering_by_icon": Example(
summary="Search for pdf documents where the text 'Noam Chomsky' appears",
description="For a complete list of filters, visit: https://github.com/nuclia/nucliadb/blob/main/docs/internal/SEARCH.md#filters-and-facets", # noqa
value={
"query": "Noam Chomsky",
"filters": ["/n/i/application/pdf"],
"features": [SearchOptions.DOCUMENT],
},
},
"get_language_counts": {
"summary": "Get the number of documents for each language",
"description": "For a complete list of facets, visit: https://github.com/nuclia/nucliadb/blob/main/docs/internal/SEARCH.md#filters-and-facets", # noqa
"value": {
),
"get_language_counts": Example(
summary="Get the number of documents for each language",
description="For a complete list of facets, visit: https://github.com/nuclia/nucliadb/blob/main/docs/internal/SEARCH.md#filters-and-facets", # noqa
value={
"page_size": 0,
"faceted": ["/s/p"],
"features": [SearchOptions.DOCUMENT],
},
},
),
}


Expand Down Expand Up @@ -254,7 +255,7 @@ async def search_post_knowledgebox(
request: Request,
response: Response,
kbid: str,
item: SearchRequest = Body(examples=SEARCH_EXAMPLES),
item: SearchRequest = Body(openapi_examples=SEARCH_EXAMPLES),
x_ndb_client: NucliaDBClientType = Header(NucliaDBClientType.API),
x_nucliadb_user: str = Header(""),
x_forwarded_for: str = Header(""),
Expand Down
7 changes: 4 additions & 3 deletions nucliadb/nucliadb/search/requesters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,12 @@ def validate_node_query_results(results: list[Any]) -> Optional[HTTPException]:
if isinstance(result, AioRpcError):
if result.code() is GrpcStatusCode.INTERNAL:
# handle node response errors
if "AllButQueryForbidden" in result.details():
details = result.details() or "gRPC error without details"
if "AllButQueryForbidden" in details:
status_code = 412
reason = result.details().split(":")[-1].strip().strip("'")
reason = details.split(":")[-1].strip().strip("'")
else:
reason = result.details()
reason = details
logger.exception(f"Unhandled node error", exc_info=result)
else:
logger.error(
Expand Down
4 changes: 2 additions & 2 deletions nucliadb/nucliadb/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,13 +191,13 @@ async def knowledgebox(nucliadb_manager: AsyncClient):

@pytest.fixture(scope="function")
async def nucliadb_grpc(nucliadb: Settings):
stub = WriterStub(aio.insecure_channel(f"localhost:{nucliadb.ingest_grpc_port}"))
stub = WriterStub(aio.insecure_channel(f"localhost:{nucliadb.ingest_grpc_port}")) # type: ignore
return stub


@pytest.fixture(scope="function")
async def nucliadb_train(nucliadb: Settings):
stub = TrainStub(aio.insecure_channel(f"localhost:{nucliadb.train_grpc_port}"))
stub = TrainStub(aio.insecure_channel(f"localhost:{nucliadb.train_grpc_port}")) # type: ignore
return stub


Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/train/tests/test_image_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ async def image_classification_resource(
) as _,
):
resp = await nucliadb_grpc.ProcessMessage( # type: ignore
[broker_message], timeout=10, wait_for_ready=True
iter([broker_message]), timeout=10, wait_for_ready=True
)
assert resp.status == OpStatusWriter.Status.OK
yield
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/train/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def start_train_grpc(service_name: Optional[str] = None):
if actual_service is not None:
return

aio.init_grpc_aio()
aio.init_grpc_aio() # type: ignore

await setup_telemetry(service_name or "train")
server = get_traced_grpc_server(service_name or "train")
Expand Down
Loading

2 comments on commit 2b31db7

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2b31db7 Previous: 374ff84 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 10081.254638487715 iter/sec (stddev: 1.1767135336427957e-7) 8667.761941074237 iter/sec (stddev: 2.4327377097212595e-7) 0.86

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 2b31db7 Previous: 374ff84 Ratio
nucliadb/tests/benchmarks/test_search.py::test_search_returns_labels[tikv_driver_settings] 52.24998748034538 iter/sec (stddev: 0.0005332181922529258)
nucliadb/tests/benchmarks/test_search.py::test_search_relations[tikv_driver_settings] 143.1320928931033 iter/sec (stddev: 0.00042173125698643165)

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.