Skip to content

Commit

Permalink
fix cursor types
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkLark86 committed Oct 16, 2024
1 parent f6848fe commit 0a04de2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
11 changes: 5 additions & 6 deletions superdesk/core/resources/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

from motor.motor_asyncio import AsyncIOMotorCollection, AsyncIOMotorCursor

from .model import ResourceModel

ResourceModelType = TypeVar("ResourceModelType", bound="ResourceModel")

ResourceModelType = TypeVar("ResourceModelType", bound=ResourceModel)


class ResourceCursorAsync(Generic[ResourceModelType]):
Expand Down Expand Up @@ -63,7 +65,7 @@ def get_model_instance(self, data: Dict[str, Any]):
return self.data_class.from_dict(data)


class ElasticsearchResourceCursorAsync(ResourceCursorAsync):
class ElasticsearchResourceCursorAsync(ResourceCursorAsync[ResourceModelType], Generic[ResourceModelType]):
no_hits = {"hits": {"total": 0, "hits": []}}

def __init__(self, data_class: Type[ResourceModelType], hits=None):
Expand Down Expand Up @@ -113,7 +115,7 @@ def extra(self, response: Dict[str, Any]):
response["_aggregations"] = self.hits["aggregations"]


class MongoResourceCursorAsync(ResourceCursorAsync):
class MongoResourceCursorAsync(ResourceCursorAsync[ResourceModelType], Generic[ResourceModelType]):
def __init__(
self,
data_class: Type[ResourceModelType],
Expand All @@ -137,6 +139,3 @@ async def next_raw(self) -> Optional[Dict[str, Any]]:

async def count(self):
return await self.collection.count_documents(self.lookup)


from .model import ResourceModel # noqa: E402
20 changes: 15 additions & 5 deletions superdesk/core/resources/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Union,
cast,
overload,
Type,
)
import logging
import ast
Expand Down Expand Up @@ -201,12 +202,14 @@ async def search(self, lookup: Dict[str, Any], use_mongo=False) -> ResourceCurso
try:
if not use_mongo:
response = await self.elastic.search(lookup)
return ElasticsearchResourceCursorAsync(self.config.data_class, response)
return ElasticsearchResourceCursorAsync(cast(Type[ResourceModelType], self.config.data_class), response)
except KeyError:
pass

response = self.mongo_async.find(lookup)
return MongoResourceCursorAsync(self.config.data_class, self.mongo_async, response, lookup)
return MongoResourceCursorAsync(
cast(Type[ResourceModelType], self.config.data_class), self.mongo_async, response, lookup
)

async def on_create(self, docs: List[ResourceModelType]) -> None:
"""Hook to run before creating new resource(s)
Expand Down Expand Up @@ -536,13 +539,17 @@ async def find(
try:
if not use_mongo:
cursor, count = await self.elastic.find(search_request)
return ElasticsearchResourceCursorAsync(self.config.data_class, cursor.hits)
return ElasticsearchResourceCursorAsync(
cast(Type[ResourceModelType], self.config.data_class), cursor.hits
)
except KeyError:
pass

return await self._mongo_find(search_request)

async def _mongo_find(self, req: SearchRequest, versioned: bool = False) -> MongoResourceCursorAsync:
async def _mongo_find(
self, req: SearchRequest, versioned: bool = False
) -> MongoResourceCursorAsync[ResourceModelType]:
kwargs: Dict[str, Any] = {}

if req.max_results:
Expand All @@ -567,7 +574,10 @@ async def _mongo_find(self, req: SearchRequest, versioned: bool = False) -> Mong
cursor = self.mongo_async.find(**kwargs) if not versioned else self.mongo_versioned_async.find(**kwargs)

return MongoResourceCursorAsync(
self.config.data_class, self.mongo_async if not versioned else self.mongo_versioned_async, cursor, where
cast(Type[ResourceModelType], self.config.data_class),
self.mongo_async if not versioned else self.mongo_versioned_async,
cursor,
where,
)

def _convert_req_to_mongo_sort(self, sort: SortParam | None) -> SortListParam:
Expand Down

0 comments on commit 0a04de2

Please sign in to comment.