Skip to content

Commit

Permalink
Add reload() method to queries + Support mongodb operators in `upda…
Browse files Browse the repository at this point in the history
…te()` and `update_one()`
  • Loading branch information
AliRn76 committed Apr 11, 2024
1 parent 3270b8d commit f898354
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 15 deletions.
2 changes: 1 addition & 1 deletion panther/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def validate_object_id(value, handler):
else:
try:
return bson.ObjectId(value)
except bson.objectid.InvalidId as e:
except Exception as e:
msg = 'Invalid ObjectId'
raise ValueError(msg) from e
return str(value)
Expand Down
32 changes: 23 additions & 9 deletions panther/db/queries/mongodb_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,35 @@ async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int:

# # # # # Update # # # # #
async def update(self, _update: dict | None = None, /, **kwargs) -> None:
document = self._merge(_update, kwargs)
document.pop('_id', None)
self._validate_data(data=document, is_updating=True)
merged_update_query = self._merge(_update, kwargs)
merged_update_query.pop('_id', None)

for field, value in document.items():
setattr(self, field, value)
update_fields = {'$set': document}
await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields)
self._validate_data(data=merged_update_query, is_updating=True)

update_query = {}
for field, value in merged_update_query.items():
if field.startswith('$'):
update_query[field] = value
else:
update_query['$set'] = update_query.get('$set', {})
update_query['$set'][field] = value
setattr(self, field, value)

await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_query)

@classmethod
async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool:
prepare_id_for_query(_filter, is_mongo=True)
update_fields = {'$set': cls._merge(_update, kwargs)}
merged_update_query = cls._merge(_update, kwargs)

update_query = {}
for field, value in merged_update_query.items():
if field.startswith('$'):
update_query[field] = value
else:
update_query['$set'][field] = value

result = await db.session[cls.__name__].update_one(_filter, update_fields)
result = await db.session[cls.__name__].update_one(_filter, update_query)
return bool(result.matched_count)

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions panther/db/queries/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Se

raise NotFoundAPIError(detail=f'{cls.__name__} Does Not Exist')

@check_connection
@log_query
async def save(self) -> None:
"""
Save the document
Expand All @@ -384,8 +382,14 @@ async def save(self) -> None:
>>> user = User(name='Ali')
>>> await user.save()
"""
document = self.model_dump(exclude=['_id'])
document = {field: getattr(self, field) for field in self.model_fields_set if field != 'request'}

if self.id:
await self.update(document)
else:
await self.insert_one(document)

async def reload(self) -> Self:
new_obj = await self.find_one(id=self.id)
[setattr(self, f, getattr(new_obj, f)) for f in new_obj.model_fields]
return self
4 changes: 2 additions & 2 deletions panther/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def check_config(cls, cls_name: str, namespace: dict) -> None:

# Check `model` type
try:
if not issubclass(model, Model):
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'
if not issubclass(model, (Model, BaseModel)):
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model` or `pydantic.BaseModel`.'
raise AttributeError(msg) from None
except TypeError:
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'
Expand Down

0 comments on commit f898354

Please sign in to comment.