diff --git a/panther/db/models.py b/panther/db/models.py index 8441a62..a4aebb4 100644 --- a/panther/db/models.py +++ b/panther/db/models.py @@ -21,7 +21,7 @@ def validate_object_id(value, handler): else: try: return bson.ObjectId(value) - except bson.objectid.InvalidId as e: + except Exception as e: msg = 'Invalid ObjectId' raise ValueError(msg) from e return str(value) diff --git a/panther/db/queries/mongodb_queries.py b/panther/db/queries/mongodb_queries.py index 5e2f10e..59bf2a7 100644 --- a/panther/db/queries/mongodb_queries.py +++ b/panther/db/queries/mongodb_queries.py @@ -102,21 +102,35 @@ async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int: # # # # # Update # # # # # async def update(self, _update: dict | None = None, /, **kwargs) -> None: - document = self._merge(_update, kwargs) - document.pop('_id', None) - self._validate_data(data=document, is_updating=True) + merged_update_query = self._merge(_update, kwargs) + merged_update_query.pop('_id', None) - for field, value in document.items(): - setattr(self, field, value) - update_fields = {'$set': document} - await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields) + self._validate_data(data=merged_update_query, is_updating=True) + + update_query = {} + for field, value in merged_update_query.items(): + if field.startswith('$'): + update_query[field] = value + else: + update_query['$set'] = update_query.get('$set', {}) + update_query['$set'][field] = value + setattr(self, field, value) + + await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_query) @classmethod async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool: prepare_id_for_query(_filter, is_mongo=True) - update_fields = {'$set': cls._merge(_update, kwargs)} + merged_update_query = cls._merge(_update, kwargs) + + update_query = {} + for field, value in merged_update_query.items(): + if field.startswith('$'): + update_query[field] = value + else: + update_query['$set'][field] = value - result = await db.session[cls.__name__].update_one(_filter, update_fields) + result = await db.session[cls.__name__].update_one(_filter, update_query) return bool(result.matched_count) @classmethod diff --git a/panther/db/queries/queries.py b/panther/db/queries/queries.py index e06eee0..cba2b6d 100644 --- a/panther/db/queries/queries.py +++ b/panther/db/queries/queries.py @@ -364,8 +364,6 @@ async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Se raise NotFoundAPIError(detail=f'{cls.__name__} Does Not Exist') - @check_connection - @log_query async def save(self) -> None: """ Save the document @@ -384,8 +382,14 @@ async def save(self) -> None: >>> user = User(name='Ali') >>> await user.save() """ - document = self.model_dump(exclude=['_id']) + document = {field: getattr(self, field) for field in self.model_fields_set if field != 'request'} + if self.id: await self.update(document) else: await self.insert_one(document) + + async def reload(self) -> Self: + new_obj = await self.find_one(id=self.id) + [setattr(self, f, getattr(new_obj, f)) for f in new_obj.model_fields] + return self diff --git a/panther/serializer.py b/panther/serializer.py index f7bf980..d6fae02 100644 --- a/panther/serializer.py +++ b/panther/serializer.py @@ -65,8 +65,8 @@ def check_config(cls, cls_name: str, namespace: dict) -> None: # Check `model` type try: - if not issubclass(model, Model): - msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.' + if not issubclass(model, (Model, BaseModel)): + msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model` or `pydantic.BaseModel`.' raise AttributeError(msg) from None except TypeError: msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'