From cc742d16ba89cc2d93dc376ebbcc906d14af5178 Mon Sep 17 00:00:00 2001 From: David Schrooten Date: Wed, 22 Nov 2023 13:13:49 +0100 Subject: [PATCH] Add kwargs to default methods --- pydantic_mongo/abstract_repository.py | 29 +++++++++++++++++---------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/pydantic_mongo/abstract_repository.py b/pydantic_mongo/abstract_repository.py index 141718b..233e634 100644 --- a/pydantic_mongo/abstract_repository.py +++ b/pydantic_mongo/abstract_repository.py @@ -104,7 +104,7 @@ def to_model(self, data: dict) -> T: """ return self.to_model_custom(self.__document_class, data) - def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: + def save(self, model: T, **kwargs) -> Union[InsertOneResult, UpdateResult]: """ Save entity to database. It will update the entity if it has id, otherwise it will insert it. """ @@ -116,11 +116,11 @@ def save(self, model: T) -> Union[InsertOneResult, UpdateResult]: {"_id": mongo_id}, {"$set": document}, upsert=True ) - result = self.get_collection().insert_one(document) + result = self.get_collection().insert_one(document, **kwargs) model.id = result.inserted_id return result - def save_many(self, models: Iterable[T]): + def save_many(self, models: Iterable[T], **kwargs): """ Save multiple entities to database """ @@ -149,24 +149,24 @@ def save_many(self, models: Iterable[T]): UpdateOne({"_id": mongo_id}, {"$set": document}, upsert=True) for mongo_id, document in zip(mongo_ids, documents_to_update) ] - self.get_collection().bulk_write(bulk_operations) + self.get_collection().bulk_write(bulk_operations, **kwargs) - def delete(self, model: T): - return self.get_collection().delete_one({"_id": model.id}) + def delete(self, model: T, **kwargs): + return self.get_collection().delete_one({"_id": model.id}, **kwargs) - def find_one_by_id(self, _id: Any) -> Optional[T]: + def find_one_by_id(self, _id: Any, **kwargs) -> Optional[T]: """ Find entity by id Note: The id should be of the same type as the id field in the document class, ie. ObjectId """ - return self.find_one_by({"id": _id}) + return self.find_one_by({"id": _id}, **kwargs) - def find_one_by(self, query: dict) -> Optional[T]: + def find_one_by(self, query: dict, **kwargs) -> Optional[T]: """ Find entity by mongo query """ - result = self.get_collection().find_one(self.__map_id(query)) + result = self.get_collection().find_one(self.__map_id(query), **kwargs) return self.to_model(result) if result else None def find_by_with_output_type( @@ -177,6 +177,7 @@ def find_by_with_output_type( limit: Optional[int] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, + **kwargs, ) -> Iterable[OutputT]: """ Find entities by mongo query allowing custom output type @@ -190,7 +191,7 @@ def find_by_with_output_type( """ mapped_projection = self.__map_id(projection) if projection else None mapped_sort = self.__map_sort(sort) if sort else None - cursor = self.get_collection().find(self.__map_id(query), mapped_projection) + cursor = self.get_collection().find(self.__map_id(query), mapped_projection, **kwargs) if limit: cursor.limit(limit) if skip: @@ -206,6 +207,7 @@ def find_by( limit: Optional[int] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, + **kwargs, ) -> Iterable[T]: """ " Find entities by mongo query @@ -217,6 +219,7 @@ def find_by( limit=limit, sort=sort, projection=projection, + **kwargs, ) def get_pagination_query( @@ -259,6 +262,7 @@ def paginate_with_output_type( before: Optional[str] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, + **kwargs, ) -> Iterable[Edge[OutputT]]: """ Paginate entities by mongo query allowing custom output type @@ -279,6 +283,7 @@ def paginate_with_output_type( limit=limit, sort=sort, projection=projection, + **kwargs, ) return map( @@ -299,6 +304,7 @@ def paginate( before: Optional[str] = None, sort: Optional[Sort] = None, projection: Optional[Dict[str, int]] = None, + **kwargs, ) -> Iterable[Edge[T]]: """ Paginate entities by mongo query using cursor based pagination @@ -313,4 +319,5 @@ def paginate( before=before, sort=sort, projection=projection, + **kwargs, )