Skip to content

Commit

Permalink
Improvement 4.1.3 #91
Browse files Browse the repository at this point in the history
  • Loading branch information
AliRn76 authored Apr 11, 2024
2 parents 005f37c + 0026bc2 commit 5407729
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 25 deletions.
16 changes: 11 additions & 5 deletions panther/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,15 @@ def validate_input(cls, model, request: Request):
# `request` will be ignored in regular `BaseModel`
return model(**request.data, request=request)
except ValidationError as validation_error:
error = {'.'.join(loc for loc in e['loc']): e['msg'] for e in validation_error.errors()}
error = {'.'.join(str(loc) for loc in e['loc']): e['msg'] for e in validation_error.errors()}
raise BadRequestAPIError(detail=error)
except JSONDecodeError:
raise JSONDecodeAPIError


class GenericAPI:
input_model: type[ModelSerializer] | type[BaseModel] = None
output_model: type[ModelSerializer] | type[BaseModel] = None
input_model: type[ModelSerializer] | type[BaseModel] | None = None
output_model: type[ModelSerializer] | type[BaseModel] | None = None
auth: bool = False
permissions: list | None = None
throttling: Throttling | None = None
Expand All @@ -181,6 +181,12 @@ async def patch(self, *args, **kwargs):
async def delete(self, *args, **kwargs):
raise MethodNotAllowedAPIError

async def get_input_model(self, request: Request) -> type[ModelSerializer] | type[BaseModel] | None:
return None

async def get_output_model(self, request: Request) -> type[ModelSerializer] | type[BaseModel] | None:
return None

async def call_method(self, request: Request):
match request.method:
case 'GET':
Expand All @@ -197,8 +203,8 @@ async def call_method(self, request: Request):
raise MethodNotAllowedAPIError

return await API(
input_model=self.input_model,
output_model=self.output_model,
input_model=self.input_model or await self.get_input_model(request=request),
output_model=self.output_model or await self.get_output_model(request=request),
auth=self.auth,
permissions=self.permissions,
throttling=self.throttling,
Expand Down
2 changes: 1 addition & 1 deletion panther/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def validate_object_id(value, handler):
else:
try:
return bson.ObjectId(value)
except bson.objectid.InvalidId as e:
except Exception as e:
msg = 'Invalid ObjectId'
raise ValueError(msg) from e
return str(value)
Expand Down
33 changes: 24 additions & 9 deletions panther/db/queries/mongodb_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,36 @@ async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int:

# # # # # Update # # # # #
async def update(self, _update: dict | None = None, /, **kwargs) -> None:
document = self._merge(_update, kwargs)
document.pop('_id', None)
self._validate_data(data=document, is_updating=True)
merged_update_query = self._merge(_update, kwargs)
merged_update_query.pop('_id', None)

for field, value in document.items():
setattr(self, field, value)
update_fields = {'$set': document}
await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields)
self._validate_data(data=merged_update_query, is_updating=True)

update_query = {}
for field, value in merged_update_query.items():
if field.startswith('$'):
update_query[field] = value
else:
update_query['$set'] = update_query.get('$set', {})
update_query['$set'][field] = value
setattr(self, field, value)

await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_query)

@classmethod
async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool:
prepare_id_for_query(_filter, is_mongo=True)
update_fields = {'$set': cls._merge(_update, kwargs)}
merged_update_query = cls._merge(_update, kwargs)

update_query = {}
for field, value in merged_update_query.items():
if field.startswith('$'):
update_query[field] = value
else:
update_query['$set'] = update_query.get('$set', {})
update_query['$set'][field] = value

result = await db.session[cls.__name__].update_one(_filter, update_fields)
result = await db.session[cls.__name__].update_one(_filter, update_query)
return bool(result.matched_count)

@classmethod
Expand Down
10 changes: 7 additions & 3 deletions panther/db/queries/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,6 @@ async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Se

raise NotFoundAPIError(detail=f'{cls.__name__} Does Not Exist')

@check_connection
@log_query
async def save(self) -> None:
"""
Save the document
Expand All @@ -384,8 +382,14 @@ async def save(self) -> None:
>>> user = User(name='Ali')
>>> await user.save()
"""
document = self.model_dump(exclude=['_id'])
document = {field: getattr(self, field) for field in self.model_fields_set if field != 'request'}

if self.id:
await self.update(document)
else:
await self.insert_one(document)

async def reload(self) -> Self:
new_obj = await self.find_one(id=self.id)
[setattr(self, f, getattr(new_obj, f)) for f in new_obj.model_fields]
return self
19 changes: 18 additions & 1 deletion panther/file_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from functools import cached_property

from panther import status
from pydantic import BaseModel, field_validator
from pydantic import BaseModel, field_validator, model_serializer

from panther.exceptions import APIError

Expand All @@ -15,6 +15,23 @@ class File(BaseModel):
def size(self):
return len(self.file)

def save(self) -> str:
if hasattr(self, '_file_name'):
return self._file_name

self._file_name = self.file_name
# TODO: check for duplication
with open(self._file_name, 'wb') as file:
file.write(self.file)

return self.file_name

@model_serializer(mode='wrap')
def _serialize(self, handler):
result = handler(self)
result['path'] = self.save()
return result

def __repr__(self) -> str:
return f'{self.__repr_name__()}(file_name={self.file_name}, content_type={self.content_type})'

Expand Down
27 changes: 23 additions & 4 deletions panther/generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class ObjectRequired:
def _check_object(self, instance):
if issubclass(type(instance), Model) is False:
if instance and issubclass(type(instance), Model) is False:
logger.critical(f'`{self.__class__.__name__}.object()` should return instance of a Model --> `find_one()`')
raise APIError

Expand Down Expand Up @@ -129,9 +129,11 @@ class CreateAPI(GenericAPI):
input_model: type[ModelSerializer]

async def post(self, request: Request, **kwargs):
instance = await request.validated_data.create(
validated_data=request.validated_data.model_dump()
)
instance = await request.validated_data.create(validated_data={
field: getattr(request.validated_data, field)
for field in request.validated_data.model_fields_set
if field != 'request'
})
return Response(data=instance, status_code=status.HTTP_201_CREATED)


Expand Down Expand Up @@ -160,13 +162,30 @@ async def patch(self, request: Request, **kwargs):


class DeleteAPI(GenericAPI, ObjectRequired):
async def pre_delete(self, instance, request: Request, **kwargs):
pass

async def post_delete(self, instance, request: Request, **kwargs):
pass

async def delete(self, request: Request, **kwargs):
instance = await self.object(request=request, **kwargs)
self._check_object(instance)

await self.pre_delete(instance, request=request, **kwargs)
await instance.delete()
await self.post_delete(instance, request=request, **kwargs)

return Response(status_code=status.HTTP_204_NO_CONTENT)


class ListCreateAPI(CreateAPI, ListAPI):
pass


class UpdateDeleteAPI(UpdateAPI, DeleteAPI):
pass


class RetrieveUpdateDeleteAPI(RetrieveAPI, UpdateAPI, DeleteAPI):
pass
4 changes: 2 additions & 2 deletions panther/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def check_config(cls, cls_name: str, namespace: dict) -> None:

# Check `model` type
try:
if not issubclass(model, Model):
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'
if not issubclass(model, (Model, BaseModel)):
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model` or `pydantic.BaseModel`.'
raise AttributeError(msg) from None
except TypeError:
msg = f'`{cls_name}.Config.model` is not subclass of `panther.db.Model`.'
Expand Down

0 comments on commit 5407729

Please sign in to comment.