Skip to content

Commit

Permalink
Add kwargs to default methods
Browse files Browse the repository at this point in the history
  • Loading branch information
davidschrooten committed Nov 22, 2023
1 parent 2dc5149 commit cc742d1
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions pydantic_mongo/abstract_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def to_model(self, data: dict) -> T:
"""
return self.to_model_custom(self.__document_class, data)

def save(self, model: T) -> Union[InsertOneResult, UpdateResult]:
def save(self, model: T, **kwargs) -> Union[InsertOneResult, UpdateResult]:
"""
Save entity to database. It will update the entity if it has id, otherwise it will insert it.
"""
Expand All @@ -116,11 +116,11 @@ def save(self, model: T) -> Union[InsertOneResult, UpdateResult]:
{"_id": mongo_id}, {"$set": document}, upsert=True
)

result = self.get_collection().insert_one(document)
result = self.get_collection().insert_one(document, **kwargs)
model.id = result.inserted_id
return result

def save_many(self, models: Iterable[T]):
def save_many(self, models: Iterable[T], **kwargs):
"""
Save multiple entities to database
"""
Expand Down Expand Up @@ -149,24 +149,24 @@ def save_many(self, models: Iterable[T]):
UpdateOne({"_id": mongo_id}, {"$set": document}, upsert=True)
for mongo_id, document in zip(mongo_ids, documents_to_update)
]
self.get_collection().bulk_write(bulk_operations)
self.get_collection().bulk_write(bulk_operations, **kwargs)

def delete(self, model: T):
return self.get_collection().delete_one({"_id": model.id})
def delete(self, model: T, **kwargs):
return self.get_collection().delete_one({"_id": model.id}, **kwargs)

def find_one_by_id(self, _id: Any) -> Optional[T]:
def find_one_by_id(self, _id: Any, **kwargs) -> Optional[T]:
"""
Find entity by id
Note: The id should be of the same type as the id field in the document class, ie. ObjectId
"""
return self.find_one_by({"id": _id})
return self.find_one_by({"id": _id}, **kwargs)

def find_one_by(self, query: dict) -> Optional[T]:
def find_one_by(self, query: dict, **kwargs) -> Optional[T]:
"""
Find entity by mongo query
"""
result = self.get_collection().find_one(self.__map_id(query))
result = self.get_collection().find_one(self.__map_id(query), **kwargs)
return self.to_model(result) if result else None

def find_by_with_output_type(
Expand All @@ -177,6 +177,7 @@ def find_by_with_output_type(
limit: Optional[int] = None,
sort: Optional[Sort] = None,
projection: Optional[Dict[str, int]] = None,
**kwargs,
) -> Iterable[OutputT]:
"""
Find entities by mongo query allowing custom output type
Expand All @@ -190,7 +191,7 @@ def find_by_with_output_type(
"""
mapped_projection = self.__map_id(projection) if projection else None
mapped_sort = self.__map_sort(sort) if sort else None
cursor = self.get_collection().find(self.__map_id(query), mapped_projection)
cursor = self.get_collection().find(self.__map_id(query), mapped_projection, **kwargs)
if limit:
cursor.limit(limit)
if skip:
Expand All @@ -206,6 +207,7 @@ def find_by(
limit: Optional[int] = None,
sort: Optional[Sort] = None,
projection: Optional[Dict[str, int]] = None,
**kwargs,
) -> Iterable[T]:
""" "
Find entities by mongo query
Expand All @@ -217,6 +219,7 @@ def find_by(
limit=limit,
sort=sort,
projection=projection,
**kwargs,
)

def get_pagination_query(
Expand Down Expand Up @@ -259,6 +262,7 @@ def paginate_with_output_type(
before: Optional[str] = None,
sort: Optional[Sort] = None,
projection: Optional[Dict[str, int]] = None,
**kwargs,
) -> Iterable[Edge[OutputT]]:
"""
Paginate entities by mongo query allowing custom output type
Expand All @@ -279,6 +283,7 @@ def paginate_with_output_type(
limit=limit,
sort=sort,
projection=projection,
**kwargs,
)

return map(
Expand All @@ -299,6 +304,7 @@ def paginate(
before: Optional[str] = None,
sort: Optional[Sort] = None,
projection: Optional[Dict[str, int]] = None,
**kwargs,
) -> Iterable[Edge[T]]:
"""
Paginate entities by mongo query using cursor based pagination
Expand All @@ -313,4 +319,5 @@ def paginate(
before=before,
sort=sort,
projection=projection,
**kwargs,
)

0 comments on commit cc742d1

Please sign in to comment.