diff --git a/README.md b/README.md index 014092c..531e730 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,11 @@ - Built-in API **Caching** System (In Memory, **Redis**) - Built-in **Authentication** Classes - Built-in **Permission** Classes +- Built-in Visual API **Monitoring** (In Terminal) - Support Custom **Background Tasks** - Support Custom **Middlewares** - Support Custom **Throttling** -- Visual API **Monitoring** (In Terminal) +- Support **Function-Base** and **Class-Base** APIs - It's One Of The **Fastest Python Frameworks** --- @@ -176,6 +177,12 @@ $ pip install panther --- +### How Panther Works! + +![diagram](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/diagram.png) + +--- + ### Roadmap ![roadmap](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/roadmap.jpg) diff --git a/docs/docs/authentications.md b/docs/docs/authentications.md index 3b2efe6..6563086 100644 --- a/docs/docs/authentications.md +++ b/docs/docs/authentications.md @@ -3,12 +3,20 @@ > Type: `str` > > Default: `None` - -You can set your Authentication class in `configs`, now, if you set `auth=True` in `@API()`, Panther will use this class for authentication of every `API`, and put the `user` in `request.user` or `raise HTTP_401_UNAUTHORIZED` -We implemented a built-in authentication class which used `JWT` for authentication. +You can set Authentication class in your `configs` + +Panther use it, to authenticate every API/ WS if `auth=True` and give you the user or `raise HTTP_401_UNAUTHORIZED` + +The `user` will be in `request.user` in APIs and in `self.user` in WSs + +--- + +We implemented 2 built-in authentication classes which use `JWT` for authentication. + But, You can use your own custom authentication class too. +--- ### JWTAuthentication @@ -17,7 +25,7 @@ This class will - Get the `token` from `Authorization` header of request. - Check the `Bearer` - `decode` the `token` -- Find the matched `user` (It uses the `USER_MODEL`) +- Find the matched `user` > `JWTAuthentication` is going to use `panther.db.models.BaseUser` if you didn't set the `USER_MODEL` in your `configs` @@ -59,9 +67,7 @@ JWTConfig = { #### Websocket Authentication -This class is very useful when you are trying to authenticate the user in websocket - -Add this into your `configs`: +The `QueryParamJWTAuthentication` is very useful when you are trying to authenticate the user in websocket, you just have to add this into your `configs`: ```python WS_AUTHENTICATION = 'panther.authentications.QueryParamJWTAuthentication' ``` @@ -77,7 +83,9 @@ WS_AUTHENTICATION = 'panther.authentications.QueryParamJWTAuthentication' - Or raise `panther.exceptions.AuthenticationAPIError` -- Address it in `configs` - - `AUTHENTICATION = 'project_name.core.authentications.CustomAuthentication'` +- Add it into your `configs` + ```python + AUTHENTICATION = 'project_name.core.authentications.CustomAuthentication' + ``` -> You can see the source code of JWTAuthentication [[here]](https://github.com/AliRn76/panther/blob/da2654ccdd83ebcacda91a1aaf51d5aeb539eff5/panther/authentications.py#L38) \ No newline at end of file +> You can see the source code of JWTAuthentication [[here]](https://github.com/AliRn76/panther/blob/da2654ccdd83ebcacda91a1aaf51d5aeb539eff5/panther/authentications.py) \ No newline at end of file diff --git a/docs/docs/background_tasks.md b/docs/docs/background_tasks.md index 670f455..b5057f3 100644 --- a/docs/docs/background_tasks.md +++ b/docs/docs/background_tasks.md @@ -110,7 +110,7 @@ Panther is going to run the `background tasks` as a thread in the background on task = BackgroundTask(do_something, 'Ali', 26) ``` -- > Default interval is 1. +- > Default of interval() is 1. - > The -1 interval means infinite, diff --git a/docs/docs/class_first_crud.md b/docs/docs/class_first_crud.md index 9c81906..274cc28 100644 --- a/docs/docs/class_first_crud.md +++ b/docs/docs/class_first_crud.md @@ -1,6 +1,6 @@ We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation) -Now let's write custom API `Create`, `Retrieve`, `Update` and `Delete` for a `Book`: +Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`: ## Structure & Requirements ### Create Model @@ -82,7 +82,7 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): - def post(self): + async def post(self): ... ``` @@ -95,7 +95,7 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): - def post(self, request: Request): + async def post(self, request: Request): ... ``` @@ -122,7 +122,7 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): input_model = BookSerializer - def post(self, request: Request): + async def post(self, request: Request): ... ``` @@ -138,7 +138,7 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): input_model = BookSerializer - def post(self, request: Request): + async def post(self, request: Request): body: BookSerializer = request.validated_data ... ``` @@ -156,9 +156,9 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): input_model = BookSerializer - def post(self, request: Request): + async def post(self, request: Request): body: BookSerializer = request.validated_data - Book.insert_one( + await Book.insert_one( name=body.name, author=body.author, pages_count=body.pages_count, @@ -180,15 +180,14 @@ Now we are going to create a book on `post` request, We need to: class BookAPI(GenericAPI): input_model = BookSerializer - def post(self, request: Request): + async def post(self, request: Request): body: BookSerializer = request.validated_data - book = Book.insert_one( + book = await Book.insert_one( name=body.name, author=body.author, pages_count=body.pages_count, ) return Response(data=book, status_code=status.HTTP_201_CREATED) - ``` > The response.data can be `Instance of Models`, `dict`, `str`, `tuple`, `list`, `str` or `None` @@ -213,11 +212,11 @@ from app.models import Book class BookAPI(GenericAPI): input_model = BookSerializer - def post(self, request: Request): + async def post(self, request: Request): ... - def get(self): - books: list[Book] = Book.find() + async def get(self): + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) ``` @@ -256,11 +255,11 @@ Assume we don't want to return field `author` in response: input_model = BookSerializer output_model = BookOutputSerializer - def post(self, request: Request): + async def post(self, request: Request): ... - def get(self): - books: list[Book] = Book.find() + async def get(self): + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) ``` @@ -292,11 +291,11 @@ class BookAPI(GenericAPI): cache = True cache_exp_time = timedelta(seconds=10) - def post(self, request: Request): + async def post(self, request: Request): ... - def get(self): - books: list[Book] = Book.find() + async def get(self): + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) ``` @@ -329,11 +328,11 @@ class BookAPI(GenericAPI): cache_exp_time = timedelta(seconds=10) throttling = Throttling(rate=10, duration=timedelta(minutes=1)) - def post(self, request: Request): + async def post(self, request: Request): ... - def get(self): - books: list[Book] = Book.find() + async def get(self): + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) ``` @@ -383,8 +382,8 @@ For `retrieve`, `update` and `delete` API, we are going to class SingleBookAPI(GenericAPI): - def get(self, book_id: int): - if book := Book.find_one(id=book_id): + async def get(self, book_id: int): + if book := await Book.find_one(id=book_id): return Response(data=book, status_code=status.HTTP_200_OK) else: return Response(status_code=status.HTTP_404_NOT_FOUND) @@ -406,11 +405,11 @@ from app.serializers import BookSerializer class SingleBookAPI(GenericAPI): input_model = BookSerializer - def get(self, book_id: int): + async def get(self, book_id: int): ... - def put(self, request: Request, book_id: int): - is_updated: bool = Book.update_one({'id': book_id}, request.validated_data.model_dump()) + async def put(self, request: Request, book_id: int): + is_updated: bool = await Book.update_one({'id': book_id}, request.validated_data.model_dump()) data = {'is_updated': is_updated} return Response(data=data, status_code=status.HTTP_202_ACCEPTED) ``` @@ -431,14 +430,14 @@ from app.models import Book class SingleBookAPI(GenericAPI): input_model = BookSerializer - def get(self, book_id: int): + async def get(self, book_id: int): ... - def put(self, request: Request, book_id: int): + async def put(self, request: Request, book_id: int): ... - def delete(self, book_id: int): - is_deleted: bool = Book.delete_one(id=book_id) + async def delete(self, book_id: int): + is_deleted: bool = await Book.delete_one(id=book_id) if is_deleted: return Response(status_code=status.HTTP_204_NO_CONTENT) else: diff --git a/docs/docs/configs.md b/docs/docs/configs.md index 294ef32..3a37775 100644 --- a/docs/docs/configs.md +++ b/docs/docs/configs.md @@ -7,12 +7,14 @@ Panther collect all the configs from your `core/configs.py` or the module you pa > Type: `bool` (Default: `False`) It should be `True` if you want to use `panther monitor` command -and see the monitoring logs +and watch the monitoring If `True`: - Log every request in `logs/monitoring.log` +_Requires [watchfiles](https://watchfiles.helpmanual.io) package._ + --- ### [LOG_QUERIES](https://pantherpy.github.io/log_queries) > Type: `bool` (Default: `False`) @@ -33,6 +35,8 @@ List of middlewares you want to use Every request goes through `authentication()` method of this `class` (if `auth = True`) +_Requires [python-jose](https://python-jose.readthedocs.io/en/latest/) package._ + _Example:_ `AUTHENTICATION = 'panther.authentications.JWTAuthentication'` --- @@ -119,6 +123,8 @@ It will reformat your code on every reload (on every change if you run the proje You may want to write your custom `ruff.toml` in root of your project. +_Requires [ruff](https://docs.astral.sh/ruff/) package._ + Reference: [https://docs.astral.sh/ruff/formatter/](https://docs.astral.sh/ruff/formatter/) _Example:_ `AUTO_REFORMAT = True` @@ -133,4 +139,16 @@ We use it to create `database` connection ### [REDIS](https://pantherpy.github.io/redis) > Type: `dict` (Default: `{}`) +_Requires [redis](https://redis-py.readthedocs.io/en/stable/) package._ + We use it to create `redis` connection + + +--- +### [TIMEZONE](https://pantherpy.github.io/timezone) +> Type: `str` (Default: `'UTC'`) + +Used in `panther.utils.timezone_now()` which returns a `datetime` based on your `timezone` + +And `panther.utils.timezone_now()` used in `BaseUser.date_created` and `BaseUser.last_login` + diff --git a/docs/docs/events.md b/docs/docs/events.md new file mode 100644 index 0000000..bf7a21e --- /dev/null +++ b/docs/docs/events.md @@ -0,0 +1,29 @@ +## Startup Event + +Use `Event.startup` decorator + +```python +from panther.events import Event + + +@Event.startup +def do_something_on_startup(): + print('Hello, I am at startup') +``` + +## Shutdown Event + +```python +from panther.events import Event + + +@Event.shutdown +def do_something_on_shutdown(): + print('Good Bye, I am at shutdown') +``` + + +## Notice + +- You can have **multiple** events on `startup` and `shutdown` +- Events can be `sync` or `async` diff --git a/docs/docs/function_first_crud.md b/docs/docs/function_first_crud.md index 3c83565..dd97e3b 100644 --- a/docs/docs/function_first_crud.md +++ b/docs/docs/function_first_crud.md @@ -1,6 +1,6 @@ We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation) -Now let's write custom API `Create`, `Retrieve`, `Update` and `Delete` for a `Book`: +Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`: ## Structure & Requirements ### Create Model @@ -142,7 +142,7 @@ Now we are going to create a book on `post` request, We need to: async def book_api(request: Request): body: BookSerializer = request.validated_data - Book.insert_one( + await Book.insert_one( name=body.name, author=body.author, pages_count=body.pages_count, @@ -165,7 +165,7 @@ Now we are going to create a book on `post` request, We need to: if request.method == 'POST': body: BookSerializer = request.validated_data - Book.create( + await Book.insert_one( name=body.name, author=body.author, pages_count=body.pages_count, @@ -189,7 +189,7 @@ Now we are going to create a book on `post` request, We need to: if request.method == 'POST': body: BookSerializer = request.validated_data - book: Book = Book.create( + book: Book = await Book.insert_one( name=body.name, author=body.author, pages_count=body.pages_count, @@ -224,7 +224,7 @@ async def book_api(request: Request): ... elif request.method == 'GET': - books: list[Book] = Book.find() + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED) @@ -265,7 +265,7 @@ Assume we don't want to return field `author` in response: ... elif request.method == 'GET': - books: list[Book] = Book.find() + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED) @@ -299,7 +299,7 @@ async def book_api(request: Request): ... elif request.method == 'GET': - books: list[Book] = Book.find() + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED) @@ -338,7 +338,7 @@ async def book_api(request: Request): ... elif request.method == 'GET': - books: list[Book] = Book.find() + books = await Book.find() return Response(data=books, status_code=status.HTTP_200_OK) return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED) @@ -393,7 +393,7 @@ For `retrieve`, `update` and `delete` API, we are going to @API() async def single_book_api(request: Request, book_id: int): if request.method == 'GET': - if book := Book.find_one(id=book_id): + if book := await Book.find_one(id=book_id): return Response(data=book, status_code=status.HTTP_200_OK) else: return Response(status_code=status.HTTP_404_NOT_FOUND) @@ -420,12 +420,12 @@ For `retrieve`, `update` and `delete` API, we are going to if request.method == 'GET': ... elif request.method == 'PUT': - book: Book = Book.find_one(id=book_id) - book.update( + book: Book = await Book.find_one(id=book_id) + await book.update( name=body.name, author=body.author, pages_count=body.pages_count - ) + ) return Response(status_code=status.HTTP_202_ACCEPTED) ``` @@ -446,7 +446,7 @@ For `retrieve`, `update` and `delete` API, we are going to if request.method == 'GET': ... elif request.method == 'PUT': - is_updated: bool = Book.update_one({'id': book_id}, request.validated_data.model_dump()) + is_updated: bool = await Book.update_one({'id': book_id}, request.validated_data.model_dump()) data = {'is_updated': is_updated} return Response(data=data, status_code=status.HTTP_202_ACCEPTED) ``` @@ -468,7 +468,7 @@ For `retrieve`, `update` and `delete` API, we are going to if request.method == 'GET': ... elif request.method == 'PUT': - updated_count: int = Book.update_many({'id': book_id}, request.validated_data.model_dump()) + updated_count: int = await Book.update_many({'id': book_id}, request.validated_data.model_dump()) data = {'updated_count': updated_count} return Response(data=data, status_code=status.HTTP_202_ACCEPTED) ``` @@ -496,7 +496,7 @@ For `retrieve`, `update` and `delete` API, we are going to elif request.method == 'PUT': ... elif request.method == 'DELETE': - is_deleted: bool = Book.delete_one(id=book_id) + is_deleted: bool = await Book.delete_one(id=book_id) if is_deleted: return Response(status_code=status.HTTP_204_NO_CONTENT) else: @@ -520,7 +520,7 @@ For `retrieve`, `update` and `delete` API, we are going to elif request.method == 'PUT': ... elif request.method == 'DELETE': - is_deleted: bool = Book.delete_one(id=book_id) + is_deleted: bool = await Book.delete_one(id=book_id) return Response(status_code=status.HTTP_204_NO_CONTENT) ``` @@ -542,6 +542,6 @@ For `retrieve`, `update` and `delete` API, we are going to elif request.method == 'PUT': ... elif request.method == 'DELETE': - deleted_count: int = Book.delete_many(id=book_id) + deleted_count: int = await Book.delete_many(id=book_id) return Response(status_code=status.HTTP_204_NO_CONTENT) ``` \ No newline at end of file diff --git a/docs/docs/generic_crud.md b/docs/docs/generic_crud.md new file mode 100644 index 0000000..ab6ed28 --- /dev/null +++ b/docs/docs/generic_crud.md @@ -0,0 +1,313 @@ +We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation) + +Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`: + +## Structure & Requirements +### Create Model + +Create a model named `Book` in `app/models.py`: + +```python +from panther.db import Model + + +class Book(Model): + name: str + author: str + pages_count: int +``` + +### Create API Class + +Create the `BookAPI()` in `app/apis.py`: + +```python +from panther.app import GenericAPI + + +class BookAPI(GenericAPI): + ... +``` + +> We are going to complete it later ... + +### Update URLs + +Add the `BookAPI` in `app/urls.py`: + +```python +from app.apis import BookAPI + + +urls = { + 'book/': BookAPI, +} +``` + +We assume that the `urls` in `core/urls.py` pointing to `app/urls.py`, like below: + +```python +from app.urls import urls as app_urls + + +urls = { + '/': app_urls, +} +``` + +### Add Database + +Add `DATABASE` in `configs`, we are going to add `pantherdb` +> [PantherDB](https://github.com/PantherPy/PantherDB/#readme) is a Simple, File-Base and Document Oriented database + +```python +... +DATABASE = { + 'engine': { + 'class': 'panther.db.connections.PantherDBConnection', + } +} +... +``` + +## APIs +### API - Create a Book + +Now we are going to create a book on `POST` request, we need to: + +1. Inherit from `CreateAPI`: + ```python + from panther.generics import CreateAPI + + + class BookAPI(CreateAPI): + ... + ``` + +2. Create a ModelSerializer in `app/serializers.py`, for `validation` of the `request.data`: + + ```python + from panther.serializer import ModelSerializer + from app.models import Book + + class BookSerializer(ModelSerializer): + class Config: + model = Book + fields = ['name', 'author', 'pages_count'] + ``` + +3. Set the created serializer in `BookAPI` as `input_model`: + + ```python + from panther.app import CreateAPI + from app.serializers import BookSerializer + + + class BookAPI(CreateAPI): + input_model = BookSerializer + ``` + +It is going to create a `Book` with incoming `request.data` and return that instance to user with status code of `201` + + +### API - List of Books + +Let's return list of books of `GET` method, we need to: + +1. Inherit from `ListAPI` + ```python + from panther.generics import CreateAPI, ListAPI + from app.serializers import BookSerializer + + class BookAPI(CreateAPI, ListAPI): + input_model = BookSerializer + ... + ``` + +2. define `objects` method, so the `ListAPI` knows to return which books + + ```python + from panther.generics import CreateAPI, ListAPI + from panther.request import Request + + from app.models import Book + from app.serializers import BookSerializer + + class BookAPI(CreateAPI, ListAPI): + input_model = BookSerializer + + async def objects(self, request: Request, **kwargs): + return await Book.find() + ``` + +### Pagination, Search, Filter, Sort + +#### Pagination + +Use `panther.pagination.Pagination` as `pagination` + +**Usage:** It will look for the `limit` and `skip` in the `query params` and return its own response template + +**Example:** `?limit=10&skip=20` + +--- + +#### Search + +Define the fields you want the search, query on them in `search_fields` + +The value of `search_fields` should be `list` + +**Usage:** It works with `search` query param --> `?search=maybe_name_of_the_book_or_author` + +**Example:** `?search=maybe_name_of_the_book_or_author` + +--- + +#### Filter + +Define the fields you want to be filterable in `filter_fields` + +The value of `filter_fields` should be `list` + +**Usage:** It will look for each value of the `filter_fields` in the `query params` and query on them + +**Example:** `?name=name_of_the_book&author=author_of_the_book` + +--- + + +#### Sort + +Define the fields you want to be sortable in `sort_fields` + +The value of `sort_fields` should be `list` + + + +**Usage:** It will look for each value of the `sort_fields` in the `query params` and sort with them + +**Example:** `?sort=pages_count,-name` + +**Notice:** + - fields should be separated with a column `,` + - use `field_name` for ascending sort + - use `-field_name` for descending sort +--- + +#### Example +```python +from panther.generics import CreateAPI, ListAPI +from panther.pagination import Pagination +from panther.request import Request + +from app.models import Book +from app.serializers import BookSerializer, BookOutputSerializer + +class BookAPI(CreateAPI, ListAPI): + input_model = BookSerializer + output_model = BookOutputSerializer + pagination = Pagination + search_fields = ['name', 'author'] + filter_fields = ['name', 'author'] + sort_fields = ['name', 'pages_count'] + + async def objects(self, request: Request, **kwargs): + return await Book.find() +``` + + +### API - Retrieve a Book + +Now we are going to retrieve a book on `GET` request, we need to: + +1. Inherit from `RetrieveAPI`: + ```python + from panther.generics import RetrieveAPI + + + class SingleBookAPI(RetrieveAPI): + ... + ``` + +2. define `object` method, so the `RetrieveAPI` knows to return which book + + ```python + from panther.generics import RetrieveAPI + from panther.request import Request + from app.models import Book + + + class SingleBookAPI(RetrieveAPI): + async def object(self, request: Request, **kwargs): + return await Book.find_one_or_raise(id=kwargs['book_id']) + ``` + +3. Add it in `app/urls.py`: + + ```python + from app.apis import BookAPI, SingleBookAPI + + + urls = { + 'book/': BookAPI, + 'book//': SingleBookAPI, + } + ``` + + > You should write the [Path Variable](https://pantherpy.github.io/urls/#path-variables-are-handled-like-below) in `<` and `>` + + +### API - Update a Book + +1. Inherit from `UpdateAPI` + + ```python + from panther.generics import RetrieveAPI, UpdateAPI + from panther.request import Request + from app.models import Book + + + class SingleBookAPI(RetrieveAPI, UpdateAPI): + ... + + async def object(self, request: Request, **kwargs): + return await Book.find_one_or_raise(id=kwargs['book_id']) + ``` + +2. Add `input_model` so the `UpdateAPI` knows how to validate the `request.data` +> We use the same serializer as CreateAPI serializer we defined above + + ```python + from panther.generics import RetrieveAPI, UpdateAPI + from panther.request import Request + from app.models import Book + from app.serializers import BookSerializer + + + class SingleBookAPI(RetrieveAPI, UpdateAPI): + input_model = BookSerializer + + async def object(self, request: Request, **kwargs): + return await Book.find_one_or_raise(id=kwargs['book_id']) + ``` + +### API - Delete a Book + +1. Inherit from `DeleteAPI` + + ```python + from panther.generics import RetrieveAPI, UpdateAPI, DeleteAPI + from panther.request import Request + from app.models import Book + from app.serializers import BookSerializer + + + class SingleBookAPI(RetrieveAPI, UpdateAPI, DeleteAPI): + input_model = BookSerializer + + async def object(self, request: Request, **kwargs): + return await Book.find_one_or_raise(id=kwargs['book_id']) + ``` + +2. It requires `object` method which we defined before, so it's done. diff --git a/docs/docs/images/diagram.png b/docs/docs/images/diagram.png new file mode 100644 index 0000000..2c60eaa Binary files /dev/null and b/docs/docs/images/diagram.png differ diff --git a/docs/docs/images/roadmap.jpg b/docs/docs/images/roadmap.jpg index 5eb7378..cf376ac 100644 Binary files a/docs/docs/images/roadmap.jpg and b/docs/docs/images/roadmap.jpg differ diff --git a/docs/docs/index.md b/docs/docs/index.md index d59067b..03b4f38 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -21,10 +21,11 @@ - Built-in API **Caching** System (In Memory, **Redis**) - Built-in **Authentication** Classes - Built-in **Permission** Classes +- Built-in Visual API **Monitoring** (In Terminal) - Support Custom **Background Tasks** - Support Custom **Middlewares** - Support Custom **Throttling** -- Visual API **Monitoring** (In Terminal) +- Support **Function-Base** and **Class-Base** APIs - It's One Of The **Fastest Python Frameworks** --- @@ -198,6 +199,12 @@ --- +### How Panther Works! + +![diagram](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/diagram.png) + +--- + ### Roadmap ![roadmap](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/roadmap.jpg) diff --git a/docs/docs/middlewares.md b/docs/docs/middlewares.md index e745153..5fd2a1d 100644 --- a/docs/docs/middlewares.md +++ b/docs/docs/middlewares.md @@ -14,8 +14,11 @@ ## Custom Middleware ### Middleware Types We have 3 type of Middlewares, make sure that you are inheriting from the correct one: + - `Base Middleware`: which is used for both `websocket` and `http` requests + - `HTTP Middleware`: which is only used for `http` requests + - `Websocket Middleware`: which is only used for `websocket` requests ### Write Custom Middleware diff --git a/docs/docs/panther_odm.md b/docs/docs/panther_odm.md index 13f3e80..ee849ba 100644 --- a/docs/docs/panther_odm.md +++ b/docs/docs/panther_odm.md @@ -16,18 +16,36 @@ user: User = await User.find_one({'id': 1}, name='Ali') Get documents from the database. ```python -users: list[User] = await User.find(age=18, name='Ali') +users: Cursor = await User.find(age=18, name='Ali') or -users: list[User] = await User.find({'age': 18, 'name': 'Ali'}) +users: Cursor = await User.find({'age': 18, 'name': 'Ali'}) or -users: list[User] = await User.find({'age': 18}, name='Ali') +users: Cursor = await User.find({'age': 18}, name='Ali') ``` +#### skip, limit and sort + +`skip()`, `limit()` and `sort()` are also available as chain methods + +```python +users: Cursor = await User.find(age=18, name='Ali').skip(10).limit(10).sort([('age', -1), ('name', 1)]) +``` + +#### cursor + +The `find()` method, returns a `Crusor` depends on the database +- `from panther.db.cursor import Cursor ` for `MongoDB` +- `from pantherdb import Cursor` for `PantherDB` + +You can work with these cursors as list and directly pass them to `Response(data=cursor)` + +They are designed to return an instance of your model in `iterations`(`__next__()`) or accessing by index(`__getitem__()`) + ### all Get all documents from the database. (Alias of `.find()`) ```python -users: list[User] = User.all() +users: Cursor = User.all() ``` ### first @@ -73,19 +91,7 @@ pipeline = [ ... ] users: Iterable[dict] = await User.aggregate(pipeline) -``` - -### find with skip, limit, sort -Get limited documents from the database from offset and sorted by something -> Only available in mongodb - -```python -users: list[User] = await User.find(age=18, name='Ali').limit(10).skip(10).sort('_id', -1) -or -users: list[User] = await User.find({'age': 18, 'name': 'Ali'}).limit(10).skip(10).sort('_id', -1) -or -users: list[User] = await User.find({'age': 18}, name='Ali').limit(10).skip(10).sort('_id', -1) -``` +``` ### count Count the number of documents in this collection. @@ -234,7 +240,7 @@ is_inserted, user = await User.find_one_or_insert({'age': 18}, name='Ali') ### find_one_or_raise - Get a single document from the database] - **or** -- Raise an `APIError(f'{Model} Does Not Exist')` +- Raise an `NotFoundAPIError(f'{Model} Does Not Exist')` ```python from app.models import User @@ -248,6 +254,7 @@ user: User = await User.find_one_or_raise({'age': 18}, name='Ali') ### save Save the document. + - If it has id --> `Update` It - else --> `Insert` It diff --git a/docs/docs/redis.md b/docs/docs/redis.md index 22dade6..eb2a44d 100644 --- a/docs/docs/redis.md +++ b/docs/docs/redis.md @@ -31,7 +31,7 @@ REDIS = { ### How it works? -- Panther creates a redis connection depends on `REDIS` block you defined in `configs` +- Panther creates an async redis connection depends on `REDIS` block you defined in `configs` - You can access to it from `from panther.db.connections import redis` @@ -39,7 +39,7 @@ REDIS = { ```python from panther.db.connections import redis - redis.set('name', 'Ali') - result = redis.get('name') + await redis.set('name', 'Ali') + result = await redis.get('name') print(result) ``` \ No newline at end of file diff --git a/docs/docs/release_notes.md b/docs/docs/release_notes.md index 12df57d..47a49f3 100644 --- a/docs/docs/release_notes.md +++ b/docs/docs/release_notes.md @@ -1,10 +1,15 @@ ### 4.0.0 - Move `database` and `redis` connections from `MIDDLEWARES` to their own block, `DATABASE` and `REDIS` -- Make queries `async` +- Make `Database` queries `async` +- Make `Redis` queries `async` +- Add `StreamingResponse` +- Add `generics` API classes - Add `login()` & `logout()` to `JWTAuthentication` and used it in `BaseUser` - Support `Authentication` & `Authorization` in `Websocket` - Rename all exceptions suffix from `Exception` to `Error` (https://peps.python.org/pep-0008/#exception-names) -- Support `pantherdb 1.4` +- Support `pantherdb 2.0.0` (`Cursor` Added) +- Remove `watchfiles` from required dependencies +- Support `exclude` and `optional_fields` in `ModelSerializer` - Minor Improvements ### 3.9.0 diff --git a/docs/docs/serializer.md b/docs/docs/serializer.md index 7d4852a..5afec4b 100644 --- a/docs/docs/serializer.md +++ b/docs/docs/serializer.md @@ -45,13 +45,21 @@ Use panther `ModelSerializer` to write your serializer which will use your `mode first_name: str = Field(default='', min_length=2) last_name: str = Field(default='', min_length=4) - + # type 1 - using fields class UserModelSerializer(ModelSerializer): class Config: model = User fields = ['username', 'first_name', 'last_name'] required_fields = ['first_name'] + # type 2 - using exclude + class UserModelSerializer(ModelSerializer): + class Config: + model = User + fields = '*' + required_fields = ['first_name'] + exclude = ['id', 'password'] + @API(input_model=UserModelSerializer) async def model_serializer_example(request: Request): @@ -59,9 +67,14 @@ Use panther `ModelSerializer` to write your serializer which will use your `mode ``` ### Notes -1. In the example above, `ModelSerializer` will look up for the value of `Config.fields` in the `User.fields` and use their `type` and `value` for the `validation`. +1. In the example above, `ModelSerializer` will look up for the value of `Config.fields` in the `User.model_fields` and use their `type` and `value` for the `validation`. 2. `Config.model` and `Config.fields` are `required` when you are using `ModelSerializer`. -3. If you want to use `Config.required_fields`, you have to put its value in `Config.fields` too. +3. You can force a field to be required with `Config.required_fields` +4. You can force a field to be optional with `Config.optional_fields` +5. `Config.required_fields` and `Config.optional_fields` can't include same fields +6. If you want to use `Config.required_fields` or `Config.optional_fields` you have to put their value in `Config.fields` too. +7. `Config.fields` & `Config.required_fields` & `Config.optional_fields` can be `*` too (Include all the model fields) +8. `Config.exclude` is mostly used when `Config.fields` is `'*'` @@ -94,8 +107,9 @@ You can use `pydantic.BaseModel` features in `ModelSerializer` too. class Config: model = User - fields = ['username', 'first_name', 'last_name'] + fields = ['first_name', 'last_name'] required_fields = ['first_name'] + optional_fields = ['last_name'] @field_validator('username') def validate_username(cls, username): diff --git a/docs/docs/single_file.md b/docs/docs/single_file.md index 8676c8f..62f2b71 100644 --- a/docs/docs/single_file.md +++ b/docs/docs/single_file.md @@ -42,15 +42,16 @@ If you want to work with `Panther` in a `single-file` structure, follow the step app = Panther(__name__, configs=__name__, urls=url_routing) ``` 4. Run the project - ```bash - panther run - ``` + + - If name of your file is `main.py` --> + ```panther run``` + - else use `uvicorn` --> + ```uvicorn file_name:app``` ### Notes - `URLs` is a required config unless you pass the `urls` directly to the `Panther` - When you pass the `configs` to the `Panther(configs=...)`, Panther is going to load the configs from this file, else it is going to load `core/configs.py` file -- You can pass the `startup` and `shutdown` functions to the `Panther()` too. ```python from panther import Panther @@ -64,11 +65,5 @@ else it is going to load `core/configs.py` file '/': hello_world_api, } - def startup(): - pass - - def shutdown(): - pass - - app = Panther(__name__, configs=__name__, urls=url_routing, startup=startup, shutdown=shutdown) + app = Panther(__name__, configs=__name__, urls=url_routing) ``` \ No newline at end of file diff --git a/docs/docs/urls.md b/docs/docs/urls.md index a497d24..028d42c 100644 --- a/docs/docs/urls.md +++ b/docs/docs/urls.md @@ -14,7 +14,7 @@ - Example: `user//blog//` - The `endpoint` should have parameters with those names too - Example Function-Base: `async def profile_api(user_id: int, title: str):` - - Example Class-Base: `async def get(user_id: int, title: str):` + - Example Class-Base: `async def get(self, user_id: int, title: str):` ## Example diff --git a/docs/docs/websocket.md b/docs/docs/websocket.md index 385bc73..ec21692 100644 --- a/docs/docs/websocket.md +++ b/docs/docs/websocket.md @@ -15,7 +15,7 @@ class BookWebsocket(GenericWebsocket): await self.accept() print(f'{self.connection_id=}') - async def receive(self, data: str | bytes = None): + async def receive(self, data: str | bytes): # Just Echo The Message await self.send(data=data) ``` @@ -64,7 +64,7 @@ you have to use `--preload`, like below: from panther import status await self.close(code=status.WS_1000_NORMAL_CLOSURE, reason='I just want to close it') ``` - - Out of websocket class scope **(Not Recommended)**: You can close it with `close_websocket_connection()` from `panther.websocket`, it's `async` function with takes 3 args, `connection_id`, `code` and `reason`, like below: + - Out of websocket class scope: You can close it with `close_websocket_connection()` from `panther.websocket`, it's `async` function with takes 3 args, `connection_id`, `code` and `reason`, like below: ```python from panther import status from panther.websocket import close_websocket_connection @@ -83,6 +83,5 @@ you have to use `--preload`, like below: '/ws/<user_id>/<room_id>/': UserWebsocket } ``` -12. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket) -13. Enjoy. +12. Enjoy. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 10de3bb..67db86a 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -26,9 +26,10 @@ edit_uri: edit/master/docs/ nav: - Introduction: 'index.md' - - First Crud: + - How To CRUD: - Function Base: 'function_first_crud.md' - Class Base: 'class_first_crud.md' + - Generic: 'generic_crud.md' - Database: 'database.md' - Panther ODM: 'panther_odm.md' - Configs: 'configs.md' @@ -37,6 +38,7 @@ nav: - WebSocket: 'websocket.md' - Middlewares: 'middlewares.md' - Authentications: 'authentications.md' + - Events: 'events.md' - URLs: 'urls.md' - Throttling: 'throttling.md' - Background Tasks: 'background_tasks.md' diff --git a/example/app/apis.py b/example/app/apis.py index 140ed70..45dac10 100644 --- a/example/app/apis.py +++ b/example/app/apis.py @@ -16,9 +16,13 @@ from panther.app import API, GenericAPI from panther.authentications import JWTAuthentication from panther.background_tasks import BackgroundTask, background_tasks +from pantherdb import Cursor as PantherDBCursor from panther.db.connections import redis +from panther.db.cursor import Cursor +from panther.generics import ListAPI +from panther.pagination import Pagination from panther.request import Request -from panther.response import HTMLResponse, Response +from panther.response import HTMLResponse, Response, StreamingResponse from panther.throttling import Throttling from panther.websocket import close_websocket_connection, send_message_to_websocket @@ -85,8 +89,8 @@ async def res_request_data_with_output_model(request: Request): @API(input_model=UserInputSerializer) async def using_redis(request: Request): - redis.set('ali', '1') - logger.debug(f"{redis.get('ali') = }") + await redis.set('ali', '1') + logger.debug(f"{await redis.get('ali') = }") return Response() @@ -213,3 +217,30 @@ async def login_api(): @API(auth=True) def logout_api(request: Request): return request.user.logout() + + +def reader(): + from faker import Faker + import time + f = Faker() + for _ in range(5): + name = f.name() + print(f'{name=}') + yield name + time.sleep(1) + + +@API() +def stream_api(): + # Test --> curl http://127.0.0.1:8000/stream/ --no-buffer + return StreamingResponse(reader()) + + +class PaginationAPI(ListAPI): + pagination = Pagination + sort_fields = ['username', 'id'] + filter_fields = ['username'] + search_fields = ['username'] + + async def objects(self, request: Request, **kwargs) -> Cursor | PantherDBCursor: + return await User.find() diff --git a/example/app/urls.py b/example/app/urls.py index b6defc1..7adb05c 100644 --- a/example/app/urls.py +++ b/example/app/urls.py @@ -39,4 +39,6 @@ async def test(*args, **kwargs): 'custom-response/': custom_response_class_api, 'image/': ImageAPI, 'logout/': logout_api, + 'stream/': stream_api, + 'pagination/': PaginationAPI, } diff --git a/example/core/configs.py b/example/core/configs.py index d6ee78e..cff954d 100644 --- a/example/core/configs.py +++ b/example/core/configs.py @@ -50,9 +50,9 @@ DATABASE = { 'engine': { - 'class': 'panther.db.connections.MongoDBConnection', - # 'class': 'panther.db.connections.PantherDBConnection', - 'host': f'mongodb://{DB_HOST}:27017/{DB_NAME}' + # 'class': 'panther.db.connections.MongoDBConnection', + 'class': 'panther.db.connections.PantherDBConnection', + # 'host': f'mongodb://{DB_HOST}:27017/{DB_NAME}' }, # 'query': ..., } @@ -78,3 +78,5 @@ async def shutdown(): SHUTDOWN = 'core.configs.shutdown' AUTO_REFORMAT = False + +TIMEZONE = 'UTC' diff --git a/panther/_load_configs.py b/panther/_load_configs.py index abdfea6..62b3e7d 100644 --- a/panther/_load_configs.py +++ b/panther/_load_configs.py @@ -21,6 +21,7 @@ 'load_redis', 'load_startup', 'load_shutdown', + 'load_timezone', 'load_database', 'load_secret_key', 'load_monitoring', @@ -39,10 +40,10 @@ logger = logging.getLogger('panther') -def load_configs_module(_configs, /) -> dict: +def load_configs_module(module_name: str, /) -> dict: """Read the config file as dict""" - if _configs: - _module = sys.modules[_configs] + if module_name: + _module = sys.modules[module_name] else: try: _module = import_module('core.configs') @@ -55,10 +56,10 @@ def load_redis(_configs: dict, /) -> None: if redis_config := _configs.get('REDIS'): # Check redis module installation try: - from redis import Redis as _Redis - except ModuleNotFoundError as e: + from redis.asyncio import Redis + except ImportError as e: raise import_error(e, package='redis') - redis_class_path = redis_config.get('class', 'panther.db.connections.Redis') + redis_class_path = redis_config.get('class', 'panther.db.connections.RedisConnection') redis_class = import_class(redis_class_path) # We have to create another dict then pop the 'class' else we can't pass the tests args = redis_config.copy() @@ -68,12 +69,17 @@ def load_redis(_configs: dict, /) -> None: def load_startup(_configs: dict, /) -> None: if startup := _configs.get('STARTUP'): - config['startup'] = import_class(startup) + config.STARTUP = import_class(startup) def load_shutdown(_configs: dict, /) -> None: if shutdown := _configs.get('SHUTDOWN'): - config['shutdown'] = import_class(shutdown) + config.SHUTDOWN = import_class(shutdown) + + +def load_timezone(_configs: dict, /) -> None: + if timezone := _configs.get('TIMEZONE'): + config.TIMEZONE = timezone def load_database(_configs: dict, /) -> None: @@ -87,42 +93,42 @@ def load_database(_configs: dict, /) -> None: # We have to create another dict then pop the 'class' else we can't pass the tests args = database_config['engine'].copy() args.pop('class') - config['database'] = engine_class(**args) + config.DATABASE = engine_class(**args) if engine_class_path == 'panther.db.connections.PantherDBConnection': - config['query_engine'] = BasePantherDBQuery + config.QUERY_ENGINE = BasePantherDBQuery elif engine_class_path == 'panther.db.connections.MongoDBConnection': - config['query_engine'] = BaseMongoDBQuery + config.QUERY_ENGINE = BaseMongoDBQuery if 'query' in database_config: - if config['query_engine']: + if config.QUERY_ENGINE: logger.warning('`DATABASE.query` has already been filled.') - config['query_engine'] = import_class(database_config['query']) + config.QUERY_ENGINE = import_class(database_config['query']) def load_secret_key(_configs: dict, /) -> None: if secret_key := _configs.get('SECRET_KEY'): - config['secret_key'] = secret_key.encode() + config.SECRET_KEY = secret_key.encode() def load_monitoring(_configs: dict, /) -> None: if _configs.get('MONITORING'): - config['monitoring'] = True + config.MONITORING = True def load_throttling(_configs: dict, /) -> None: if throttling := _configs.get('THROTTLING'): - config['throttling'] = throttling + config.THROTTLING = throttling def load_user_model(_configs: dict, /) -> None: - config['user_model'] = import_class(_configs.get('USER_MODEL', 'panther.db.models.BaseUser')) - config['models'].append(config['user_model']) + config.USER_MODEL = import_class(_configs.get('USER_MODEL', 'panther.db.models.BaseUser')) + config.MODELS.append(config.USER_MODEL) def load_log_queries(_configs: dict, /) -> None: if _configs.get('LOG_QUERIES'): - config['log_queries'] = True + config.LOG_QUERIES = True def load_middlewares(_configs: dict, /) -> None: @@ -153,42 +159,39 @@ def load_middlewares(_configs: dict, /) -> None: if issubclass(middleware_class, BaseMiddleware) is False: raise _exception_handler(field='MIDDLEWARES', error='is not a sub class of BaseMiddleware') - middleware_instance = middleware_class(**data) - if isinstance(middleware_instance, BaseMiddleware | HTTPMiddleware): - middlewares['http'].append(middleware_instance) + if middleware_class.__bases__[0] in (BaseMiddleware, HTTPMiddleware): + middlewares['http'].append((middleware_class, data)) - if isinstance(middleware_instance, BaseMiddleware | WebsocketMiddleware): - middlewares['ws'].append(middleware_instance) + if middleware_class.__bases__[0] in (BaseMiddleware, WebsocketMiddleware): + middlewares['ws'].append((middleware_class, data)) - config['http_middlewares'] = middlewares['http'] - config['ws_middlewares'] = middlewares['ws'] - config['reversed_http_middlewares'] = middlewares['http'][::-1] - config['reversed_ws_middlewares'] = middlewares['ws'][::-1] + config.HTTP_MIDDLEWARES = middlewares['http'] + config.WS_MIDDLEWARES = middlewares['ws'] def load_auto_reformat(_configs: dict, /) -> None: if _configs.get('AUTO_REFORMAT'): - config['auto_reformat'] = True + config.AUTO_REFORMAT = True def load_background_tasks(_configs: dict, /) -> None: if _configs.get('BACKGROUND_TASKS'): - config['background_tasks'] = True + config.BACKGROUND_TASKS = True background_tasks.initialize() def load_default_cache_exp(_configs: dict, /) -> None: if default_cache_exp := _configs.get('DEFAULT_CACHE_EXP'): - config['default_cache_exp'] = default_cache_exp + config.DEFAULT_CACHE_EXP = default_cache_exp def load_authentication_class(_configs: dict, /) -> None: """Should be after `load_secret_key()`""" if authentication := _configs.get('AUTHENTICATION'): - config['authentication'] = import_class(authentication) + config.AUTHENTICATION = import_class(authentication) if ws_authentication := _configs.get('WS_AUTHENTICATION'): - config['ws_authentication'] = import_class(ws_authentication) + config.WS_AUTHENTICATION = import_class(ws_authentication) load_jwt_config(_configs) @@ -196,16 +199,16 @@ def load_authentication_class(_configs: dict, /) -> None: def load_jwt_config(_configs: dict, /) -> None: """Only Collect JWT Config If Authentication Is JWTAuthentication""" auth_is_jwt = ( - getattr(config['authentication'], '__name__', None) == 'JWTAuthentication' or - getattr(config['ws_authentication'], '__name__', None) == 'QueryParamJWTAuthentication' + getattr(config.AUTHENTICATION, '__name__', None) == 'JWTAuthentication' or + getattr(config.WS_AUTHENTICATION, '__name__', None) == 'QueryParamJWTAuthentication' ) jwt = _configs.get('JWTConfig', {}) if auth_is_jwt or jwt: if 'key' not in jwt: - if config['secret_key'] is None: + if config.SECRET_KEY is None: raise _exception_handler(field='JWTConfig', error='`JWTConfig.key` or `SECRET_KEY` is required.') - jwt['key'] = config['secret_key'].decode() - config['jwt_config'] = JWTConfig(**jwt) + jwt['key'] = config.SECRET_KEY.decode() + config.JWT_CONFIG = JWTConfig(**jwt) def load_urls(_configs: dict, /, urls: dict | None) -> None: @@ -237,14 +240,14 @@ def load_urls(_configs: dict, /, urls: dict | None) -> None: if not isinstance(urls, dict): raise _exception_handler(field='URLs', error='should point to a dict.') - config['flat_urls'] = flatten_urls(urls) - config['urls'] = finalize_urls(config['flat_urls']) - config['urls']['_panel'] = finalize_urls(flatten_urls(panel_urls)) + config.FLAT_URLS = flatten_urls(urls) + config.URLS = finalize_urls(config.FLAT_URLS) + config.URLS['_panel'] = finalize_urls(flatten_urls(panel_urls)) def load_websocket_connections(): """Should be after `load_redis()`""" - if config['has_ws']: + if config.HAS_WS: # Check `websockets` try: import websockets @@ -253,7 +256,7 @@ def load_websocket_connections(): # Use the redis pubsub if `redis.is_connected`, else use the `multiprocessing.Manager` pubsub_connection = redis.create_connection_for_websocket() if redis.is_connected else Manager() - config['websocket_connections'] = WebsocketConnections(pubsub_connection=pubsub_connection) + config.WEBSOCKET_CONNECTIONS = WebsocketConnections(pubsub_connection=pubsub_connection) def _exception_handler(field: str, error: str | Exception) -> PantherError: diff --git a/panther/_utils.py b/panther/_utils.py index e1371ba..e7fffcb 100644 --- a/panther/_utils.py +++ b/panther/_utils.py @@ -1,57 +1,20 @@ +import asyncio import importlib import logging import re import subprocess import types -import typing +from typing import Any, Generator, Iterator, AsyncGenerator from collections.abc import Callable from traceback import TracebackException -import orjson as json - -from panther import status from panther.exceptions import PantherError from panther.file_handler import File logger = logging.getLogger('panther') -async def _http_response_start(send: Callable, /, headers: dict, status_code: int) -> None: - bytes_headers = [[k.encode(), str(v).encode()] for k, v in (headers or {}).items()] - await send({ - 'type': 'http.response.start', - 'status': status_code, - 'headers': bytes_headers, - }) - - -async def _http_response_body(send: Callable, /, body: bytes) -> None: - if body: - await send({'type': 'http.response.body', 'body': body}) - else: # body = b'' - await send({'type': 'http.response.body'}) - - -async def http_response( - send: Callable, - /, - *, - monitoring, # type: MonitoringMiddleware - status_code: int, - headers: dict, - body: bytes = b'', - exception: bool = False, -) -> None: - if exception: - body = json.dumps({'detail': status.status_text[status_code]}) - - await monitoring.after(status_code) - - await _http_response_start(send, headers=headers, status_code=status_code) - await _http_response_body(send, body=body) - - -def import_class(dotted_path: str, /) -> type[typing.Any]: +def import_class(dotted_path: str, /) -> type[Any]: """ Example: ------- @@ -148,4 +111,25 @@ def check_class_type_endpoint(endpoint: Callable) -> Callable: logger.critical(f'You may have forgotten to inherit from GenericAPI on the {endpoint.__name__}()') raise TypeError - return endpoint.call_method + return endpoint().call_method + + +def async_next(iterator: Iterator): + """ + The StopIteration exception is a special case in Python, + particularly when it comes to asynchronous programming and the use of asyncio. + This is because StopIteration is not meant to be caught in the traditional sense; + it's used internally by Python to signal the end of an iteration. + """ + try: + return next(iterator) + except StopIteration: + raise StopAsyncIteration + + +async def to_async_generator(generator: Generator) -> AsyncGenerator: + while True: + try: + yield await asyncio.to_thread(async_next, iter(generator)) + except StopAsyncIteration: + break diff --git a/panther/app.py b/panther/app.py index 2f4282f..8dff292 100644 --- a/panther/app.py +++ b/panther/app.py @@ -1,13 +1,18 @@ import functools import logging -from datetime import datetime, timedelta +from datetime import timedelta from typing import Literal from orjson import JSONDecodeError -from pydantic import ValidationError +from pydantic import ValidationError, BaseModel from panther._utils import is_function_async -from panther.caching import cache_key, get_cached_response_data, set_cache_response +from panther.caching import ( + get_response_from_cache, + set_response_in_cache, + get_throttling_from_cache, + increment_throttling_in_cache +) from panther.configs import config from panther.exceptions import ( APIError, @@ -19,8 +24,8 @@ ) from panther.request import Request from panther.response import Response -from panther.throttling import Throttling, throttling_storage -from panther.utils import round_datetime +from panther.serializer import ModelSerializer +from panther.throttling import Throttling __all__ = ('API', 'GenericAPI') @@ -31,11 +36,11 @@ class API: def __init__( self, *, - input_model=None, - output_model=None, + input_model: type[ModelSerializer] | type[BaseModel] | None = None, + output_model: type[ModelSerializer] | type[BaseModel] | None = None, auth: bool = False, permissions: list | None = None, - throttling: Throttling = None, + throttling: Throttling | None = None, cache: bool = False, cache_exp_time: timedelta | int | None = None, methods: list[Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']] | None = None, @@ -53,20 +58,20 @@ def __init__( def __call__(self, func): @functools.wraps(func) async def wrapper(request: Request) -> Response: - self.request: Request = request # noqa: Non-self attribute could not be type hinted + self.request = request # 1. Check Method if self.methods and self.request.method not in self.methods: raise MethodNotAllowedAPIError # 2. Authentication - await self.handle_authentications() + await self.handle_authentication() - # 3. Throttling - self.handle_throttling() + # 3. Permissions + await self.handle_permission() - # 4. Permissions - await self.handle_permissions() + # 4. Throttling + await self.handle_throttling() # 5. Validate Input if self.request.method in ['POST', 'PUT', 'PATCH']: @@ -74,7 +79,7 @@ async def wrapper(request: Request) -> Response: # 6. Get Cached Response if self.cache and self.request.method == 'GET': - if cached := get_cached_response_data(request=self.request, cache_exp_time=self.cache_exp_time): + if cached := await get_response_from_cache(request=self.request, cache_exp_time=self.cache_exp_time): return Response(data=cached.data, status_code=cached.status_code) # 7. Put PathVariables and Request(If User Wants It) In kwargs @@ -89,11 +94,12 @@ async def wrapper(request: Request) -> Response: # 9. Clean Response if not isinstance(response, Response): response = Response(data=response) - response._clean_data_with_output_model(output_model=self.output_model) # noqa: SLF001 + if self.output_model and response.data: + response.data = response.apply_output_model(response.data, output_model=self.output_model) # 10. Set New Response To Cache if self.cache and self.request.method == 'GET': - set_cache_response( + await set_response_in_cache( request=self.request, response=response, cache_exp_time=self.cache_exp_time @@ -107,26 +113,21 @@ async def wrapper(request: Request) -> Response: return wrapper - async def handle_authentications(self) -> None: - auth_class = config['authentication'] + async def handle_authentication(self) -> None: if self.auth: - if not auth_class: + if not config.AUTHENTICATION: logger.critical('"AUTHENTICATION" has not been set in configs') raise APIError - user = await auth_class.authentication(self.request) - self.request.user = user - - def handle_throttling(self) -> None: - if throttling := self.throttling or config['throttling']: - key = cache_key(self.request) - time = round_datetime(datetime.now(), throttling.duration) - throttling_key = f'{time}-{key}' - if throttling_storage[throttling_key] > throttling.rate: + self.request.user = await config.AUTHENTICATION.authentication(self.request) + + async def handle_throttling(self) -> None: + if throttling := self.throttling or config.THROTTLING: + if await get_throttling_from_cache(self.request, duration=throttling.duration) + 1 > throttling.rate: raise ThrottlingAPIError - throttling_storage[throttling_key] += 1 + await increment_throttling_in_cache(self.request, duration=throttling.duration) - async def handle_permissions(self) -> None: + async def handle_permission(self) -> None: for perm in self.permissions: if type(perm.authorization).__name__ != 'method': logger.error(f'{perm.__name__}.authorization should be "classmethod"') @@ -136,15 +137,17 @@ async def handle_permissions(self) -> None: def handle_input_validation(self): if self.input_model: - validated_data = self.validate_input(model=self.input_model, request=self.request) - self.request.set_validated_data(validated_data) + self.request.validated_data = self.validate_input(model=self.input_model, request=self.request) @classmethod def validate_input(cls, model, request: Request): + if isinstance(request.data, bytes): + raise BadRequestAPIError(detail='Content-Type is not valid') + if request.data is None: + raise BadRequestAPIError(detail='Request body is required') try: - if isinstance(request.data, bytes): - raise BadRequestAPIError(detail='Content-Type is not valid') - return model(**request.data) + # `request` will be ignored in regular `BaseModel` + return model(**request.data, request=request) except ValidationError as validation_error: error = {'.'.join(loc for loc in e['loc']): e['msg'] for e in validation_error.errors()} raise BadRequestAPIError(detail=error) @@ -153,8 +156,8 @@ def validate_input(cls, model, request: Request): class GenericAPI: - input_model = None - output_model = None + input_model: type[ModelSerializer] | type[BaseModel] = None + output_model: type[ModelSerializer] | type[BaseModel] = None auth: bool = False permissions: list | None = None throttling: Throttling | None = None @@ -176,26 +179,27 @@ async def patch(self, *args, **kwargs): async def delete(self, *args, **kwargs): raise MethodNotAllowedAPIError - @classmethod - async def call_method(cls, *args, **kwargs): - match kwargs['request'].method: + async def call_method(self, request: Request): + match request.method: case 'GET': - func = cls().get + func = self.get case 'POST': - func = cls().post + func = self.post case 'PUT': - func = cls().put + func = self.put case 'PATCH': - func = cls().patch + func = self.patch case 'DELETE': - func = cls().delete + func = self.delete + case _: + raise MethodNotAllowedAPIError return await API( - input_model=cls.input_model, - output_model=cls.output_model, - auth=cls.auth, - permissions=cls.permissions, - throttling=cls.throttling, - cache=cls.cache, - cache_exp_time=cls.cache_exp_time, - )(func)(*args, **kwargs) + input_model=self.input_model, + output_model=self.output_model, + auth=self.auth, + permissions=self.permissions, + throttling=self.throttling, + cache=self.cache, + cache_exp_time=self.cache_exp_time, + )(func)(request=request) diff --git a/panther/authentications.py b/panther/authentications.py index 1861714..27454ea 100644 --- a/panther/authentications.py +++ b/panther/authentications.py @@ -67,7 +67,7 @@ async def authentication(cls, request: Request | Websocket) -> Model: msg = 'Authorization keyword is not valid' raise cls.exception(msg) from None - if redis.is_connected and cls._check_in_cache(token=token): + if redis.is_connected and await cls._check_in_cache(token=token): msg = 'User logged out' raise cls.exception(msg) from None @@ -82,8 +82,8 @@ def decode_jwt(cls, token: str) -> dict: try: return jwt.decode( token=token, - key=config['jwt_config'].key, - algorithms=[config['jwt_config'].algorithm], + key=config.JWT_CONFIG.key, + algorithms=[config.JWT_CONFIG.algorithm], ) except JWTError as e: raise cls.exception(e) from None @@ -95,7 +95,7 @@ async def get_user(cls, payload: dict) -> Model: msg = 'Payload does not have `user_id`' raise cls.exception(msg) - user_model = config['user_model'] or cls.model + user_model = config.USER_MODEL or cls.model if user := await user_model.find_one(id=user_id): return user @@ -107,9 +107,9 @@ def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'ac """Encode JWT from user_id.""" issued_at = datetime.now(timezone.utc).timestamp() if token_type == 'access': - expire = issued_at + config['jwt_config'].life_time + expire = issued_at + config.JWT_CONFIG.life_time else: - expire = issued_at + config['jwt_config'].refresh_life_time + expire = issued_at + config.JWT_CONFIG.refresh_life_time claims = { 'token_type': token_type, @@ -119,8 +119,8 @@ def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'ac } return jwt.encode( claims, - key=config['jwt_config'].key, - algorithm=config['jwt_config'].algorithm, + key=config.JWT_CONFIG.key, + algorithm=config.JWT_CONFIG.algorithm, ) @classmethod @@ -132,24 +132,24 @@ def login(cls, user_id: str) -> dict: } @classmethod - def logout(cls, raw_token: str) -> None: + async def logout(cls, raw_token: str) -> None: *_, token = raw_token.split() if redis.is_connected: payload = cls.decode_jwt(token=token) remaining_exp_time = payload['exp'] - time.time() - cls._set_in_cache(token=token, exp=int(remaining_exp_time)) + await cls._set_in_cache(token=token, exp=int(remaining_exp_time)) else: logger.error('`redis` middleware is required for `logout()`') @classmethod - def _set_in_cache(cls, token: str, exp: int) -> None: + async def _set_in_cache(cls, token: str, exp: int) -> None: key = generate_hash_value_from_string(token) - redis.set(key, b'', ex=exp) + await redis.set(key, b'', ex=exp) @classmethod - def _check_in_cache(cls, token: str) -> bool: + async def _check_in_cache(cls, token: str) -> bool: key = generate_hash_value_from_string(token) - return bool(redis.exists(key)) + return bool(await redis.exists(key)) @staticmethod def exception(message: str | JWTError | UnicodeEncodeError, /) -> type[AuthenticationAPIError]: diff --git a/panther/background_tasks.py b/panther/background_tasks.py index c925375..9165b04 100644 --- a/panther/background_tasks.py +++ b/panther/background_tasks.py @@ -9,7 +9,6 @@ from panther._utils import is_function_async from panther.utils import Singleton - __all__ = ( 'BackgroundTask', 'background_tasks', diff --git a/panther/base_request.py b/panther/base_request.py index 314e44a..5e012f7 100644 --- a/panther/base_request.py +++ b/panther/base_request.py @@ -60,8 +60,6 @@ def __init__(self, scope: dict, receive: Callable, send: Callable): self.scope = scope self.asgi_send = send self.asgi_receive = receive - self._data = ... - self._validated_data = None self._headers: Headers | None = None self._params: dict | None = None self.user: Model | None = None @@ -116,25 +114,22 @@ def collect_path_variables(self, found_path: str): } def clean_parameters(self, func: Callable) -> dict: - kwargs = {} + kwargs = self.path_variables.copy() + for variable_name, variable_type in func.__annotations__.items(): # Put Request/ Websocket In kwargs (If User Wants It) if issubclass(variable_type, BaseRequest): kwargs[variable_name] = self - continue - - for name, value in self.path_variables.items(): - if name == variable_name: - # Check the type and convert the value - if variable_type is bool: - kwargs[name] = value.lower() not in ['false', '0'] - - elif variable_type is int: - try: - kwargs[name] = int(value) - except ValueError: - raise InvalidPathVariableAPIError(value=value, variable_type=variable_type) - else: - kwargs[name] = value - return kwargs + elif variable_name in kwargs: + # Cast To Boolean + if variable_type is bool: + kwargs[variable_name] = kwargs[variable_name].lower() not in ['false', '0'] + + # Cast To Int + elif variable_type is int: + try: + kwargs[variable_name] = int(kwargs[variable_name]) + except ValueError: + raise InvalidPathVariableAPIError(value=kwargs[variable_name], variable_type=variable_type) + return kwargs diff --git a/panther/base_websocket.py b/panther/base_websocket.py index 7da3b74..5ba79e6 100644 --- a/panther/base_websocket.py +++ b/panther/base_websocket.py @@ -1,11 +1,8 @@ from __future__ import annotations import asyncio -import contextlib import logging -from multiprocessing import Manager from multiprocessing.managers import SyncManager -from threading import Thread from typing import TYPE_CHECKING, Literal import orjson as json @@ -15,16 +12,17 @@ from panther.configs import config from panther.db.connections import redis from panther.exceptions import AuthenticationAPIError, InvalidPathVariableAPIError +from panther.monitoring import Monitoring from panther.utils import Singleton, ULID if TYPE_CHECKING: - from redis import Redis + from redis.asyncio import Redis logger = logging.getLogger('panther') class PubSub: - def __init__(self, manager): + def __init__(self, manager: SyncManager): self._manager = manager self._subscribers = self._manager.list() @@ -38,17 +36,8 @@ def publish(self, msg): queue.put(msg) -class WebsocketListener(Thread): - def __init__(self): - super().__init__(target=config['websocket_connections'], daemon=True) - - def run(self): - with contextlib.suppress(Exception): - super().run() - - class WebsocketConnections(Singleton): - def __init__(self, pubsub_connection: Redis | Manager): + def __init__(self, pubsub_connection: Redis | SyncManager): self.connections = {} self.connections_count = 0 self.pubsub_connection = pubsub_connection @@ -56,20 +45,21 @@ def __init__(self, pubsub_connection: Redis | Manager): if isinstance(self.pubsub_connection, SyncManager): self.pubsub = PubSub(manager=self.pubsub_connection) - def __call__(self): + async def __call__(self): if isinstance(self.pubsub_connection, SyncManager): - # We don't have redis connection, so use the `multiprocessing.PubSub` + # We don't have redis connection, so use the `multiprocessing.Manager` + self.pubsub: PubSub queue = self.pubsub.subscribe() logger.info("Subscribed to 'websocket_connections' queue") while True: received_message = queue.get() - self._handle_received_message(received_message=received_message) + await self._handle_received_message(received_message=received_message) else: # We have a redis connection, so use it for pubsub self.pubsub = self.pubsub_connection.pubsub() - self.pubsub.subscribe('websocket_connections') + await self.pubsub.subscribe('websocket_connections') logger.info("Subscribed to 'websocket_connections' channel") - for channel_data in self.pubsub.listen(): + async for channel_data in self.pubsub.listen(): match channel_data['type']: # Subscribed case 'subscribe': @@ -78,12 +68,12 @@ def __call__(self): # Message Received case 'message': loaded_data = json.loads(channel_data['data'].decode()) - self._handle_received_message(received_message=loaded_data) + await self._handle_received_message(received_message=loaded_data) case unknown_type: - logger.debug(f'Unknown Channel Type: {unknown_type}') + logger.error(f'Unknown Channel Type: {unknown_type}') - def _handle_received_message(self, received_message): + async def _handle_received_message(self, received_message): if ( isinstance(received_message, dict) and (connection_id := received_message.get('connection_id')) @@ -94,120 +84,156 @@ def _handle_received_message(self, received_message): # Check Action of WS match received_message['action']: case 'send': - asyncio.run(self.connections[connection_id].send(data=received_message['data'])) + await self.connections[connection_id].send(data=received_message['data']) case 'close': - with contextlib.suppress(RuntimeError): - asyncio.run(self.connections[connection_id].close( - code=received_message['data']['code'], - reason=received_message['data']['reason'] - )) - # We are trying to disconnect the connection between a thread and a user - # from another thread, it's working, but we have to find another solution for it - # - # Error: - # Task <Task pending coro=<Websocket.close()>> got Future - # <Task pending coro=<WebSocketCommonProtocol.transfer_data()>> - # attached to a different loop + await self.connections[connection_id].close( + code=received_message['data']['code'], + reason=received_message['data']['reason'] + ) case unknown_action: - logger.debug(f'Unknown Message Action: {unknown_action}') + logger.error(f'Unknown Message Action: {unknown_action}') - def publish(self, connection_id: str, action: Literal['send', 'close'], data: any): + async def publish(self, connection_id: str, action: Literal['send', 'close'], data: any): publish_data = {'connection_id': connection_id, 'action': action, 'data': data} if redis.is_connected: - redis.publish('websocket_connections', json.dumps(publish_data)) + await redis.publish('websocket_connections', json.dumps(publish_data)) else: self.pubsub.publish(publish_data) - async def new_connection(self, connection: Websocket) -> None: + async def listen(self, connection: Websocket) -> None: # 1. Authentication - connection_closed = await self.handle_authentication(connection=connection) + if not connection.is_rejected: + await self.handle_authentication(connection=connection) # 2. Permissions - connection_closed = connection_closed or await self.handle_permissions(connection=connection) + if not connection.is_rejected: + await self.handle_permissions(connection=connection) - if connection_closed: - # Don't run the following code... + if connection.is_rejected: + # Connection is rejected so don't continue the flow ... return None # 3. Put PathVariables and Request(If User Wants It) In kwargs try: kwargs = connection.clean_parameters(connection.connect) except InvalidPathVariableAPIError as e: - return await connection.close(status.WS_1000_NORMAL_CLOSURE, reason=str(e)) + connection.log(e.detail) + return await connection.close() # 4. Connect To Endpoint await connection.connect(**kwargs) - if not hasattr(connection, '_connection_id'): - # User didn't even call the `self.accept()` so close the connection - await connection.close() + # 5. Check Connection + if not connection.is_connected and not connection.is_rejected: + # User didn't call the `self.accept()` or `self.close()` so we `close()` the connection (reject) + return await connection.close() - # 5. Connection Accepted - if connection.is_connected: - self.connections_count += 1 + # 6. Listen Connection + await self.listen_connection(connection=connection) + + async def listen_connection(self, connection: Websocket): + while True: + response = await connection.asgi_receive() + if response['type'] == 'websocket.connect': + continue + + if response['type'] == 'websocket.disconnect': + # Connect has to be closed by the client + await self.connection_closed(connection=connection) + break - # Save New ConnectionID - self.connections[connection.connection_id] = connection + if 'text' in response: + await connection.receive(data=response['text']) + else: + await connection.receive(data=response['bytes']) - def remove_connection(self, connection: Websocket) -> None: + async def connection_accepted(self, connection: Websocket) -> None: + # Generate ConnectionID + connection._connection_id = ULID.new() + + # Save Connection + self.connections[connection.connection_id] = connection + + # Logs + await connection.monitoring.after('Accepted') + connection.log(f'Accepted {connection.connection_id}') + + async def connection_closed(self, connection: Websocket, from_server: bool = False) -> None: if connection.is_connected: - self.connections_count -= 1 del self.connections[connection.connection_id] + await connection.monitoring.after('Closed') + connection.log(f'Closed {connection.connection_id}') + connection._connection_id = '' + + elif connection.is_rejected is False and from_server is True: + await connection.monitoring.after('Rejected') + connection.log('Rejected') + connection._is_rejected = True + + async def start(self): + """ + Start Websocket Listener (Redis/ Queue) + + Cause of --preload in gunicorn we have to keep this function here, + and we can't move it to __init__ of Panther + + * Each process should start this listener for itself, + but they have same Manager() + """ + + if config.HAS_WS: + # Schedule the async function to run in the background, + # We don't need to await for this task + asyncio.create_task(self()) @classmethod - async def handle_authentication(cls, connection: Websocket) -> bool: + async def handle_authentication(cls, connection: Websocket): """Return True if connection is closed, False otherwise.""" if connection.auth: - if not config.ws_authentication: - logger.critical('"WS_AUTHENTICATION" has not been set in configs') - await connection.close(reason='Authentication Error') - return True - try: - connection.user = await config.ws_authentication.authentication(connection) - except AuthenticationAPIError as e: - await connection.close(reason=e.detail) - return False + if not config.WS_AUTHENTICATION: + logger.critical('`WS_AUTHENTICATION` has not been set in configs') + await connection.close() + else: + try: + connection.user = await config.WS_AUTHENTICATION.authentication(connection) + except AuthenticationAPIError as e: + connection.log(e.detail) + await connection.close() @classmethod - async def handle_permissions(cls, connection: Websocket) -> bool: + async def handle_permissions(cls, connection: Websocket): """Return True if connection is closed, False otherwise.""" for perm in connection.permissions: if type(perm.authorization).__name__ != 'method': - logger.error(f'{perm.__name__}.authorization should be "classmethod"') - await connection.close(reason='Permission Denied') - return True - if await perm.authorization(connection) is False: - await connection.close(reason='Permission Denied') - return True - return False + logger.critical(f'{perm.__name__}.authorization should be "classmethod"') + await connection.close() + elif await perm.authorization(connection) is False: + connection.log('Permission Denied') + await connection.close() class Websocket(BaseRequest): - is_connected: bool = False auth: bool = False permissions: list = [] + _connection_id: str = '' + _is_rejected: bool = False + _monitoring: Monitoring def __init_subclass__(cls, **kwargs): if cls.__module__ != 'panther.websocket': - config['has_ws'] = True + config.HAS_WS = True async def connect(self, **kwargs) -> None: - """Check your conditions then self.accept() the connection""" - await self.accept() - - async def accept(self, subprotocol: str | None = None, headers: dict | None = None) -> None: - await self.asgi_send({'type': 'websocket.accept', 'subprotocol': subprotocol, 'headers': headers or {}}) - self.is_connected = True - - # Generate ConnectionID - self._connection_id = ULID.new() - - logger.debug(f'Accepting WS Connection {self._connection_id}') + pass async def receive(self, data: str | bytes) -> None: pass + async def accept(self, subprotocol: str | None = None, headers: dict | None = None) -> None: + await self.asgi_send({'type': 'websocket.accept', 'subprotocol': subprotocol, 'headers': headers or {}}) + await config.WEBSOCKET_CONNECTIONS.connection_accepted(connection=self) + async def send(self, data: any = None) -> None: logger.debug(f'Sending WS Message to {self.connection_id}') if data: @@ -225,32 +251,26 @@ async def send_bytes(self, bytes_data: bytes) -> None: await self.asgi_send({'type': 'websocket.send', 'bytes': bytes_data}) async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = '') -> None: - connection_id = getattr(self, '_connection_id', '') - logger.debug(f'Closing WS Connection {connection_id} Code: {code}') - self.is_connected = False - config['websocket_connections'].remove_connection(self) await self.asgi_send({'type': 'websocket.close', 'code': code, 'reason': reason}) + await config.WEBSOCKET_CONNECTIONS.connection_closed(connection=self, from_server=True) - async def listen(self) -> None: - while self.is_connected: - response = await self.asgi_receive() - if response['type'] == 'websocket.connect': - continue - - if response['type'] == 'websocket.disconnect': - break + @property + def connection_id(self) -> str: + if self.is_connected: + return self._connection_id + logger.error('You should first `self.accept()` the connection then use the `self.connection_id`') - if 'text' in response: - await self.receive(data=response['text']) - else: - await self.receive(data=response['bytes']) + @property + def is_connected(self) -> bool: + return bool(self._connection_id) - def set_connection_id(self, connection_id: str) -> None: - self._connection_id = connection_id + @property + def is_rejected(self) -> bool: + return self._is_rejected @property - def connection_id(self) -> str: - connection_id = getattr(self, '_connection_id', None) - if connection_id is None: - logger.error('You should first `self.accept()` the connection then use the `self.connection_id`') - return connection_id + def monitoring(self) -> Monitoring: + return self._monitoring + + def log(self, message: str): + logger.debug(f'WS {self.path} --> {message}') diff --git a/panther/caching.py b/panther/caching.py index 5738679..b8c5db9 100644 --- a/panther/caching.py +++ b/panther/caching.py @@ -9,51 +9,53 @@ from panther.db.connections import redis from panther.request import Request from panther.response import Response, ResponseDataTypes +from panther.throttling import throttling_storage from panther.utils import generate_hash_value_from_string, round_datetime logger = logging.getLogger('panther') caches = {} -CachedResponse = namedtuple('Cached', ['data', 'status_code']) +CachedResponse = namedtuple('CachedResponse', ['data', 'status_code']) -def cache_key(request: Request, /) -> str: +def api_cache_key(request: Request, cache_exp_time: timedelta | None = None) -> str: client = request.user and request.user.id or request.client.ip query_params_hash = generate_hash_value_from_string(request.scope['query_string'].decode('utf-8')) - return f'{client}-{request.path}-{query_params_hash}-{request.validated_data}' + key = f'{client}-{request.path}-{query_params_hash}-{request.validated_data}' - -def local_cache_key(*, request: Request, cache_exp_time: timedelta | None = None) -> str: - key = cache_key(request) if cache_exp_time: time = round_datetime(datetime.now(), cache_exp_time) return f'{time}-{key}' - else: - return key + + return key + + +def throttling_cache_key(request: Request, duration: timedelta) -> str: + client = request.user and request.user.id or request.client.ip + time = round_datetime(datetime.now(), duration) + return f'{time}-{client}-{request.path}' -def get_cached_response_data(*, request: Request, cache_exp_time: timedelta) -> CachedResponse | None: +async def get_response_from_cache(*, request: Request, cache_exp_time: timedelta) -> CachedResponse | None: """ If redis.is_connected: Get Cached Data From Redis else: Get Cached Data From Memory """ - if redis.is_connected: # noqa: Unresolved References - key = cache_key(request) - data = (redis.get(key) or b'{}').decode() + if redis.is_connected: + key = api_cache_key(request=request) + data = (await redis.get(key) or b'{}').decode() if cached_value := json.loads(data): return CachedResponse(*cached_value) else: - key = local_cache_key(request=request, cache_exp_time=cache_exp_time) + key = api_cache_key(request=request, cache_exp_time=cache_exp_time) if cached_value := caches.get(key): return CachedResponse(*cached_value) - return None - -def set_cache_response(*, request: Request, response: Response, cache_exp_time: timedelta | int) -> None: +async def set_response_in_cache(*, request: Request, response: Response, cache_exp_time: timedelta | int) -> None: """ If redis.is_connected: Cache The Data In Redis @@ -64,9 +66,9 @@ def set_cache_response(*, request: Request, response: Response, cache_exp_time: cache_data: tuple[ResponseDataTypes, int] = (response.data, response.status_code) if redis.is_connected: - key = cache_key(request) + key = api_cache_key(request=request) - cache_exp_time = cache_exp_time or config['default_cache_exp'] + cache_exp_time = cache_exp_time or config.DEFAULT_CACHE_EXP cache_data: bytes = json.dumps(cache_data) if not isinstance(cache_exp_time, timedelta | int | NoneType): @@ -76,15 +78,48 @@ def set_cache_response(*, request: Request, response: Response, cache_exp_time: if cache_exp_time is None: logger.warning( 'your response are going to cache in redis forever ' - '** set DEFAULT_CACHE_EXP in configs or pass the cache_exp_time in @API.get() for prevent this **' + '** set DEFAULT_CACHE_EXP in `configs` or set the `cache_exp_time` in `@API.get()` to prevent this **' ) - redis.set(key, cache_data) + await redis.set(key, cache_data) else: - redis.set(key, cache_data, ex=cache_exp_time) + await redis.set(key, cache_data, ex=cache_exp_time) else: - key = local_cache_key(request=request, cache_exp_time=cache_exp_time) + key = api_cache_key(request=request, cache_exp_time=cache_exp_time) caches[key] = cache_data if cache_exp_time: - logger.info('"cache_exp_time" is not very accurate when redis is not connected.') + logger.info('`cache_exp_time` is not very accurate when `redis` is not connected.') + + +async def get_throttling_from_cache(request: Request, duration: timedelta) -> int: + """ + If redis.is_connected: + Get Cached Data From Redis + else: + Get Cached Data From Memory + """ + key = throttling_cache_key(request=request, duration=duration) + + if redis.is_connected: + data = (await redis.get(key) or b'0').decode() + return json.loads(data) + + else: + return throttling_storage[key] + + +async def increment_throttling_in_cache(request: Request, duration: timedelta) -> None: + """ + If redis.is_connected: + Increment The Data In Redis + else: + Increment The Data In Memory + """ + key = throttling_cache_key(request=request, duration=duration) + + if redis.is_connected: + await redis.incrby(key, amount=1) + + else: + throttling_storage[key] += 1 diff --git a/panther/cli/create_command.py b/panther/cli/create_command.py index 2800564..48d4779 100644 --- a/panther/cli/create_command.py +++ b/panther/cli/create_command.py @@ -16,7 +16,7 @@ AUTO_REFORMAT_PART, DATABASE_PANTHERDB_PART, DATABASE_MONGODB_PART, - USER_MODEL_PART, + USER_MODEL_PART, REDIS_PART, ) from panther.cli.utils import cli_error @@ -34,6 +34,7 @@ def __init__(self): self.base_directory = '.' self.database = '0' self.database_encryption = False + self.redis = False self.authentication = False self.monitoring = True self.log_queries = True @@ -60,7 +61,7 @@ def __init__(self): }, { 'field': 'database', - 'message': ' 0: PantherDB\n 1: MongoDB (Required `pymongo`)\n 2: No Database\nChoose Your Database (default is 0)', + 'message': ' 0: PantherDB (File-Base, No Requirements)\n 1: MongoDB (Required `pymongo`)\n 2: No Database\nChoose Your Database (default is 0)', 'validation_func': lambda x: x in ['0', '1', '2'], 'error_message': "Invalid Choice, '{}' not in ['0', '1', '2']", }, @@ -70,6 +71,11 @@ def __init__(self): 'is_boolean': True, 'condition': "self.database == '0'" }, + { + 'field': 'redis', + 'message': 'Do You Want To Use Redis (Required `redis`)', + 'is_boolean': True, + }, { 'field': 'authentication', 'message': 'Do You Want To Use JWT Authentication (Required `python-jose`)', @@ -77,7 +83,7 @@ def __init__(self): }, { 'field': 'monitoring', - 'message': 'Do You Want To Use Built-in Monitoring', + 'message': 'Do You Want To Use Built-in Monitoring (Required `watchfiles`)', 'is_boolean': True, }, { @@ -87,7 +93,7 @@ def __init__(self): }, { 'field': 'auto_reformat', - 'message': 'Do You Want To Use Auto Reformat (Required `ruff`)', + 'message': 'Do You Want To Use Auto Code Reformat (Required `ruff`)', 'is_boolean': True, }, ] @@ -139,6 +145,8 @@ def _create_file(self, *, path: str, data: str): log_queries_part = LOG_QUERIES_PART if self.log_queries else '' auto_reformat_part = AUTO_REFORMAT_PART if self.auto_reformat else '' database_encryption = 'True' if self.database_encryption else 'False' + database_extension = 'pdb' if self.database_encryption else 'json' + redis_part = REDIS_PART if self.redis else '' if self.database == '0': database_part = DATABASE_PANTHERDB_PART elif self.database == '1': @@ -153,6 +161,8 @@ def _create_file(self, *, path: str, data: str): data = data.replace('{AUTO_REFORMAT}', auto_reformat_part) data = data.replace('{DATABASE}', database_part) data = data.replace('{PANTHERDB_ENCRYPTION}', database_encryption) # Should be after `DATABASE` + data = data.replace('{PANTHERDB_EXTENSION}', database_extension) # Should be after `DATABASE` + data = data.replace('{REDIS}', redis_part) data = data.replace('{PROJECT_NAME}', self.project_name.lower()) data = data.replace('{PANTHER_VERSION}', version()) @@ -167,19 +177,19 @@ def collect_creation_data(self): field_name = question.pop('field') question['default'] = getattr(self, field_name) is_boolean = question.pop('is_boolean', False) - clean_output = str # Do Nothing + convert_output = str # Do Nothing if is_boolean: question['message'] += f' (default is {self._to_str(question["default"])})' question['validation_func'] = self._is_boolean question['error_message'] = "Invalid Choice, '{}' not in ['y', 'n']" - clean_output = self._to_boolean + convert_output = self._to_boolean # Check Question Condition if 'condition' in question and eval(question.pop('condition')) is False: print(flush=True) # Ask Question else: - setattr(self, field_name, clean_output(self.ask(**question))) + setattr(self, field_name, convert_output(self.ask(**question))) self.progress(i + 1) def ask( @@ -192,6 +202,7 @@ def ask( ) -> str: value = Prompt.ask(message, console=self.input_console).lower() or default while not validation_func(value): + # Remove the last line, show error message and ask again [print(end=self.REMOVE_LAST_LINE, flush=True) for _ in range(message.count('\n') + 1)] error = validation_func(value, return_error=True) if show_validation_error else value self.console.print(error_message.format(error), style='bold red') diff --git a/panther/cli/monitor_command.py b/panther/cli/monitor_command.py index 9916b23..8d7bf0e 100644 --- a/panther/cli/monitor_command.py +++ b/panther/cli/monitor_command.py @@ -1,71 +1,97 @@ import contextlib +import logging import os +import signal from collections import deque from pathlib import Path from rich import box from rich.align import Align from rich.console import Group -from rich.layout import Layout from rich.live import Live from rich.panel import Panel from rich.table import Table -from watchfiles import watch -from panther.cli.utils import cli_error +from panther.cli.utils import import_error from panther.configs import config +with contextlib.suppress(ImportError): + from watchfiles import watch -def monitor() -> None: - monitoring_log_file = Path(config['base_dir'] / 'logs' / 'monitoring.log') +loggerr = logging.getLogger('panther') - def _generate_table(rows: deque) -> Panel: - layout = Layout() - rows = list(rows) - _, lines = os.get_terminal_size() +class Monitoring: + def __init__(self): + self.rows = deque() + self.monitoring_log_file = Path(config.BASE_DIR / 'logs' / 'monitoring.log') + + def monitor(self) -> None: + if error := self.initialize(): + # Don't continue if initialize() has error + loggerr.error(error) + return + + with ( + self.monitoring_log_file.open() as f, + Live( + self.generate_table(), + vertical_overflow='visible', + screen=True, + ) as live, + contextlib.suppress(KeyboardInterrupt) + ): + f.readlines() # Set cursor at the end of the file + + for _ in watch(self.monitoring_log_file): + for line in f.readlines(): + self.rows.append(line.split('|')) + live.update(self.generate_table()) + + def initialize(self) -> str: + # Check requirements + try: + from watchfiles import watch + except ImportError as e: + return import_error(e, package='watchfiles').args[0] + + # Check log file + if not self.monitoring_log_file.exists(): + return f'`{self.monitoring_log_file}` file not found. (Make sure `MONITORING` is `True` in `configs`' + + # Initialize Deque + self.update_rows() + + # Register the signal handler + signal.signal(signal.SIGWINCH, self.update_rows) + + def generate_table(self) -> Panel: + # 2023-03-24 01:42:52 | GET | /user/317/ | 127.0.0.1:48856 | 0.0366 ms | 200 table = Table(box=box.MINIMAL_DOUBLE_HEAD) table.add_column('Datetime', justify='center', style='magenta', no_wrap=True) - table.add_column('Method', justify='center', style='cyan') - table.add_column('Path', justify='center', style='cyan') + table.add_column('Method', justify='center', style='cyan', no_wrap=True) + table.add_column('Path', justify='center', style='cyan', no_wrap=True) table.add_column('Client', justify='center', style='cyan') table.add_column('Response Time', justify='center', style='blue') - table.add_column('Status Code', justify='center', style='blue') + table.add_column('Status', justify='center', style='blue', no_wrap=True) - for row in rows[-lines:]: # It will give us "lines" last lines of "rows" + for row in self.rows: table.add_row(*row) - layout.update(table) return Panel( Align.center(Group(table)), box=box.ROUNDED, - padding=(1, 2), + padding=(0, 2), title='Monitoring', border_style='bright_blue', ) - if not monitoring_log_file.exists(): - return cli_error('Monitoring file not found. (You need at least one monitoring record for this action)') + def update_rows(self, *args, **kwargs): + # Top = -4, Bottom = -2 --> -6 + # Print of each line needs two line, so --> x // 2 + lines = (os.get_terminal_size()[1] - 6) // 2 + self.rows = deque(self.rows, maxlen=lines) - with monitoring_log_file.open() as f: - f.readlines() # Set cursor at the end of file - _, init_lines_count = os.get_terminal_size() - messages = deque(maxlen=init_lines_count - 10) # Save space for header and footer - - with ( - Live( - _generate_table(messages), - auto_refresh=False, - vertical_overflow='visible', - screen=True, - ) as live, - contextlib.suppress(KeyboardInterrupt), - ): - for _ in watch(monitoring_log_file): - data = f.readline().split('|') - # 2023-03-24 01:42:52 | GET | /user/317/ | 127.0.0.1:48856 | 0.0366 ms | 200 - messages.append(data) - live.update(_generate_table(messages)) - live.refresh() +monitor = Monitoring().monitor diff --git a/panther/cli/template.py b/panther/cli/template.py index 6e10437..187a902 100644 --- a/panther/cli/template.py +++ b/panther/cli/template.py @@ -11,6 +11,7 @@ from panther.app import API from panther.request import Request from panther.response import Response +from panther.utils import timezone_now @API() @@ -24,7 +25,7 @@ async def info_api(request: Request): 'panther_version': version(), 'method': request.method, 'query_params': request.query_params, - 'datetime_now': datetime.now().isoformat(), + 'datetime_now': timezone_now().isoformat(), 'user_agent': request.headers.user_agent, } return Response(data=data, status_code=status.HTTP_202_ACCEPTED) @@ -55,19 +56,19 @@ async def info_api(request: Request): {PROJECT_NAME} Project (Generated by Panther on %s) \""" -from datetime import timedelta from pathlib import Path -from panther.throttling import Throttling from panther.utils import load_env BASE_DIR = Path(__name__).resolve().parent env = load_env(BASE_DIR / '.env') -SECRET_KEY = env['SECRET_KEY']{DATABASE}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT} +SECRET_KEY = env['SECRET_KEY']{DATABASE}{REDIS}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT} # More Info: https://PantherPy.GitHub.io/urls/ URLs = 'core.urls.url_routing' + +TIMEZONE = 'UTC' """ % datetime.now().date().isoformat() env = """SECRET_KEY='%s' @@ -128,15 +129,16 @@ async def info_api(request: Request): from panther.request import Request from panther.response import Response from panther.throttling import Throttling -from panther.utils import load_env +from panther.utils import load_env, timezone_now BASE_DIR = Path(__name__).resolve().parent env = load_env(BASE_DIR / '.env') -SECRET_KEY = env['SECRET_KEY']{DATABASE}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT} +SECRET_KEY = env['SECRET_KEY']{DATABASE}{REDIS}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT} InfoThrottling = Throttling(rate=5, duration=timedelta(minutes=1)) +TIMEZONE = 'UTC' @API() async def hello_world_api(): @@ -149,7 +151,7 @@ async def info_api(request: Request): 'panther_version': version(), 'method': request.method, 'query_params': request.query_params, - 'datetime_now': datetime.now().isoformat(), + 'datetime_now': timezone_now().isoformat(), 'user_agent': request.headers.user_agent, } return Response(data=data, status_code=status.HTTP_202_ACCEPTED) @@ -157,6 +159,7 @@ async def info_api(request: Request): url_routing = { '/': hello_world_api, + 'info/': info_api, } app = Panther(__name__, configs=__name__, urls=url_routing) @@ -170,17 +173,19 @@ async def info_api(request: Request): } DATABASE_PANTHERDB_PART = """ -# More Info: Https://PantherPy.GitHub.io/middlewares/ + +# More Info: https://PantherPy.GitHub.io/database/ DATABASE = { 'engine': { 'class': 'panther.db.connections.PantherDBConnection', - 'path': BASE_DIR, + 'path': BASE_DIR / 'database.{PANTHERDB_EXTENSION}', 'encryption': {PANTHERDB_ENCRYPTION} } }""" DATABASE_MONGODB_PART = """ -# More Info: Https://PantherPy.GitHub.io/middlewares/ + +# More Info: https://PantherPy.GitHub.io/database/ DATABASE = { 'engine': { 'class': 'panther.db.connections.MongoDBConnection', @@ -190,6 +195,16 @@ async def info_api(request: Request): } }""" +REDIS_PART = """ + +# More Info: https://PantherPy.GitHub.io/redis/ +REDIS = { + 'class': 'panther.db.connections.RedisConnection', + 'host': '127.0.0.1', + 'port': 6379, + 'db': 0, +}""" + USER_MODEL_PART = """ # More Info: https://PantherPy.GitHub.io/configs/#user_model diff --git a/panther/cli/utils.py b/panther/cli/utils.py index 2df3400..7e2fdac 100644 --- a/panther/cli/utils.py +++ b/panther/cli/utils.py @@ -3,6 +3,7 @@ from rich import print as rprint +from panther.configs import Config from panther.exceptions import PantherError logger = logging.getLogger('panther') @@ -108,21 +109,20 @@ def print_uvicorn_help_message(): rprint('Run `uvicorn --help` for more help') -def print_info(config: dict): +def print_info(config: Config): from panther.db.connections import redis - - mo = config['monitoring'] - lq = config['log_queries'] - bt = config['background_tasks'] - ws = config['has_ws'] + mo = config.MONITORING + lq = config.LOG_QUERIES + bt = config.BACKGROUND_TASKS + ws = config.HAS_WS rd = redis.is_connected - bd = '{0:<39}'.format(str(config['base_dir'])) + bd = '{0:<39}'.format(str(config.BASE_DIR)) if len(bd) > 39: bd = f'{bd[:36]}...' # Monitoring - if config['monitoring']: + if config.MONITORING: monitor = f'{h} * Run "panther monitor" in another session for Monitoring{h}\n' else: monitor = None @@ -139,7 +139,7 @@ def print_info(config: dict): # Gunicorn if Websocket gunicorn_msg = None - if config['has_ws']: + if config.HAS_WS: try: import gunicorn gunicorn_msg = f'{h} * You have WS so make sure to run gunicorn with --preload{h}\n' diff --git a/panther/configs.py b/panther/configs.py index 7433a0f..39d4d12 100644 --- a/panther/configs.py +++ b/panther/configs.py @@ -1,3 +1,4 @@ +import copy import typing from dataclasses import dataclass from datetime import timedelta @@ -7,7 +8,6 @@ from pydantic._internal._model_construction import ModelMetaclass from panther.throttling import Throttling -from panther.utils import Singleton class JWTConfig: @@ -41,80 +41,78 @@ def observe(cls, observer): @classmethod def update(cls): for observer in cls.observers: - observer._reload_bases(parent=config.query_engine) + observer._reload_bases(parent=config.QUERY_ENGINE) @dataclass -class Config(Singleton): - base_dir: Path - monitoring: bool - log_queries: bool - default_cache_exp: timedelta | None - throttling: Throttling | None - secret_key: bytes | None - http_middlewares: list - ws_middlewares: list - reversed_http_middlewares: list - reversed_ws_middlewares: list - user_model: ModelMetaclass | None - authentication: ModelMetaclass | None - ws_authentication: ModelMetaclass | None - jwt_config: JWTConfig | None - models: list[dict] - flat_urls: dict - urls: dict - websocket_connections: typing.Callable | None - background_tasks: bool - has_ws: bool - startup: Callable | None - shutdown: Callable | None - auto_reformat: bool - query_engine: typing.Callable | None - database: typing.Callable | None +class Config: + BASE_DIR: Path + MONITORING: bool + LOG_QUERIES: bool + DEFAULT_CACHE_EXP: timedelta | None + THROTTLING: Throttling | None + SECRET_KEY: bytes | None + HTTP_MIDDLEWARES: list[tuple] + WS_MIDDLEWARES: list[tuple] + USER_MODEL: ModelMetaclass | None + AUTHENTICATION: ModelMetaclass | None + WS_AUTHENTICATION: ModelMetaclass | None + JWT_CONFIG: JWTConfig | None + MODELS: list[dict] + FLAT_URLS: dict + URLS: dict + WEBSOCKET_CONNECTIONS: typing.Callable | None + BACKGROUND_TASKS: bool + HAS_WS: bool + STARTUPS: list[Callable] + SHUTDOWNS: list[Callable] + TIMEZONE: str + AUTO_REFORMAT: bool + QUERY_ENGINE: typing.Callable | None + DATABASE: typing.Callable | None def __setattr__(self, key, value): super().__setattr__(key, value) - if key == 'query_engine' and value: + if key == 'QUERY_ENGINE' and value: QueryObservable.update() def __setitem__(self, key, value): - setattr(self, key, value) + setattr(self, key.upper(), value) def __getitem__(self, item): - return getattr(self, item) + return getattr(self, item.upper()) def refresh(self): # In some tests we need to `refresh` the `config` values - for key, value in default_configs.items(): + for key, value in copy.deepcopy(default_configs).items(): setattr(self, key, value) default_configs = { - 'base_dir': Path(), - 'monitoring': False, - 'log_queries': False, - 'default_cache_exp': None, - 'throttling': None, - 'secret_key': None, - 'http_middlewares': [], - 'ws_middlewares': [], - 'reversed_http_middlewares': [], - 'reversed_ws_middlewares': [], - 'user_model': None, - 'authentication': None, - 'ws_authentication': None, - 'jwt_config': None, - 'models': [], - 'flat_urls': {}, - 'urls': {}, - 'websocket_connections': None, - 'background_tasks': False, - 'has_ws': False, - 'startup': None, - 'shutdown': None, - 'auto_reformat': False, - 'query_engine': None, - 'database': None, + 'BASE_DIR': Path(), + 'MONITORING': False, + 'LOG_QUERIES': False, + 'DEFAULT_CACHE_EXP': None, + 'THROTTLING': None, + 'SECRET_KEY': None, + 'HTTP_MIDDLEWARES': [], + 'WS_MIDDLEWARES': [], + 'USER_MODEL': None, + 'AUTHENTICATION': None, + 'WS_AUTHENTICATION': None, + 'JWT_CONFIG': None, + 'MODELS': [], + 'FLAT_URLS': {}, + 'URLS': {}, + 'WEBSOCKET_CONNECTIONS': None, + 'BACKGROUND_TASKS': False, + 'HAS_WS': False, + 'STARTUPS': [], + 'SHUTDOWNS': [], + 'TIMEZONE': 'UTC', + 'AUTO_REFORMAT': False, + 'QUERY_ENGINE': None, + 'DATABASE': None, } -config = Config(**default_configs) +config = Config(**copy.deepcopy(default_configs)) diff --git a/panther/db/connections.py b/panther/db/connections.py index 3fde392..6bb82b9 100644 --- a/panther/db/connections.py +++ b/panther/db/connections.py @@ -10,8 +10,8 @@ from panther.utils import Singleton try: - from redis import Redis as _Redis -except ModuleNotFoundError: + from redis.asyncio import Redis as _Redis +except ImportError: # This '_Redis' is not going to be used, # If user really wants to use redis, # we are going to force him to install it in `panther._load_configs.load_redis` @@ -75,13 +75,13 @@ def session(self): class PantherDBConnection(BaseDatabaseConnection): def init(self, path: str | None = None, encryption: bool = False): - params = {'db_name': path, 'return_dict': True} + params = {'db_name': str(path), 'return_dict': True, 'return_cursor': True} if encryption: try: import cryptography except ImportError as e: raise import_error(e, package='cryptography') - params['secret_key'] = config['secret_key'] + params['secret_key'] = config.SECRET_KEY self._connection: PantherDB = PantherDB(**params) @@ -93,7 +93,7 @@ def session(self): class DatabaseConnection(Singleton): @property def session(self): - return config['database'].session + return config.DATABASE.session class RedisConnection(Singleton, _Redis): @@ -117,7 +117,12 @@ def __init__( super().__init__(host=host, port=port, db=db, **kwargs) self.is_connected = True - self.ping() + self.sync_ping() + + def sync_ping(self): + from redis import Redis + + Redis(host=self.host, port=self.port, **self.kwargs).ping() def create_connection_for_websocket(self) -> _Redis: if not hasattr(self, 'websocket_connection'): diff --git a/panther/db/cursor.py b/panther/db/cursor.py index 923806e..1171518 100644 --- a/panther/db/cursor.py +++ b/panther/db/cursor.py @@ -2,7 +2,13 @@ from sys import version_info -from pymongo.cursor import Cursor as _Cursor +try: + from pymongo.cursor import Cursor as _Cursor +except ImportError: + # This '_Cursor' is not going to be used, + # If user really wants to use it, + # we are going to force him to install it in `panther.db.connections.MongoDBConnection.init` + _Cursor = type('_Cursor', (), {}) if version_info >= (3, 11): from typing import Self @@ -20,6 +26,7 @@ def __init__(self, collection, *args, cls=None, **kwargs): if cls: self.models[collection.name] = cls self.cls = cls + self.filter = kwargs['filter'] else: self.cls = self.models[collection.name] super().__init__(collection, *args, **kwargs) diff --git a/panther/db/models.py b/panther/db/models.py index dd235f3..2d6bc3b 100644 --- a/panther/db/models.py +++ b/panther/db/models.py @@ -1,41 +1,42 @@ import contextlib +import os from datetime import datetime +from typing import Annotated -from pydantic import BaseModel as PydanticBaseModel -from pydantic import Field, field_validator +from pydantic import Field, WrapValidator, PlainSerializer, BaseModel as PydanticBaseModel from panther.configs import config from panther.db.queries import Query - +from panther.utils import scrypt, URANDOM_SIZE, timezone_now with contextlib.suppress(ImportError): # Only required if user wants to use mongodb import bson +def validate_object_id(value, handler): + if config.DATABASE.__class__.__name__ == 'MongoDBConnection': + if isinstance(value, bson.ObjectId): + return value + else: + try: + return bson.ObjectId(value) + except bson.objectid.InvalidId as e: + msg = 'Invalid ObjectId' + raise ValueError(msg) from e + return str(value) + + +ID = Annotated[str, WrapValidator(validate_object_id), PlainSerializer(lambda x: str(x), return_type=str)] + + class Model(PydanticBaseModel, Query): def __init_subclass__(cls, **kwargs): if cls.__module__ == 'panther.db.models' and cls.__name__ == 'BaseUser': return - config['models'].append(cls) - - id: str | None = Field(None, validation_alias='_id') - - @field_validator('id', mode='before') - def validate_id(cls, value) -> str: - if config['database'].__class__.__name__ == 'MongoDBConnection': - if isinstance(value, str): - try: - bson.ObjectId(value) - except bson.objectid.InvalidId as e: - msg = 'Invalid ObjectId' - raise ValueError(msg) from e - - elif not isinstance(value, bson.ObjectId): - msg = 'ObjectId required' - raise ValueError(msg) from None + config.MODELS.append(cls) - return str(value) + id: ID | None = Field(None, validation_alias='_id') @property def _id(self): @@ -44,26 +45,40 @@ def _id(self): `str` for PantherDB `ObjectId` for MongoDB """ - if config['database'].__class__.__name__ == 'MongoDBConnection': - return bson.ObjectId(self.id) if self.id else None return self.id - def dict(self, *args, **kwargs) -> dict: - return self.model_dump(*args, **kwargs) - class BaseUser(Model): - first_name: str = Field('', max_length=64) - last_name: str = Field('', max_length=64) + password: str = Field('', max_length=64) last_login: datetime | None = None + date_created: datetime | None = Field(default_factory=timezone_now) async def update_last_login(self) -> None: - await self.update(last_login=datetime.now()) + await self.update(last_login=timezone_now()) async def login(self) -> dict: """Return dict of access and refresh token""" - await self.update_last_login() - return config.authentication.login(self.id) + return config.AUTHENTICATION.login(self.id) + + async def logout(self) -> dict: + return await config.AUTHENTICATION.logout(self._auth_token) + + def set_password(self, password: str): + """ + URANDOM_SIZE = 16 char --> + salt = 16 bytes + salt.hex() = 32 char + derived_key = 32 char + """ + salt = os.urandom(URANDOM_SIZE) + derived_key = scrypt(password=password, salt=salt, digest=True) + + self.password = f'{salt.hex()}{derived_key}' + + def check_password(self, new_password: str) -> bool: + size = URANDOM_SIZE * 2 + salt = self.password[:size] + stored_hash = self.password[size:] + derived_key = scrypt(password=new_password, salt=bytes.fromhex(salt), digest=True) - def logout(self) -> dict: - return config.authentication.logout(self._auth_token) + return derived_key == stored_hash diff --git a/panther/db/queries/base_queries.py b/panther/db/queries/base_queries.py index 51c8c0f..4bdb44a 100644 --- a/panther/db/queries/base_queries.py +++ b/panther/db/queries/base_queries.py @@ -2,6 +2,7 @@ from abc import abstractmethod from functools import reduce from sys import version_info +from typing import Iterator from pydantic_core._pydantic_core import ValidationError @@ -10,7 +11,7 @@ from panther.exceptions import DatabaseError if version_info >= (3, 11): - from typing import Self, Iterator + from typing import Self else: from typing import TypeVar diff --git a/panther/db/queries/mongodb_queries.py b/panther/db/queries/mongodb_queries.py index a83cd9f..5e2f10e 100644 --- a/panther/db/queries/mongodb_queries.py +++ b/panther/db/queries/mongodb_queries.py @@ -3,13 +3,19 @@ from sys import version_info from typing import Iterable, Sequence -from bson.codec_options import CodecOptions - from panther.db.connections import db from panther.db.cursor import Cursor from panther.db.queries.base_queries import BaseQuery from panther.db.utils import prepare_id_for_query +try: + from bson.codec_options import CodecOptions +except ImportError: + # This 'CodecOptions' is not going to be used, + # If user really wants to use it, + # we are going to force him to install it in `panther.db.connections.MongoDBConnection.init` + CodecOptions = type('CodecOptions', (), {}) + if version_info >= (3, 11): from typing import Self else: @@ -23,24 +29,21 @@ class BaseMongoDBQuery(BaseQuery): def _merge(cls, *args, is_mongo: bool = True) -> dict: return super()._merge(*args, is_mongo=is_mongo) - @classmethod - def collection(cls): - return db.session.get_collection( - name=cls.__name__, - codec_options=CodecOptions(document_class=dict) - # codec_options=CodecOptions(document_class=cls) TODO: https://jira.mongodb.org/browse/PYTHON-4192 - ) + # TODO: https://jira.mongodb.org/browse/PYTHON-4192 + # @classmethod + # def collection(cls): + # return db.session.get_collection(name=cls.__name__, codec_options=CodecOptions(document_class=cls)) # # # # # Find # # # # # @classmethod async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None: - if document := await cls.collection().find_one(cls._merge(_filter, kwargs)): + if document := await db.session[cls.__name__].find_one(cls._merge(_filter, kwargs)): return cls._create_model_instance(document=document) return None @classmethod async def find(cls, _filter: dict | None = None, /, **kwargs) -> Cursor: - return Cursor(cls=cls, collection=cls.collection().delegate, filter=cls._merge(_filter, kwargs)) + return Cursor(cls=cls, collection=db.session[cls.__name__].delegate, filter=cls._merge(_filter, kwargs)) @classmethod async def first(cls, _filter: dict | None = None, /, **kwargs) -> Self | None: @@ -58,12 +61,12 @@ async def last(cls, _filter: dict | None = None, /, **kwargs) -> Self | None: @classmethod async def aggregate(cls, pipeline: Sequence[dict]) -> Iterable[dict]: - return await cls.collection().aggregate(pipeline) + return await db.session[cls.__name__].aggregate(pipeline) # # # # # Count # # # # # @classmethod async def count(cls, _filter: dict | None = None, /, **kwargs) -> int: - return await cls.collection().count_documents(cls._merge(_filter, kwargs)) + return await db.session[cls.__name__].count_documents(cls._merge(_filter, kwargs)) # # # # # Insert # # # # # @classmethod @@ -71,7 +74,7 @@ async def insert_one(cls, _document: dict | None = None, /, **kwargs) -> Self: document = cls._merge(_document, kwargs) cls._validate_data(data=document) - await cls.collection().insert_one(document) + await db.session[cls.__name__].insert_one(document) return cls._create_model_instance(document=document) @classmethod @@ -80,21 +83,21 @@ async def insert_many(cls, documents: Iterable[dict]) -> list[Self]: prepare_id_for_query(document, is_mongo=True) cls._validate_data(data=document) - await cls.collection().insert_many(documents) + await db.session[cls.__name__].insert_many(documents) return [cls._create_model_instance(document=document) for document in documents] # # # # # Delete # # # # # async def delete(self) -> None: - await self.collection().delete_one({'_id': self._id}) + await db.session[self.__class__.__name__].delete_one({'_id': self._id}) @classmethod async def delete_one(cls, _filter: dict | None = None, /, **kwargs) -> bool: - result = await cls.collection().delete_one(cls._merge(_filter, kwargs)) + result = await db.session[cls.__name__].delete_one(cls._merge(_filter, kwargs)) return bool(result.deleted_count) @classmethod async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int: - result = await cls.collection().delete_many(cls._merge(_filter, kwargs)) + result = await db.session[cls.__name__].delete_many(cls._merge(_filter, kwargs)) return result.deleted_count # # # # # Update # # # # # @@ -106,14 +109,14 @@ async def update(self, _update: dict | None = None, /, **kwargs) -> None: for field, value in document.items(): setattr(self, field, value) update_fields = {'$set': document} - await self.collection().update_one({'_id': self._id}, update_fields) + await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields) @classmethod async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool: prepare_id_for_query(_filter, is_mongo=True) update_fields = {'$set': cls._merge(_update, kwargs)} - result = await cls.collection().update_one(_filter, update_fields) + result = await db.session[cls.__name__].update_one(_filter, update_fields) return bool(result.matched_count) @classmethod @@ -121,5 +124,5 @@ async def update_many(cls, _filter: dict, _update: dict | None = None, /, **kwar prepare_id_for_query(_filter, is_mongo=True) update_fields = {'$set': cls._merge(_update, kwargs)} - result = await cls.collection().update_many(_filter, update_fields) + result = await db.session[cls.__name__].update_many(_filter, update_fields) return result.modified_count diff --git a/panther/db/queries/pantherdb_queries.py b/panther/db/queries/pantherdb_queries.py index 17f0af5..caa2eae 100644 --- a/panther/db/queries/pantherdb_queries.py +++ b/panther/db/queries/pantherdb_queries.py @@ -1,6 +1,8 @@ from sys import version_info from typing import Iterable +from pantherdb import Cursor + from panther.db.connections import db from panther.db.queries.base_queries import BaseQuery from panther.db.utils import prepare_id_for_query @@ -27,9 +29,11 @@ async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None return None @classmethod - async def find(cls, _filter: dict | None = None, /, **kwargs) -> list[Self]: - documents = db.session.collection(cls.__name__).find(**cls._merge(_filter, kwargs)) - return [cls._create_model_instance(document=document) for document in documents] + async def find(cls, _filter: dict | None = None, /, **kwargs) -> Cursor: + cursor = db.session.collection(cls.__name__).find(**cls._merge(_filter, kwargs)) + cursor.response_type = cls._create_model_instance + cursor.cls = cls + return cursor @classmethod async def first(cls, _filter: dict | None = None, /, **kwargs) -> Self | None: @@ -65,12 +69,12 @@ async def insert_one(cls, _document: dict | None = None, /, **kwargs) -> Self: @classmethod async def insert_many(cls, documents: Iterable[dict]) -> list[Self]: result = [] - for _document in documents: - prepare_id_for_query(_document, is_mongo=False) - cls._validate_data(data=_document) - document = db.session.collection(cls.__name__).insert_one(**_document) - result.append(document) - + for document in documents: + prepare_id_for_query(document, is_mongo=False) + cls._validate_data(data=document) + inserted_document = db.session.collection(cls.__name__).insert_one(**document) + document['_id'] = inserted_document['_id'] + result.append(cls._create_model_instance(document=document)) return result # # # # # Delete # # # # # diff --git a/panther/db/queries/queries.py b/panther/db/queries/queries.py index 5009990..e06eee0 100644 --- a/panther/db/queries/queries.py +++ b/panther/db/queries/queries.py @@ -1,6 +1,8 @@ import sys from typing import Sequence, Iterable +from pantherdb import Cursor as PantherDBCursor + from panther.configs import QueryObservable from panther.db.cursor import Cursor from panther.db.queries.base_queries import BaseQuery @@ -57,7 +59,7 @@ async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None @classmethod @check_connection @log_query - async def find(cls, _filter: dict | None = None, /, **kwargs) -> list[Self] | Cursor: + async def find(cls, _filter: dict | None = None, /, **kwargs) -> PantherDBCursor | Cursor: """ Get documents from the database. @@ -193,7 +195,7 @@ async def insert_many(cls, documents: Iterable[dict]) -> list[Self]: >>> ] >>> await User.insert_many(users) """ - return super().insert_many(documents) + return await super().insert_many(documents) # # # # # Delete # # # # # @check_connection @@ -324,7 +326,7 @@ async def all(cls) -> list[Self] | Cursor: return await cls.find() @classmethod - async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> tuple[bool, Self]: + async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> tuple[Self, bool]: """ Get a single document from the database. or @@ -341,8 +343,8 @@ async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> t >>> await User.find_one_or_insert({'age': 18}, name='Ali') """ if obj := await cls.find_one(_filter, **kwargs): - return False, obj - return True, await cls.insert_one(_filter, **kwargs) + return obj, False + return await cls.insert_one(_filter, **kwargs), True @classmethod async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Self: diff --git a/panther/db/utils.py b/panther/db/utils.py index 8154020..9fe2fdc 100644 --- a/panther/db/utils.py +++ b/panther/db/utils.py @@ -14,7 +14,7 @@ def log_query(func): async def log(*args, **kwargs): - if config['log_queries'] is False: + if config.LOG_QUERIES is False: return await func(*args, **kwargs) start = perf_counter() response = await func(*args, **kwargs) @@ -28,7 +28,7 @@ async def log(*args, **kwargs): def check_connection(func): async def wrapper(*args, **kwargs): - if config['query_engine'] is None: + if config.QUERY_ENGINE is None: msg = "You don't have active database connection, Check your middlewares" raise NotImplementedError(msg) return await func(*args, **kwargs) diff --git a/panther/events.py b/panther/events.py new file mode 100644 index 0000000..7fe84a4 --- /dev/null +++ b/panther/events.py @@ -0,0 +1,44 @@ +import asyncio + +from panther._utils import is_function_async +from panther.configs import config + + +class Event: + @staticmethod + def startup(func): + config.STARTUPS.append(func) + + def wrapper(): + return func() + return wrapper + + @staticmethod + def shutdown(func): + config.SHUTDOWNS.append(func) + + def wrapper(): + return func() + return wrapper + + @staticmethod + async def run_startups(): + for func in config.STARTUPS: + if is_function_async(func): + await func() + else: + func() + + @staticmethod + def run_shutdowns(): + for func in config.SHUTDOWNS: + if is_function_async(func): + try: + asyncio.run(func()) + except ModuleNotFoundError: + # Error: import of asyncio halted; None in sys.modules + # And as I figured it out, it only happens when we are running with + # gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker) + pass + else: + func() diff --git a/panther/generics.py b/panther/generics.py new file mode 100644 index 0000000..a79384a --- /dev/null +++ b/panther/generics.py @@ -0,0 +1,163 @@ +import contextlib +import logging + +from pantherdb import Cursor as PantherDBCursor + +from panther import status +from panther.app import GenericAPI +from panther.configs import config +from panther.db import Model +from panther.db.cursor import Cursor +from panther.exceptions import APIError +from panther.pagination import Pagination +from panther.request import Request +from panther.response import Response +from panther.serializer import ModelSerializer + +with contextlib.suppress(ImportError): + # Only required if user wants to use mongodb + import bson + +logger = logging.getLogger('panther') + + +class ObjectRequired: + def _check_object(self, instance): + if issubclass(type(instance), Model) is False: + logger.critical(f'`{self.__class__.__name__}.object()` should return instance of a Model --> `find_one()`') + raise APIError + + async def object(self, request: Request, **kwargs) -> Model: + """ + Used in `RetrieveAPI`, `UpdateAPI`, `DeleteAPI` + """ + logger.error(f'`object()` method is not implemented in {self.__class__} .') + raise APIError(status_code=status.HTTP_501_NOT_IMPLEMENTED) + + +class ObjectsRequired: + def _check_objects(self, cursor): + if isinstance(cursor, (Cursor, PantherDBCursor)) is False: + logger.critical(f'`{self.__class__.__name__}.objects()` should return a Cursor --> `find()`') + raise APIError + + async def objects(self, request: Request, **kwargs) -> Cursor | PantherDBCursor: + """ + Used in `ListAPI` + Should return `.find()` + """ + logger.error(f'`objects()` method is not implemented in {self.__class__} .') + raise APIError(status_code=status.HTTP_501_NOT_IMPLEMENTED) + + +class RetrieveAPI(GenericAPI, ObjectRequired): + async def get(self, request: Request, **kwargs): + instance = await self.object(request=request, **kwargs) + self._check_object(instance) + + return Response(data=instance, status_code=status.HTTP_200_OK) + + +class ListAPI(GenericAPI, ObjectsRequired): + sort_fields: list[str] + search_fields: list[str] + filter_fields: list[str] + pagination: type[Pagination] + + async def get(self, request: Request, **kwargs): + cursor = await self.objects(request=request, **kwargs) + self._check_objects(cursor) + + query = {} + query |= self.process_filters(query_params=request.query_params, cursor=cursor) + query |= self.process_search(query_params=request.query_params) + + if query: + cursor = await cursor.cls.find(cursor.filter | query) + + if sort := self.process_sort(query_params=request.query_params): + cursor = cursor.sort(sort) + + if pagination := self.process_pagination(query_params=request.query_params, cursor=cursor): + cursor = await pagination.paginate() + + return Response(data=cursor, status_code=status.HTTP_200_OK) + + def process_filters(self, query_params: dict, cursor: Cursor | PantherDBCursor) -> dict: + _filter = {} + if hasattr(self, 'filter_fields'): + for field in self.filter_fields: + if field in query_params: + if config.DATABASE.__class__.__name__ == 'MongoDBConnection': + with contextlib.suppress(Exception): + if cursor.cls.model_fields[field].metadata[0].func.__name__ == 'validate_object_id': + _filter[field] = bson.ObjectId(query_params[field]) + continue + _filter[field] = query_params[field] + return _filter + + def process_search(self, query_params: dict) -> dict: + if hasattr(self, 'search_fields') and 'search' in query_params: + value = query_params['search'] + if config.DATABASE.__class__.__name__ == 'MongoDBConnection': + if search := [{field: {'$regex': value}} for field in self.search_fields]: + return {'$or': search} + else: + logger.warning(f'`?search={value} does not work well while using `PantherDB` as Database') + return {field: value for field in self.search_fields} + return {} + + def process_sort(self, query_params: dict) -> list: + if hasattr(self, 'sort_fields') and 'sort' in query_params: + return [ + (field, -1 if param[0] == '-' else 1) + for field in self.sort_fields for param in query_params['sort'].split(',') + if field == param.removeprefix('-') + ] + + def process_pagination(self, query_params: dict, cursor: Cursor | PantherDBCursor) -> Pagination | None: + if hasattr(self, 'pagination'): + return self.pagination(query_params=query_params, cursor=cursor) + + +class CreateAPI(GenericAPI): + input_model: type[ModelSerializer] + + async def post(self, request: Request, **kwargs): + instance = await request.validated_data.create( + validated_data=request.validated_data.model_dump() + ) + return Response(data=instance, status_code=status.HTTP_201_CREATED) + + +class UpdateAPI(GenericAPI, ObjectRequired): + input_model: type[ModelSerializer] + + async def put(self, request: Request, **kwargs): + instance = await self.object(request=request, **kwargs) + self._check_object(instance) + + await request.validated_data.update( + instance=instance, + validated_data=request.validated_data.model_dump() + ) + return Response(data=instance, status_code=status.HTTP_200_OK) + + async def patch(self, request: Request, **kwargs): + instance = await self.object(request=request, **kwargs) + self._check_object(instance) + + await request.validated_data.partial_update( + instance=instance, + validated_data=request.validated_data.model_dump(exclude_none=True) + ) + return Response(data=instance, status_code=status.HTTP_200_OK) + + +class DeleteAPI(GenericAPI, ObjectRequired): + async def delete(self, request: Request, **kwargs): + instance = await self.object(request=request, **kwargs) + self._check_object(instance) + + await instance.delete() + return Response(status_code=status.HTTP_204_NO_CONTENT) diff --git a/panther/logging.py b/panther/logging.py index 2f37930..346dfbb 100644 --- a/panther/logging.py +++ b/panther/logging.py @@ -2,7 +2,7 @@ from pathlib import Path from panther.configs import config -LOGS_DIR = config['base_dir'] / 'logs' +LOGS_DIR = config.BASE_DIR / 'logs' class FileHandler(logging.FileHandler): @@ -63,6 +63,11 @@ def __init__(self, filename, mode='a', encoding=None, delay=False, errors=None): 'query': { 'handlers': ['default', 'query_file'], 'level': 'DEBUG', - } + }, + 'uvicorn.error': { + 'handlers': ['default'], + 'level': 'WARNING', + 'propagate': False, + }, } } \ No newline at end of file diff --git a/panther/main.py b/panther/main.py index 13ab18f..a4a6cd9 100644 --- a/panther/main.py +++ b/panther/main.py @@ -7,14 +7,15 @@ from logging.config import dictConfig from pathlib import Path +import orjson as json + import panther.logging from panther import status from panther._load_configs import * -from panther._utils import clean_traceback_message, http_response, is_function_async, reformat_code, \ - check_class_type_endpoint, check_function_type_endpoint -from panther.base_websocket import WebsocketListener +from panther._utils import clean_traceback_message, reformat_code, check_class_type_endpoint, check_function_type_endpoint from panther.cli.utils import print_info from panther.configs import config +from panther.events import Event from panther.exceptions import APIError, PantherError from panther.monitoring import Monitoring from panther.request import Request @@ -26,35 +27,20 @@ class Panther: - def __init__( - self, - name: str, - configs=None, - urls: dict | None = None, - startup: Callable = None, - shutdown: Callable = None - ): + def __init__(self, name: str, configs: str | None = None, urls: dict | None = None): self._configs_module_name = configs self._urls = urls - self._startup = startup - self._shutdown = shutdown - config['base_dir'] = Path(name).resolve().parent + config.BASE_DIR = Path(name).resolve().parent try: self.load_configs() - if config['auto_reformat']: - reformat_code(base_dir=config['base_dir']) + if config.AUTO_REFORMAT: + reformat_code(base_dir=config.BASE_DIR) except Exception as e: # noqa: BLE001 - if isinstance(e, PantherError): - logger.error(e.args[0]) - else: - logger.error(clean_traceback_message(e)) + logger.error(e.args[0] if isinstance(e, PantherError) else clean_traceback_message(e)) sys.exit() - # Monitoring - self.monitoring = Monitoring(is_active=config['monitoring']) - # Print Info print_info(config) @@ -65,6 +51,7 @@ def load_configs(self) -> None: load_redis(self._configs_module) load_startup(self._configs_module) load_shutdown(self._configs_module) + load_timezone(self._configs_module) load_database(self._configs_module) load_secret_key(self._configs_module) load_monitoring(self._configs_module) @@ -80,25 +67,14 @@ def load_configs(self) -> None: load_websocket_connections() async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None: - """ - 1. - await func(scope, receive, send) - 2. - async with asyncio.TaskGroup() as tg: - tg.create_task(func(scope, receive, send)) - 3. - async with anyio.create_task_group() as task_group: - task_group.start_soon(func, scope, receive, send) - await anyio.to_thread.run_sync(func, scope, receive, send) - 4. - with ProcessPoolExecutor() as e: - e.submit(func, scope, receive, send) - """ if scope['type'] == 'lifespan': message = await receive() if message["type"] == 'lifespan.startup': - await self.handle_ws_listener() - await self.handle_startup() + await config.WEBSOCKET_CONNECTIONS.start() + await Event.run_startups() + elif message["type"] == 'lifespan.shutdown': + # It's not happening :\, so handle the shutdowns in __del__ ... + pass return func = self.handle_http if scope['type'] == 'http' else self.handle_ws @@ -108,66 +84,75 @@ async def handle_ws(self, scope: dict, receive: Callable, send: Callable) -> Non from panther.websocket import GenericWebsocket, Websocket # Monitoring - monitoring = Monitoring(is_active=config['monitoring'], is_ws=True) + monitoring = Monitoring(is_ws=True) # Create Temp Connection temp_connection = Websocket(scope=scope, receive=receive, send=send) await monitoring.before(request=temp_connection) + temp_connection._monitoring = monitoring # Find Endpoint endpoint, found_path = find_endpoint(path=temp_connection.path) if endpoint is None: - await monitoring.after('Rejected') - return await temp_connection.close(status.WS_1000_NORMAL_CLOSURE) + logger.debug(f'Path `{temp_connection.path}` not found') + return await temp_connection.close() # Check Endpoint Type if not issubclass(endpoint, GenericWebsocket): - logger.critical(f'You may have forgotten to inherit from GenericWebsocket on the {endpoint.__name__}()') - await monitoring.after('Rejected') - return await temp_connection.close(status.WS_1014_BAD_GATEWAY) + logger.critical(f'You may have forgotten to inherit from `GenericWebsocket` on the `{endpoint.__name__}()`') + return await temp_connection.close() # Create The Connection del temp_connection connection = endpoint(scope=scope, receive=receive, send=send) + connection._monitoring = monitoring # Collect Path Variables connection.collect_path_variables(found_path=found_path) - # Call 'Before' Middlewares - if await self._run_ws_middlewares_before_listen(connection=connection): - # Only Listen() If Middlewares Didn't Raise Anything - await config['websocket_connections'].new_connection(connection=connection) - await monitoring.after('Accepted') - await connection.listen() + middlewares = [middleware(**data) for middleware, data in config.WS_MIDDLEWARES] + + # Call Middlewares .before() + await self._run_ws_middlewares_before_listen(connection=connection, middlewares=middlewares) - # Call 'After' Middlewares - await self._run_ws_middlewares_after_listen(connection=connection) + # Listen The Connection + await config.WEBSOCKET_CONNECTIONS.listen(connection=connection) - # Done - await monitoring.after('Closed') - return None + # Call Middlewares .after() + middlewares.reverse() + await self._run_ws_middlewares_after_listen(connection=connection, middlewares=middlewares) @classmethod - async def _run_ws_middlewares_before_listen(cls, *, connection) -> bool: - for middleware in config['ws_middlewares']: - try: - connection = await middleware.before(request=connection) - except APIError: - await connection.close() - return False - return True + async def _run_ws_middlewares_before_listen(cls, *, connection, middlewares): + try: + for middleware in middlewares: + new_connection = await middleware.before(request=connection) + if new_connection is None: + logger.critical( + f'Make sure to return the `request` at the end of `{middleware.__class__.__name__}.before()`') + await connection.close() + connection = new_connection + except APIError as e: + connection.log(e.detail) + await connection.close() @classmethod - async def _run_ws_middlewares_after_listen(cls, *, connection): - for middleware in config['reversed_ws_middlewares']: + async def _run_ws_middlewares_after_listen(cls, *, connection, middlewares): + for middleware in middlewares: with contextlib.suppress(APIError): - await middleware.after(response=connection) + connection = await middleware.after(response=connection) + if connection is None: + logger.critical( + f'Make sure to return the `response` at the end of `{middleware.__class__.__name__}.after()`') + break async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> None: + # Monitoring + monitoring = Monitoring() + request = Request(scope=scope, receive=receive, send=send) - # Monitoring - await self.monitoring.before(request=request) + await monitoring.before(request=request) # Read Request Payload await request.read_body() @@ -175,7 +160,7 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N # Find Endpoint _endpoint, found_path = find_endpoint(path=request.path) if _endpoint is None: - return await self._raise(send, status_code=status.HTTP_404_NOT_FOUND) + return await self._raise(send, monitoring=monitoring, status_code=status.HTTP_404_NOT_FOUND) # Check Endpoint Type try: @@ -184,15 +169,20 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N else: endpoint = check_class_type_endpoint(endpoint=_endpoint) except TypeError: - return await self._raise(send, status_code=status.HTTP_501_NOT_IMPLEMENTED) + return await self._raise(send, monitoring=monitoring, status_code=status.HTTP_501_NOT_IMPLEMENTED) # Collect Path Variables request.collect_path_variables(found_path=found_path) - try: # They Both(middleware.before() & _endpoint()) Have The Same Exception (APIException) - # Call 'Before' Middlewares - for middleware in config['http_middlewares']: + middlewares = [middleware(**data) for middleware, data in config.HTTP_MIDDLEWARES] + try: # They Both(middleware.before() & _endpoint()) Have The Same Exception (APIError) + # Call Middlewares .before() + for middleware in middlewares: request = await middleware.before(request=request) + if request is None: + logger.critical( + f'Make sure to return the `request` at the end of `{middleware.__class__.__name__}.before()`') + return await self._raise(send, monitoring=monitoring) # Call Endpoint response = await endpoint(request=request) @@ -201,60 +191,27 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N response = self._handle_exceptions(e) except Exception as e: # noqa: BLE001 - # Every unhandled exception in Panther or code will catch here + # All unhandled exceptions are caught here exception = clean_traceback_message(exception=e) logger.critical(exception) - return await self._raise(send, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + return await self._raise(send, monitoring=monitoring) - # Call 'After' Middleware - for middleware in config['reversed_http_middlewares']: + # Call Middlewares .after() + middlewares.reverse() + for middleware in middlewares: try: response = await middleware.after(response=response) + if response is None: + logger.critical( + f'Make sure to return the `response` at the end of `{middleware.__class__.__name__}.after()`') + return await self._raise(send, monitoring=monitoring) except APIError as e: # noqa: PERF203 response = self._handle_exceptions(e) - await http_response( - send, - status_code=response.status_code, - monitoring=self.monitoring, - headers=response.headers, - body=response.body, - ) - - async def handle_ws_listener(self): - """ - Cause of --preload in gunicorn we have to keep this function here, - and we can't move it to __init__ of Panther - - * Each process should start this listener for itself, - but they have same Manager() - """ - # Start Websocket Listener (Redis/ Queue) - if config['has_ws']: - WebsocketListener().start() - - async def handle_startup(self): - if startup := config['startup'] or self._startup: - if is_function_async(startup): - await startup() - else: - startup() - - def handle_shutdown(self): - if shutdown := config['shutdown'] or self._shutdown: - if is_function_async(shutdown): - try: - asyncio.run(shutdown()) - except ModuleNotFoundError: - # Error: import of asyncio halted; None in sys.modules - # And as I figured it out, it only happens when we are running with - # gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker) - pass - else: - shutdown() + await response.send(send, receive, monitoring=monitoring) def __del__(self): - self.handle_shutdown() + Event.run_shutdowns() @classmethod def _handle_exceptions(cls, e: APIError, /) -> Response: @@ -263,11 +220,11 @@ def _handle_exceptions(cls, e: APIError, /) -> Response: status_code=e.status_code, ) - async def _raise(self, send, *, status_code: int): - await http_response( - send, - headers={'content-type': 'application/json'}, - status_code=status_code, - monitoring=self.monitoring, - exception=True, - ) + @classmethod + async def _raise(cls, send, *, monitoring, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR): + headers = [[b'Content-Type', b'application/json']] + body = json.dumps({'detail': status.status_text[status_code]}) + await monitoring.after(status_code) + await send({'type': 'http.response.start', 'status': status_code, 'headers': headers}) + await send({'type': 'http.response.body', 'body': body, 'more_body': False}) + diff --git a/panther/monitoring.py b/panther/monitoring.py index 730a0a8..3af1089 100644 --- a/panther/monitoring.py +++ b/panther/monitoring.py @@ -3,7 +3,7 @@ from typing import Literal from panther.base_request import BaseRequest - +from panther.configs import config logger = logging.getLogger('monitoring') @@ -13,12 +13,11 @@ class Monitoring: Create Log Message Like Below: date time | method | path | ip:port | response_time [ms, s] | status """ - def __init__(self, is_active: bool, is_ws: bool = False): - self.is_active = is_active + def __init__(self, is_ws: bool = False): self.is_ws = is_ws async def before(self, request: BaseRequest): - if self.is_active: + if config.MONITORING: ip, port = request.client if self.is_ws: @@ -30,7 +29,7 @@ async def before(self, request: BaseRequest): self.start_time = perf_counter() async def after(self, status: int | Literal['Accepted', 'Rejected', 'Closed'], /): - if self.is_active: + if config.MONITORING: response_time = perf_counter() - self.start_time time_unit = ' s' @@ -38,4 +37,8 @@ async def after(self, status: int | Literal['Accepted', 'Rejected', 'Closed'], / response_time = response_time * 1_000 time_unit = 'ms' + elif response_time >= 10: + response_time = response_time / 60 + time_unit = ' m' + logger.info(f'{self.log} | {round(response_time, 4)} {time_unit} | {status}') diff --git a/panther/pagination.py b/panther/pagination.py new file mode 100644 index 0000000..6cb70fd --- /dev/null +++ b/panther/pagination.py @@ -0,0 +1,48 @@ +from panther.db.cursor import Cursor +from pantherdb import Cursor as PantherDBCursor + + +class Pagination: + """ + Request URL: + example.com/users?limit=10&skip=0 + Response Data: + { + 'count': 10, + 'next': '?limit=10&skip=10', + 'previous': None, + results: [...] + } + """ + DEFAULT_LIMIT = 20 + DEFAULT_SKIP = 0 + + def __init__(self, query_params: dict, cursor: Cursor | PantherDBCursor): + self.limit = self.get_limit(query_params=query_params) + self.skip = self.get_skip(query_params=query_params) + self.cursor = cursor + + def get_limit(self, query_params: dict) -> int: + return int(query_params.get('limit', self.DEFAULT_LIMIT)) + + def get_skip(self, query_params: dict) -> int: + return int(query_params.get('skip', self.DEFAULT_SKIP)) + + def build_next_params(self): + next_skip = self.skip + self.limit + return f'?limit={self.limit}&skip={next_skip}' + + def build_previous_params(self): + previous_skip = max(self.skip - self.limit, 0) + return f'?limit={self.limit}&skip={previous_skip}' + + async def paginate(self): + count = await self.cursor.cls.count(self.cursor.filter) + has_next = not bool(self.limit + self.skip >= count) + + return { + 'count': count, + 'next': self.build_next_params() if has_next else None, + 'previous': self.build_previous_params() if self.skip else None, + 'results': self.cursor.skip(skip=self.skip).limit(limit=self.limit) + } diff --git a/panther/panel/apis.py b/panther/panel/apis.py index c30d3c4..f3f067a 100644 --- a/panther/panel/apis.py +++ b/panther/panel/apis.py @@ -20,12 +20,12 @@ async def models_api(): 'name': model.__name__, 'module': model.__module__, 'index': i - } for i, model in enumerate(config['models'])] + } for i, model in enumerate(config.MODELS)] @API(methods=['GET', 'POST']) async def documents_api(request: Request, index: int): - model = config['models'][index] + model = config.MODELS[index] if request.method == 'POST': validated_data = API.validate_input(model=model, request=request) @@ -45,7 +45,7 @@ async def documents_api(request: Request, index: int): @API(methods=['PUT', 'DELETE', 'GET']) async def single_document_api(request: Request, index: int, document_id: int | str): - model = config['models'][index] + model = config.MODELS[index] if document := model.find_one(id=document_id): if request.method == 'PUT': @@ -69,7 +69,7 @@ async def healthcheck_api(): checks = [] # Database - if config['query_engine'].__name__ == 'BaseMongoDBQuery': + if config.QUERY_ENGINE.__name__ == 'BaseMongoDBQuery': with pymongo.timeout(3): try: ping = db.session.command('ping').get('ok') == 1.0 @@ -78,6 +78,6 @@ async def healthcheck_api(): checks.append(False) # Redis if redis.is_connected: - checks.append(redis.ping()) + checks.append(await redis.ping()) return Response(all(checks)) diff --git a/panther/permissions.py b/panther/permissions.py index decfa0f..26f9cbd 100644 --- a/panther/permissions.py +++ b/panther/permissions.py @@ -10,4 +10,4 @@ async def authorization(cls, request: Request) -> bool: class AdminPermission(BasePermission): @classmethod async def authorization(cls, request: Request) -> bool: - return request.user and hasattr(request.user, 'is_admin') and request.user.is_admin + return request.user and getattr(request.user, 'is_admin', False) diff --git a/panther/request.py b/panther/request.py index e980e42..921cabe 100644 --- a/panther/request.py +++ b/panther/request.py @@ -1,5 +1,5 @@ import logging -from typing import Literal +from typing import Literal, Callable import orjson as json @@ -10,6 +10,11 @@ class Request(BaseRequest): + def __init__(self, scope: dict, receive: Callable, send: Callable): + self._data = ... + self.validated_data = None # It's been set in API.validate_input() + super().__init__(scope=scope, receive=receive, send=send) + @property def method(self) -> Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']: return self.scope['method'] @@ -29,17 +34,6 @@ def data(self) -> dict | bytes: self._data = self.__body return self._data - def set_validated_data(self, validated_data) -> None: - self._validated_data = validated_data - - @property - def validated_data(self): - """ - Return The Validated Data - It has been set on API.validate_input() while request is happening. - """ - return getattr(self, '_validated_data', None) - async def read_body(self) -> None: """Read the entire body from an incoming ASGI message.""" self.__body = b'' diff --git a/panther/response.py b/panther/response.py index 69b7bbf..9943223 100644 --- a/panther/response.py +++ b/panther/response.py @@ -1,13 +1,20 @@ +import asyncio from types import NoneType +from typing import Generator, AsyncGenerator import orjson as json from pydantic import BaseModel as PydanticBaseModel from pydantic._internal._model_construction import ModelMetaclass +from panther import status +from panther._utils import to_async_generator from panther.db.cursor import Cursor +from pantherdb import Cursor as PantherDBCursor +from panther.monitoring import Monitoring -ResponseDataTypes = list | tuple | set | dict | int | float | str | bool | bytes | NoneType | ModelMetaclass -IterableDataTypes = list | tuple | set | Cursor +ResponseDataTypes = list | tuple | set | Cursor | PantherDBCursor | dict | int | float | str | bool | bytes | NoneType | ModelMetaclass +IterableDataTypes = list | tuple | set | Cursor | PantherDBCursor +StreamingDataTypes = Generator | AsyncGenerator class Response: @@ -17,16 +24,16 @@ def __init__( self, data: ResponseDataTypes = None, headers: dict | None = None, - status_code: int = 200, + status_code: int = status.HTTP_200_OK, ): """ - :param data: should be int | float | dict | list | tuple | set | str | bool | bytes | NoneType - or instance of Pydantic.BaseModel + :param data: should be an instance of ResponseDataTypes + :param headers: should be dict of headers :param status_code: should be int """ - self.data = self._clean_data_type(data) - self._check_status_code(status_code) - self._headers = headers + self.headers = headers or {} + self.data = self.prepare_data(data=data) + self.status_code = self.check_status_code(status_code=status_code) @property def body(self) -> bytes: @@ -40,54 +47,73 @@ def body(self) -> bytes: @property def headers(self) -> dict: return { - 'content-type': self.content_type, - 'content-length': len(self.body), - 'access-control-allow-origin': '*', - } | (self._headers or {}) + 'Content-Type': self.content_type, + 'Content-Length': len(self.body), + 'Access-Control-Allow-Origin': '*', + } | self._headers - def _clean_data_type(self, data: any): - """Make sure the response data is only ResponseDataTypes or Iterable of ResponseDataTypes""" - if issubclass(type(data), PydanticBaseModel): - return data.model_dump() + @property + def bytes_headers(self) -> list[list[bytes]]: + return [[k.encode(), str(v).encode()] for k, v in (self.headers or {}).items()] - elif isinstance(data, IterableDataTypes): - return [self._clean_data_type(d) for d in data] + @headers.setter + def headers(self, headers: dict): + self._headers = headers + + def prepare_data(self, data: any): + """Make sure the response data is only ResponseDataTypes or Iterable of ResponseDataTypes""" + if isinstance(data, (int | float | str | bool | bytes | NoneType)): + return data elif isinstance(data, dict): - return {key: self._clean_data_type(value) for key, value in data.items()} + return {key: self.prepare_data(value) for key, value in data.items()} - elif isinstance(data, (int | float | str | bool | bytes | NoneType)): - return data + elif issubclass(type(data), PydanticBaseModel): + return data.model_dump() + + elif isinstance(data, IterableDataTypes): + return [self.prepare_data(d) for d in data] else: msg = f'Invalid Response Type: {type(data)}' raise TypeError(msg) - def _check_status_code(self, status_code: any): + @classmethod + def check_status_code(cls, status_code: any): if not isinstance(status_code, int): - error = f'Response "status_code" Should Be "int". ("{status_code}" is {type(status_code)})' + error = f'Response `status_code` Should Be `int`. (`{status_code}` is {type(status_code)})' raise TypeError(error) - - self.status_code = status_code - - def _clean_data_with_output_model(self, output_model: ModelMetaclass | None): - if self.data and output_model: - self.data = self._serialize_with_output_model(self.data, output_model=output_model) + return status_code @classmethod - def _serialize_with_output_model(cls, data: any, /, output_model: ModelMetaclass): + def apply_output_model(cls, data: any, /, output_model: ModelMetaclass): + """This method is called in API.__call__""" # Dict if isinstance(data, dict): + for field_name, field in output_model.model_fields.items(): + if field.validation_alias and field_name in data: + data[field.validation_alias] = data.pop(field_name) return output_model(**data).model_dump() # Iterable if isinstance(data, IterableDataTypes): - return [cls._serialize_with_output_model(d, output_model=output_model) for d in data] + return [cls.apply_output_model(d, output_model=output_model) for d in data] # Str | Bool | Bytes msg = 'Type of Response data is not match with `output_model`.\n*hint: You may want to remove `output_model`' raise TypeError(msg) + async def send_headers(self, send, /): + await send({'type': 'http.response.start', 'status': self.status_code, 'headers': self.bytes_headers}) + + async def send_body(self, send, receive, /): + await send({'type': 'http.response.body', 'body': self.body, 'more_body': False}) + + async def send(self, send, receive, /, monitoring: Monitoring): + await self.send_headers(send) + await self.send_body(send, receive) + await monitoring.after(self.status_code) + def __str__(self): if len(data := str(self.data)) > 30: data = f'{data:.27}...' @@ -96,6 +122,57 @@ def __str__(self): __repr__ = __str__ +class StreamingResponse(Response): + content_type = 'application/octet-stream' + + def __init__(self, *args, **kwargs): + self.connection_closed = False + super().__init__(*args, **kwargs) + + async def listen_to_disconnection(self, receive): + message = await receive() + if message['type'] == 'http.disconnect': + self.connection_closed = True + + def prepare_data(self, data: any) -> AsyncGenerator: + if isinstance(data, AsyncGenerator): + return data + elif isinstance(data, Generator): + return to_async_generator(data) + msg = f'Invalid Response Type: {type(data)}' + raise TypeError(msg) + + @property + def headers(self) -> dict: + return { + 'Content-Type': self.content_type, + 'Access-Control-Allow-Origin': '*', + } | self._headers + + @headers.setter + def headers(self, headers: dict): + self._headers = headers + + @property + async def body(self) -> AsyncGenerator: + async for chunk in self.data: + if isinstance(chunk, bytes): + yield chunk + elif chunk is None: + yield b'' + else: + yield json.dumps(chunk) + + async def send_body(self, send, receive, /): + asyncio.create_task(self.listen_to_disconnection(receive)) + async for chunk in self.body: + if self.connection_closed: + break + await send({'type': 'http.response.body', 'body': chunk, 'more_body': True}) + else: + await send({'type': 'http.response.body', 'body': b'', 'more_body': False}) + + class HTMLResponse(Response): content_type = 'text/html; charset=utf-8' diff --git a/panther/routings.py b/panther/routings.py index 5471613..2543926 100644 --- a/panther/routings.py +++ b/panther/routings.py @@ -1,4 +1,3 @@ -import logging import re from collections import Counter from collections.abc import Callable, Mapping, MutableMapping @@ -6,9 +5,7 @@ from functools import partial, reduce from panther.configs import config - - -logger = logging.getLogger('panther') +from panther.exceptions import PantherError def flatten_urls(urls: dict) -> dict: @@ -28,20 +25,17 @@ def _flattening_urls(data: dict | Callable, url: str = ''): url = url.removeprefix('/') # Collect it, if it doesn't have problem - if _is_url_endpoint_valid(url=url, endpoint=data): - yield url, data + _is_url_endpoint_valid(url=url, endpoint=data) + yield url, data -def _is_url_endpoint_valid(url: str, endpoint: Callable) -> bool: +def _is_url_endpoint_valid(url: str, endpoint: Callable): if endpoint is ...: - logger.error(f"URL Can't Point To Ellipsis. ('{url}' -> ...)") + raise PantherError(f"URL Can't Point To Ellipsis. ('{url}' -> ...)") elif endpoint is None: - logger.error(f"URL Can't Point To None. ('{url}' -> None)") + raise PantherError(f"URL Can't Point To None. ('{url}' -> None)") elif url and not re.match(r'^[a-zA-Z<>0-9_/-]+$', url): - logger.error(f"URL Is Not Valid. --> '{url}'") - else: - return True - return False + raise PantherError(f"URL Is Not Valid. --> '{url}'") def finalize_urls(urls: dict) -> dict: @@ -60,7 +54,33 @@ def finalize_urls(urls: dict) -> dict: else: path = {single_path: path or endpoint} urls_list.append(path) - return _merge(*urls_list) if urls_list else {} + final_urls = _merge(*urls_list) if urls_list else {} + check_urls_path_variables(final_urls) + return final_urls + + +def check_urls_path_variables(urls: dict, path: str = '', ) -> None: + middle_route_error = [] + last_route_error = [] + for key, value in urls.items(): + new_path = f'{path}/{key}' + + if isinstance(value, dict): + if key.startswith('<'): + middle_route_error.append(new_path) + check_urls_path_variables(value, path=new_path) + elif key.startswith('<'): + last_route_error.append(new_path) + + if len(middle_route_error) > 1: + msg = '\n\t- ' + '\n\t- '.join(e for e in middle_route_error) + raise PantherError( + f"URLs can't have same-level path variables that point to a dict: {msg}") + + if len(last_route_error) > 1: + msg = '\n\t- ' + '\n\t- '.join(e for e in last_route_error) + raise PantherError( + f"URLs can't have same-level path variables that point to an endpoint: {msg}") def _merge(destination: MutableMapping, *sources) -> MutableMapping: @@ -106,7 +126,7 @@ def _is_recursive_merge(a, b): def find_endpoint(path: str) -> tuple[Callable | None, str]: - urls = config['urls'] + urls = config.URLS # 'user/list/?name=ali' --> 'user/list/' --> 'user/list' --> ['user', 'list'] parts = path.split('?')[0].strip('/').split('/') diff --git a/panther/serializer.py b/panther/serializer.py index 2eea063..c2f538e 100644 --- a/panther/serializer.py +++ b/panther/serializer.py @@ -1,14 +1,16 @@ import typing +from typing import TypeVar, Type -from pydantic import create_model, BaseModel -from pydantic.fields import FieldInfo +from pydantic import create_model, BaseModel, ConfigDict +from pydantic.fields import FieldInfo, Field from pydantic_core._pydantic_core import PydanticUndefined from panther.db import Model +from panther.request import Request class MetaModelSerializer: - KNOWN_CONFIGS = ['model', 'fields', 'required_fields'] + KNOWN_CONFIGS = ['model', 'fields', 'exclude', 'required_fields', 'optional_fields'] def __new__( cls, @@ -18,6 +20,9 @@ def __new__( **kwargs ): if cls_name == 'ModelSerializer': + # Put `model` and `request` to the main class with `create_model()` + namespace['__annotations__'].pop('model') + namespace['__annotations__'].pop('request') cls.model_serializer = type(cls_name, (), namespace) return super().__new__(cls) @@ -38,7 +43,8 @@ def __new__( __module__=namespace['__module__'], __validators__=namespace, __base__=(cls.model_serializer, BaseModel), - model=(typing.ClassVar, config.model), + model=(typing.ClassVar[type[BaseModel]], config.model), + request=(Request, Field(None, exclude=True)), **field_definitions ) @@ -67,22 +73,70 @@ def check_config(cls, cls_name: str, namespace: dict) -> None: raise AttributeError(msg) from None # Check `fields` - if (fields := getattr(config, 'fields', None)) is None: + if not hasattr(config, 'fields'): msg = f'`{cls_name}.Config.fields` is required.' raise AttributeError(msg) from None - for field_name in fields: - if field_name not in model.model_fields: - msg = f'`{cls_name}.Config.fields.{field_name}` is not valid.' - raise AttributeError(msg) from None + if config.fields != '*': + for field_name in config.fields: + if field_name == '*': + msg = f"`{cls_name}.Config.fields.{field_name}` is not valid. Did you mean `fields = '*'`" + raise AttributeError(msg) from None + + if field_name not in model.model_fields: + msg = f'`{cls_name}.Config.fields.{field_name}` is not in `{model.__name__}.model_fields`' + raise AttributeError(msg) from None # Check `required_fields` if not hasattr(config, 'required_fields'): config.required_fields = [] - for required in config.required_fields: - if required not in config.fields: - msg = f'`{cls_name}.Config.required_fields.{required}` should be in `Config.fields` too.' + if config.required_fields != '*': + for required in config.required_fields: + if required not in config.fields: + msg = f'`{cls_name}.Config.required_fields.{required}` should be in `Config.fields` too.' + raise AttributeError(msg) from None + + # Check `optional_fields` + if not hasattr(config, 'optional_fields'): + config.optional_fields = [] + + if config.optional_fields != '*': + for optional in config.optional_fields: + if optional not in config.fields: + msg = f'`{cls_name}.Config.optional_fields.{optional}` should be in `Config.fields` too.' + raise AttributeError(msg) from None + + # Check `required_fields` and `optional_fields` together + if ( + (config.optional_fields == '*' and config.required_fields != []) or + (config.required_fields == '*' and config.optional_fields != []) + ): + msg = ( + f"`{cls_name}.Config.optional_fields` and " + f"`{cls_name}.Config.required_fields` can't include same fields at the same time" + ) + raise AttributeError(msg) from None + for optional in config.optional_fields: + for required in config.required_fields: + if optional == required: + msg = ( + f"`{optional}` can't be in `{cls_name}.Config.optional_fields` and " + f"`{cls_name}.Config.required_fields` at the same time" + ) + raise AttributeError(msg) from None + + # Check `exclude` + if not hasattr(config, 'exclude'): + config.exclude = [] + + for field_name in config.exclude: + if field_name not in model.model_fields: + msg = f'`{cls_name}.Config.exclude.{field_name}` is not valid.' + raise AttributeError(msg) from None + + if config.fields != '*' and field_name not in config.fields: + msg = f'`{cls_name}.Config.exclude.{field_name}` is not defined in `Config.fields`.' raise AttributeError(msg) from None @classmethod @@ -90,15 +144,35 @@ def collect_fields(cls, config: typing.Callable, namespace: dict) -> dict: field_definitions = {} # Define `fields` - for field_name in config.fields: - field_definitions[field_name] = ( - config.model.model_fields[field_name].annotation, - config.model.model_fields[field_name] - ) + if config.fields == '*': + for field_name, field in config.model.model_fields.items(): + field_definitions[field_name] = (field.annotation, field) + else: + for field_name in config.fields: + field_definitions[field_name] = ( + config.model.model_fields[field_name].annotation, + config.model.model_fields[field_name] + ) + + # Apply `exclude` + for field_name in config.exclude: + del field_definitions[field_name] # Apply `required_fields` - for required in config.required_fields: - field_definitions[required][1].default = PydanticUndefined + if config.required_fields == '*': + for value in field_definitions.values(): + value[1].default = PydanticUndefined + else: + for field_name in config.required_fields: + field_definitions[field_name][1].default = PydanticUndefined + + # Apply `optional_fields` + if config.optional_fields == '*': + for value in field_definitions.values(): + value[1].default = value[0]() + else: + for field_name in config.optional_fields: + field_definitions[field_name][1].default = field_definitions[field_name][0]() # Collect and Override `Class Fields` for key, value in namespace.pop('__annotations__', {}).items(): @@ -112,9 +186,43 @@ def collect_model_config(cls, config: typing.Callable, namespace: dict) -> dict: return { attr: getattr(config, attr) for attr in dir(config) if not attr.startswith('__') and attr not in cls.KNOWN_CONFIGS - } | namespace.pop('model_config', {}) + } | namespace.pop('model_config', {}) | {'arbitrary_types_allowed': True} class ModelSerializer(metaclass=MetaModelSerializer): - def create(self) -> type[Model]: - return self.model.insert_one(self.model_dump()) + """ + Doc: + https://pantherpy.github.io/serializer/#style-2-model-serializer + Example: + class PersonSerializer(ModelSerializer): + class Meta: + model = Person + fields = '*' + exclude = ['created_date'] # Optional + required_fields = ['first_name', 'last_name'] # Optional + optional_fields = ['age'] # Optional + """ + model: type[BaseModel] + request: Request + + async def create(self, validated_data: dict) -> Model: + """ + validated_data = ModelSerializer.model_dump() + """ + return await self.model.insert_one(validated_data) + + async def update(self, instance: Model, validated_data: dict) -> Model: + """ + instance = UpdateAPI.object() + validated_data = ModelSerializer.model_dump() + """ + await instance.update(validated_data) + return instance + + async def partial_update(self, instance: Model, validated_data: dict) -> Model: + """ + instance = UpdateAPI.object() + validated_data = ModelSerializer.model_dump(exclude_none=True) + """ + await instance.update(validated_data) + return instance diff --git a/panther/test.py b/panther/test.py index 554ae6b..7e45052 100644 --- a/panther/test.py +++ b/panther/test.py @@ -4,7 +4,7 @@ import orjson as json -from panther.response import Response +from panther.response import Response, HTMLResponse, PlainTextResponse, StreamingResponse __all__ = ('APIClient', 'WebsocketClient') @@ -12,12 +12,13 @@ class RequestClient: def __init__(self, app: Callable): self.app = app + self.response = b'' async def send(self, data: dict): if data['type'] == 'http.response.start': self.header = data else: - self.response = data + self.response += data['body'] async def receive(self): return { @@ -54,11 +55,22 @@ async def request( receive=self.receive, send=self.send, ) - return Response( - data=json.loads(self.response.get('body', b'null')), - status_code=self.header['status'], - headers=self.header['headers'], - ) + response_headers = {key.decode(): value.decode() for key, value in self.header['headers']} + if response_headers['Content-Type'] == 'text/html; charset=utf-8': + data = self.response.decode() + return HTMLResponse(data=data, status_code=self.header['status'], headers=response_headers) + + elif response_headers['Content-Type'] == 'text/plain; charset=utf-8': + data = self.response.decode() + return PlainTextResponse(data=data, status_code=self.header['status'], headers=response_headers) + + elif response_headers['Content-Type'] == 'application/octet-stream': + data = self.response.decode() + return PlainTextResponse(data=data, status_code=self.header['status'], headers=response_headers) + + else: + data = json.loads(self.response or b'null') + return Response(data=data, status_code=self.header['status'], headers=response_headers) class APIClient: diff --git a/panther/utils.py b/panther/utils.py index 0d19e82..337d8a5 100644 --- a/panther/utils.py +++ b/panther/utils.py @@ -7,6 +7,10 @@ from pathlib import Path from typing import ClassVar +import pytz + +from panther.configs import config + logger = logging.getLogger('panther') URANDOM_SIZE = 16 @@ -25,8 +29,7 @@ def load_env(env_file: str | Path, /) -> dict[str, str]: variables = {} if env_file is None or not Path(env_file).is_file(): - logger.critical(f'"{env_file}" is not valid file for load_env()') - return variables + raise ValueError(f'"{env_file}" is not a file.') from None with open(env_file) as file: for line in file.readlines(): @@ -80,26 +83,8 @@ def scrypt(password: str, salt: bytes, digest: bool = False) -> str | bytes: dklen=dk_len ) if digest: - return hashlib.md5(derived_key).hexdigest() - else: - return derived_key - - -def encrypt_password(password: str) -> str: - salt = os.urandom(URANDOM_SIZE) - derived_key = scrypt(password=password, salt=salt, digest=True) - - return f'{salt.hex()}{derived_key}' - - -def check_password(stored_password: str, new_password: str) -> bool: - size = URANDOM_SIZE * 2 - salt = stored_password[:size] - stored_hash = stored_password[size:] - derived_key = scrypt(password=new_password, salt=bytes.fromhex(salt), digest=True) - - return derived_key == stored_hash + return derived_key class ULID: @@ -121,3 +106,8 @@ def _generate(cls, bits: str) -> str: cls.crockford_base32_characters[int(bits[i: i + 5], base=2)] for i in range(0, 130, 5) ) + + +def timezone_now(): + tz = pytz.timezone(config.TIMEZONE) + return datetime.now(tz=tz) diff --git a/panther/websocket.py b/panther/websocket.py index f32784a..37f0972 100644 --- a/panther/websocket.py +++ b/panther/websocket.py @@ -22,16 +22,16 @@ async def receive(self, data: str | bytes): async def send(self, data: any = None): """ - We are using this method to send message to the client, + Send message to the client, You may want to override it with your custom scenario. (not recommended) """ return await super().send(data=data) async def send_message_to_websocket(connection_id: str, data: any): - config.websocket_connections.publish(connection_id=connection_id, action='send', data=data) + await config.WEBSOCKET_CONNECTIONS.publish(connection_id=connection_id, action='send', data=data) async def close_websocket_connection(connection_id: str, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = ''): data = {'code': code, 'reason': reason} - config.websocket_connections.publish(connection_id=connection_id, action='close', data=data) + await config.WEBSOCKET_CONNECTIONS.publish(connection_id=connection_id, action='close', data=data) diff --git a/setup.py b/setup.py index 003e262..8fce320 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ def panther_version() -> str: 'python-jose~=3.3', 'websockets~=12.0', 'cryptography~=42.0', + 'watchfiles~=0.21.0', ], } @@ -50,11 +51,11 @@ def panther_version() -> str: }, install_requires=[ 'httptools~=0.6', - 'pantherdb==1.4.0', + 'pantherdb==2.0.0', 'pydantic~=2.6', 'rich~=13.7', 'uvicorn~=0.27', - 'watchfiles~=0.21.0', + 'pytz~=2024.1', ], extras_require=EXTRAS_REQUIRE, ) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 821fe6c..b223af3 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -69,15 +69,15 @@ async def test_user_without_auth(self): assert res.data is None async def test_user_auth_required_without_auth_class(self): - auth_config = config['authentication'] - config['authentication'] = None + auth_config = config.AUTHENTICATION + config.AUTHENTICATION = None with self.assertLogs(level='CRITICAL') as captured: res = await self.client.get('auth-required') assert len(captured.records) == 1 assert captured.records[0].getMessage() == '"AUTHENTICATION" has not been set in configs' assert res.status_code == 500 assert res.data['detail'] == 'Internal Server Error' - config['authentication'] = auth_config + config.AUTHENTICATION = auth_config async def test_user_auth_required_without_token(self): with self.assertLogs(level='ERROR') as captured: diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index 7a8bf30..4530064 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -9,11 +9,11 @@ class TestBackgroundTasks(TestCase): def setUp(self): self.obj = BackgroundTasks() - config['background_tasks'] = True + config.BACKGROUND_TASKS = True def tearDown(self): del Singleton._instances[BackgroundTasks] - config['background_tasks'] = False + config.BACKGROUND_TASKS = False def test_background_tasks_singleton(self): new_obj = BackgroundTasks() @@ -60,7 +60,7 @@ def func(_numbers): _numbers.append(1) with self.assertLogs() as captured: self.obj.add_task(task) assert len(captured.records) == 1 - assert captured.records[0].getMessage() == 'Task will be ignored, `BACKGROUND_TASKS` is not True in `core/configs.py`' + assert captured.records[0].getMessage() == 'Task will be ignored, `BACKGROUND_TASKS` is not True in `configs`' assert self.obj.tasks == [] def test_add_task_with_args(self): diff --git a/tests/test_caching.py b/tests/test_caching.py index efbea48..25efa68 100644 --- a/tests/test_caching.py +++ b/tests/test_caching.py @@ -1,4 +1,5 @@ import time +import asyncio from datetime import timedelta from unittest import IsolatedAsyncioTestCase @@ -10,19 +11,19 @@ @API() async def without_cache_api(): - time.sleep(0.01) + await asyncio.sleep(0.01) return {'detail': time.time()} @API(cache=True) async def with_cache_api(): - time.sleep(0.01) + await asyncio.sleep(0.01) return {'detail': time.time()} @API(cache=True, cache_exp_time=timedelta(seconds=5)) async def expired_cache_api(): - time.sleep(0.01) + await asyncio.sleep(0.01) return {'detail': time.time()} @@ -65,7 +66,7 @@ async def test_with_cache_5second_exp_time(self): # Check Logs assert len(captured.records) == 1 - assert captured.records[0].getMessage() == '"cache_exp_time" is not very accurate when redis is not connected.' + assert captured.records[0].getMessage() == '`cache_exp_time` is not very accurate when `redis` is not connected.' # Second Request res2 = await self.client.get('with-expired-cache') @@ -74,7 +75,7 @@ async def test_with_cache_5second_exp_time(self): # Response should be cached assert check_two_dicts(res1.data, res2.data) is True - time.sleep(5) + await asyncio.sleep(5) # Third Request res3 = await self.client.get('with-expired-cache') diff --git a/tests/test_cli.py b/tests/test_cli.py index 955fd5f..2e7212a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -23,7 +23,7 @@ def interactive_cli_1_mock_responses(index=None): global interactive_cli_1_index if index is None: index = interactive_cli_1_index - responses = ['project1', 'project1_dir', 'n', '0', 'y', 'y', 'y', 'y', 'y'] + responses = ['project1', 'project1_dir', 'n', '0', 'y', 'n', 'y', 'y', 'y', 'y'] response = responses[index] interactive_cli_1_index += 1 return response @@ -33,7 +33,7 @@ def interactive_cli_2_mock_responses(index=None): global interactive_cli_2_index if index is None: index = interactive_cli_2_index - responses = ['project2', 'project2_dir', 'y', '0', 'y', 'y', 'y', 'y', 'y'] + responses = ['project2', 'project2_dir', 'y', '0', 'y', 'n', 'y', 'y', 'y', 'y'] response = responses[index] interactive_cli_2_index += 1 return response diff --git a/tests/test_database.py b/tests/test_database.py index eed5750..560a596 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -8,7 +8,8 @@ from panther import Panther from panther.db import Model from panther.db.connections import db -from panther.db.cursor import Cursor +from panther.db.cursor import Cursor as MongoCursor +from pantherdb import Cursor as PantherDBCursor f = faker.Faker() @@ -33,10 +34,24 @@ async def test_insert_one(self): assert book.name == name assert book.pages_count == pages_count - async def test_insert_many(self): + async def test_insert_many_with_insert_one(self): insert_count = await self._insert_many() assert insert_count > 1 + async def test_insert_many(self): + insert_count = random.randint(2, 10) + data = [ + {'name': f.name(), 'author': f.name(), 'pages_count': random.randint(0, 10)} + for _ in range(insert_count) + ] + books = await Book.insert_many(data) + inserted_books = [ + {'_id': book._id, 'name': book.name, 'author': book.author, 'pages_count': book.pages_count} + for book in books + ] + assert len(books) == insert_count + assert data == inserted_books + # # # FindOne async def test_find_one_not_found(self): # Insert Many @@ -157,9 +172,9 @@ async def test_find(self): _len = sum(1 for _ in books) if self.__class__.__name__ == 'TestMongoDB': - assert isinstance(books, Cursor) + assert isinstance(books, MongoCursor) else: - assert isinstance(books, list) + assert isinstance(books, PantherDBCursor) assert _len == insert_count for book in books: assert isinstance(book, Book) @@ -174,9 +189,9 @@ async def test_find_not_found(self): _len = sum(1 for _ in books) if self.__class__.__name__ == 'TestMongoDB': - assert isinstance(books, Cursor) + assert isinstance(books, MongoCursor) else: - assert isinstance(books, list) + assert isinstance(books, PantherDBCursor) assert _len == 0 async def test_find_without_filter(self): @@ -188,9 +203,9 @@ async def test_find_without_filter(self): _len = sum(1 for _ in books) if self.__class__.__name__ == 'TestMongoDB': - assert isinstance(books, Cursor) + assert isinstance(books, MongoCursor) else: - assert isinstance(books, list) + assert isinstance(books, PantherDBCursor) assert _len == insert_count for book in books: assert isinstance(book, Book) @@ -204,9 +219,9 @@ async def test_all(self): _len = sum(1 for _ in books) if self.__class__.__name__ == 'TestMongoDB': - assert isinstance(books, Cursor) + assert isinstance(books, MongoCursor) else: - assert isinstance(books, list) + assert isinstance(books, PantherDBCursor) assert _len == insert_count for book in books: @@ -418,9 +433,9 @@ async def test_update_many(self): _len = sum(1 for _ in books) if self.__class__.__name__ == 'TestMongoDB': - assert isinstance(books, Cursor) + assert isinstance(books, MongoCursor) else: - assert isinstance(books, list) + assert isinstance(books, PantherDBCursor) assert _len == updated_count == insert_count for book in books: assert book.author == author diff --git a/tests/test_events.py b/tests/test_events.py new file mode 100644 index 0000000..5eff523 --- /dev/null +++ b/tests/test_events.py @@ -0,0 +1,123 @@ +import logging +from unittest import IsolatedAsyncioTestCase + +from panther.configs import config +from panther.events import Event + +logger = logging.getLogger('panther') + + +class TestEvents(IsolatedAsyncioTestCase): + def tearDown(self): + config.refresh() + + async def test_async_startup(self): + assert len(config.STARTUPS) == 0 + + async def startup_event(): + logger.info('This Is Startup.') + + Event.startup(startup_event) + + assert len(config.STARTUPS) == 1 + assert config.STARTUPS[0] == startup_event + + with self.assertLogs(level='INFO') as capture: + await Event.run_startups() + + assert len(capture.records) == 1 + assert capture.records[0].getMessage() == 'This Is Startup.' + + async def test_sync_startup(self): + assert len(config.STARTUPS) == 0 + + def startup_event(): + logger.info('This Is Startup.') + + Event.startup(startup_event) + + assert len(config.STARTUPS) == 1 + assert config.STARTUPS[0] == startup_event + + with self.assertLogs(level='INFO') as capture: + await Event.run_startups() + + assert len(capture.records) == 1 + assert capture.records[0].getMessage() == 'This Is Startup.' + + async def test_startup(self): + assert len(config.STARTUPS) == 0 + + def startup_event1(): + logger.info('This Is Startup1.') + + async def startup_event2(): + logger.info('This Is Startup2.') + + Event.startup(startup_event1) + Event.startup(startup_event2) + + assert len(config.STARTUPS) == 2 + assert config.STARTUPS[0] == startup_event1 + assert config.STARTUPS[1] == startup_event2 + + with self.assertLogs(level='INFO') as capture: + await Event.run_startups() + + assert len(capture.records) == 2 + assert capture.records[0].getMessage() == 'This Is Startup1.' + assert capture.records[1].getMessage() == 'This Is Startup2.' + + async def test_sync_shutdown(self): + assert len(config.SHUTDOWNS) == 0 + + def shutdown_event(): + logger.info('This Is Shutdown.') + + Event.shutdown(shutdown_event) + + assert len(config.SHUTDOWNS) == 1 + assert config.SHUTDOWNS[0] == shutdown_event + + with self.assertLogs(level='INFO') as capture: + Event.run_shutdowns() + + assert len(capture.records) == 1 + assert capture.records[0].getMessage() == 'This Is Shutdown.' + + async def shutdown_event(self): + logger.info('This Is Shutdown.') + + def test_async_shutdown(self): + assert len(config.SHUTDOWNS) == 0 + + Event.shutdown(self.shutdown_event) + + assert len(config.SHUTDOWNS) == 1 + assert config.SHUTDOWNS[0] == self.shutdown_event + + with self.assertLogs(level='INFO') as capture: + Event.run_shutdowns() + + assert len(capture.records) == 1 + assert capture.records[0].getMessage() == 'This Is Shutdown.' + + def test_shutdown(self): + assert len(config.SHUTDOWNS) == 0 + + def shutdown_event_sync(): + logger.info('This Is Sync Shutdown.') + + Event.shutdown(self.shutdown_event) + Event.shutdown(shutdown_event_sync) + + assert len(config.SHUTDOWNS) == 2 + assert config.SHUTDOWNS[0] == self.shutdown_event + assert config.SHUTDOWNS[1] == shutdown_event_sync + + with self.assertLogs(level='INFO') as capture: + Event.run_shutdowns() + + assert len(capture.records) == 2 + assert capture.records[0].getMessage() == 'This Is Shutdown.' + assert capture.records[1].getMessage() == 'This Is Sync Shutdown.' diff --git a/tests/test_generics.py b/tests/test_generics.py new file mode 100644 index 0000000..97a2a6a --- /dev/null +++ b/tests/test_generics.py @@ -0,0 +1,226 @@ +from pathlib import Path +from unittest import IsolatedAsyncioTestCase + +from panther import Panther +from panther.db import Model +from panther.generics import RetrieveAPI, ListAPI, UpdateAPI, DeleteAPI, CreateAPI +from panther.pagination import Pagination +from panther.request import Request +from panther.serializer import ModelSerializer +from panther.test import APIClient + + +class User(Model): + name: str + + +class Person(User): + age: int + + +class RetrieveAPITest(RetrieveAPI): + async def object(self, request: Request, **kwargs) -> Model: + return await User.find_one(id=kwargs['id']) + + +class ListAPITest(ListAPI): + async def objects(self, request: Request, **kwargs): + return await User.find() + + +class FullListAPITest(ListAPI): + sort_fields = ['name', 'age'] + search_fields = ['name'] + filter_fields = ['name', 'age'] + pagination = Pagination + + async def objects(self, request: Request, **kwargs): + return await Person.find() + + +class UserSerializer(ModelSerializer): + class Config: + model = User + fields = '*' + + +class UpdateAPITest(UpdateAPI): + input_model = UserSerializer + + async def object(self, request: Request, **kwargs) -> Model: + return await User.find_one(id=kwargs['id']) + + +class CreateAPITest(CreateAPI): + input_model = UserSerializer + + +class DeleteAPITest(DeleteAPI): + async def object(self, request: Request, **kwargs) -> Model: + return await User.find_one(id=kwargs['id']) + + +urls = { + 'retrieve/<id>': RetrieveAPITest, + 'list': ListAPITest, + 'full-list': FullListAPITest, + 'update/<id>': UpdateAPITest, + 'create': CreateAPITest, + 'delete/<id>': DeleteAPITest, +} + + +class TestGeneric(IsolatedAsyncioTestCase): + DB_PATH = 'test.pdb' + + @classmethod + def setUpClass(cls) -> None: + global DATABASE + DATABASE = { + 'engine': { + 'class': 'panther.db.connections.PantherDBConnection', + 'path': cls.DB_PATH + }, + } + app = Panther(__name__, configs=__name__, urls=urls) + cls.client = APIClient(app=app) + + def tearDown(self) -> None: + Path(self.DB_PATH).unlink() + + async def test_retrieve(self): + user = await User.insert_one(name='Ali') + res = await self.client.get(f'retrieve/{user.id}') + assert res.status_code == 200 + assert res.data == {'id': user.id, 'name': user.name} + + async def test_list(self): + users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}]) + res = await self.client.get('list') + assert res.status_code == 200 + assert res.data == [{'id': u.id, 'name': u.name} for u in users] + + async def test_list_features(self): + await Person.insert_many([ + {'name': 'Ali', 'age': 0}, + {'name': 'Ali', 'age': 1}, + {'name': 'Saba', 'age': 0}, + {'name': 'Saba', 'age': 1}, + ]) + res = await self.client.get('full-list') + assert res.status_code == 200 + assert set(res.data.keys()) == {'results', 'count', 'previous', 'next'} + + # Normal + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Ali', 'age': 0}, + {'name': 'Ali', 'age': 1}, + {'name': 'Saba', 'age': 0}, + {'name': 'Saba', 'age': 1}, + ] + + # Sort 1 + res = await self.client.get('full-list', query_params={'sort': '-name'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Saba', 'age': 0}, + {'name': 'Saba', 'age': 1}, + {'name': 'Ali', 'age': 0}, + {'name': 'Ali', 'age': 1}, + ] + + # Sort 2 + res = await self.client.get('full-list', query_params={'sort': '-name,-age'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Saba', 'age': 1}, + {'name': 'Saba', 'age': 0}, + {'name': 'Ali', 'age': 1}, + {'name': 'Ali', 'age': 0}, + ] + + # Sort 3 + res = await self.client.get('full-list', query_params={'sort': 'name,-age'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Ali', 'age': 1}, + {'name': 'Ali', 'age': 0}, + {'name': 'Saba', 'age': 1}, + {'name': 'Saba', 'age': 0}, + ] + + # Filter 1 + res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'name': 'Ali'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Ali', 'age': 1}, + {'name': 'Ali', 'age': 0}, + ] + + # Filter 2 + res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'name': 'Alex'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [] + + # Search + res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'search': 'Ali'}) + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Ali', 'age': 1}, + {'name': 'Ali', 'age': 0}, + ] + + # Pagination 1 + res = await self.client.get('full-list', query_params={'sort': 'name,-age'}) + assert res.data['previous'] is None + assert res.data['next'] is None + assert res.data['count'] == 4 + + # Pagination 2 + res = await self.client.get('full-list', query_params={'limit': 2}) + assert res.data['previous'] is None + assert res.data['next'] == '?limit=2&skip=2' + assert res.data['count'] == 4 + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Ali', 'age': 0}, + {'name': 'Ali', 'age': 1}, + ] + + res = await self.client.get('full-list', query_params={'limit': 2, 'skip': 2}) + assert res.data['previous'] == '?limit=2&skip=0' + assert res.data['next'] is None + assert res.data['count'] == 4 + response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']] + assert response == [ + {'name': 'Saba', 'age': 0}, + {'name': 'Saba', 'age': 1}, + ] + + async def test_update(self): + users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}]) + res = await self.client.put(f'update/{users[1].id}', payload={'name': 'NewName'}) + assert res.status_code == 200 + assert res.data['name'] == 'NewName' + + new_users = await User.find() + users[1].name = 'NewName' + assert {(u.id, u.name) for u in new_users} == {(u.id, u.name) for u in users} + + async def test_create(self): + res = await self.client.post('create', payload={'name': 'Sara'}) + assert res.status_code == 201 + assert res.data['name'] == 'Sara' + + new_users = await User.find() + assert len([i for i in new_users]) + assert new_users[0].name == 'Sara' + + async def test_delete(self): + users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}]) + res = await self.client.delete(f'delete/{users[1].id}') + assert res.status_code == 204 + new_users = await User.find() + assert len([u for u in new_users]) == 1 + assert new_users[0].model_dump() == users[0].model_dump() diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..f00cc54 --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,456 @@ +from unittest import IsolatedAsyncioTestCase + +from panther import Panther +from panther.app import API, GenericAPI +from panther.response import Response, HTMLResponse, PlainTextResponse, StreamingResponse +from panther.test import APIClient + + +@API() +async def return_nothing(): + pass + + +class ReturnNothing(GenericAPI): + def get(self): + pass + + +@API() +async def return_none(): + return None + + +class ReturnNone(GenericAPI): + def get(self): + return None + + +@API() +async def return_string(): + return 'Hello' + + +class ReturnString(GenericAPI): + def get(self): + return 'Hello' + + +@API() +async def return_dict(): + return {'detail': 'ok'} + + +class ReturnDict(GenericAPI): + def get(self): + return {'detail': 'ok'} + + +@API() +async def return_list(): + return [1, 2, 3] + + +class ReturnList(GenericAPI): + def get(self): + return [1, 2, 3] + + +@API() +async def return_tuple(): + return 1, 2, 3, 4 + + +class ReturnTuple(GenericAPI): + def get(self): + return 1, 2, 3, 4 + + +@API() +async def return_response_none(): + return Response() + + +class ReturnResponseNone(GenericAPI): + def get(self): + return Response() + + +@API() +async def return_response_dict(): + return Response(data={'detail': 'ok'}) + + +class ReturnResponseDict(GenericAPI): + def get(self): + return Response(data={'detail': 'ok'}) + + +@API() +async def return_response_list(): + return Response(data=['car', 'home', 'phone']) + + +class ReturnResponseList(GenericAPI): + def get(self): + return Response(data=['car', 'home', 'phone']) + + +@API() +async def return_response_tuple(): + return Response(data=('car', 'home', 'phone', 'book')) + + +class ReturnResponseTuple(GenericAPI): + def get(self): + return Response(data=('car', 'home', 'phone', 'book')) + + +@API() +async def return_html_response(): + return HTMLResponse('<html><head><title>') + + +class ReturnHTMLResponse(GenericAPI): + def get(self): + return HTMLResponse('') + + +@API() +async def return_plain_response(): + return PlainTextResponse('Hello World') + + +class ReturnPlainResponse(GenericAPI): + def get(self): + return PlainTextResponse('Hello World') + + +class ReturnStreamingResponse(GenericAPI): + def get(self): + def f(): + for i in range(5): + yield i + return StreamingResponse(f()) + + +class ReturnAsyncStreamingResponse(GenericAPI): + async def get(self): + async def f(): + for i in range(6): + yield i + return StreamingResponse(f()) + + +class ReturnInvalidStatusCode(GenericAPI): + def get(self): + return Response(status_code='ali') + + +urls = { + 'nothing': return_nothing, + 'none': return_none, + 'dict': return_dict, + 'str': return_string, + 'list': return_list, + 'tuple': return_tuple, + 'response-none': return_response_none, + 'response-dict': return_response_dict, + 'response-list': return_response_list, + 'response-tuple': return_response_tuple, + 'html': return_html_response, + 'plain': return_plain_response, + + 'nothing-cls': ReturnNothing, + 'none-cls': ReturnNone, + 'dict-cls': ReturnDict, + 'str-cls': ReturnString, + 'list-cls': ReturnList, + 'tuple-cls': ReturnTuple, + 'response-none-cls': ReturnResponseNone, + 'response-dict-cls': ReturnResponseDict, + 'response-list-cls': ReturnResponseList, + 'response-tuple-cls': ReturnResponseTuple, + 'html-cls': ReturnHTMLResponse, + 'plain-cls': ReturnPlainResponse, + + 'stream': ReturnStreamingResponse, + 'async-stream': ReturnAsyncStreamingResponse, + 'invalid-status-code': ReturnInvalidStatusCode, +} + + +class TestResponses(IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + app = Panther(__name__, configs=__name__, urls=urls) + cls.client = APIClient(app=app) + + async def test_nothing(self): + res = await self.client.get('nothing/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_nothing_cls(self): + res = await self.client.get('nothing-cls/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_none(self): + res = await self.client.get('none/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_none_cls(self): + res = await self.client.get('none-cls/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_dict(self): + res = await self.client.get('dict/') + assert res.status_code == 200 + assert res.data == {'detail': 'ok'} + assert res.body == b'{"detail":"ok"}' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '15' + + async def test_dict_cls(self): + res = await self.client.get('dict-cls/') + assert res.status_code == 200 + assert res.data == {'detail': 'ok'} + assert res.body == b'{"detail":"ok"}' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '15' + + async def test_string(self): + res = await self.client.get('str/') + assert res.status_code == 200 + assert res.data == 'Hello' + assert res.body == b'"Hello"' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '7' + + async def test_string_cls(self): + res = await self.client.get('str-cls/') + assert res.status_code == 200 + assert res.data == 'Hello' + assert res.body == b'"Hello"' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '7' + + async def test_list(self): + res = await self.client.get('list/') + assert res.status_code == 200 + assert res.data == [1, 2, 3] + assert res.body == b'[1,2,3]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '7' + + async def test_list_cls(self): + res = await self.client.get('list-cls/') + assert res.status_code == 200 + assert res.data == [1, 2, 3] + assert res.body == b'[1,2,3]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '7' + + async def test_tuple(self): + res = await self.client.get('tuple/') + assert res.status_code == 200 + assert res.data == [1, 2, 3, 4] + assert res.body == b'[1,2,3,4]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '9' + + async def test_tuple_cls(self): + res = await self.client.get('tuple-cls/') + assert res.status_code == 200 + assert res.data == [1, 2, 3, 4] + assert res.body == b'[1,2,3,4]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '9' + + async def test_response_none(self): + res = await self.client.get('response-none/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_response_none_cls(self): + res = await self.client.get('response-none-cls/') + assert res.status_code == 200 + assert res.data is None + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '0' + + async def test_response_dict(self): + res = await self.client.get('response-dict/') + assert res.status_code == 200 + assert res.data == {'detail': 'ok'} + assert res.body == b'{"detail":"ok"}' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '15' + + async def test_response_dict_cls(self): + res = await self.client.get('response-dict-cls/') + assert res.status_code == 200 + assert res.data == {'detail': 'ok'} + assert res.body == b'{"detail":"ok"}' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '15' + + async def test_response_list(self): + res = await self.client.get('response-list/') + assert res.status_code == 200 + assert res.data == ['car', 'home', 'phone'] + assert res.body == b'["car","home","phone"]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '22' + + async def test_response_list_cls(self): + res = await self.client.get('response-list-cls/') + assert res.status_code == 200 + assert res.data == ['car', 'home', 'phone'] + assert res.body == b'["car","home","phone"]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '22' + + async def test_response_tuple(self): + res = await self.client.get('response-tuple/') + assert res.status_code == 200 + assert res.data == ['car', 'home', 'phone', 'book'] + assert res.body == b'["car","home","phone","book"]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '29' + + async def test_response_tuple_cls(self): + res = await self.client.get('response-tuple-cls/') + assert res.status_code == 200 + assert res.data == ['car', 'home', 'phone', 'book'] + assert res.body == b'["car","home","phone","book"]' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '29' + + async def test_response_html(self): + res = await self.client.get('html/') + assert res.status_code == 200 + assert res.data == '' + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'text/html; charset=utf-8' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '41' + + async def test_response_html_cls(self): + res = await self.client.get('html-cls/') + assert res.status_code == 200 + assert res.data == '' + assert res.body == b'' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'text/html; charset=utf-8' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '41' + + async def test_response_plain(self): + res = await self.client.get('plain/') + assert res.status_code == 200 + assert res.data == 'Hello World' + assert res.body == b'Hello World' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'text/plain; charset=utf-8' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '11' + + async def test_response_plain_cls(self): + res = await self.client.get('plain-cls/') + assert res.status_code == 200 + assert res.data == 'Hello World' + assert res.body == b'Hello World' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'text/plain; charset=utf-8' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == '11' + + async def test_streaming_response(self): + res = await self.client.get('stream/') + assert res.status_code == 200 + assert res.headers['Content-Type'] == 'application/octet-stream' + assert res.data == '01234' + assert res.body == b'01234' + + async def test_async_streaming_response(self): + res = await self.client.get('async-stream/') + assert res.status_code == 200 + assert res.headers['Content-Type'] == 'application/octet-stream' + assert res.data == '012345' + assert res.body == b'012345' + + async def test_invalid_status_code(self): + with self.assertLogs(level='CRITICAL') as captured: + res = await self.client.get('invalid-status-code/') + + assert len(captured.records) == 1 + assert captured.records[0].getMessage().split('\n')[0] == "Response `status_code` Should Be `int`. (`ali` is )" + + assert res.status_code == 500 + assert res.data == {'detail': 'Internal Server Error'} + assert res.body == b'{"detail":"Internal Server Error"}' + assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'} + assert res.headers['Content-Type'] == 'application/json' + assert res.headers['Access-Control-Allow-Origin'] == '*' + assert res.headers['Content-Length'] == 34 diff --git a/tests/test_routing.py b/tests/test_routing.py index 6ef398e..91da3f7 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,8 +1,9 @@ import random from unittest import TestCase +from panther.base_request import BaseRequest +from panther.exceptions import PantherError from panther.routings import ( - collect_path_variables, finalize_urls, find_endpoint, flatten_urls, @@ -13,67 +14,88 @@ class TestRoutingFunctions(TestCase): def tearDown(self) -> None: from panther.configs import config - config['urls'] = {} + config.URLS = {} # Collecting def test_collect_ellipsis_urls(self): - urls = { + urls1 = { 'user/': { '/': ..., - 'profile/': ..., + }, + } + urls2 = { + 'user/': { 'list/': ..., }, } - with self.assertLogs() as captured: - collected_urls = flatten_urls(urls) - - assert len(captured.records) == 3 - assert captured.records[0].getMessage() == "URL Can't Point To Ellipsis. ('user//' -> ...)" - assert captured.records[1].getMessage() == "URL Can't Point To Ellipsis. ('user/profile/' -> ...)" - assert captured.records[2].getMessage() == "URL Can't Point To Ellipsis. ('user/list/' -> ...)" + try: + flatten_urls(urls1) + except PantherError as exc: + assert exc.args[0] == "URL Can't Point To Ellipsis. ('user//' -> ...)" + else: + assert False - assert collected_urls == {} + try: + flatten_urls(urls2) + except PantherError as exc: + assert exc.args[0] == "URL Can't Point To Ellipsis. ('user/list/' -> ...)" + else: + assert False - def test_collect_None_urls(self): # noqa: N802 - urls = { + def test_collect_None_urls(self): + urls1 = { 'user/': { - '/': None, - 'profile/': None, 'list/': None, }, } + urls2 = { + 'user/': { + '/': None, + }, + } - with self.assertLogs() as captured: - collected_urls = flatten_urls(urls) - - assert len(captured.records) == 3 - assert captured.records[0].getMessage() == "URL Can't Point To None. ('user//' -> None)" - assert captured.records[1].getMessage() == "URL Can't Point To None. ('user/profile/' -> None)" - assert captured.records[2].getMessage() == "URL Can't Point To None. ('user/list/' -> None)" + try: + flatten_urls(urls1) + except PantherError as exc: + assert exc.args[0] == "URL Can't Point To None. ('user/list/' -> None)" + else: + assert False - assert collected_urls == {} + try: + flatten_urls(urls2) + except PantherError as exc: + assert exc.args[0] == "URL Can't Point To None. ('user//' -> None)" + else: + assert False def test_collect_invalid_urls(self): def temp_func(): pass - urls = { + urls1 = { 'user/': { '?': temp_func, - '%^': temp_func, + }, + } + urls2 = { + 'user/': { 'لیست': temp_func, }, } - with self.assertLogs() as captured: - collected_urls = flatten_urls(urls) - - assert len(captured.records) == 3 - assert captured.records[0].getMessage() == "URL Is Not Valid. --> 'user/?/'" - assert captured.records[1].getMessage() == "URL Is Not Valid. --> 'user/%^/'" - assert captured.records[2].getMessage() == "URL Is Not Valid. --> 'user/لیست/'" + try: + flatten_urls(urls1) + except PantherError as exc: + assert exc.args[0] == "URL Is Not Valid. --> 'user/?/'" + else: + assert False - assert collected_urls == {} + try: + flatten_urls(urls2) + except PantherError as exc: + assert exc.args[0] == "URL Is Not Valid. --> 'user/لیست/'" + else: + assert False def test_collect_empty_url(self): def temp_func(): pass @@ -504,13 +526,54 @@ def temp_func(): pass } assert finalized_urls == expected_result + def test_finalize_urls_with_same_level_path_variables(self): + def temp_func(): pass + + urls1 = { + 'user': { + '/': temp_func, + '/': temp_func, + } + } + urls2 = { + 'user': { + '/': {'detail': temp_func}, + '/': temp_func, + '/': {'detail': temp_func}, + '/': {'detail': temp_func}, + } + } + + try: + finalize_urls(flatten_urls(urls1)) + except PantherError as exc: + assert exc.args[0] == ( + "URLs can't have same-level path variables that point to an endpoint: " + "\n\t- /user/" + "\n\t- /user/" + ) + else: + assert False + + try: + finalize_urls(flatten_urls(urls2)) + except PantherError as exc: + assert exc.args[0] == ( + "URLs can't have same-level path variables that point to a dict: " + "\n\t- /user/" + "\n\t- /user/" + "\n\t- /user/" + ) + else: + assert False + # Find Endpoint def test_find_endpoint_root_url(self): def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { '': temp_func, } _func, _ = find_endpoint('') @@ -534,7 +597,7 @@ def admin_v2_users_detail_not_registered(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': { 'profile': { @@ -598,7 +661,7 @@ def admin_v2_users_detail_not_registered(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': { 'profile': { @@ -661,7 +724,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { 'list': temp_func, }, @@ -687,7 +750,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { 'list': temp_func, }, @@ -713,7 +776,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': temp_func, }, @@ -739,7 +802,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': temp_func, }, @@ -765,7 +828,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user/name': temp_func, } func, path = find_endpoint('user/name/troublemaker') @@ -778,7 +841,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user/name': temp_func, } func, path = find_endpoint('user/') @@ -795,7 +858,7 @@ def temp_3(): pass from panther.configs import config - config['urls'] = { + config.URLS = { '': temp_1, '': { '': temp_2, @@ -819,7 +882,7 @@ def temp_3(): pass from panther.configs import config - config['urls'] = { + config.URLS = { '': temp_1, '': { '': temp_2, @@ -843,7 +906,7 @@ def temp_3(): pass from panther.configs import config - config['urls'] = { + config.URLS = { '': temp_1, 'hello': { '': temp_2, @@ -868,7 +931,7 @@ def temp_3(): pass from panther.configs import config - config['urls'] = { + config.URLS = { '': temp_1, 'hello': { '': temp_2, @@ -888,7 +951,7 @@ def test_find_endpoint_with_params(self): def user_id_profile_id(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': { 'profile': user_id_profile_id, @@ -905,7 +968,7 @@ def temp_func(): pass from panther.configs import config - config['urls'] = { + config.URLS = { 'user': { '': { 'profile': { @@ -920,7 +983,9 @@ def temp_func(): pass request_path = f'user/{_user_id}/profile/{_id}' _, found_path = find_endpoint(request_path) - path_variables = collect_path_variables(request_path=request_path, found_path=found_path) + request = BaseRequest(scope={'path': request_path}, receive=lambda x: x, send=lambda x: x) + request.collect_path_variables(found_path=found_path) + path_variables = request.path_variables assert isinstance(path_variables, dict) diff --git a/tests/test_run.py b/tests/test_run.py index c05a9eb..52d742e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -30,30 +30,28 @@ def test_load_configs(self): Panther(__name__) assert isinstance(config, Config) - assert config['base_dir'] == base_dir - assert config['monitoring'] is True - assert config['log_queries'] is True - assert config['default_cache_exp'] == timedelta(seconds=10) - assert config['throttling'].rate == 10 - assert config['throttling'].duration == timedelta(seconds=10) - assert config['secret_key'] == secret_key.encode() + assert config.BASE_DIR == base_dir + assert config.MONITORING is True + assert config.LOG_QUERIES is True + assert config.DEFAULT_CACHE_EXP == timedelta(seconds=10) + assert config.THROTTLING.rate == 10 + assert config.THROTTLING.duration == timedelta(seconds=10) + assert config.SECRET_KEY == secret_key.encode() - assert len(config['http_middlewares']) == 0 - assert len(config['reversed_http_middlewares']) == 0 - assert len(config['ws_middlewares']) == 0 - assert len(config['reversed_ws_middlewares']) == 0 + assert len(config.HTTP_MIDDLEWARES) == 0 + assert len(config.WS_MIDDLEWARES) == 0 - assert config['user_model'].__name__ == tests.sample_project.app.models.User.__name__ - assert config['user_model'].__module__.endswith('app.models') - assert config['jwt_config'].algorithm == 'HS256' - assert config['jwt_config'].life_time == timedelta(days=2).total_seconds() - assert config['jwt_config'].key == secret_key + assert config.USER_MODEL.__name__ == tests.sample_project.app.models.User.__name__ + assert config.USER_MODEL.__module__.endswith('app.models') + assert config.JWT_CONFIG.algorithm == 'HS256' + assert config.JWT_CONFIG.life_time == timedelta(days=2).total_seconds() + assert config.JWT_CONFIG.key == secret_key - assert '' in config['urls'] - config['urls'].pop('') + assert '' in config.URLS + config.URLS.pop('') - assert 'second' in config['urls'] - config['urls'].pop('second') + assert 'second' in config.URLS + config.URLS.pop('second') urls = { '_panel': { @@ -65,5 +63,5 @@ def test_load_configs(self): 'health': healthcheck_api }, } - assert config['urls'] == urls - assert config['query_engine'].__name__ == 'BasePantherDBQuery' + assert config.URLS == urls + assert config.QUERY_ENGINE.__name__ == 'BasePantherDBQuery' diff --git a/tests/test_serializer.py b/tests/test_serializer.py index 3a776b6..3232822 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -233,7 +233,7 @@ class Config: fields = ['ok', 'no'] except Exception as e: assert isinstance(e, AttributeError) - assert e.args[0] == '`Serializer3.Config.fields.ok` is not valid.' + assert e.args[0] == '`Serializer3.Config.fields.ok` is not in `Book.model_fields`' else: assert False @@ -262,6 +262,48 @@ class Config: else: assert False + async def test_define_class_with_invalid_exclude_1(self): + try: + class Serializer6(ModelSerializer): + class Config: + model = Book + fields = ['name', 'author', 'pages_count'] + exclude = ['not_found'] + + except Exception as e: + assert isinstance(e, AttributeError) + assert e.args[0] == '`Serializer6.Config.exclude.not_found` is not valid.' + else: + assert False + + async def test_define_class_with_invalid_exclude_2(self): + try: + class Serializer7(ModelSerializer): + class Config: + model = Book + fields = ['name', 'pages_count'] + exclude = ['author'] + + except Exception as e: + assert isinstance(e, AttributeError) + assert e.args[0] == '`Serializer7.Config.exclude.author` is not defined in `Config.fields`.' + else: + assert False + + async def test_with_star_fields_with_exclude3(self): + try: + class Serializer8(ModelSerializer): + class Config: + model = Book + fields = ['*'] + exclude = ['author'] + + except Exception as e: + assert isinstance(e, AttributeError) + assert e.args[0] == "`Serializer8.Config.fields.*` is not valid. Did you mean `fields = '*'`" + else: + assert False + # # # Serializer Usage async def test_with_simple_model_config(self): class Serializer(ModelSerializer): @@ -319,3 +361,39 @@ class Config: serialized = Serializer2(name='book', author='AliRn', pages_count='12') assert serialized.__doc__ is None + + async def test_with_exclude(self): + class Serializer(ModelSerializer): + class Config: + model = Book + fields = ['name', 'author', 'pages_count'] + exclude = ['author'] + + serialized = Serializer(name='book', author='AliRn', pages_count='12') + assert set(serialized.model_dump().keys()) == {'name', 'pages_count'} + assert serialized.name == 'book' + assert serialized.pages_count == 12 + + async def test_with_star_fields(self): + class Serializer(ModelSerializer): + class Config: + model = Book + fields = '*' + + serialized = Serializer(name='book', author='AliRn', pages_count='12') + assert set(serialized.model_dump().keys()) == {'id', 'name', 'author', 'pages_count'} + assert serialized.name == 'book' + assert serialized.author == 'AliRn' + assert serialized.pages_count == 12 + + async def test_with_star_fields_with_exclude(self): + class Serializer(ModelSerializer): + class Config: + model = Book + fields = '*' + exclude = ['author'] + + serialized = Serializer(name='book', author='AliRn', pages_count='12') + assert set(serialized.model_dump().keys()) == {'id', 'name', 'pages_count'} + assert serialized.name == 'book' + assert serialized.pages_count == 12 diff --git a/tests/test_simple_responses.py b/tests/test_simple_responses.py deleted file mode 100644 index 9499d2d..0000000 --- a/tests/test_simple_responses.py +++ /dev/null @@ -1,116 +0,0 @@ -from unittest import IsolatedAsyncioTestCase - -from panther import Panther -from panther.app import API -from panther.response import Response -from panther.test import APIClient - - -@API() -async def return_nothing(): - pass - - -@API() -async def return_none(): - return None - - -@API() -async def return_dict(): - return {'detail': 'ok'} - - -@API() -async def return_list(): - return [1, 2, 3] - - -@API() -async def return_tuple(): - return 1, 2, 3, 4 - - -@API() -async def return_response_none(): - return Response() - - -@API() -async def return_response_dict(): - return Response(data={'detail': 'ok'}) - - -@API() -async def return_response_list(): - return Response(data=['car', 'home', 'phone']) - - -@API() -async def return_response_tuple(): - return Response(data=('car', 'home', 'phone', 'book')) - - -urls = { - 'nothing': return_nothing, - 'none': return_none, - 'dict': return_dict, - 'list': return_list, - 'tuple': return_tuple, - 'response-none': return_response_none, - 'response-dict': return_response_dict, - 'response-list': return_response_list, - 'response-tuple': return_response_tuple, -} - - -class TestSimpleResponses(IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls) -> None: - app = Panther(__name__, configs=__name__, urls=urls) - cls.client = APIClient(app=app) - - async def test_nothing(self): - res = await self.client.get('nothing/') - assert res.status_code == 200 - assert res.data is None - - async def test_none(self): - res = await self.client.get('none/') - assert res.status_code == 200 - assert res.data is None - - async def test_dict(self): - res = await self.client.get('dict/') - assert res.status_code == 200 - assert res.data == {'detail': 'ok'} - - async def test_list(self): - res = await self.client.get('list/') - assert res.status_code == 200 - assert res.data == [1, 2, 3] - - async def test_tuple(self): - res = await self.client.get('tuple/') - assert res.status_code == 200 - assert res.data == [1, 2, 3, 4] - - async def test_response_none(self): - res = await self.client.get('response-none/') - assert res.status_code == 200 - assert res.data is None - - async def test_response_dict(self): - res = await self.client.get('response-dict/') - assert res.status_code == 200 - assert res.data == {'detail': 'ok'} - - async def test_response_list(self): - res = await self.client.get('response-list/') - assert res.status_code == 200 - assert res.data == ['car', 'home', 'phone'] - - async def test_response_tuple(self): - res = await self.client.get('response-tuple/') - assert res.status_code == 200 - assert res.data == ['car', 'home', 'phone', 'book'] diff --git a/tests/test_throttling.py b/tests/test_throttling.py new file mode 100644 index 0000000..10034c5 --- /dev/null +++ b/tests/test_throttling.py @@ -0,0 +1,74 @@ +import asyncio +from datetime import timedelta +from unittest import IsolatedAsyncioTestCase + +from panther import Panther +from panther.app import API +from panther.test import APIClient +from panther.throttling import Throttling + + +@API() +async def without_throttling_api(): + return 'ok' + + +@API(throttling=Throttling(rate=3, duration=timedelta(seconds=3))) +async def with_throttling_api(): + return 'ok' + + +urls = { + 'without-throttling': without_throttling_api, + 'with-throttling': with_throttling_api, +} + + +class TestThrottling(IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls) -> None: + app = Panther(__name__, configs=__name__, urls=urls) + cls.client = APIClient(app=app) + + async def test_without_throttling(self): + res1 = await self.client.get('without-throttling') + assert res1.status_code == 200 + + res2 = await self.client.get('without-throttling') + assert res2.status_code == 200 + + res3 = await self.client.get('without-throttling') + assert res3.status_code == 200 + + async def test_with_throttling(self): + res1 = await self.client.get('with-throttling') + assert res1.status_code == 200 + + res2 = await self.client.get('with-throttling') + assert res2.status_code == 200 + + res3 = await self.client.get('with-throttling') + assert res3.status_code == 200 + + res4 = await self.client.get('with-throttling') + assert res4.status_code == 429 + + res5 = await self.client.get('with-throttling') + assert res5.status_code == 429 + + await asyncio.sleep(3) # Sleep and try again + + res6 = await self.client.get('with-throttling') + assert res6.status_code == 200 + + res7 = await self.client.get('with-throttling') + assert res7.status_code == 200 + + res8 = await self.client.get('with-throttling') + assert res8.status_code == 200 + + res9 = await self.client.get('with-throttling') + assert res9.status_code == 429 + + res10 = await self.client.get('with-throttling') + assert res10.status_code == 429 diff --git a/tests/test_utils.py b/tests/test_utils.py index 36d5e6b..083e467 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,11 +3,10 @@ from pathlib import Path from unittest import TestCase -import panther.utils from panther import Panther from panther.configs import config from panther.middlewares import BaseMiddleware -from panther.utils import generate_hash_value_from_string, load_env, round_datetime, encrypt_password +from panther.utils import generate_secret_key, generate_hash_value_from_string, load_env, round_datetime class TestLoadEnvFile(TestCase): @@ -25,11 +24,10 @@ def _create_env_file(self, file_data): file.write(file_data) def test_load_env_invalid_file(self): - with self.assertLogs(level='ERROR') as captured: + try: load_env('fake.file') - - assert len(captured.records) == 1 - assert captured.records[0].getMessage() == '"fake.file" is not valid file for load_env()' + except ValueError as e: + assert e.args[0] == '"fake.file" is not a file.' def test_load_env_double_quote(self): self._create_env_file(f""" @@ -179,11 +177,6 @@ def test_generate_hash_value_from_string(self): assert hashed_1 == hashed_2 assert text != hashed_1 - def test_encrypt_password(self): - password = 'Password' - encrypted = encrypt_password(password=password) - assert password != encrypted - class TestLoadConfigs(TestCase): def setUp(self): @@ -202,7 +195,7 @@ def test_urls_not_found(self): assert False assert len(captured.records) == 1 - assert captured.records[0].getMessage() == "Invalid 'URLs': is required." + assert captured.records[0].getMessage() == "Invalid 'URLs': required." def test_urls_cant_be_dict(self): global URLs @@ -373,7 +366,7 @@ def test_jwt_auth_without_secret_key(self): def test_jwt_auth_with_secret_key(self): global AUTHENTICATION, SECRET_KEY AUTHENTICATION = 'panther.authentications.JWTAuthentication' - SECRET_KEY = panther.utils.generate_secret_key() + SECRET_KEY = generate_secret_key() with self.assertNoLogs(level='ERROR'): try: diff --git a/tests/test_websockets.py b/tests/test_websockets.py index 35d1464..da8d310 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -131,7 +131,7 @@ async def connect(self): class TestWebsocket(TestCase): @classmethod def setUpClass(cls) -> None: - config['has_ws'] = True + config.HAS_WS = True cls.app = Panther(__name__, configs=__name__, urls=urls) def test_without_accept(self): @@ -266,7 +266,7 @@ def test_with_auth_failed(self): assert responses[0]['type'] == 'websocket.close' assert responses[0]['code'] == 1000 - assert responses[0]['reason'] == 'Authentication Error' + assert responses[0]['reason'] == '' def test_with_auth_not_defined(self): ws = WebsocketClient(app=self.app) @@ -274,11 +274,11 @@ def test_with_auth_not_defined(self): responses = ws.connect('with-auth?authorization=Bearer token') assert len(captured.records) == 1 - assert captured.records[0].getMessage() == '"WS_AUTHENTICATION" has not been set in configs' + assert captured.records[0].getMessage() == '`WS_AUTHENTICATION` has not been set in configs' assert responses[0]['type'] == 'websocket.close' assert responses[0]['code'] == 1000 - assert responses[0]['reason'] == 'Authentication Error' + assert responses[0]['reason'] == '' def test_with_auth_success(self): global WS_AUTHENTICATION, SECRET_KEY, DATABASE @@ -304,7 +304,7 @@ def test_with_auth_success(self): assert responses[0]['type'] == 'websocket.close' assert responses[0]['code'] == 1000 - assert responses[0]['reason'] == 'Authentication Error' + assert responses[0]['reason'] == '' def test_with_permission(self): ws = WebsocketClient(app=self.app) @@ -327,4 +327,4 @@ def test_without_permission(self): assert responses[0]['type'] == 'websocket.close' assert responses[0]['code'] == 1000 - assert responses[0]['reason'] == 'Permission Denied' + assert responses[0]['reason'] == ''