diff --git a/panther/app.py b/panther/app.py index 82c523e..ba5452e 100644 --- a/panther/app.py +++ b/panther/app.py @@ -151,15 +151,15 @@ def validate_input(cls, model, request: Request): # `request` will be ignored in regular `BaseModel` return model(**request.data, request=request) except ValidationError as validation_error: - error = {'.'.join(loc for loc in e['loc']): e['msg'] for e in validation_error.errors()} + error = {'.'.join(str(loc) for loc in e['loc']): e['msg'] for e in validation_error.errors()} raise BadRequestAPIError(detail=error) except JSONDecodeError: raise JSONDecodeAPIError class GenericAPI: - input_model: type[ModelSerializer] | type[BaseModel] = None - output_model: type[ModelSerializer] | type[BaseModel] = None + input_model: type[ModelSerializer] | type[BaseModel] | None = None + output_model: type[ModelSerializer] | type[BaseModel] | None = None auth: bool = False permissions: list | None = None throttling: Throttling | None = None @@ -181,6 +181,12 @@ async def patch(self, *args, **kwargs): async def delete(self, *args, **kwargs): raise MethodNotAllowedAPIError + async def get_input_model(self, request: Request) -> type[ModelSerializer] | type[BaseModel] | None: + return None + + async def get_output_model(self, request: Request) -> type[ModelSerializer] | type[BaseModel] | None: + return None + async def call_method(self, request: Request): match request.method: case 'GET': @@ -197,8 +203,8 @@ async def call_method(self, request: Request): raise MethodNotAllowedAPIError return await API( - input_model=self.input_model, - output_model=self.output_model, + input_model=self.input_model or await self.get_input_model(request=request), + output_model=self.output_model or await self.get_output_model(request=request), auth=self.auth, permissions=self.permissions, throttling=self.throttling, diff --git a/panther/db/models.py b/panther/db/models.py index 8441a62..a4aebb4 100644 --- a/panther/db/models.py +++ b/panther/db/models.py @@ -21,7 +21,7 @@ def validate_object_id(value, handler): else: try: return bson.ObjectId(value) - except bson.objectid.InvalidId as e: + except Exception as e: msg = 'Invalid ObjectId' raise ValueError(msg) from e return str(value) diff --git a/panther/db/queries/mongodb_queries.py b/panther/db/queries/mongodb_queries.py index 5e2f10e..5a2ddf2 100644 --- a/panther/db/queries/mongodb_queries.py +++ b/panther/db/queries/mongodb_queries.py @@ -102,21 +102,36 @@ async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int: # # # # # Update # # # # # async def update(self, _update: dict | None = None, /, **kwargs) -> None: - document = self._merge(_update, kwargs) - document.pop('_id', None) - self._validate_data(data=document, is_updating=True) + merged_update_query = self._merge(_update, kwargs) + merged_update_query.pop('_id', None) - for field, value in document.items(): - setattr(self, field, value) - update_fields = {'$set': document} - await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields) + self._validate_data(data=merged_update_query, is_updating=True) + + update_query = {} + for field, value in merged_update_query.items(): + if field.startswith('$'): + update_query[field] = value + else: + update_query['$set'] = update_query.get('$set', {}) + update_query['$set'][field] = value + setattr(self, field, value) + + await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_query) @classmethod async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool: prepare_id_for_query(_filter, is_mongo=True) - update_fields = {'$set': cls._merge(_update, kwargs)} + merged_update_query = cls._merge(_update, kwargs) + + update_query = {} + for field, value in merged_update_query.items(): + if field.startswith('$'): + update_query[field] = value + else: + update_query['$set'] = update_query.get('$set', {}) + update_query['$set'][field] = value - result = await db.session[cls.__name__].update_one(_filter, update_fields) + result = await db.session[cls.__name__].update_one(_filter, update_query) return bool(result.matched_count) @classmethod diff --git a/panther/db/queries/queries.py b/panther/db/queries/queries.py index e06eee0..cba2b6d 100644 --- a/panther/db/queries/queries.py +++ b/panther/db/queries/queries.py @@ -364,8 +364,6 @@ async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Se raise NotFoundAPIError(detail=f'{cls.__name__} Does Not Exist') - @check_connection - @log_query async def save(self) -> None: """ Save the document @@ -384,8 +382,14 @@ async def save(self) -> None: >>> user = User(name='Ali') >>> await user.save() """ - document = self.model_dump(exclude=['_id']) + document = {field: getattr(self, field) for field in self.model_fields_set if field != 'request'} + if self.id: await self.update(document) else: await self.insert_one(document) + + async def reload(self) -> Self: + new_obj = await self.find_one(id=self.id) + [setattr(self, f, getattr(new_obj, f)) for f in new_obj.model_fields] + return self diff --git a/panther/file_handler.py b/panther/file_handler.py index 9a3acae..c5ff293 100644 --- a/panther/file_handler.py +++ b/panther/file_handler.py @@ -1,7 +1,7 @@ from functools import cached_property from panther import status -from pydantic import BaseModel, field_validator +from pydantic import BaseModel, field_validator, model_serializer from panther.exceptions import APIError @@ -15,6 +15,23 @@ class File(BaseModel): def size(self): return len(self.file) + def save(self) -> str: + if hasattr(self, '_file_name'): + return self._file_name + + self._file_name = self.file_name + # TODO: check for duplication + with open(self._file_name, 'wb') as file: + file.write(self.file) + + return self.file_name + + @model_serializer(mode='wrap') + def _serialize(self, handler): + result = handler(self) + result['path'] = self.save() + return result + def __repr__(self) -> str: return f'{self.__repr_name__()}(file_name={self.file_name}, content_type={self.content_type})' diff --git a/panther/generics.py b/panther/generics.py index ce6bb6f..bf0c389 100644 --- a/panther/generics.py +++ b/panther/generics.py @@ -23,7 +23,7 @@ class ObjectRequired: def _check_object(self, instance): - if issubclass(type(instance), Model) is False: + if instance and issubclass(type(instance), Model) is False: logger.critical(f'`{self.__class__.__name__}.object()` should return instance of a Model --> `find_one()`') raise APIError @@ -129,9 +129,11 @@ class CreateAPI(GenericAPI): input_model: type[ModelSerializer] async def post(self, request: Request, **kwargs): - instance = await request.validated_data.create( - validated_data=request.validated_data.model_dump() - ) + instance = await request.validated_data.create(validated_data={ + field: getattr(request.validated_data, field) + for field in request.validated_data.model_fields_set + if field != 'request' + }) return Response(data=instance, status_code=status.HTTP_201_CREATED) @@ -160,13 +162,30 @@ async def patch(self, request: Request, **kwargs): class DeleteAPI(GenericAPI, ObjectRequired): + async def pre_delete(self, instance, request: Request, **kwargs): + pass + + async def post_delete(self, instance, request: Request, **kwargs): + pass + async def delete(self, request: Request, **kwargs): instance = await self.object(request=request, **kwargs) self._check_object(instance) + await self.pre_delete(instance, request=request, **kwargs) await instance.delete() + await self.post_delete(instance, request=request, **kwargs) + return Response(status_code=status.HTTP_204_NO_CONTENT) class ListCreateAPI(CreateAPI, ListAPI): pass + + +class UpdateDeleteAPI(UpdateAPI, DeleteAPI): + pass + + +class RetrieveUpdateDeleteAPI(RetrieveAPI, UpdateAPI, DeleteAPI): + pass diff --git a/panther/serializer.py b/panther/serializer.py index f7bf980..d6fae02 100644 --- a/panther/serializer.py +++ b/panther/serializer.py @@ -65,8 +65,8 @@ def check_config(cls, cls_name: str, namespace: dict) -> None: # Check `model` type try: - if not issubclass(model, Model): - msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.' + if not issubclass(model, (Model, BaseModel)): + msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model` or `pydantic.BaseModel`.' raise AttributeError(msg) from None except TypeError: msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'