diff --git a/README.md b/README.md
index 014092c..531e730 100644
--- a/README.md
+++ b/README.md
@@ -20,10 +20,11 @@
- Built-in API **Caching** System (In Memory, **Redis**)
- Built-in **Authentication** Classes
- Built-in **Permission** Classes
+- Built-in Visual API **Monitoring** (In Terminal)
- Support Custom **Background Tasks**
- Support Custom **Middlewares**
- Support Custom **Throttling**
-- Visual API **Monitoring** (In Terminal)
+- Support **Function-Base** and **Class-Base** APIs
- It's One Of The **Fastest Python Frameworks**
---
@@ -176,6 +177,12 @@ $ pip install panther
---
+### How Panther Works!
+
+![diagram](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/diagram.png)
+
+---
+
### Roadmap
![roadmap](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/roadmap.jpg)
diff --git a/docs/docs/authentications.md b/docs/docs/authentications.md
index 3b2efe6..6563086 100644
--- a/docs/docs/authentications.md
+++ b/docs/docs/authentications.md
@@ -3,12 +3,20 @@
> Type: `str`
>
> Default: `None`
-
-You can set your Authentication class in `configs`, now, if you set `auth=True` in `@API()`, Panther will use this class for authentication of every `API`, and put the `user` in `request.user` or `raise HTTP_401_UNAUTHORIZED`
-We implemented a built-in authentication class which used `JWT` for authentication.
+You can set Authentication class in your `configs`
+
+Panther use it, to authenticate every API/ WS if `auth=True` and give you the user or `raise HTTP_401_UNAUTHORIZED`
+
+The `user` will be in `request.user` in APIs and in `self.user` in WSs
+
+---
+
+We implemented 2 built-in authentication classes which use `JWT` for authentication.
+
But, You can use your own custom authentication class too.
+---
### JWTAuthentication
@@ -17,7 +25,7 @@ This class will
- Get the `token` from `Authorization` header of request.
- Check the `Bearer`
- `decode` the `token`
-- Find the matched `user` (It uses the `USER_MODEL`)
+- Find the matched `user`
> `JWTAuthentication` is going to use `panther.db.models.BaseUser` if you didn't set the `USER_MODEL` in your `configs`
@@ -59,9 +67,7 @@ JWTConfig = {
#### Websocket Authentication
-This class is very useful when you are trying to authenticate the user in websocket
-
-Add this into your `configs`:
+The `QueryParamJWTAuthentication` is very useful when you are trying to authenticate the user in websocket, you just have to add this into your `configs`:
```python
WS_AUTHENTICATION = 'panther.authentications.QueryParamJWTAuthentication'
```
@@ -77,7 +83,9 @@ WS_AUTHENTICATION = 'panther.authentications.QueryParamJWTAuthentication'
- Or raise `panther.exceptions.AuthenticationAPIError`
-- Address it in `configs`
- - `AUTHENTICATION = 'project_name.core.authentications.CustomAuthentication'`
+- Add it into your `configs`
+ ```python
+ AUTHENTICATION = 'project_name.core.authentications.CustomAuthentication'
+ ```
-> You can see the source code of JWTAuthentication [[here]](https://github.com/AliRn76/panther/blob/da2654ccdd83ebcacda91a1aaf51d5aeb539eff5/panther/authentications.py#L38)
\ No newline at end of file
+> You can see the source code of JWTAuthentication [[here]](https://github.com/AliRn76/panther/blob/da2654ccdd83ebcacda91a1aaf51d5aeb539eff5/panther/authentications.py)
\ No newline at end of file
diff --git a/docs/docs/background_tasks.md b/docs/docs/background_tasks.md
index 670f455..b5057f3 100644
--- a/docs/docs/background_tasks.md
+++ b/docs/docs/background_tasks.md
@@ -110,7 +110,7 @@ Panther is going to run the `background tasks` as a thread in the background on
task = BackgroundTask(do_something, 'Ali', 26)
```
-- > Default interval is 1.
+- > Default of interval() is 1.
- > The -1 interval means infinite,
diff --git a/docs/docs/class_first_crud.md b/docs/docs/class_first_crud.md
index 9c81906..274cc28 100644
--- a/docs/docs/class_first_crud.md
+++ b/docs/docs/class_first_crud.md
@@ -1,6 +1,6 @@
We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation)
-Now let's write custom API `Create`, `Retrieve`, `Update` and `Delete` for a `Book`:
+Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`:
## Structure & Requirements
### Create Model
@@ -82,7 +82,7 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
- def post(self):
+ async def post(self):
...
```
@@ -95,7 +95,7 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
- def post(self, request: Request):
+ async def post(self, request: Request):
...
```
@@ -122,7 +122,7 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
input_model = BookSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
...
```
@@ -138,7 +138,7 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
input_model = BookSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
body: BookSerializer = request.validated_data
...
```
@@ -156,9 +156,9 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
input_model = BookSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
body: BookSerializer = request.validated_data
- Book.insert_one(
+ await Book.insert_one(
name=body.name,
author=body.author,
pages_count=body.pages_count,
@@ -180,15 +180,14 @@ Now we are going to create a book on `post` request, We need to:
class BookAPI(GenericAPI):
input_model = BookSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
body: BookSerializer = request.validated_data
- book = Book.insert_one(
+ book = await Book.insert_one(
name=body.name,
author=body.author,
pages_count=body.pages_count,
)
return Response(data=book, status_code=status.HTTP_201_CREATED)
-
```
> The response.data can be `Instance of Models`, `dict`, `str`, `tuple`, `list`, `str` or `None`
@@ -213,11 +212,11 @@ from app.models import Book
class BookAPI(GenericAPI):
input_model = BookSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
...
- def get(self):
- books: list[Book] = Book.find()
+ async def get(self):
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
```
@@ -256,11 +255,11 @@ Assume we don't want to return field `author` in response:
input_model = BookSerializer
output_model = BookOutputSerializer
- def post(self, request: Request):
+ async def post(self, request: Request):
...
- def get(self):
- books: list[Book] = Book.find()
+ async def get(self):
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
```
@@ -292,11 +291,11 @@ class BookAPI(GenericAPI):
cache = True
cache_exp_time = timedelta(seconds=10)
- def post(self, request: Request):
+ async def post(self, request: Request):
...
- def get(self):
- books: list[Book] = Book.find()
+ async def get(self):
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
```
@@ -329,11 +328,11 @@ class BookAPI(GenericAPI):
cache_exp_time = timedelta(seconds=10)
throttling = Throttling(rate=10, duration=timedelta(minutes=1))
- def post(self, request: Request):
+ async def post(self, request: Request):
...
- def get(self):
- books: list[Book] = Book.find()
+ async def get(self):
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
```
@@ -383,8 +382,8 @@ For `retrieve`, `update` and `delete` API, we are going to
class SingleBookAPI(GenericAPI):
- def get(self, book_id: int):
- if book := Book.find_one(id=book_id):
+ async def get(self, book_id: int):
+ if book := await Book.find_one(id=book_id):
return Response(data=book, status_code=status.HTTP_200_OK)
else:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@@ -406,11 +405,11 @@ from app.serializers import BookSerializer
class SingleBookAPI(GenericAPI):
input_model = BookSerializer
- def get(self, book_id: int):
+ async def get(self, book_id: int):
...
- def put(self, request: Request, book_id: int):
- is_updated: bool = Book.update_one({'id': book_id}, request.validated_data.model_dump())
+ async def put(self, request: Request, book_id: int):
+ is_updated: bool = await Book.update_one({'id': book_id}, request.validated_data.model_dump())
data = {'is_updated': is_updated}
return Response(data=data, status_code=status.HTTP_202_ACCEPTED)
```
@@ -431,14 +430,14 @@ from app.models import Book
class SingleBookAPI(GenericAPI):
input_model = BookSerializer
- def get(self, book_id: int):
+ async def get(self, book_id: int):
...
- def put(self, request: Request, book_id: int):
+ async def put(self, request: Request, book_id: int):
...
- def delete(self, book_id: int):
- is_deleted: bool = Book.delete_one(id=book_id)
+ async def delete(self, book_id: int):
+ is_deleted: bool = await Book.delete_one(id=book_id)
if is_deleted:
return Response(status_code=status.HTTP_204_NO_CONTENT)
else:
diff --git a/docs/docs/configs.md b/docs/docs/configs.md
index 294ef32..3a37775 100644
--- a/docs/docs/configs.md
+++ b/docs/docs/configs.md
@@ -7,12 +7,14 @@ Panther collect all the configs from your `core/configs.py` or the module you pa
> Type: `bool` (Default: `False`)
It should be `True` if you want to use `panther monitor` command
-and see the monitoring logs
+and watch the monitoring
If `True`:
- Log every request in `logs/monitoring.log`
+_Requires [watchfiles](https://watchfiles.helpmanual.io) package._
+
---
### [LOG_QUERIES](https://pantherpy.github.io/log_queries)
> Type: `bool` (Default: `False`)
@@ -33,6 +35,8 @@ List of middlewares you want to use
Every request goes through `authentication()` method of this `class` (if `auth = True`)
+_Requires [python-jose](https://python-jose.readthedocs.io/en/latest/) package._
+
_Example:_ `AUTHENTICATION = 'panther.authentications.JWTAuthentication'`
---
@@ -119,6 +123,8 @@ It will reformat your code on every reload (on every change if you run the proje
You may want to write your custom `ruff.toml` in root of your project.
+_Requires [ruff](https://docs.astral.sh/ruff/) package._
+
Reference: [https://docs.astral.sh/ruff/formatter/](https://docs.astral.sh/ruff/formatter/)
_Example:_ `AUTO_REFORMAT = True`
@@ -133,4 +139,16 @@ We use it to create `database` connection
### [REDIS](https://pantherpy.github.io/redis)
> Type: `dict` (Default: `{}`)
+_Requires [redis](https://redis-py.readthedocs.io/en/stable/) package._
+
We use it to create `redis` connection
+
+
+---
+### [TIMEZONE](https://pantherpy.github.io/timezone)
+> Type: `str` (Default: `'UTC'`)
+
+Used in `panther.utils.timezone_now()` which returns a `datetime` based on your `timezone`
+
+And `panther.utils.timezone_now()` used in `BaseUser.date_created` and `BaseUser.last_login`
+
diff --git a/docs/docs/events.md b/docs/docs/events.md
new file mode 100644
index 0000000..bf7a21e
--- /dev/null
+++ b/docs/docs/events.md
@@ -0,0 +1,29 @@
+## Startup Event
+
+Use `Event.startup` decorator
+
+```python
+from panther.events import Event
+
+
+@Event.startup
+def do_something_on_startup():
+ print('Hello, I am at startup')
+```
+
+## Shutdown Event
+
+```python
+from panther.events import Event
+
+
+@Event.shutdown
+def do_something_on_shutdown():
+ print('Good Bye, I am at shutdown')
+```
+
+
+## Notice
+
+- You can have **multiple** events on `startup` and `shutdown`
+- Events can be `sync` or `async`
diff --git a/docs/docs/function_first_crud.md b/docs/docs/function_first_crud.md
index 3c83565..dd97e3b 100644
--- a/docs/docs/function_first_crud.md
+++ b/docs/docs/function_first_crud.md
@@ -1,6 +1,6 @@
We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation)
-Now let's write custom API `Create`, `Retrieve`, `Update` and `Delete` for a `Book`:
+Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`:
## Structure & Requirements
### Create Model
@@ -142,7 +142,7 @@ Now we are going to create a book on `post` request, We need to:
async def book_api(request: Request):
body: BookSerializer = request.validated_data
- Book.insert_one(
+ await Book.insert_one(
name=body.name,
author=body.author,
pages_count=body.pages_count,
@@ -165,7 +165,7 @@ Now we are going to create a book on `post` request, We need to:
if request.method == 'POST':
body: BookSerializer = request.validated_data
- Book.create(
+ await Book.insert_one(
name=body.name,
author=body.author,
pages_count=body.pages_count,
@@ -189,7 +189,7 @@ Now we are going to create a book on `post` request, We need to:
if request.method == 'POST':
body: BookSerializer = request.validated_data
- book: Book = Book.create(
+ book: Book = await Book.insert_one(
name=body.name,
author=body.author,
pages_count=body.pages_count,
@@ -224,7 +224,7 @@ async def book_api(request: Request):
...
elif request.method == 'GET':
- books: list[Book] = Book.find()
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED)
@@ -265,7 +265,7 @@ Assume we don't want to return field `author` in response:
...
elif request.method == 'GET':
- books: list[Book] = Book.find()
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED)
@@ -299,7 +299,7 @@ async def book_api(request: Request):
...
elif request.method == 'GET':
- books: list[Book] = Book.find()
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED)
@@ -338,7 +338,7 @@ async def book_api(request: Request):
...
elif request.method == 'GET':
- books: list[Book] = Book.find()
+ books = await Book.find()
return Response(data=books, status_code=status.HTTP_200_OK)
return Response(status_code=status.HTTP_501_NOT_IMPLEMENTED)
@@ -393,7 +393,7 @@ For `retrieve`, `update` and `delete` API, we are going to
@API()
async def single_book_api(request: Request, book_id: int):
if request.method == 'GET':
- if book := Book.find_one(id=book_id):
+ if book := await Book.find_one(id=book_id):
return Response(data=book, status_code=status.HTTP_200_OK)
else:
return Response(status_code=status.HTTP_404_NOT_FOUND)
@@ -420,12 +420,12 @@ For `retrieve`, `update` and `delete` API, we are going to
if request.method == 'GET':
...
elif request.method == 'PUT':
- book: Book = Book.find_one(id=book_id)
- book.update(
+ book: Book = await Book.find_one(id=book_id)
+ await book.update(
name=body.name,
author=body.author,
pages_count=body.pages_count
- )
+ )
return Response(status_code=status.HTTP_202_ACCEPTED)
```
@@ -446,7 +446,7 @@ For `retrieve`, `update` and `delete` API, we are going to
if request.method == 'GET':
...
elif request.method == 'PUT':
- is_updated: bool = Book.update_one({'id': book_id}, request.validated_data.model_dump())
+ is_updated: bool = await Book.update_one({'id': book_id}, request.validated_data.model_dump())
data = {'is_updated': is_updated}
return Response(data=data, status_code=status.HTTP_202_ACCEPTED)
```
@@ -468,7 +468,7 @@ For `retrieve`, `update` and `delete` API, we are going to
if request.method == 'GET':
...
elif request.method == 'PUT':
- updated_count: int = Book.update_many({'id': book_id}, request.validated_data.model_dump())
+ updated_count: int = await Book.update_many({'id': book_id}, request.validated_data.model_dump())
data = {'updated_count': updated_count}
return Response(data=data, status_code=status.HTTP_202_ACCEPTED)
```
@@ -496,7 +496,7 @@ For `retrieve`, `update` and `delete` API, we are going to
elif request.method == 'PUT':
...
elif request.method == 'DELETE':
- is_deleted: bool = Book.delete_one(id=book_id)
+ is_deleted: bool = await Book.delete_one(id=book_id)
if is_deleted:
return Response(status_code=status.HTTP_204_NO_CONTENT)
else:
@@ -520,7 +520,7 @@ For `retrieve`, `update` and `delete` API, we are going to
elif request.method == 'PUT':
...
elif request.method == 'DELETE':
- is_deleted: bool = Book.delete_one(id=book_id)
+ is_deleted: bool = await Book.delete_one(id=book_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
```
@@ -542,6 +542,6 @@ For `retrieve`, `update` and `delete` API, we are going to
elif request.method == 'PUT':
...
elif request.method == 'DELETE':
- deleted_count: int = Book.delete_many(id=book_id)
+ deleted_count: int = await Book.delete_many(id=book_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)
```
\ No newline at end of file
diff --git a/docs/docs/generic_crud.md b/docs/docs/generic_crud.md
new file mode 100644
index 0000000..ab6ed28
--- /dev/null
+++ b/docs/docs/generic_crud.md
@@ -0,0 +1,313 @@
+We assume you could run the project with [Introduction](https://pantherpy.github.io/#installation)
+
+Now let's write custom APIs for `Create`, `Retrieve`, `Update` and `Delete` a `Book`:
+
+## Structure & Requirements
+### Create Model
+
+Create a model named `Book` in `app/models.py`:
+
+```python
+from panther.db import Model
+
+
+class Book(Model):
+ name: str
+ author: str
+ pages_count: int
+```
+
+### Create API Class
+
+Create the `BookAPI()` in `app/apis.py`:
+
+```python
+from panther.app import GenericAPI
+
+
+class BookAPI(GenericAPI):
+ ...
+```
+
+> We are going to complete it later ...
+
+### Update URLs
+
+Add the `BookAPI` in `app/urls.py`:
+
+```python
+from app.apis import BookAPI
+
+
+urls = {
+ 'book/': BookAPI,
+}
+```
+
+We assume that the `urls` in `core/urls.py` pointing to `app/urls.py`, like below:
+
+```python
+from app.urls import urls as app_urls
+
+
+urls = {
+ '/': app_urls,
+}
+```
+
+### Add Database
+
+Add `DATABASE` in `configs`, we are going to add `pantherdb`
+> [PantherDB](https://github.com/PantherPy/PantherDB/#readme) is a Simple, File-Base and Document Oriented database
+
+```python
+...
+DATABASE = {
+ 'engine': {
+ 'class': 'panther.db.connections.PantherDBConnection',
+ }
+}
+...
+```
+
+## APIs
+### API - Create a Book
+
+Now we are going to create a book on `POST` request, we need to:
+
+1. Inherit from `CreateAPI`:
+ ```python
+ from panther.generics import CreateAPI
+
+
+ class BookAPI(CreateAPI):
+ ...
+ ```
+
+2. Create a ModelSerializer in `app/serializers.py`, for `validation` of the `request.data`:
+
+ ```python
+ from panther.serializer import ModelSerializer
+ from app.models import Book
+
+ class BookSerializer(ModelSerializer):
+ class Config:
+ model = Book
+ fields = ['name', 'author', 'pages_count']
+ ```
+
+3. Set the created serializer in `BookAPI` as `input_model`:
+
+ ```python
+ from panther.app import CreateAPI
+ from app.serializers import BookSerializer
+
+
+ class BookAPI(CreateAPI):
+ input_model = BookSerializer
+ ```
+
+It is going to create a `Book` with incoming `request.data` and return that instance to user with status code of `201`
+
+
+### API - List of Books
+
+Let's return list of books of `GET` method, we need to:
+
+1. Inherit from `ListAPI`
+ ```python
+ from panther.generics import CreateAPI, ListAPI
+ from app.serializers import BookSerializer
+
+ class BookAPI(CreateAPI, ListAPI):
+ input_model = BookSerializer
+ ...
+ ```
+
+2. define `objects` method, so the `ListAPI` knows to return which books
+
+ ```python
+ from panther.generics import CreateAPI, ListAPI
+ from panther.request import Request
+
+ from app.models import Book
+ from app.serializers import BookSerializer
+
+ class BookAPI(CreateAPI, ListAPI):
+ input_model = BookSerializer
+
+ async def objects(self, request: Request, **kwargs):
+ return await Book.find()
+ ```
+
+### Pagination, Search, Filter, Sort
+
+#### Pagination
+
+Use `panther.pagination.Pagination` as `pagination`
+
+**Usage:** It will look for the `limit` and `skip` in the `query params` and return its own response template
+
+**Example:** `?limit=10&skip=20`
+
+---
+
+#### Search
+
+Define the fields you want the search, query on them in `search_fields`
+
+The value of `search_fields` should be `list`
+
+**Usage:** It works with `search` query param --> `?search=maybe_name_of_the_book_or_author`
+
+**Example:** `?search=maybe_name_of_the_book_or_author`
+
+---
+
+#### Filter
+
+Define the fields you want to be filterable in `filter_fields`
+
+The value of `filter_fields` should be `list`
+
+**Usage:** It will look for each value of the `filter_fields` in the `query params` and query on them
+
+**Example:** `?name=name_of_the_book&author=author_of_the_book`
+
+---
+
+
+#### Sort
+
+Define the fields you want to be sortable in `sort_fields`
+
+The value of `sort_fields` should be `list`
+
+
+
+**Usage:** It will look for each value of the `sort_fields` in the `query params` and sort with them
+
+**Example:** `?sort=pages_count,-name`
+
+**Notice:**
+ - fields should be separated with a column `,`
+ - use `field_name` for ascending sort
+ - use `-field_name` for descending sort
+---
+
+#### Example
+```python
+from panther.generics import CreateAPI, ListAPI
+from panther.pagination import Pagination
+from panther.request import Request
+
+from app.models import Book
+from app.serializers import BookSerializer, BookOutputSerializer
+
+class BookAPI(CreateAPI, ListAPI):
+ input_model = BookSerializer
+ output_model = BookOutputSerializer
+ pagination = Pagination
+ search_fields = ['name', 'author']
+ filter_fields = ['name', 'author']
+ sort_fields = ['name', 'pages_count']
+
+ async def objects(self, request: Request, **kwargs):
+ return await Book.find()
+```
+
+
+### API - Retrieve a Book
+
+Now we are going to retrieve a book on `GET` request, we need to:
+
+1. Inherit from `RetrieveAPI`:
+ ```python
+ from panther.generics import RetrieveAPI
+
+
+ class SingleBookAPI(RetrieveAPI):
+ ...
+ ```
+
+2. define `object` method, so the `RetrieveAPI` knows to return which book
+
+ ```python
+ from panther.generics import RetrieveAPI
+ from panther.request import Request
+ from app.models import Book
+
+
+ class SingleBookAPI(RetrieveAPI):
+ async def object(self, request: Request, **kwargs):
+ return await Book.find_one_or_raise(id=kwargs['book_id'])
+ ```
+
+3. Add it in `app/urls.py`:
+
+ ```python
+ from app.apis import BookAPI, SingleBookAPI
+
+
+ urls = {
+ 'book/': BookAPI,
+ 'book//': SingleBookAPI,
+ }
+ ```
+
+ > You should write the [Path Variable](https://pantherpy.github.io/urls/#path-variables-are-handled-like-below) in `<` and `>`
+
+
+### API - Update a Book
+
+1. Inherit from `UpdateAPI`
+
+ ```python
+ from panther.generics import RetrieveAPI, UpdateAPI
+ from panther.request import Request
+ from app.models import Book
+
+
+ class SingleBookAPI(RetrieveAPI, UpdateAPI):
+ ...
+
+ async def object(self, request: Request, **kwargs):
+ return await Book.find_one_or_raise(id=kwargs['book_id'])
+ ```
+
+2. Add `input_model` so the `UpdateAPI` knows how to validate the `request.data`
+> We use the same serializer as CreateAPI serializer we defined above
+
+ ```python
+ from panther.generics import RetrieveAPI, UpdateAPI
+ from panther.request import Request
+ from app.models import Book
+ from app.serializers import BookSerializer
+
+
+ class SingleBookAPI(RetrieveAPI, UpdateAPI):
+ input_model = BookSerializer
+
+ async def object(self, request: Request, **kwargs):
+ return await Book.find_one_or_raise(id=kwargs['book_id'])
+ ```
+
+### API - Delete a Book
+
+1. Inherit from `DeleteAPI`
+
+ ```python
+ from panther.generics import RetrieveAPI, UpdateAPI, DeleteAPI
+ from panther.request import Request
+ from app.models import Book
+ from app.serializers import BookSerializer
+
+
+ class SingleBookAPI(RetrieveAPI, UpdateAPI, DeleteAPI):
+ input_model = BookSerializer
+
+ async def object(self, request: Request, **kwargs):
+ return await Book.find_one_or_raise(id=kwargs['book_id'])
+ ```
+
+2. It requires `object` method which we defined before, so it's done.
diff --git a/docs/docs/images/diagram.png b/docs/docs/images/diagram.png
new file mode 100644
index 0000000..2c60eaa
Binary files /dev/null and b/docs/docs/images/diagram.png differ
diff --git a/docs/docs/images/roadmap.jpg b/docs/docs/images/roadmap.jpg
index 5eb7378..cf376ac 100644
Binary files a/docs/docs/images/roadmap.jpg and b/docs/docs/images/roadmap.jpg differ
diff --git a/docs/docs/index.md b/docs/docs/index.md
index d59067b..03b4f38 100644
--- a/docs/docs/index.md
+++ b/docs/docs/index.md
@@ -21,10 +21,11 @@
- Built-in API **Caching** System (In Memory, **Redis**)
- Built-in **Authentication** Classes
- Built-in **Permission** Classes
+- Built-in Visual API **Monitoring** (In Terminal)
- Support Custom **Background Tasks**
- Support Custom **Middlewares**
- Support Custom **Throttling**
-- Visual API **Monitoring** (In Terminal)
+- Support **Function-Base** and **Class-Base** APIs
- It's One Of The **Fastest Python Frameworks**
---
@@ -198,6 +199,12 @@
---
+### How Panther Works!
+
+![diagram](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/diagram.png)
+
+---
+
### Roadmap
![roadmap](https://raw.githubusercontent.com/AliRn76/panther/master/docs/docs/images/roadmap.jpg)
diff --git a/docs/docs/middlewares.md b/docs/docs/middlewares.md
index e745153..5fd2a1d 100644
--- a/docs/docs/middlewares.md
+++ b/docs/docs/middlewares.md
@@ -14,8 +14,11 @@
## Custom Middleware
### Middleware Types
We have 3 type of Middlewares, make sure that you are inheriting from the correct one:
+
- `Base Middleware`: which is used for both `websocket` and `http` requests
+
- `HTTP Middleware`: which is only used for `http` requests
+
- `Websocket Middleware`: which is only used for `websocket` requests
### Write Custom Middleware
diff --git a/docs/docs/panther_odm.md b/docs/docs/panther_odm.md
index 13f3e80..ee849ba 100644
--- a/docs/docs/panther_odm.md
+++ b/docs/docs/panther_odm.md
@@ -16,18 +16,36 @@ user: User = await User.find_one({'id': 1}, name='Ali')
Get documents from the database.
```python
-users: list[User] = await User.find(age=18, name='Ali')
+users: Cursor = await User.find(age=18, name='Ali')
or
-users: list[User] = await User.find({'age': 18, 'name': 'Ali'})
+users: Cursor = await User.find({'age': 18, 'name': 'Ali'})
or
-users: list[User] = await User.find({'age': 18}, name='Ali')
+users: Cursor = await User.find({'age': 18}, name='Ali')
```
+#### skip, limit and sort
+
+`skip()`, `limit()` and `sort()` are also available as chain methods
+
+```python
+users: Cursor = await User.find(age=18, name='Ali').skip(10).limit(10).sort([('age', -1), ('name', 1)])
+```
+
+#### cursor
+
+The `find()` method, returns a `Crusor` depends on the database
+- `from panther.db.cursor import Cursor ` for `MongoDB`
+- `from pantherdb import Cursor` for `PantherDB`
+
+You can work with these cursors as list and directly pass them to `Response(data=cursor)`
+
+They are designed to return an instance of your model in `iterations`(`__next__()`) or accessing by index(`__getitem__()`)
+
### all
Get all documents from the database. (Alias of `.find()`)
```python
-users: list[User] = User.all()
+users: Cursor = User.all()
```
### first
@@ -73,19 +91,7 @@ pipeline = [
...
]
users: Iterable[dict] = await User.aggregate(pipeline)
-```
-
-### find with skip, limit, sort
-Get limited documents from the database from offset and sorted by something
-> Only available in mongodb
-
-```python
-users: list[User] = await User.find(age=18, name='Ali').limit(10).skip(10).sort('_id', -1)
-or
-users: list[User] = await User.find({'age': 18, 'name': 'Ali'}).limit(10).skip(10).sort('_id', -1)
-or
-users: list[User] = await User.find({'age': 18}, name='Ali').limit(10).skip(10).sort('_id', -1)
-```
+```
### count
Count the number of documents in this collection.
@@ -234,7 +240,7 @@ is_inserted, user = await User.find_one_or_insert({'age': 18}, name='Ali')
### find_one_or_raise
- Get a single document from the database]
- **or**
-- Raise an `APIError(f'{Model} Does Not Exist')`
+- Raise an `NotFoundAPIError(f'{Model} Does Not Exist')`
```python
from app.models import User
@@ -248,6 +254,7 @@ user: User = await User.find_one_or_raise({'age': 18}, name='Ali')
### save
Save the document.
+
- If it has id --> `Update` It
- else --> `Insert` It
diff --git a/docs/docs/redis.md b/docs/docs/redis.md
index 22dade6..eb2a44d 100644
--- a/docs/docs/redis.md
+++ b/docs/docs/redis.md
@@ -31,7 +31,7 @@ REDIS = {
### How it works?
-- Panther creates a redis connection depends on `REDIS` block you defined in `configs`
+- Panther creates an async redis connection depends on `REDIS` block you defined in `configs`
- You can access to it from `from panther.db.connections import redis`
@@ -39,7 +39,7 @@ REDIS = {
```python
from panther.db.connections import redis
- redis.set('name', 'Ali')
- result = redis.get('name')
+ await redis.set('name', 'Ali')
+ result = await redis.get('name')
print(result)
```
\ No newline at end of file
diff --git a/docs/docs/release_notes.md b/docs/docs/release_notes.md
index 12df57d..47a49f3 100644
--- a/docs/docs/release_notes.md
+++ b/docs/docs/release_notes.md
@@ -1,10 +1,15 @@
### 4.0.0
- Move `database` and `redis` connections from `MIDDLEWARES` to their own block, `DATABASE` and `REDIS`
-- Make queries `async`
+- Make `Database` queries `async`
+- Make `Redis` queries `async`
+- Add `StreamingResponse`
+- Add `generics` API classes
- Add `login()` & `logout()` to `JWTAuthentication` and used it in `BaseUser`
- Support `Authentication` & `Authorization` in `Websocket`
- Rename all exceptions suffix from `Exception` to `Error` (https://peps.python.org/pep-0008/#exception-names)
-- Support `pantherdb 1.4`
+- Support `pantherdb 2.0.0` (`Cursor` Added)
+- Remove `watchfiles` from required dependencies
+- Support `exclude` and `optional_fields` in `ModelSerializer`
- Minor Improvements
### 3.9.0
diff --git a/docs/docs/serializer.md b/docs/docs/serializer.md
index 7d4852a..5afec4b 100644
--- a/docs/docs/serializer.md
+++ b/docs/docs/serializer.md
@@ -45,13 +45,21 @@ Use panther `ModelSerializer` to write your serializer which will use your `mode
first_name: str = Field(default='', min_length=2)
last_name: str = Field(default='', min_length=4)
-
+ # type 1 - using fields
class UserModelSerializer(ModelSerializer):
class Config:
model = User
fields = ['username', 'first_name', 'last_name']
required_fields = ['first_name']
+ # type 2 - using exclude
+ class UserModelSerializer(ModelSerializer):
+ class Config:
+ model = User
+ fields = '*'
+ required_fields = ['first_name']
+ exclude = ['id', 'password']
+
@API(input_model=UserModelSerializer)
async def model_serializer_example(request: Request):
@@ -59,9 +67,14 @@ Use panther `ModelSerializer` to write your serializer which will use your `mode
```
### Notes
-1. In the example above, `ModelSerializer` will look up for the value of `Config.fields` in the `User.fields` and use their `type` and `value` for the `validation`.
+1. In the example above, `ModelSerializer` will look up for the value of `Config.fields` in the `User.model_fields` and use their `type` and `value` for the `validation`.
2. `Config.model` and `Config.fields` are `required` when you are using `ModelSerializer`.
-3. If you want to use `Config.required_fields`, you have to put its value in `Config.fields` too.
+3. You can force a field to be required with `Config.required_fields`
+4. You can force a field to be optional with `Config.optional_fields`
+5. `Config.required_fields` and `Config.optional_fields` can't include same fields
+6. If you want to use `Config.required_fields` or `Config.optional_fields` you have to put their value in `Config.fields` too.
+7. `Config.fields` & `Config.required_fields` & `Config.optional_fields` can be `*` too (Include all the model fields)
+8. `Config.exclude` is mostly used when `Config.fields` is `'*'`
@@ -94,8 +107,9 @@ You can use `pydantic.BaseModel` features in `ModelSerializer` too.
class Config:
model = User
- fields = ['username', 'first_name', 'last_name']
+ fields = ['first_name', 'last_name']
required_fields = ['first_name']
+ optional_fields = ['last_name']
@field_validator('username')
def validate_username(cls, username):
diff --git a/docs/docs/single_file.md b/docs/docs/single_file.md
index 8676c8f..62f2b71 100644
--- a/docs/docs/single_file.md
+++ b/docs/docs/single_file.md
@@ -42,15 +42,16 @@ If you want to work with `Panther` in a `single-file` structure, follow the step
app = Panther(__name__, configs=__name__, urls=url_routing)
```
4. Run the project
- ```bash
- panther run
- ```
+
+ - If name of your file is `main.py` -->
+ ```panther run```
+ - else use `uvicorn` -->
+ ```uvicorn file_name:app```
### Notes
- `URLs` is a required config unless you pass the `urls` directly to the `Panther`
- When you pass the `configs` to the `Panther(configs=...)`, Panther is going to load the configs from this file,
else it is going to load `core/configs.py` file
-- You can pass the `startup` and `shutdown` functions to the `Panther()` too.
```python
from panther import Panther
@@ -64,11 +65,5 @@ else it is going to load `core/configs.py` file
'/': hello_world_api,
}
- def startup():
- pass
-
- def shutdown():
- pass
-
- app = Panther(__name__, configs=__name__, urls=url_routing, startup=startup, shutdown=shutdown)
+ app = Panther(__name__, configs=__name__, urls=url_routing)
```
\ No newline at end of file
diff --git a/docs/docs/urls.md b/docs/docs/urls.md
index a497d24..028d42c 100644
--- a/docs/docs/urls.md
+++ b/docs/docs/urls.md
@@ -14,7 +14,7 @@
- Example: `user//blog//`
- The `endpoint` should have parameters with those names too
- Example Function-Base: `async def profile_api(user_id: int, title: str):`
- - Example Class-Base: `async def get(user_id: int, title: str):`
+ - Example Class-Base: `async def get(self, user_id: int, title: str):`
## Example
diff --git a/docs/docs/websocket.md b/docs/docs/websocket.md
index 385bc73..ec21692 100644
--- a/docs/docs/websocket.md
+++ b/docs/docs/websocket.md
@@ -15,7 +15,7 @@ class BookWebsocket(GenericWebsocket):
await self.accept()
print(f'{self.connection_id=}')
- async def receive(self, data: str | bytes = None):
+ async def receive(self, data: str | bytes):
# Just Echo The Message
await self.send(data=data)
```
@@ -64,7 +64,7 @@ you have to use `--preload`, like below:
from panther import status
await self.close(code=status.WS_1000_NORMAL_CLOSURE, reason='I just want to close it')
```
- - Out of websocket class scope **(Not Recommended)**: You can close it with `close_websocket_connection()` from `panther.websocket`, it's `async` function with takes 3 args, `connection_id`, `code` and `reason`, like below:
+ - Out of websocket class scope: You can close it with `close_websocket_connection()` from `panther.websocket`, it's `async` function with takes 3 args, `connection_id`, `code` and `reason`, like below:
```python
from panther import status
from panther.websocket import close_websocket_connection
@@ -83,6 +83,5 @@ you have to use `--preload`, like below:
'/ws///': UserWebsocket
}
```
-12. WebSocket Echo Example -> [Https://GitHub.com/PantherPy/echo_websocket](https://github.com/PantherPy/echo_websocket)
-13. Enjoy.
+12. Enjoy.
diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml
index 10de3bb..67db86a 100644
--- a/docs/mkdocs.yml
+++ b/docs/mkdocs.yml
@@ -26,9 +26,10 @@ edit_uri: edit/master/docs/
nav:
- Introduction: 'index.md'
- - First Crud:
+ - How To CRUD:
- Function Base: 'function_first_crud.md'
- Class Base: 'class_first_crud.md'
+ - Generic: 'generic_crud.md'
- Database: 'database.md'
- Panther ODM: 'panther_odm.md'
- Configs: 'configs.md'
@@ -37,6 +38,7 @@ nav:
- WebSocket: 'websocket.md'
- Middlewares: 'middlewares.md'
- Authentications: 'authentications.md'
+ - Events: 'events.md'
- URLs: 'urls.md'
- Throttling: 'throttling.md'
- Background Tasks: 'background_tasks.md'
diff --git a/example/app/apis.py b/example/app/apis.py
index 140ed70..45dac10 100644
--- a/example/app/apis.py
+++ b/example/app/apis.py
@@ -16,9 +16,13 @@
from panther.app import API, GenericAPI
from panther.authentications import JWTAuthentication
from panther.background_tasks import BackgroundTask, background_tasks
+from pantherdb import Cursor as PantherDBCursor
from panther.db.connections import redis
+from panther.db.cursor import Cursor
+from panther.generics import ListAPI
+from panther.pagination import Pagination
from panther.request import Request
-from panther.response import HTMLResponse, Response
+from panther.response import HTMLResponse, Response, StreamingResponse
from panther.throttling import Throttling
from panther.websocket import close_websocket_connection, send_message_to_websocket
@@ -85,8 +89,8 @@ async def res_request_data_with_output_model(request: Request):
@API(input_model=UserInputSerializer)
async def using_redis(request: Request):
- redis.set('ali', '1')
- logger.debug(f"{redis.get('ali') = }")
+ await redis.set('ali', '1')
+ logger.debug(f"{await redis.get('ali') = }")
return Response()
@@ -213,3 +217,30 @@ async def login_api():
@API(auth=True)
def logout_api(request: Request):
return request.user.logout()
+
+
+def reader():
+ from faker import Faker
+ import time
+ f = Faker()
+ for _ in range(5):
+ name = f.name()
+ print(f'{name=}')
+ yield name
+ time.sleep(1)
+
+
+@API()
+def stream_api():
+ # Test --> curl http://127.0.0.1:8000/stream/ --no-buffer
+ return StreamingResponse(reader())
+
+
+class PaginationAPI(ListAPI):
+ pagination = Pagination
+ sort_fields = ['username', 'id']
+ filter_fields = ['username']
+ search_fields = ['username']
+
+ async def objects(self, request: Request, **kwargs) -> Cursor | PantherDBCursor:
+ return await User.find()
diff --git a/example/app/urls.py b/example/app/urls.py
index b6defc1..7adb05c 100644
--- a/example/app/urls.py
+++ b/example/app/urls.py
@@ -39,4 +39,6 @@ async def test(*args, **kwargs):
'custom-response/': custom_response_class_api,
'image/': ImageAPI,
'logout/': logout_api,
+ 'stream/': stream_api,
+ 'pagination/': PaginationAPI,
}
diff --git a/example/core/configs.py b/example/core/configs.py
index d6ee78e..cff954d 100644
--- a/example/core/configs.py
+++ b/example/core/configs.py
@@ -50,9 +50,9 @@
DATABASE = {
'engine': {
- 'class': 'panther.db.connections.MongoDBConnection',
- # 'class': 'panther.db.connections.PantherDBConnection',
- 'host': f'mongodb://{DB_HOST}:27017/{DB_NAME}'
+ # 'class': 'panther.db.connections.MongoDBConnection',
+ 'class': 'panther.db.connections.PantherDBConnection',
+ # 'host': f'mongodb://{DB_HOST}:27017/{DB_NAME}'
},
# 'query': ...,
}
@@ -78,3 +78,5 @@ async def shutdown():
SHUTDOWN = 'core.configs.shutdown'
AUTO_REFORMAT = False
+
+TIMEZONE = 'UTC'
diff --git a/panther/_load_configs.py b/panther/_load_configs.py
index abdfea6..62b3e7d 100644
--- a/panther/_load_configs.py
+++ b/panther/_load_configs.py
@@ -21,6 +21,7 @@
'load_redis',
'load_startup',
'load_shutdown',
+ 'load_timezone',
'load_database',
'load_secret_key',
'load_monitoring',
@@ -39,10 +40,10 @@
logger = logging.getLogger('panther')
-def load_configs_module(_configs, /) -> dict:
+def load_configs_module(module_name: str, /) -> dict:
"""Read the config file as dict"""
- if _configs:
- _module = sys.modules[_configs]
+ if module_name:
+ _module = sys.modules[module_name]
else:
try:
_module = import_module('core.configs')
@@ -55,10 +56,10 @@ def load_redis(_configs: dict, /) -> None:
if redis_config := _configs.get('REDIS'):
# Check redis module installation
try:
- from redis import Redis as _Redis
- except ModuleNotFoundError as e:
+ from redis.asyncio import Redis
+ except ImportError as e:
raise import_error(e, package='redis')
- redis_class_path = redis_config.get('class', 'panther.db.connections.Redis')
+ redis_class_path = redis_config.get('class', 'panther.db.connections.RedisConnection')
redis_class = import_class(redis_class_path)
# We have to create another dict then pop the 'class' else we can't pass the tests
args = redis_config.copy()
@@ -68,12 +69,17 @@ def load_redis(_configs: dict, /) -> None:
def load_startup(_configs: dict, /) -> None:
if startup := _configs.get('STARTUP'):
- config['startup'] = import_class(startup)
+ config.STARTUP = import_class(startup)
def load_shutdown(_configs: dict, /) -> None:
if shutdown := _configs.get('SHUTDOWN'):
- config['shutdown'] = import_class(shutdown)
+ config.SHUTDOWN = import_class(shutdown)
+
+
+def load_timezone(_configs: dict, /) -> None:
+ if timezone := _configs.get('TIMEZONE'):
+ config.TIMEZONE = timezone
def load_database(_configs: dict, /) -> None:
@@ -87,42 +93,42 @@ def load_database(_configs: dict, /) -> None:
# We have to create another dict then pop the 'class' else we can't pass the tests
args = database_config['engine'].copy()
args.pop('class')
- config['database'] = engine_class(**args)
+ config.DATABASE = engine_class(**args)
if engine_class_path == 'panther.db.connections.PantherDBConnection':
- config['query_engine'] = BasePantherDBQuery
+ config.QUERY_ENGINE = BasePantherDBQuery
elif engine_class_path == 'panther.db.connections.MongoDBConnection':
- config['query_engine'] = BaseMongoDBQuery
+ config.QUERY_ENGINE = BaseMongoDBQuery
if 'query' in database_config:
- if config['query_engine']:
+ if config.QUERY_ENGINE:
logger.warning('`DATABASE.query` has already been filled.')
- config['query_engine'] = import_class(database_config['query'])
+ config.QUERY_ENGINE = import_class(database_config['query'])
def load_secret_key(_configs: dict, /) -> None:
if secret_key := _configs.get('SECRET_KEY'):
- config['secret_key'] = secret_key.encode()
+ config.SECRET_KEY = secret_key.encode()
def load_monitoring(_configs: dict, /) -> None:
if _configs.get('MONITORING'):
- config['monitoring'] = True
+ config.MONITORING = True
def load_throttling(_configs: dict, /) -> None:
if throttling := _configs.get('THROTTLING'):
- config['throttling'] = throttling
+ config.THROTTLING = throttling
def load_user_model(_configs: dict, /) -> None:
- config['user_model'] = import_class(_configs.get('USER_MODEL', 'panther.db.models.BaseUser'))
- config['models'].append(config['user_model'])
+ config.USER_MODEL = import_class(_configs.get('USER_MODEL', 'panther.db.models.BaseUser'))
+ config.MODELS.append(config.USER_MODEL)
def load_log_queries(_configs: dict, /) -> None:
if _configs.get('LOG_QUERIES'):
- config['log_queries'] = True
+ config.LOG_QUERIES = True
def load_middlewares(_configs: dict, /) -> None:
@@ -153,42 +159,39 @@ def load_middlewares(_configs: dict, /) -> None:
if issubclass(middleware_class, BaseMiddleware) is False:
raise _exception_handler(field='MIDDLEWARES', error='is not a sub class of BaseMiddleware')
- middleware_instance = middleware_class(**data)
- if isinstance(middleware_instance, BaseMiddleware | HTTPMiddleware):
- middlewares['http'].append(middleware_instance)
+ if middleware_class.__bases__[0] in (BaseMiddleware, HTTPMiddleware):
+ middlewares['http'].append((middleware_class, data))
- if isinstance(middleware_instance, BaseMiddleware | WebsocketMiddleware):
- middlewares['ws'].append(middleware_instance)
+ if middleware_class.__bases__[0] in (BaseMiddleware, WebsocketMiddleware):
+ middlewares['ws'].append((middleware_class, data))
- config['http_middlewares'] = middlewares['http']
- config['ws_middlewares'] = middlewares['ws']
- config['reversed_http_middlewares'] = middlewares['http'][::-1]
- config['reversed_ws_middlewares'] = middlewares['ws'][::-1]
+ config.HTTP_MIDDLEWARES = middlewares['http']
+ config.WS_MIDDLEWARES = middlewares['ws']
def load_auto_reformat(_configs: dict, /) -> None:
if _configs.get('AUTO_REFORMAT'):
- config['auto_reformat'] = True
+ config.AUTO_REFORMAT = True
def load_background_tasks(_configs: dict, /) -> None:
if _configs.get('BACKGROUND_TASKS'):
- config['background_tasks'] = True
+ config.BACKGROUND_TASKS = True
background_tasks.initialize()
def load_default_cache_exp(_configs: dict, /) -> None:
if default_cache_exp := _configs.get('DEFAULT_CACHE_EXP'):
- config['default_cache_exp'] = default_cache_exp
+ config.DEFAULT_CACHE_EXP = default_cache_exp
def load_authentication_class(_configs: dict, /) -> None:
"""Should be after `load_secret_key()`"""
if authentication := _configs.get('AUTHENTICATION'):
- config['authentication'] = import_class(authentication)
+ config.AUTHENTICATION = import_class(authentication)
if ws_authentication := _configs.get('WS_AUTHENTICATION'):
- config['ws_authentication'] = import_class(ws_authentication)
+ config.WS_AUTHENTICATION = import_class(ws_authentication)
load_jwt_config(_configs)
@@ -196,16 +199,16 @@ def load_authentication_class(_configs: dict, /) -> None:
def load_jwt_config(_configs: dict, /) -> None:
"""Only Collect JWT Config If Authentication Is JWTAuthentication"""
auth_is_jwt = (
- getattr(config['authentication'], '__name__', None) == 'JWTAuthentication' or
- getattr(config['ws_authentication'], '__name__', None) == 'QueryParamJWTAuthentication'
+ getattr(config.AUTHENTICATION, '__name__', None) == 'JWTAuthentication' or
+ getattr(config.WS_AUTHENTICATION, '__name__', None) == 'QueryParamJWTAuthentication'
)
jwt = _configs.get('JWTConfig', {})
if auth_is_jwt or jwt:
if 'key' not in jwt:
- if config['secret_key'] is None:
+ if config.SECRET_KEY is None:
raise _exception_handler(field='JWTConfig', error='`JWTConfig.key` or `SECRET_KEY` is required.')
- jwt['key'] = config['secret_key'].decode()
- config['jwt_config'] = JWTConfig(**jwt)
+ jwt['key'] = config.SECRET_KEY.decode()
+ config.JWT_CONFIG = JWTConfig(**jwt)
def load_urls(_configs: dict, /, urls: dict | None) -> None:
@@ -237,14 +240,14 @@ def load_urls(_configs: dict, /, urls: dict | None) -> None:
if not isinstance(urls, dict):
raise _exception_handler(field='URLs', error='should point to a dict.')
- config['flat_urls'] = flatten_urls(urls)
- config['urls'] = finalize_urls(config['flat_urls'])
- config['urls']['_panel'] = finalize_urls(flatten_urls(panel_urls))
+ config.FLAT_URLS = flatten_urls(urls)
+ config.URLS = finalize_urls(config.FLAT_URLS)
+ config.URLS['_panel'] = finalize_urls(flatten_urls(panel_urls))
def load_websocket_connections():
"""Should be after `load_redis()`"""
- if config['has_ws']:
+ if config.HAS_WS:
# Check `websockets`
try:
import websockets
@@ -253,7 +256,7 @@ def load_websocket_connections():
# Use the redis pubsub if `redis.is_connected`, else use the `multiprocessing.Manager`
pubsub_connection = redis.create_connection_for_websocket() if redis.is_connected else Manager()
- config['websocket_connections'] = WebsocketConnections(pubsub_connection=pubsub_connection)
+ config.WEBSOCKET_CONNECTIONS = WebsocketConnections(pubsub_connection=pubsub_connection)
def _exception_handler(field: str, error: str | Exception) -> PantherError:
diff --git a/panther/_utils.py b/panther/_utils.py
index e1371ba..e7fffcb 100644
--- a/panther/_utils.py
+++ b/panther/_utils.py
@@ -1,57 +1,20 @@
+import asyncio
import importlib
import logging
import re
import subprocess
import types
-import typing
+from typing import Any, Generator, Iterator, AsyncGenerator
from collections.abc import Callable
from traceback import TracebackException
-import orjson as json
-
-from panther import status
from panther.exceptions import PantherError
from panther.file_handler import File
logger = logging.getLogger('panther')
-async def _http_response_start(send: Callable, /, headers: dict, status_code: int) -> None:
- bytes_headers = [[k.encode(), str(v).encode()] for k, v in (headers or {}).items()]
- await send({
- 'type': 'http.response.start',
- 'status': status_code,
- 'headers': bytes_headers,
- })
-
-
-async def _http_response_body(send: Callable, /, body: bytes) -> None:
- if body:
- await send({'type': 'http.response.body', 'body': body})
- else: # body = b''
- await send({'type': 'http.response.body'})
-
-
-async def http_response(
- send: Callable,
- /,
- *,
- monitoring, # type: MonitoringMiddleware
- status_code: int,
- headers: dict,
- body: bytes = b'',
- exception: bool = False,
-) -> None:
- if exception:
- body = json.dumps({'detail': status.status_text[status_code]})
-
- await monitoring.after(status_code)
-
- await _http_response_start(send, headers=headers, status_code=status_code)
- await _http_response_body(send, body=body)
-
-
-def import_class(dotted_path: str, /) -> type[typing.Any]:
+def import_class(dotted_path: str, /) -> type[Any]:
"""
Example:
-------
@@ -148,4 +111,25 @@ def check_class_type_endpoint(endpoint: Callable) -> Callable:
logger.critical(f'You may have forgotten to inherit from GenericAPI on the {endpoint.__name__}()')
raise TypeError
- return endpoint.call_method
+ return endpoint().call_method
+
+
+def async_next(iterator: Iterator):
+ """
+ The StopIteration exception is a special case in Python,
+ particularly when it comes to asynchronous programming and the use of asyncio.
+ This is because StopIteration is not meant to be caught in the traditional sense;
+ it's used internally by Python to signal the end of an iteration.
+ """
+ try:
+ return next(iterator)
+ except StopIteration:
+ raise StopAsyncIteration
+
+
+async def to_async_generator(generator: Generator) -> AsyncGenerator:
+ while True:
+ try:
+ yield await asyncio.to_thread(async_next, iter(generator))
+ except StopAsyncIteration:
+ break
diff --git a/panther/app.py b/panther/app.py
index 2f4282f..8dff292 100644
--- a/panther/app.py
+++ b/panther/app.py
@@ -1,13 +1,18 @@
import functools
import logging
-from datetime import datetime, timedelta
+from datetime import timedelta
from typing import Literal
from orjson import JSONDecodeError
-from pydantic import ValidationError
+from pydantic import ValidationError, BaseModel
from panther._utils import is_function_async
-from panther.caching import cache_key, get_cached_response_data, set_cache_response
+from panther.caching import (
+ get_response_from_cache,
+ set_response_in_cache,
+ get_throttling_from_cache,
+ increment_throttling_in_cache
+)
from panther.configs import config
from panther.exceptions import (
APIError,
@@ -19,8 +24,8 @@
)
from panther.request import Request
from panther.response import Response
-from panther.throttling import Throttling, throttling_storage
-from panther.utils import round_datetime
+from panther.serializer import ModelSerializer
+from panther.throttling import Throttling
__all__ = ('API', 'GenericAPI')
@@ -31,11 +36,11 @@ class API:
def __init__(
self,
*,
- input_model=None,
- output_model=None,
+ input_model: type[ModelSerializer] | type[BaseModel] | None = None,
+ output_model: type[ModelSerializer] | type[BaseModel] | None = None,
auth: bool = False,
permissions: list | None = None,
- throttling: Throttling = None,
+ throttling: Throttling | None = None,
cache: bool = False,
cache_exp_time: timedelta | int | None = None,
methods: list[Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']] | None = None,
@@ -53,20 +58,20 @@ def __init__(
def __call__(self, func):
@functools.wraps(func)
async def wrapper(request: Request) -> Response:
- self.request: Request = request # noqa: Non-self attribute could not be type hinted
+ self.request = request
# 1. Check Method
if self.methods and self.request.method not in self.methods:
raise MethodNotAllowedAPIError
# 2. Authentication
- await self.handle_authentications()
+ await self.handle_authentication()
- # 3. Throttling
- self.handle_throttling()
+ # 3. Permissions
+ await self.handle_permission()
- # 4. Permissions
- await self.handle_permissions()
+ # 4. Throttling
+ await self.handle_throttling()
# 5. Validate Input
if self.request.method in ['POST', 'PUT', 'PATCH']:
@@ -74,7 +79,7 @@ async def wrapper(request: Request) -> Response:
# 6. Get Cached Response
if self.cache and self.request.method == 'GET':
- if cached := get_cached_response_data(request=self.request, cache_exp_time=self.cache_exp_time):
+ if cached := await get_response_from_cache(request=self.request, cache_exp_time=self.cache_exp_time):
return Response(data=cached.data, status_code=cached.status_code)
# 7. Put PathVariables and Request(If User Wants It) In kwargs
@@ -89,11 +94,12 @@ async def wrapper(request: Request) -> Response:
# 9. Clean Response
if not isinstance(response, Response):
response = Response(data=response)
- response._clean_data_with_output_model(output_model=self.output_model) # noqa: SLF001
+ if self.output_model and response.data:
+ response.data = response.apply_output_model(response.data, output_model=self.output_model)
# 10. Set New Response To Cache
if self.cache and self.request.method == 'GET':
- set_cache_response(
+ await set_response_in_cache(
request=self.request,
response=response,
cache_exp_time=self.cache_exp_time
@@ -107,26 +113,21 @@ async def wrapper(request: Request) -> Response:
return wrapper
- async def handle_authentications(self) -> None:
- auth_class = config['authentication']
+ async def handle_authentication(self) -> None:
if self.auth:
- if not auth_class:
+ if not config.AUTHENTICATION:
logger.critical('"AUTHENTICATION" has not been set in configs')
raise APIError
- user = await auth_class.authentication(self.request)
- self.request.user = user
-
- def handle_throttling(self) -> None:
- if throttling := self.throttling or config['throttling']:
- key = cache_key(self.request)
- time = round_datetime(datetime.now(), throttling.duration)
- throttling_key = f'{time}-{key}'
- if throttling_storage[throttling_key] > throttling.rate:
+ self.request.user = await config.AUTHENTICATION.authentication(self.request)
+
+ async def handle_throttling(self) -> None:
+ if throttling := self.throttling or config.THROTTLING:
+ if await get_throttling_from_cache(self.request, duration=throttling.duration) + 1 > throttling.rate:
raise ThrottlingAPIError
- throttling_storage[throttling_key] += 1
+ await increment_throttling_in_cache(self.request, duration=throttling.duration)
- async def handle_permissions(self) -> None:
+ async def handle_permission(self) -> None:
for perm in self.permissions:
if type(perm.authorization).__name__ != 'method':
logger.error(f'{perm.__name__}.authorization should be "classmethod"')
@@ -136,15 +137,17 @@ async def handle_permissions(self) -> None:
def handle_input_validation(self):
if self.input_model:
- validated_data = self.validate_input(model=self.input_model, request=self.request)
- self.request.set_validated_data(validated_data)
+ self.request.validated_data = self.validate_input(model=self.input_model, request=self.request)
@classmethod
def validate_input(cls, model, request: Request):
+ if isinstance(request.data, bytes):
+ raise BadRequestAPIError(detail='Content-Type is not valid')
+ if request.data is None:
+ raise BadRequestAPIError(detail='Request body is required')
try:
- if isinstance(request.data, bytes):
- raise BadRequestAPIError(detail='Content-Type is not valid')
- return model(**request.data)
+ # `request` will be ignored in regular `BaseModel`
+ return model(**request.data, request=request)
except ValidationError as validation_error:
error = {'.'.join(loc for loc in e['loc']): e['msg'] for e in validation_error.errors()}
raise BadRequestAPIError(detail=error)
@@ -153,8 +156,8 @@ def validate_input(cls, model, request: Request):
class GenericAPI:
- input_model = None
- output_model = None
+ input_model: type[ModelSerializer] | type[BaseModel] = None
+ output_model: type[ModelSerializer] | type[BaseModel] = None
auth: bool = False
permissions: list | None = None
throttling: Throttling | None = None
@@ -176,26 +179,27 @@ async def patch(self, *args, **kwargs):
async def delete(self, *args, **kwargs):
raise MethodNotAllowedAPIError
- @classmethod
- async def call_method(cls, *args, **kwargs):
- match kwargs['request'].method:
+ async def call_method(self, request: Request):
+ match request.method:
case 'GET':
- func = cls().get
+ func = self.get
case 'POST':
- func = cls().post
+ func = self.post
case 'PUT':
- func = cls().put
+ func = self.put
case 'PATCH':
- func = cls().patch
+ func = self.patch
case 'DELETE':
- func = cls().delete
+ func = self.delete
+ case _:
+ raise MethodNotAllowedAPIError
return await API(
- input_model=cls.input_model,
- output_model=cls.output_model,
- auth=cls.auth,
- permissions=cls.permissions,
- throttling=cls.throttling,
- cache=cls.cache,
- cache_exp_time=cls.cache_exp_time,
- )(func)(*args, **kwargs)
+ input_model=self.input_model,
+ output_model=self.output_model,
+ auth=self.auth,
+ permissions=self.permissions,
+ throttling=self.throttling,
+ cache=self.cache,
+ cache_exp_time=self.cache_exp_time,
+ )(func)(request=request)
diff --git a/panther/authentications.py b/panther/authentications.py
index 1861714..27454ea 100644
--- a/panther/authentications.py
+++ b/panther/authentications.py
@@ -67,7 +67,7 @@ async def authentication(cls, request: Request | Websocket) -> Model:
msg = 'Authorization keyword is not valid'
raise cls.exception(msg) from None
- if redis.is_connected and cls._check_in_cache(token=token):
+ if redis.is_connected and await cls._check_in_cache(token=token):
msg = 'User logged out'
raise cls.exception(msg) from None
@@ -82,8 +82,8 @@ def decode_jwt(cls, token: str) -> dict:
try:
return jwt.decode(
token=token,
- key=config['jwt_config'].key,
- algorithms=[config['jwt_config'].algorithm],
+ key=config.JWT_CONFIG.key,
+ algorithms=[config.JWT_CONFIG.algorithm],
)
except JWTError as e:
raise cls.exception(e) from None
@@ -95,7 +95,7 @@ async def get_user(cls, payload: dict) -> Model:
msg = 'Payload does not have `user_id`'
raise cls.exception(msg)
- user_model = config['user_model'] or cls.model
+ user_model = config.USER_MODEL or cls.model
if user := await user_model.find_one(id=user_id):
return user
@@ -107,9 +107,9 @@ def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'ac
"""Encode JWT from user_id."""
issued_at = datetime.now(timezone.utc).timestamp()
if token_type == 'access':
- expire = issued_at + config['jwt_config'].life_time
+ expire = issued_at + config.JWT_CONFIG.life_time
else:
- expire = issued_at + config['jwt_config'].refresh_life_time
+ expire = issued_at + config.JWT_CONFIG.refresh_life_time
claims = {
'token_type': token_type,
@@ -119,8 +119,8 @@ def encode_jwt(cls, user_id: str, token_type: Literal['access', 'refresh'] = 'ac
}
return jwt.encode(
claims,
- key=config['jwt_config'].key,
- algorithm=config['jwt_config'].algorithm,
+ key=config.JWT_CONFIG.key,
+ algorithm=config.JWT_CONFIG.algorithm,
)
@classmethod
@@ -132,24 +132,24 @@ def login(cls, user_id: str) -> dict:
}
@classmethod
- def logout(cls, raw_token: str) -> None:
+ async def logout(cls, raw_token: str) -> None:
*_, token = raw_token.split()
if redis.is_connected:
payload = cls.decode_jwt(token=token)
remaining_exp_time = payload['exp'] - time.time()
- cls._set_in_cache(token=token, exp=int(remaining_exp_time))
+ await cls._set_in_cache(token=token, exp=int(remaining_exp_time))
else:
logger.error('`redis` middleware is required for `logout()`')
@classmethod
- def _set_in_cache(cls, token: str, exp: int) -> None:
+ async def _set_in_cache(cls, token: str, exp: int) -> None:
key = generate_hash_value_from_string(token)
- redis.set(key, b'', ex=exp)
+ await redis.set(key, b'', ex=exp)
@classmethod
- def _check_in_cache(cls, token: str) -> bool:
+ async def _check_in_cache(cls, token: str) -> bool:
key = generate_hash_value_from_string(token)
- return bool(redis.exists(key))
+ return bool(await redis.exists(key))
@staticmethod
def exception(message: str | JWTError | UnicodeEncodeError, /) -> type[AuthenticationAPIError]:
diff --git a/panther/background_tasks.py b/panther/background_tasks.py
index c925375..9165b04 100644
--- a/panther/background_tasks.py
+++ b/panther/background_tasks.py
@@ -9,7 +9,6 @@
from panther._utils import is_function_async
from panther.utils import Singleton
-
__all__ = (
'BackgroundTask',
'background_tasks',
diff --git a/panther/base_request.py b/panther/base_request.py
index 314e44a..5e012f7 100644
--- a/panther/base_request.py
+++ b/panther/base_request.py
@@ -60,8 +60,6 @@ def __init__(self, scope: dict, receive: Callable, send: Callable):
self.scope = scope
self.asgi_send = send
self.asgi_receive = receive
- self._data = ...
- self._validated_data = None
self._headers: Headers | None = None
self._params: dict | None = None
self.user: Model | None = None
@@ -116,25 +114,22 @@ def collect_path_variables(self, found_path: str):
}
def clean_parameters(self, func: Callable) -> dict:
- kwargs = {}
+ kwargs = self.path_variables.copy()
+
for variable_name, variable_type in func.__annotations__.items():
# Put Request/ Websocket In kwargs (If User Wants It)
if issubclass(variable_type, BaseRequest):
kwargs[variable_name] = self
- continue
-
- for name, value in self.path_variables.items():
- if name == variable_name:
- # Check the type and convert the value
- if variable_type is bool:
- kwargs[name] = value.lower() not in ['false', '0']
-
- elif variable_type is int:
- try:
- kwargs[name] = int(value)
- except ValueError:
- raise InvalidPathVariableAPIError(value=value, variable_type=variable_type)
- else:
- kwargs[name] = value
- return kwargs
+ elif variable_name in kwargs:
+ # Cast To Boolean
+ if variable_type is bool:
+ kwargs[variable_name] = kwargs[variable_name].lower() not in ['false', '0']
+
+ # Cast To Int
+ elif variable_type is int:
+ try:
+ kwargs[variable_name] = int(kwargs[variable_name])
+ except ValueError:
+ raise InvalidPathVariableAPIError(value=kwargs[variable_name], variable_type=variable_type)
+ return kwargs
diff --git a/panther/base_websocket.py b/panther/base_websocket.py
index 7da3b74..5ba79e6 100644
--- a/panther/base_websocket.py
+++ b/panther/base_websocket.py
@@ -1,11 +1,8 @@
from __future__ import annotations
import asyncio
-import contextlib
import logging
-from multiprocessing import Manager
from multiprocessing.managers import SyncManager
-from threading import Thread
from typing import TYPE_CHECKING, Literal
import orjson as json
@@ -15,16 +12,17 @@
from panther.configs import config
from panther.db.connections import redis
from panther.exceptions import AuthenticationAPIError, InvalidPathVariableAPIError
+from panther.monitoring import Monitoring
from panther.utils import Singleton, ULID
if TYPE_CHECKING:
- from redis import Redis
+ from redis.asyncio import Redis
logger = logging.getLogger('panther')
class PubSub:
- def __init__(self, manager):
+ def __init__(self, manager: SyncManager):
self._manager = manager
self._subscribers = self._manager.list()
@@ -38,17 +36,8 @@ def publish(self, msg):
queue.put(msg)
-class WebsocketListener(Thread):
- def __init__(self):
- super().__init__(target=config['websocket_connections'], daemon=True)
-
- def run(self):
- with contextlib.suppress(Exception):
- super().run()
-
-
class WebsocketConnections(Singleton):
- def __init__(self, pubsub_connection: Redis | Manager):
+ def __init__(self, pubsub_connection: Redis | SyncManager):
self.connections = {}
self.connections_count = 0
self.pubsub_connection = pubsub_connection
@@ -56,20 +45,21 @@ def __init__(self, pubsub_connection: Redis | Manager):
if isinstance(self.pubsub_connection, SyncManager):
self.pubsub = PubSub(manager=self.pubsub_connection)
- def __call__(self):
+ async def __call__(self):
if isinstance(self.pubsub_connection, SyncManager):
- # We don't have redis connection, so use the `multiprocessing.PubSub`
+ # We don't have redis connection, so use the `multiprocessing.Manager`
+ self.pubsub: PubSub
queue = self.pubsub.subscribe()
logger.info("Subscribed to 'websocket_connections' queue")
while True:
received_message = queue.get()
- self._handle_received_message(received_message=received_message)
+ await self._handle_received_message(received_message=received_message)
else:
# We have a redis connection, so use it for pubsub
self.pubsub = self.pubsub_connection.pubsub()
- self.pubsub.subscribe('websocket_connections')
+ await self.pubsub.subscribe('websocket_connections')
logger.info("Subscribed to 'websocket_connections' channel")
- for channel_data in self.pubsub.listen():
+ async for channel_data in self.pubsub.listen():
match channel_data['type']:
# Subscribed
case 'subscribe':
@@ -78,12 +68,12 @@ def __call__(self):
# Message Received
case 'message':
loaded_data = json.loads(channel_data['data'].decode())
- self._handle_received_message(received_message=loaded_data)
+ await self._handle_received_message(received_message=loaded_data)
case unknown_type:
- logger.debug(f'Unknown Channel Type: {unknown_type}')
+ logger.error(f'Unknown Channel Type: {unknown_type}')
- def _handle_received_message(self, received_message):
+ async def _handle_received_message(self, received_message):
if (
isinstance(received_message, dict)
and (connection_id := received_message.get('connection_id'))
@@ -94,120 +84,156 @@ def _handle_received_message(self, received_message):
# Check Action of WS
match received_message['action']:
case 'send':
- asyncio.run(self.connections[connection_id].send(data=received_message['data']))
+ await self.connections[connection_id].send(data=received_message['data'])
case 'close':
- with contextlib.suppress(RuntimeError):
- asyncio.run(self.connections[connection_id].close(
- code=received_message['data']['code'],
- reason=received_message['data']['reason']
- ))
- # We are trying to disconnect the connection between a thread and a user
- # from another thread, it's working, but we have to find another solution for it
- #
- # Error:
- # Task > got Future
- # >
- # attached to a different loop
+ await self.connections[connection_id].close(
+ code=received_message['data']['code'],
+ reason=received_message['data']['reason']
+ )
case unknown_action:
- logger.debug(f'Unknown Message Action: {unknown_action}')
+ logger.error(f'Unknown Message Action: {unknown_action}')
- def publish(self, connection_id: str, action: Literal['send', 'close'], data: any):
+ async def publish(self, connection_id: str, action: Literal['send', 'close'], data: any):
publish_data = {'connection_id': connection_id, 'action': action, 'data': data}
if redis.is_connected:
- redis.publish('websocket_connections', json.dumps(publish_data))
+ await redis.publish('websocket_connections', json.dumps(publish_data))
else:
self.pubsub.publish(publish_data)
- async def new_connection(self, connection: Websocket) -> None:
+ async def listen(self, connection: Websocket) -> None:
# 1. Authentication
- connection_closed = await self.handle_authentication(connection=connection)
+ if not connection.is_rejected:
+ await self.handle_authentication(connection=connection)
# 2. Permissions
- connection_closed = connection_closed or await self.handle_permissions(connection=connection)
+ if not connection.is_rejected:
+ await self.handle_permissions(connection=connection)
- if connection_closed:
- # Don't run the following code...
+ if connection.is_rejected:
+ # Connection is rejected so don't continue the flow ...
return None
# 3. Put PathVariables and Request(If User Wants It) In kwargs
try:
kwargs = connection.clean_parameters(connection.connect)
except InvalidPathVariableAPIError as e:
- return await connection.close(status.WS_1000_NORMAL_CLOSURE, reason=str(e))
+ connection.log(e.detail)
+ return await connection.close()
# 4. Connect To Endpoint
await connection.connect(**kwargs)
- if not hasattr(connection, '_connection_id'):
- # User didn't even call the `self.accept()` so close the connection
- await connection.close()
+ # 5. Check Connection
+ if not connection.is_connected and not connection.is_rejected:
+ # User didn't call the `self.accept()` or `self.close()` so we `close()` the connection (reject)
+ return await connection.close()
- # 5. Connection Accepted
- if connection.is_connected:
- self.connections_count += 1
+ # 6. Listen Connection
+ await self.listen_connection(connection=connection)
+
+ async def listen_connection(self, connection: Websocket):
+ while True:
+ response = await connection.asgi_receive()
+ if response['type'] == 'websocket.connect':
+ continue
+
+ if response['type'] == 'websocket.disconnect':
+ # Connect has to be closed by the client
+ await self.connection_closed(connection=connection)
+ break
- # Save New ConnectionID
- self.connections[connection.connection_id] = connection
+ if 'text' in response:
+ await connection.receive(data=response['text'])
+ else:
+ await connection.receive(data=response['bytes'])
- def remove_connection(self, connection: Websocket) -> None:
+ async def connection_accepted(self, connection: Websocket) -> None:
+ # Generate ConnectionID
+ connection._connection_id = ULID.new()
+
+ # Save Connection
+ self.connections[connection.connection_id] = connection
+
+ # Logs
+ await connection.monitoring.after('Accepted')
+ connection.log(f'Accepted {connection.connection_id}')
+
+ async def connection_closed(self, connection: Websocket, from_server: bool = False) -> None:
if connection.is_connected:
- self.connections_count -= 1
del self.connections[connection.connection_id]
+ await connection.monitoring.after('Closed')
+ connection.log(f'Closed {connection.connection_id}')
+ connection._connection_id = ''
+
+ elif connection.is_rejected is False and from_server is True:
+ await connection.monitoring.after('Rejected')
+ connection.log('Rejected')
+ connection._is_rejected = True
+
+ async def start(self):
+ """
+ Start Websocket Listener (Redis/ Queue)
+
+ Cause of --preload in gunicorn we have to keep this function here,
+ and we can't move it to __init__ of Panther
+
+ * Each process should start this listener for itself,
+ but they have same Manager()
+ """
+
+ if config.HAS_WS:
+ # Schedule the async function to run in the background,
+ # We don't need to await for this task
+ asyncio.create_task(self())
@classmethod
- async def handle_authentication(cls, connection: Websocket) -> bool:
+ async def handle_authentication(cls, connection: Websocket):
"""Return True if connection is closed, False otherwise."""
if connection.auth:
- if not config.ws_authentication:
- logger.critical('"WS_AUTHENTICATION" has not been set in configs')
- await connection.close(reason='Authentication Error')
- return True
- try:
- connection.user = await config.ws_authentication.authentication(connection)
- except AuthenticationAPIError as e:
- await connection.close(reason=e.detail)
- return False
+ if not config.WS_AUTHENTICATION:
+ logger.critical('`WS_AUTHENTICATION` has not been set in configs')
+ await connection.close()
+ else:
+ try:
+ connection.user = await config.WS_AUTHENTICATION.authentication(connection)
+ except AuthenticationAPIError as e:
+ connection.log(e.detail)
+ await connection.close()
@classmethod
- async def handle_permissions(cls, connection: Websocket) -> bool:
+ async def handle_permissions(cls, connection: Websocket):
"""Return True if connection is closed, False otherwise."""
for perm in connection.permissions:
if type(perm.authorization).__name__ != 'method':
- logger.error(f'{perm.__name__}.authorization should be "classmethod"')
- await connection.close(reason='Permission Denied')
- return True
- if await perm.authorization(connection) is False:
- await connection.close(reason='Permission Denied')
- return True
- return False
+ logger.critical(f'{perm.__name__}.authorization should be "classmethod"')
+ await connection.close()
+ elif await perm.authorization(connection) is False:
+ connection.log('Permission Denied')
+ await connection.close()
class Websocket(BaseRequest):
- is_connected: bool = False
auth: bool = False
permissions: list = []
+ _connection_id: str = ''
+ _is_rejected: bool = False
+ _monitoring: Monitoring
def __init_subclass__(cls, **kwargs):
if cls.__module__ != 'panther.websocket':
- config['has_ws'] = True
+ config.HAS_WS = True
async def connect(self, **kwargs) -> None:
- """Check your conditions then self.accept() the connection"""
- await self.accept()
-
- async def accept(self, subprotocol: str | None = None, headers: dict | None = None) -> None:
- await self.asgi_send({'type': 'websocket.accept', 'subprotocol': subprotocol, 'headers': headers or {}})
- self.is_connected = True
-
- # Generate ConnectionID
- self._connection_id = ULID.new()
-
- logger.debug(f'Accepting WS Connection {self._connection_id}')
+ pass
async def receive(self, data: str | bytes) -> None:
pass
+ async def accept(self, subprotocol: str | None = None, headers: dict | None = None) -> None:
+ await self.asgi_send({'type': 'websocket.accept', 'subprotocol': subprotocol, 'headers': headers or {}})
+ await config.WEBSOCKET_CONNECTIONS.connection_accepted(connection=self)
+
async def send(self, data: any = None) -> None:
logger.debug(f'Sending WS Message to {self.connection_id}')
if data:
@@ -225,32 +251,26 @@ async def send_bytes(self, bytes_data: bytes) -> None:
await self.asgi_send({'type': 'websocket.send', 'bytes': bytes_data})
async def close(self, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = '') -> None:
- connection_id = getattr(self, '_connection_id', '')
- logger.debug(f'Closing WS Connection {connection_id} Code: {code}')
- self.is_connected = False
- config['websocket_connections'].remove_connection(self)
await self.asgi_send({'type': 'websocket.close', 'code': code, 'reason': reason})
+ await config.WEBSOCKET_CONNECTIONS.connection_closed(connection=self, from_server=True)
- async def listen(self) -> None:
- while self.is_connected:
- response = await self.asgi_receive()
- if response['type'] == 'websocket.connect':
- continue
-
- if response['type'] == 'websocket.disconnect':
- break
+ @property
+ def connection_id(self) -> str:
+ if self.is_connected:
+ return self._connection_id
+ logger.error('You should first `self.accept()` the connection then use the `self.connection_id`')
- if 'text' in response:
- await self.receive(data=response['text'])
- else:
- await self.receive(data=response['bytes'])
+ @property
+ def is_connected(self) -> bool:
+ return bool(self._connection_id)
- def set_connection_id(self, connection_id: str) -> None:
- self._connection_id = connection_id
+ @property
+ def is_rejected(self) -> bool:
+ return self._is_rejected
@property
- def connection_id(self) -> str:
- connection_id = getattr(self, '_connection_id', None)
- if connection_id is None:
- logger.error('You should first `self.accept()` the connection then use the `self.connection_id`')
- return connection_id
+ def monitoring(self) -> Monitoring:
+ return self._monitoring
+
+ def log(self, message: str):
+ logger.debug(f'WS {self.path} --> {message}')
diff --git a/panther/caching.py b/panther/caching.py
index 5738679..b8c5db9 100644
--- a/panther/caching.py
+++ b/panther/caching.py
@@ -9,51 +9,53 @@
from panther.db.connections import redis
from panther.request import Request
from panther.response import Response, ResponseDataTypes
+from panther.throttling import throttling_storage
from panther.utils import generate_hash_value_from_string, round_datetime
logger = logging.getLogger('panther')
caches = {}
-CachedResponse = namedtuple('Cached', ['data', 'status_code'])
+CachedResponse = namedtuple('CachedResponse', ['data', 'status_code'])
-def cache_key(request: Request, /) -> str:
+def api_cache_key(request: Request, cache_exp_time: timedelta | None = None) -> str:
client = request.user and request.user.id or request.client.ip
query_params_hash = generate_hash_value_from_string(request.scope['query_string'].decode('utf-8'))
- return f'{client}-{request.path}-{query_params_hash}-{request.validated_data}'
+ key = f'{client}-{request.path}-{query_params_hash}-{request.validated_data}'
-
-def local_cache_key(*, request: Request, cache_exp_time: timedelta | None = None) -> str:
- key = cache_key(request)
if cache_exp_time:
time = round_datetime(datetime.now(), cache_exp_time)
return f'{time}-{key}'
- else:
- return key
+
+ return key
+
+
+def throttling_cache_key(request: Request, duration: timedelta) -> str:
+ client = request.user and request.user.id or request.client.ip
+ time = round_datetime(datetime.now(), duration)
+ return f'{time}-{client}-{request.path}'
-def get_cached_response_data(*, request: Request, cache_exp_time: timedelta) -> CachedResponse | None:
+async def get_response_from_cache(*, request: Request, cache_exp_time: timedelta) -> CachedResponse | None:
"""
If redis.is_connected:
Get Cached Data From Redis
else:
Get Cached Data From Memory
"""
- if redis.is_connected: # noqa: Unresolved References
- key = cache_key(request)
- data = (redis.get(key) or b'{}').decode()
+ if redis.is_connected:
+ key = api_cache_key(request=request)
+ data = (await redis.get(key) or b'{}').decode()
if cached_value := json.loads(data):
return CachedResponse(*cached_value)
else:
- key = local_cache_key(request=request, cache_exp_time=cache_exp_time)
+ key = api_cache_key(request=request, cache_exp_time=cache_exp_time)
if cached_value := caches.get(key):
return CachedResponse(*cached_value)
- return None
-
-def set_cache_response(*, request: Request, response: Response, cache_exp_time: timedelta | int) -> None:
+async def set_response_in_cache(*, request: Request, response: Response, cache_exp_time: timedelta | int) -> None:
"""
If redis.is_connected:
Cache The Data In Redis
@@ -64,9 +66,9 @@ def set_cache_response(*, request: Request, response: Response, cache_exp_time:
cache_data: tuple[ResponseDataTypes, int] = (response.data, response.status_code)
if redis.is_connected:
- key = cache_key(request)
+ key = api_cache_key(request=request)
- cache_exp_time = cache_exp_time or config['default_cache_exp']
+ cache_exp_time = cache_exp_time or config.DEFAULT_CACHE_EXP
cache_data: bytes = json.dumps(cache_data)
if not isinstance(cache_exp_time, timedelta | int | NoneType):
@@ -76,15 +78,48 @@ def set_cache_response(*, request: Request, response: Response, cache_exp_time:
if cache_exp_time is None:
logger.warning(
'your response are going to cache in redis forever '
- '** set DEFAULT_CACHE_EXP in configs or pass the cache_exp_time in @API.get() for prevent this **'
+ '** set DEFAULT_CACHE_EXP in `configs` or set the `cache_exp_time` in `@API.get()` to prevent this **'
)
- redis.set(key, cache_data)
+ await redis.set(key, cache_data)
else:
- redis.set(key, cache_data, ex=cache_exp_time)
+ await redis.set(key, cache_data, ex=cache_exp_time)
else:
- key = local_cache_key(request=request, cache_exp_time=cache_exp_time)
+ key = api_cache_key(request=request, cache_exp_time=cache_exp_time)
caches[key] = cache_data
if cache_exp_time:
- logger.info('"cache_exp_time" is not very accurate when redis is not connected.')
+ logger.info('`cache_exp_time` is not very accurate when `redis` is not connected.')
+
+
+async def get_throttling_from_cache(request: Request, duration: timedelta) -> int:
+ """
+ If redis.is_connected:
+ Get Cached Data From Redis
+ else:
+ Get Cached Data From Memory
+ """
+ key = throttling_cache_key(request=request, duration=duration)
+
+ if redis.is_connected:
+ data = (await redis.get(key) or b'0').decode()
+ return json.loads(data)
+
+ else:
+ return throttling_storage[key]
+
+
+async def increment_throttling_in_cache(request: Request, duration: timedelta) -> None:
+ """
+ If redis.is_connected:
+ Increment The Data In Redis
+ else:
+ Increment The Data In Memory
+ """
+ key = throttling_cache_key(request=request, duration=duration)
+
+ if redis.is_connected:
+ await redis.incrby(key, amount=1)
+
+ else:
+ throttling_storage[key] += 1
diff --git a/panther/cli/create_command.py b/panther/cli/create_command.py
index 2800564..48d4779 100644
--- a/panther/cli/create_command.py
+++ b/panther/cli/create_command.py
@@ -16,7 +16,7 @@
AUTO_REFORMAT_PART,
DATABASE_PANTHERDB_PART,
DATABASE_MONGODB_PART,
- USER_MODEL_PART,
+ USER_MODEL_PART, REDIS_PART,
)
from panther.cli.utils import cli_error
@@ -34,6 +34,7 @@ def __init__(self):
self.base_directory = '.'
self.database = '0'
self.database_encryption = False
+ self.redis = False
self.authentication = False
self.monitoring = True
self.log_queries = True
@@ -60,7 +61,7 @@ def __init__(self):
},
{
'field': 'database',
- 'message': ' 0: PantherDB\n 1: MongoDB (Required `pymongo`)\n 2: No Database\nChoose Your Database (default is 0)',
+ 'message': ' 0: PantherDB (File-Base, No Requirements)\n 1: MongoDB (Required `pymongo`)\n 2: No Database\nChoose Your Database (default is 0)',
'validation_func': lambda x: x in ['0', '1', '2'],
'error_message': "Invalid Choice, '{}' not in ['0', '1', '2']",
},
@@ -70,6 +71,11 @@ def __init__(self):
'is_boolean': True,
'condition': "self.database == '0'"
},
+ {
+ 'field': 'redis',
+ 'message': 'Do You Want To Use Redis (Required `redis`)',
+ 'is_boolean': True,
+ },
{
'field': 'authentication',
'message': 'Do You Want To Use JWT Authentication (Required `python-jose`)',
@@ -77,7 +83,7 @@ def __init__(self):
},
{
'field': 'monitoring',
- 'message': 'Do You Want To Use Built-in Monitoring',
+ 'message': 'Do You Want To Use Built-in Monitoring (Required `watchfiles`)',
'is_boolean': True,
},
{
@@ -87,7 +93,7 @@ def __init__(self):
},
{
'field': 'auto_reformat',
- 'message': 'Do You Want To Use Auto Reformat (Required `ruff`)',
+ 'message': 'Do You Want To Use Auto Code Reformat (Required `ruff`)',
'is_boolean': True,
},
]
@@ -139,6 +145,8 @@ def _create_file(self, *, path: str, data: str):
log_queries_part = LOG_QUERIES_PART if self.log_queries else ''
auto_reformat_part = AUTO_REFORMAT_PART if self.auto_reformat else ''
database_encryption = 'True' if self.database_encryption else 'False'
+ database_extension = 'pdb' if self.database_encryption else 'json'
+ redis_part = REDIS_PART if self.redis else ''
if self.database == '0':
database_part = DATABASE_PANTHERDB_PART
elif self.database == '1':
@@ -153,6 +161,8 @@ def _create_file(self, *, path: str, data: str):
data = data.replace('{AUTO_REFORMAT}', auto_reformat_part)
data = data.replace('{DATABASE}', database_part)
data = data.replace('{PANTHERDB_ENCRYPTION}', database_encryption) # Should be after `DATABASE`
+ data = data.replace('{PANTHERDB_EXTENSION}', database_extension) # Should be after `DATABASE`
+ data = data.replace('{REDIS}', redis_part)
data = data.replace('{PROJECT_NAME}', self.project_name.lower())
data = data.replace('{PANTHER_VERSION}', version())
@@ -167,19 +177,19 @@ def collect_creation_data(self):
field_name = question.pop('field')
question['default'] = getattr(self, field_name)
is_boolean = question.pop('is_boolean', False)
- clean_output = str # Do Nothing
+ convert_output = str # Do Nothing
if is_boolean:
question['message'] += f' (default is {self._to_str(question["default"])})'
question['validation_func'] = self._is_boolean
question['error_message'] = "Invalid Choice, '{}' not in ['y', 'n']"
- clean_output = self._to_boolean
+ convert_output = self._to_boolean
# Check Question Condition
if 'condition' in question and eval(question.pop('condition')) is False:
print(flush=True)
# Ask Question
else:
- setattr(self, field_name, clean_output(self.ask(**question)))
+ setattr(self, field_name, convert_output(self.ask(**question)))
self.progress(i + 1)
def ask(
@@ -192,6 +202,7 @@ def ask(
) -> str:
value = Prompt.ask(message, console=self.input_console).lower() or default
while not validation_func(value):
+ # Remove the last line, show error message and ask again
[print(end=self.REMOVE_LAST_LINE, flush=True) for _ in range(message.count('\n') + 1)]
error = validation_func(value, return_error=True) if show_validation_error else value
self.console.print(error_message.format(error), style='bold red')
diff --git a/panther/cli/monitor_command.py b/panther/cli/monitor_command.py
index 9916b23..8d7bf0e 100644
--- a/panther/cli/monitor_command.py
+++ b/panther/cli/monitor_command.py
@@ -1,71 +1,97 @@
import contextlib
+import logging
import os
+import signal
from collections import deque
from pathlib import Path
from rich import box
from rich.align import Align
from rich.console import Group
-from rich.layout import Layout
from rich.live import Live
from rich.panel import Panel
from rich.table import Table
-from watchfiles import watch
-from panther.cli.utils import cli_error
+from panther.cli.utils import import_error
from panther.configs import config
+with contextlib.suppress(ImportError):
+ from watchfiles import watch
-def monitor() -> None:
- monitoring_log_file = Path(config['base_dir'] / 'logs' / 'monitoring.log')
+loggerr = logging.getLogger('panther')
- def _generate_table(rows: deque) -> Panel:
- layout = Layout()
- rows = list(rows)
- _, lines = os.get_terminal_size()
+class Monitoring:
+ def __init__(self):
+ self.rows = deque()
+ self.monitoring_log_file = Path(config.BASE_DIR / 'logs' / 'monitoring.log')
+
+ def monitor(self) -> None:
+ if error := self.initialize():
+ # Don't continue if initialize() has error
+ loggerr.error(error)
+ return
+
+ with (
+ self.monitoring_log_file.open() as f,
+ Live(
+ self.generate_table(),
+ vertical_overflow='visible',
+ screen=True,
+ ) as live,
+ contextlib.suppress(KeyboardInterrupt)
+ ):
+ f.readlines() # Set cursor at the end of the file
+
+ for _ in watch(self.monitoring_log_file):
+ for line in f.readlines():
+ self.rows.append(line.split('|'))
+ live.update(self.generate_table())
+
+ def initialize(self) -> str:
+ # Check requirements
+ try:
+ from watchfiles import watch
+ except ImportError as e:
+ return import_error(e, package='watchfiles').args[0]
+
+ # Check log file
+ if not self.monitoring_log_file.exists():
+ return f'`{self.monitoring_log_file}` file not found. (Make sure `MONITORING` is `True` in `configs`'
+
+ # Initialize Deque
+ self.update_rows()
+
+ # Register the signal handler
+ signal.signal(signal.SIGWINCH, self.update_rows)
+
+ def generate_table(self) -> Panel:
+ # 2023-03-24 01:42:52 | GET | /user/317/ | 127.0.0.1:48856 | 0.0366 ms | 200
table = Table(box=box.MINIMAL_DOUBLE_HEAD)
table.add_column('Datetime', justify='center', style='magenta', no_wrap=True)
- table.add_column('Method', justify='center', style='cyan')
- table.add_column('Path', justify='center', style='cyan')
+ table.add_column('Method', justify='center', style='cyan', no_wrap=True)
+ table.add_column('Path', justify='center', style='cyan', no_wrap=True)
table.add_column('Client', justify='center', style='cyan')
table.add_column('Response Time', justify='center', style='blue')
- table.add_column('Status Code', justify='center', style='blue')
+ table.add_column('Status', justify='center', style='blue', no_wrap=True)
- for row in rows[-lines:]: # It will give us "lines" last lines of "rows"
+ for row in self.rows:
table.add_row(*row)
- layout.update(table)
return Panel(
Align.center(Group(table)),
box=box.ROUNDED,
- padding=(1, 2),
+ padding=(0, 2),
title='Monitoring',
border_style='bright_blue',
)
- if not monitoring_log_file.exists():
- return cli_error('Monitoring file not found. (You need at least one monitoring record for this action)')
+ def update_rows(self, *args, **kwargs):
+ # Top = -4, Bottom = -2 --> -6
+ # Print of each line needs two line, so --> x // 2
+ lines = (os.get_terminal_size()[1] - 6) // 2
+ self.rows = deque(self.rows, maxlen=lines)
- with monitoring_log_file.open() as f:
- f.readlines() # Set cursor at the end of file
- _, init_lines_count = os.get_terminal_size()
- messages = deque(maxlen=init_lines_count - 10) # Save space for header and footer
-
- with (
- Live(
- _generate_table(messages),
- auto_refresh=False,
- vertical_overflow='visible',
- screen=True,
- ) as live,
- contextlib.suppress(KeyboardInterrupt),
- ):
- for _ in watch(monitoring_log_file):
- data = f.readline().split('|')
- # 2023-03-24 01:42:52 | GET | /user/317/ | 127.0.0.1:48856 | 0.0366 ms | 200
- messages.append(data)
- live.update(_generate_table(messages))
- live.refresh()
+monitor = Monitoring().monitor
diff --git a/panther/cli/template.py b/panther/cli/template.py
index 6e10437..187a902 100644
--- a/panther/cli/template.py
+++ b/panther/cli/template.py
@@ -11,6 +11,7 @@
from panther.app import API
from panther.request import Request
from panther.response import Response
+from panther.utils import timezone_now
@API()
@@ -24,7 +25,7 @@ async def info_api(request: Request):
'panther_version': version(),
'method': request.method,
'query_params': request.query_params,
- 'datetime_now': datetime.now().isoformat(),
+ 'datetime_now': timezone_now().isoformat(),
'user_agent': request.headers.user_agent,
}
return Response(data=data, status_code=status.HTTP_202_ACCEPTED)
@@ -55,19 +56,19 @@ async def info_api(request: Request):
{PROJECT_NAME} Project (Generated by Panther on %s)
\"""
-from datetime import timedelta
from pathlib import Path
-from panther.throttling import Throttling
from panther.utils import load_env
BASE_DIR = Path(__name__).resolve().parent
env = load_env(BASE_DIR / '.env')
-SECRET_KEY = env['SECRET_KEY']{DATABASE}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT}
+SECRET_KEY = env['SECRET_KEY']{DATABASE}{REDIS}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT}
# More Info: https://PantherPy.GitHub.io/urls/
URLs = 'core.urls.url_routing'
+
+TIMEZONE = 'UTC'
""" % datetime.now().date().isoformat()
env = """SECRET_KEY='%s'
@@ -128,15 +129,16 @@ async def info_api(request: Request):
from panther.request import Request
from panther.response import Response
from panther.throttling import Throttling
-from panther.utils import load_env
+from panther.utils import load_env, timezone_now
BASE_DIR = Path(__name__).resolve().parent
env = load_env(BASE_DIR / '.env')
-SECRET_KEY = env['SECRET_KEY']{DATABASE}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT}
+SECRET_KEY = env['SECRET_KEY']{DATABASE}{REDIS}{USER_MODEL}{AUTHENTICATION}{MONITORING}{LOG_QUERIES}{AUTO_REFORMAT}
InfoThrottling = Throttling(rate=5, duration=timedelta(minutes=1))
+TIMEZONE = 'UTC'
@API()
async def hello_world_api():
@@ -149,7 +151,7 @@ async def info_api(request: Request):
'panther_version': version(),
'method': request.method,
'query_params': request.query_params,
- 'datetime_now': datetime.now().isoformat(),
+ 'datetime_now': timezone_now().isoformat(),
'user_agent': request.headers.user_agent,
}
return Response(data=data, status_code=status.HTTP_202_ACCEPTED)
@@ -157,6 +159,7 @@ async def info_api(request: Request):
url_routing = {
'/': hello_world_api,
+ 'info/': info_api,
}
app = Panther(__name__, configs=__name__, urls=url_routing)
@@ -170,17 +173,19 @@ async def info_api(request: Request):
}
DATABASE_PANTHERDB_PART = """
-# More Info: Https://PantherPy.GitHub.io/middlewares/
+
+# More Info: https://PantherPy.GitHub.io/database/
DATABASE = {
'engine': {
'class': 'panther.db.connections.PantherDBConnection',
- 'path': BASE_DIR,
+ 'path': BASE_DIR / 'database.{PANTHERDB_EXTENSION}',
'encryption': {PANTHERDB_ENCRYPTION}
}
}"""
DATABASE_MONGODB_PART = """
-# More Info: Https://PantherPy.GitHub.io/middlewares/
+
+# More Info: https://PantherPy.GitHub.io/database/
DATABASE = {
'engine': {
'class': 'panther.db.connections.MongoDBConnection',
@@ -190,6 +195,16 @@ async def info_api(request: Request):
}
}"""
+REDIS_PART = """
+
+# More Info: https://PantherPy.GitHub.io/redis/
+REDIS = {
+ 'class': 'panther.db.connections.RedisConnection',
+ 'host': '127.0.0.1',
+ 'port': 6379,
+ 'db': 0,
+}"""
+
USER_MODEL_PART = """
# More Info: https://PantherPy.GitHub.io/configs/#user_model
diff --git a/panther/cli/utils.py b/panther/cli/utils.py
index 2df3400..7e2fdac 100644
--- a/panther/cli/utils.py
+++ b/panther/cli/utils.py
@@ -3,6 +3,7 @@
from rich import print as rprint
+from panther.configs import Config
from panther.exceptions import PantherError
logger = logging.getLogger('panther')
@@ -108,21 +109,20 @@ def print_uvicorn_help_message():
rprint('Run `uvicorn --help` for more help')
-def print_info(config: dict):
+def print_info(config: Config):
from panther.db.connections import redis
-
- mo = config['monitoring']
- lq = config['log_queries']
- bt = config['background_tasks']
- ws = config['has_ws']
+ mo = config.MONITORING
+ lq = config.LOG_QUERIES
+ bt = config.BACKGROUND_TASKS
+ ws = config.HAS_WS
rd = redis.is_connected
- bd = '{0:<39}'.format(str(config['base_dir']))
+ bd = '{0:<39}'.format(str(config.BASE_DIR))
if len(bd) > 39:
bd = f'{bd[:36]}...'
# Monitoring
- if config['monitoring']:
+ if config.MONITORING:
monitor = f'{h} * Run "panther monitor" in another session for Monitoring{h}\n'
else:
monitor = None
@@ -139,7 +139,7 @@ def print_info(config: dict):
# Gunicorn if Websocket
gunicorn_msg = None
- if config['has_ws']:
+ if config.HAS_WS:
try:
import gunicorn
gunicorn_msg = f'{h} * You have WS so make sure to run gunicorn with --preload{h}\n'
diff --git a/panther/configs.py b/panther/configs.py
index 7433a0f..39d4d12 100644
--- a/panther/configs.py
+++ b/panther/configs.py
@@ -1,3 +1,4 @@
+import copy
import typing
from dataclasses import dataclass
from datetime import timedelta
@@ -7,7 +8,6 @@
from pydantic._internal._model_construction import ModelMetaclass
from panther.throttling import Throttling
-from panther.utils import Singleton
class JWTConfig:
@@ -41,80 +41,78 @@ def observe(cls, observer):
@classmethod
def update(cls):
for observer in cls.observers:
- observer._reload_bases(parent=config.query_engine)
+ observer._reload_bases(parent=config.QUERY_ENGINE)
@dataclass
-class Config(Singleton):
- base_dir: Path
- monitoring: bool
- log_queries: bool
- default_cache_exp: timedelta | None
- throttling: Throttling | None
- secret_key: bytes | None
- http_middlewares: list
- ws_middlewares: list
- reversed_http_middlewares: list
- reversed_ws_middlewares: list
- user_model: ModelMetaclass | None
- authentication: ModelMetaclass | None
- ws_authentication: ModelMetaclass | None
- jwt_config: JWTConfig | None
- models: list[dict]
- flat_urls: dict
- urls: dict
- websocket_connections: typing.Callable | None
- background_tasks: bool
- has_ws: bool
- startup: Callable | None
- shutdown: Callable | None
- auto_reformat: bool
- query_engine: typing.Callable | None
- database: typing.Callable | None
+class Config:
+ BASE_DIR: Path
+ MONITORING: bool
+ LOG_QUERIES: bool
+ DEFAULT_CACHE_EXP: timedelta | None
+ THROTTLING: Throttling | None
+ SECRET_KEY: bytes | None
+ HTTP_MIDDLEWARES: list[tuple]
+ WS_MIDDLEWARES: list[tuple]
+ USER_MODEL: ModelMetaclass | None
+ AUTHENTICATION: ModelMetaclass | None
+ WS_AUTHENTICATION: ModelMetaclass | None
+ JWT_CONFIG: JWTConfig | None
+ MODELS: list[dict]
+ FLAT_URLS: dict
+ URLS: dict
+ WEBSOCKET_CONNECTIONS: typing.Callable | None
+ BACKGROUND_TASKS: bool
+ HAS_WS: bool
+ STARTUPS: list[Callable]
+ SHUTDOWNS: list[Callable]
+ TIMEZONE: str
+ AUTO_REFORMAT: bool
+ QUERY_ENGINE: typing.Callable | None
+ DATABASE: typing.Callable | None
def __setattr__(self, key, value):
super().__setattr__(key, value)
- if key == 'query_engine' and value:
+ if key == 'QUERY_ENGINE' and value:
QueryObservable.update()
def __setitem__(self, key, value):
- setattr(self, key, value)
+ setattr(self, key.upper(), value)
def __getitem__(self, item):
- return getattr(self, item)
+ return getattr(self, item.upper())
def refresh(self):
# In some tests we need to `refresh` the `config` values
- for key, value in default_configs.items():
+ for key, value in copy.deepcopy(default_configs).items():
setattr(self, key, value)
default_configs = {
- 'base_dir': Path(),
- 'monitoring': False,
- 'log_queries': False,
- 'default_cache_exp': None,
- 'throttling': None,
- 'secret_key': None,
- 'http_middlewares': [],
- 'ws_middlewares': [],
- 'reversed_http_middlewares': [],
- 'reversed_ws_middlewares': [],
- 'user_model': None,
- 'authentication': None,
- 'ws_authentication': None,
- 'jwt_config': None,
- 'models': [],
- 'flat_urls': {},
- 'urls': {},
- 'websocket_connections': None,
- 'background_tasks': False,
- 'has_ws': False,
- 'startup': None,
- 'shutdown': None,
- 'auto_reformat': False,
- 'query_engine': None,
- 'database': None,
+ 'BASE_DIR': Path(),
+ 'MONITORING': False,
+ 'LOG_QUERIES': False,
+ 'DEFAULT_CACHE_EXP': None,
+ 'THROTTLING': None,
+ 'SECRET_KEY': None,
+ 'HTTP_MIDDLEWARES': [],
+ 'WS_MIDDLEWARES': [],
+ 'USER_MODEL': None,
+ 'AUTHENTICATION': None,
+ 'WS_AUTHENTICATION': None,
+ 'JWT_CONFIG': None,
+ 'MODELS': [],
+ 'FLAT_URLS': {},
+ 'URLS': {},
+ 'WEBSOCKET_CONNECTIONS': None,
+ 'BACKGROUND_TASKS': False,
+ 'HAS_WS': False,
+ 'STARTUPS': [],
+ 'SHUTDOWNS': [],
+ 'TIMEZONE': 'UTC',
+ 'AUTO_REFORMAT': False,
+ 'QUERY_ENGINE': None,
+ 'DATABASE': None,
}
-config = Config(**default_configs)
+config = Config(**copy.deepcopy(default_configs))
diff --git a/panther/db/connections.py b/panther/db/connections.py
index 3fde392..6bb82b9 100644
--- a/panther/db/connections.py
+++ b/panther/db/connections.py
@@ -10,8 +10,8 @@
from panther.utils import Singleton
try:
- from redis import Redis as _Redis
-except ModuleNotFoundError:
+ from redis.asyncio import Redis as _Redis
+except ImportError:
# This '_Redis' is not going to be used,
# If user really wants to use redis,
# we are going to force him to install it in `panther._load_configs.load_redis`
@@ -75,13 +75,13 @@ def session(self):
class PantherDBConnection(BaseDatabaseConnection):
def init(self, path: str | None = None, encryption: bool = False):
- params = {'db_name': path, 'return_dict': True}
+ params = {'db_name': str(path), 'return_dict': True, 'return_cursor': True}
if encryption:
try:
import cryptography
except ImportError as e:
raise import_error(e, package='cryptography')
- params['secret_key'] = config['secret_key']
+ params['secret_key'] = config.SECRET_KEY
self._connection: PantherDB = PantherDB(**params)
@@ -93,7 +93,7 @@ def session(self):
class DatabaseConnection(Singleton):
@property
def session(self):
- return config['database'].session
+ return config.DATABASE.session
class RedisConnection(Singleton, _Redis):
@@ -117,7 +117,12 @@ def __init__(
super().__init__(host=host, port=port, db=db, **kwargs)
self.is_connected = True
- self.ping()
+ self.sync_ping()
+
+ def sync_ping(self):
+ from redis import Redis
+
+ Redis(host=self.host, port=self.port, **self.kwargs).ping()
def create_connection_for_websocket(self) -> _Redis:
if not hasattr(self, 'websocket_connection'):
diff --git a/panther/db/cursor.py b/panther/db/cursor.py
index 923806e..1171518 100644
--- a/panther/db/cursor.py
+++ b/panther/db/cursor.py
@@ -2,7 +2,13 @@
from sys import version_info
-from pymongo.cursor import Cursor as _Cursor
+try:
+ from pymongo.cursor import Cursor as _Cursor
+except ImportError:
+ # This '_Cursor' is not going to be used,
+ # If user really wants to use it,
+ # we are going to force him to install it in `panther.db.connections.MongoDBConnection.init`
+ _Cursor = type('_Cursor', (), {})
if version_info >= (3, 11):
from typing import Self
@@ -20,6 +26,7 @@ def __init__(self, collection, *args, cls=None, **kwargs):
if cls:
self.models[collection.name] = cls
self.cls = cls
+ self.filter = kwargs['filter']
else:
self.cls = self.models[collection.name]
super().__init__(collection, *args, **kwargs)
diff --git a/panther/db/models.py b/panther/db/models.py
index dd235f3..2d6bc3b 100644
--- a/panther/db/models.py
+++ b/panther/db/models.py
@@ -1,41 +1,42 @@
import contextlib
+import os
from datetime import datetime
+from typing import Annotated
-from pydantic import BaseModel as PydanticBaseModel
-from pydantic import Field, field_validator
+from pydantic import Field, WrapValidator, PlainSerializer, BaseModel as PydanticBaseModel
from panther.configs import config
from panther.db.queries import Query
-
+from panther.utils import scrypt, URANDOM_SIZE, timezone_now
with contextlib.suppress(ImportError):
# Only required if user wants to use mongodb
import bson
+def validate_object_id(value, handler):
+ if config.DATABASE.__class__.__name__ == 'MongoDBConnection':
+ if isinstance(value, bson.ObjectId):
+ return value
+ else:
+ try:
+ return bson.ObjectId(value)
+ except bson.objectid.InvalidId as e:
+ msg = 'Invalid ObjectId'
+ raise ValueError(msg) from e
+ return str(value)
+
+
+ID = Annotated[str, WrapValidator(validate_object_id), PlainSerializer(lambda x: str(x), return_type=str)]
+
+
class Model(PydanticBaseModel, Query):
def __init_subclass__(cls, **kwargs):
if cls.__module__ == 'panther.db.models' and cls.__name__ == 'BaseUser':
return
- config['models'].append(cls)
-
- id: str | None = Field(None, validation_alias='_id')
-
- @field_validator('id', mode='before')
- def validate_id(cls, value) -> str:
- if config['database'].__class__.__name__ == 'MongoDBConnection':
- if isinstance(value, str):
- try:
- bson.ObjectId(value)
- except bson.objectid.InvalidId as e:
- msg = 'Invalid ObjectId'
- raise ValueError(msg) from e
-
- elif not isinstance(value, bson.ObjectId):
- msg = 'ObjectId required'
- raise ValueError(msg) from None
+ config.MODELS.append(cls)
- return str(value)
+ id: ID | None = Field(None, validation_alias='_id')
@property
def _id(self):
@@ -44,26 +45,40 @@ def _id(self):
`str` for PantherDB
`ObjectId` for MongoDB
"""
- if config['database'].__class__.__name__ == 'MongoDBConnection':
- return bson.ObjectId(self.id) if self.id else None
return self.id
- def dict(self, *args, **kwargs) -> dict:
- return self.model_dump(*args, **kwargs)
-
class BaseUser(Model):
- first_name: str = Field('', max_length=64)
- last_name: str = Field('', max_length=64)
+ password: str = Field('', max_length=64)
last_login: datetime | None = None
+ date_created: datetime | None = Field(default_factory=timezone_now)
async def update_last_login(self) -> None:
- await self.update(last_login=datetime.now())
+ await self.update(last_login=timezone_now())
async def login(self) -> dict:
"""Return dict of access and refresh token"""
- await self.update_last_login()
- return config.authentication.login(self.id)
+ return config.AUTHENTICATION.login(self.id)
+
+ async def logout(self) -> dict:
+ return await config.AUTHENTICATION.logout(self._auth_token)
+
+ def set_password(self, password: str):
+ """
+ URANDOM_SIZE = 16 char -->
+ salt = 16 bytes
+ salt.hex() = 32 char
+ derived_key = 32 char
+ """
+ salt = os.urandom(URANDOM_SIZE)
+ derived_key = scrypt(password=password, salt=salt, digest=True)
+
+ self.password = f'{salt.hex()}{derived_key}'
+
+ def check_password(self, new_password: str) -> bool:
+ size = URANDOM_SIZE * 2
+ salt = self.password[:size]
+ stored_hash = self.password[size:]
+ derived_key = scrypt(password=new_password, salt=bytes.fromhex(salt), digest=True)
- def logout(self) -> dict:
- return config.authentication.logout(self._auth_token)
+ return derived_key == stored_hash
diff --git a/panther/db/queries/base_queries.py b/panther/db/queries/base_queries.py
index 51c8c0f..4bdb44a 100644
--- a/panther/db/queries/base_queries.py
+++ b/panther/db/queries/base_queries.py
@@ -2,6 +2,7 @@
from abc import abstractmethod
from functools import reduce
from sys import version_info
+from typing import Iterator
from pydantic_core._pydantic_core import ValidationError
@@ -10,7 +11,7 @@
from panther.exceptions import DatabaseError
if version_info >= (3, 11):
- from typing import Self, Iterator
+ from typing import Self
else:
from typing import TypeVar
diff --git a/panther/db/queries/mongodb_queries.py b/panther/db/queries/mongodb_queries.py
index a83cd9f..5e2f10e 100644
--- a/panther/db/queries/mongodb_queries.py
+++ b/panther/db/queries/mongodb_queries.py
@@ -3,13 +3,19 @@
from sys import version_info
from typing import Iterable, Sequence
-from bson.codec_options import CodecOptions
-
from panther.db.connections import db
from panther.db.cursor import Cursor
from panther.db.queries.base_queries import BaseQuery
from panther.db.utils import prepare_id_for_query
+try:
+ from bson.codec_options import CodecOptions
+except ImportError:
+ # This 'CodecOptions' is not going to be used,
+ # If user really wants to use it,
+ # we are going to force him to install it in `panther.db.connections.MongoDBConnection.init`
+ CodecOptions = type('CodecOptions', (), {})
+
if version_info >= (3, 11):
from typing import Self
else:
@@ -23,24 +29,21 @@ class BaseMongoDBQuery(BaseQuery):
def _merge(cls, *args, is_mongo: bool = True) -> dict:
return super()._merge(*args, is_mongo=is_mongo)
- @classmethod
- def collection(cls):
- return db.session.get_collection(
- name=cls.__name__,
- codec_options=CodecOptions(document_class=dict)
- # codec_options=CodecOptions(document_class=cls) TODO: https://jira.mongodb.org/browse/PYTHON-4192
- )
+ # TODO: https://jira.mongodb.org/browse/PYTHON-4192
+ # @classmethod
+ # def collection(cls):
+ # return db.session.get_collection(name=cls.__name__, codec_options=CodecOptions(document_class=cls))
# # # # # Find # # # # #
@classmethod
async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None:
- if document := await cls.collection().find_one(cls._merge(_filter, kwargs)):
+ if document := await db.session[cls.__name__].find_one(cls._merge(_filter, kwargs)):
return cls._create_model_instance(document=document)
return None
@classmethod
async def find(cls, _filter: dict | None = None, /, **kwargs) -> Cursor:
- return Cursor(cls=cls, collection=cls.collection().delegate, filter=cls._merge(_filter, kwargs))
+ return Cursor(cls=cls, collection=db.session[cls.__name__].delegate, filter=cls._merge(_filter, kwargs))
@classmethod
async def first(cls, _filter: dict | None = None, /, **kwargs) -> Self | None:
@@ -58,12 +61,12 @@ async def last(cls, _filter: dict | None = None, /, **kwargs) -> Self | None:
@classmethod
async def aggregate(cls, pipeline: Sequence[dict]) -> Iterable[dict]:
- return await cls.collection().aggregate(pipeline)
+ return await db.session[cls.__name__].aggregate(pipeline)
# # # # # Count # # # # #
@classmethod
async def count(cls, _filter: dict | None = None, /, **kwargs) -> int:
- return await cls.collection().count_documents(cls._merge(_filter, kwargs))
+ return await db.session[cls.__name__].count_documents(cls._merge(_filter, kwargs))
# # # # # Insert # # # # #
@classmethod
@@ -71,7 +74,7 @@ async def insert_one(cls, _document: dict | None = None, /, **kwargs) -> Self:
document = cls._merge(_document, kwargs)
cls._validate_data(data=document)
- await cls.collection().insert_one(document)
+ await db.session[cls.__name__].insert_one(document)
return cls._create_model_instance(document=document)
@classmethod
@@ -80,21 +83,21 @@ async def insert_many(cls, documents: Iterable[dict]) -> list[Self]:
prepare_id_for_query(document, is_mongo=True)
cls._validate_data(data=document)
- await cls.collection().insert_many(documents)
+ await db.session[cls.__name__].insert_many(documents)
return [cls._create_model_instance(document=document) for document in documents]
# # # # # Delete # # # # #
async def delete(self) -> None:
- await self.collection().delete_one({'_id': self._id})
+ await db.session[self.__class__.__name__].delete_one({'_id': self._id})
@classmethod
async def delete_one(cls, _filter: dict | None = None, /, **kwargs) -> bool:
- result = await cls.collection().delete_one(cls._merge(_filter, kwargs))
+ result = await db.session[cls.__name__].delete_one(cls._merge(_filter, kwargs))
return bool(result.deleted_count)
@classmethod
async def delete_many(cls, _filter: dict | None = None, /, **kwargs) -> int:
- result = await cls.collection().delete_many(cls._merge(_filter, kwargs))
+ result = await db.session[cls.__name__].delete_many(cls._merge(_filter, kwargs))
return result.deleted_count
# # # # # Update # # # # #
@@ -106,14 +109,14 @@ async def update(self, _update: dict | None = None, /, **kwargs) -> None:
for field, value in document.items():
setattr(self, field, value)
update_fields = {'$set': document}
- await self.collection().update_one({'_id': self._id}, update_fields)
+ await db.session[self.__class__.__name__].update_one({'_id': self._id}, update_fields)
@classmethod
async def update_one(cls, _filter: dict, _update: dict | None = None, /, **kwargs) -> bool:
prepare_id_for_query(_filter, is_mongo=True)
update_fields = {'$set': cls._merge(_update, kwargs)}
- result = await cls.collection().update_one(_filter, update_fields)
+ result = await db.session[cls.__name__].update_one(_filter, update_fields)
return bool(result.matched_count)
@classmethod
@@ -121,5 +124,5 @@ async def update_many(cls, _filter: dict, _update: dict | None = None, /, **kwar
prepare_id_for_query(_filter, is_mongo=True)
update_fields = {'$set': cls._merge(_update, kwargs)}
- result = await cls.collection().update_many(_filter, update_fields)
+ result = await db.session[cls.__name__].update_many(_filter, update_fields)
return result.modified_count
diff --git a/panther/db/queries/pantherdb_queries.py b/panther/db/queries/pantherdb_queries.py
index 17f0af5..caa2eae 100644
--- a/panther/db/queries/pantherdb_queries.py
+++ b/panther/db/queries/pantherdb_queries.py
@@ -1,6 +1,8 @@
from sys import version_info
from typing import Iterable
+from pantherdb import Cursor
+
from panther.db.connections import db
from panther.db.queries.base_queries import BaseQuery
from panther.db.utils import prepare_id_for_query
@@ -27,9 +29,11 @@ async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None
return None
@classmethod
- async def find(cls, _filter: dict | None = None, /, **kwargs) -> list[Self]:
- documents = db.session.collection(cls.__name__).find(**cls._merge(_filter, kwargs))
- return [cls._create_model_instance(document=document) for document in documents]
+ async def find(cls, _filter: dict | None = None, /, **kwargs) -> Cursor:
+ cursor = db.session.collection(cls.__name__).find(**cls._merge(_filter, kwargs))
+ cursor.response_type = cls._create_model_instance
+ cursor.cls = cls
+ return cursor
@classmethod
async def first(cls, _filter: dict | None = None, /, **kwargs) -> Self | None:
@@ -65,12 +69,12 @@ async def insert_one(cls, _document: dict | None = None, /, **kwargs) -> Self:
@classmethod
async def insert_many(cls, documents: Iterable[dict]) -> list[Self]:
result = []
- for _document in documents:
- prepare_id_for_query(_document, is_mongo=False)
- cls._validate_data(data=_document)
- document = db.session.collection(cls.__name__).insert_one(**_document)
- result.append(document)
-
+ for document in documents:
+ prepare_id_for_query(document, is_mongo=False)
+ cls._validate_data(data=document)
+ inserted_document = db.session.collection(cls.__name__).insert_one(**document)
+ document['_id'] = inserted_document['_id']
+ result.append(cls._create_model_instance(document=document))
return result
# # # # # Delete # # # # #
diff --git a/panther/db/queries/queries.py b/panther/db/queries/queries.py
index 5009990..e06eee0 100644
--- a/panther/db/queries/queries.py
+++ b/panther/db/queries/queries.py
@@ -1,6 +1,8 @@
import sys
from typing import Sequence, Iterable
+from pantherdb import Cursor as PantherDBCursor
+
from panther.configs import QueryObservable
from panther.db.cursor import Cursor
from panther.db.queries.base_queries import BaseQuery
@@ -57,7 +59,7 @@ async def find_one(cls, _filter: dict | None = None, /, **kwargs) -> Self | None
@classmethod
@check_connection
@log_query
- async def find(cls, _filter: dict | None = None, /, **kwargs) -> list[Self] | Cursor:
+ async def find(cls, _filter: dict | None = None, /, **kwargs) -> PantherDBCursor | Cursor:
"""
Get documents from the database.
@@ -193,7 +195,7 @@ async def insert_many(cls, documents: Iterable[dict]) -> list[Self]:
>>> ]
>>> await User.insert_many(users)
"""
- return super().insert_many(documents)
+ return await super().insert_many(documents)
# # # # # Delete # # # # #
@check_connection
@@ -324,7 +326,7 @@ async def all(cls) -> list[Self] | Cursor:
return await cls.find()
@classmethod
- async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> tuple[bool, Self]:
+ async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> tuple[Self, bool]:
"""
Get a single document from the database.
or
@@ -341,8 +343,8 @@ async def find_one_or_insert(cls, _filter: dict | None = None, /, **kwargs) -> t
>>> await User.find_one_or_insert({'age': 18}, name='Ali')
"""
if obj := await cls.find_one(_filter, **kwargs):
- return False, obj
- return True, await cls.insert_one(_filter, **kwargs)
+ return obj, False
+ return await cls.insert_one(_filter, **kwargs), True
@classmethod
async def find_one_or_raise(cls, _filter: dict | None = None, /, **kwargs) -> Self:
diff --git a/panther/db/utils.py b/panther/db/utils.py
index 8154020..9fe2fdc 100644
--- a/panther/db/utils.py
+++ b/panther/db/utils.py
@@ -14,7 +14,7 @@
def log_query(func):
async def log(*args, **kwargs):
- if config['log_queries'] is False:
+ if config.LOG_QUERIES is False:
return await func(*args, **kwargs)
start = perf_counter()
response = await func(*args, **kwargs)
@@ -28,7 +28,7 @@ async def log(*args, **kwargs):
def check_connection(func):
async def wrapper(*args, **kwargs):
- if config['query_engine'] is None:
+ if config.QUERY_ENGINE is None:
msg = "You don't have active database connection, Check your middlewares"
raise NotImplementedError(msg)
return await func(*args, **kwargs)
diff --git a/panther/events.py b/panther/events.py
new file mode 100644
index 0000000..7fe84a4
--- /dev/null
+++ b/panther/events.py
@@ -0,0 +1,44 @@
+import asyncio
+
+from panther._utils import is_function_async
+from panther.configs import config
+
+
+class Event:
+ @staticmethod
+ def startup(func):
+ config.STARTUPS.append(func)
+
+ def wrapper():
+ return func()
+ return wrapper
+
+ @staticmethod
+ def shutdown(func):
+ config.SHUTDOWNS.append(func)
+
+ def wrapper():
+ return func()
+ return wrapper
+
+ @staticmethod
+ async def run_startups():
+ for func in config.STARTUPS:
+ if is_function_async(func):
+ await func()
+ else:
+ func()
+
+ @staticmethod
+ def run_shutdowns():
+ for func in config.SHUTDOWNS:
+ if is_function_async(func):
+ try:
+ asyncio.run(func())
+ except ModuleNotFoundError:
+ # Error: import of asyncio halted; None in sys.modules
+ # And as I figured it out, it only happens when we are running with
+ # gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker)
+ pass
+ else:
+ func()
diff --git a/panther/generics.py b/panther/generics.py
new file mode 100644
index 0000000..a79384a
--- /dev/null
+++ b/panther/generics.py
@@ -0,0 +1,163 @@
+import contextlib
+import logging
+
+from pantherdb import Cursor as PantherDBCursor
+
+from panther import status
+from panther.app import GenericAPI
+from panther.configs import config
+from panther.db import Model
+from panther.db.cursor import Cursor
+from panther.exceptions import APIError
+from panther.pagination import Pagination
+from panther.request import Request
+from panther.response import Response
+from panther.serializer import ModelSerializer
+
+with contextlib.suppress(ImportError):
+ # Only required if user wants to use mongodb
+ import bson
+
+logger = logging.getLogger('panther')
+
+
+class ObjectRequired:
+ def _check_object(self, instance):
+ if issubclass(type(instance), Model) is False:
+ logger.critical(f'`{self.__class__.__name__}.object()` should return instance of a Model --> `find_one()`')
+ raise APIError
+
+ async def object(self, request: Request, **kwargs) -> Model:
+ """
+ Used in `RetrieveAPI`, `UpdateAPI`, `DeleteAPI`
+ """
+ logger.error(f'`object()` method is not implemented in {self.__class__} .')
+ raise APIError(status_code=status.HTTP_501_NOT_IMPLEMENTED)
+
+
+class ObjectsRequired:
+ def _check_objects(self, cursor):
+ if isinstance(cursor, (Cursor, PantherDBCursor)) is False:
+ logger.critical(f'`{self.__class__.__name__}.objects()` should return a Cursor --> `find()`')
+ raise APIError
+
+ async def objects(self, request: Request, **kwargs) -> Cursor | PantherDBCursor:
+ """
+ Used in `ListAPI`
+ Should return `.find()`
+ """
+ logger.error(f'`objects()` method is not implemented in {self.__class__} .')
+ raise APIError(status_code=status.HTTP_501_NOT_IMPLEMENTED)
+
+
+class RetrieveAPI(GenericAPI, ObjectRequired):
+ async def get(self, request: Request, **kwargs):
+ instance = await self.object(request=request, **kwargs)
+ self._check_object(instance)
+
+ return Response(data=instance, status_code=status.HTTP_200_OK)
+
+
+class ListAPI(GenericAPI, ObjectsRequired):
+ sort_fields: list[str]
+ search_fields: list[str]
+ filter_fields: list[str]
+ pagination: type[Pagination]
+
+ async def get(self, request: Request, **kwargs):
+ cursor = await self.objects(request=request, **kwargs)
+ self._check_objects(cursor)
+
+ query = {}
+ query |= self.process_filters(query_params=request.query_params, cursor=cursor)
+ query |= self.process_search(query_params=request.query_params)
+
+ if query:
+ cursor = await cursor.cls.find(cursor.filter | query)
+
+ if sort := self.process_sort(query_params=request.query_params):
+ cursor = cursor.sort(sort)
+
+ if pagination := self.process_pagination(query_params=request.query_params, cursor=cursor):
+ cursor = await pagination.paginate()
+
+ return Response(data=cursor, status_code=status.HTTP_200_OK)
+
+ def process_filters(self, query_params: dict, cursor: Cursor | PantherDBCursor) -> dict:
+ _filter = {}
+ if hasattr(self, 'filter_fields'):
+ for field in self.filter_fields:
+ if field in query_params:
+ if config.DATABASE.__class__.__name__ == 'MongoDBConnection':
+ with contextlib.suppress(Exception):
+ if cursor.cls.model_fields[field].metadata[0].func.__name__ == 'validate_object_id':
+ _filter[field] = bson.ObjectId(query_params[field])
+ continue
+ _filter[field] = query_params[field]
+ return _filter
+
+ def process_search(self, query_params: dict) -> dict:
+ if hasattr(self, 'search_fields') and 'search' in query_params:
+ value = query_params['search']
+ if config.DATABASE.__class__.__name__ == 'MongoDBConnection':
+ if search := [{field: {'$regex': value}} for field in self.search_fields]:
+ return {'$or': search}
+ else:
+ logger.warning(f'`?search={value} does not work well while using `PantherDB` as Database')
+ return {field: value for field in self.search_fields}
+ return {}
+
+ def process_sort(self, query_params: dict) -> list:
+ if hasattr(self, 'sort_fields') and 'sort' in query_params:
+ return [
+ (field, -1 if param[0] == '-' else 1)
+ for field in self.sort_fields for param in query_params['sort'].split(',')
+ if field == param.removeprefix('-')
+ ]
+
+ def process_pagination(self, query_params: dict, cursor: Cursor | PantherDBCursor) -> Pagination | None:
+ if hasattr(self, 'pagination'):
+ return self.pagination(query_params=query_params, cursor=cursor)
+
+
+class CreateAPI(GenericAPI):
+ input_model: type[ModelSerializer]
+
+ async def post(self, request: Request, **kwargs):
+ instance = await request.validated_data.create(
+ validated_data=request.validated_data.model_dump()
+ )
+ return Response(data=instance, status_code=status.HTTP_201_CREATED)
+
+
+class UpdateAPI(GenericAPI, ObjectRequired):
+ input_model: type[ModelSerializer]
+
+ async def put(self, request: Request, **kwargs):
+ instance = await self.object(request=request, **kwargs)
+ self._check_object(instance)
+
+ await request.validated_data.update(
+ instance=instance,
+ validated_data=request.validated_data.model_dump()
+ )
+ return Response(data=instance, status_code=status.HTTP_200_OK)
+
+ async def patch(self, request: Request, **kwargs):
+ instance = await self.object(request=request, **kwargs)
+ self._check_object(instance)
+
+ await request.validated_data.partial_update(
+ instance=instance,
+ validated_data=request.validated_data.model_dump(exclude_none=True)
+ )
+ return Response(data=instance, status_code=status.HTTP_200_OK)
+
+
+class DeleteAPI(GenericAPI, ObjectRequired):
+ async def delete(self, request: Request, **kwargs):
+ instance = await self.object(request=request, **kwargs)
+ self._check_object(instance)
+
+ await instance.delete()
+ return Response(status_code=status.HTTP_204_NO_CONTENT)
diff --git a/panther/logging.py b/panther/logging.py
index 2f37930..346dfbb 100644
--- a/panther/logging.py
+++ b/panther/logging.py
@@ -2,7 +2,7 @@
from pathlib import Path
from panther.configs import config
-LOGS_DIR = config['base_dir'] / 'logs'
+LOGS_DIR = config.BASE_DIR / 'logs'
class FileHandler(logging.FileHandler):
@@ -63,6 +63,11 @@ def __init__(self, filename, mode='a', encoding=None, delay=False, errors=None):
'query': {
'handlers': ['default', 'query_file'],
'level': 'DEBUG',
- }
+ },
+ 'uvicorn.error': {
+ 'handlers': ['default'],
+ 'level': 'WARNING',
+ 'propagate': False,
+ },
}
}
\ No newline at end of file
diff --git a/panther/main.py b/panther/main.py
index 13ab18f..a4a6cd9 100644
--- a/panther/main.py
+++ b/panther/main.py
@@ -7,14 +7,15 @@
from logging.config import dictConfig
from pathlib import Path
+import orjson as json
+
import panther.logging
from panther import status
from panther._load_configs import *
-from panther._utils import clean_traceback_message, http_response, is_function_async, reformat_code, \
- check_class_type_endpoint, check_function_type_endpoint
-from panther.base_websocket import WebsocketListener
+from panther._utils import clean_traceback_message, reformat_code, check_class_type_endpoint, check_function_type_endpoint
from panther.cli.utils import print_info
from panther.configs import config
+from panther.events import Event
from panther.exceptions import APIError, PantherError
from panther.monitoring import Monitoring
from panther.request import Request
@@ -26,35 +27,20 @@
class Panther:
- def __init__(
- self,
- name: str,
- configs=None,
- urls: dict | None = None,
- startup: Callable = None,
- shutdown: Callable = None
- ):
+ def __init__(self, name: str, configs: str | None = None, urls: dict | None = None):
self._configs_module_name = configs
self._urls = urls
- self._startup = startup
- self._shutdown = shutdown
- config['base_dir'] = Path(name).resolve().parent
+ config.BASE_DIR = Path(name).resolve().parent
try:
self.load_configs()
- if config['auto_reformat']:
- reformat_code(base_dir=config['base_dir'])
+ if config.AUTO_REFORMAT:
+ reformat_code(base_dir=config.BASE_DIR)
except Exception as e: # noqa: BLE001
- if isinstance(e, PantherError):
- logger.error(e.args[0])
- else:
- logger.error(clean_traceback_message(e))
+ logger.error(e.args[0] if isinstance(e, PantherError) else clean_traceback_message(e))
sys.exit()
- # Monitoring
- self.monitoring = Monitoring(is_active=config['monitoring'])
-
# Print Info
print_info(config)
@@ -65,6 +51,7 @@ def load_configs(self) -> None:
load_redis(self._configs_module)
load_startup(self._configs_module)
load_shutdown(self._configs_module)
+ load_timezone(self._configs_module)
load_database(self._configs_module)
load_secret_key(self._configs_module)
load_monitoring(self._configs_module)
@@ -80,25 +67,14 @@ def load_configs(self) -> None:
load_websocket_connections()
async def __call__(self, scope: dict, receive: Callable, send: Callable) -> None:
- """
- 1.
- await func(scope, receive, send)
- 2.
- async with asyncio.TaskGroup() as tg:
- tg.create_task(func(scope, receive, send))
- 3.
- async with anyio.create_task_group() as task_group:
- task_group.start_soon(func, scope, receive, send)
- await anyio.to_thread.run_sync(func, scope, receive, send)
- 4.
- with ProcessPoolExecutor() as e:
- e.submit(func, scope, receive, send)
- """
if scope['type'] == 'lifespan':
message = await receive()
if message["type"] == 'lifespan.startup':
- await self.handle_ws_listener()
- await self.handle_startup()
+ await config.WEBSOCKET_CONNECTIONS.start()
+ await Event.run_startups()
+ elif message["type"] == 'lifespan.shutdown':
+ # It's not happening :\, so handle the shutdowns in __del__ ...
+ pass
return
func = self.handle_http if scope['type'] == 'http' else self.handle_ws
@@ -108,66 +84,75 @@ async def handle_ws(self, scope: dict, receive: Callable, send: Callable) -> Non
from panther.websocket import GenericWebsocket, Websocket
# Monitoring
- monitoring = Monitoring(is_active=config['monitoring'], is_ws=True)
+ monitoring = Monitoring(is_ws=True)
# Create Temp Connection
temp_connection = Websocket(scope=scope, receive=receive, send=send)
await monitoring.before(request=temp_connection)
+ temp_connection._monitoring = monitoring
# Find Endpoint
endpoint, found_path = find_endpoint(path=temp_connection.path)
if endpoint is None:
- await monitoring.after('Rejected')
- return await temp_connection.close(status.WS_1000_NORMAL_CLOSURE)
+ logger.debug(f'Path `{temp_connection.path}` not found')
+ return await temp_connection.close()
# Check Endpoint Type
if not issubclass(endpoint, GenericWebsocket):
- logger.critical(f'You may have forgotten to inherit from GenericWebsocket on the {endpoint.__name__}()')
- await monitoring.after('Rejected')
- return await temp_connection.close(status.WS_1014_BAD_GATEWAY)
+ logger.critical(f'You may have forgotten to inherit from `GenericWebsocket` on the `{endpoint.__name__}()`')
+ return await temp_connection.close()
# Create The Connection
del temp_connection
connection = endpoint(scope=scope, receive=receive, send=send)
+ connection._monitoring = monitoring
# Collect Path Variables
connection.collect_path_variables(found_path=found_path)
- # Call 'Before' Middlewares
- if await self._run_ws_middlewares_before_listen(connection=connection):
- # Only Listen() If Middlewares Didn't Raise Anything
- await config['websocket_connections'].new_connection(connection=connection)
- await monitoring.after('Accepted')
- await connection.listen()
+ middlewares = [middleware(**data) for middleware, data in config.WS_MIDDLEWARES]
+
+ # Call Middlewares .before()
+ await self._run_ws_middlewares_before_listen(connection=connection, middlewares=middlewares)
- # Call 'After' Middlewares
- await self._run_ws_middlewares_after_listen(connection=connection)
+ # Listen The Connection
+ await config.WEBSOCKET_CONNECTIONS.listen(connection=connection)
- # Done
- await monitoring.after('Closed')
- return None
+ # Call Middlewares .after()
+ middlewares.reverse()
+ await self._run_ws_middlewares_after_listen(connection=connection, middlewares=middlewares)
@classmethod
- async def _run_ws_middlewares_before_listen(cls, *, connection) -> bool:
- for middleware in config['ws_middlewares']:
- try:
- connection = await middleware.before(request=connection)
- except APIError:
- await connection.close()
- return False
- return True
+ async def _run_ws_middlewares_before_listen(cls, *, connection, middlewares):
+ try:
+ for middleware in middlewares:
+ new_connection = await middleware.before(request=connection)
+ if new_connection is None:
+ logger.critical(
+ f'Make sure to return the `request` at the end of `{middleware.__class__.__name__}.before()`')
+ await connection.close()
+ connection = new_connection
+ except APIError as e:
+ connection.log(e.detail)
+ await connection.close()
@classmethod
- async def _run_ws_middlewares_after_listen(cls, *, connection):
- for middleware in config['reversed_ws_middlewares']:
+ async def _run_ws_middlewares_after_listen(cls, *, connection, middlewares):
+ for middleware in middlewares:
with contextlib.suppress(APIError):
- await middleware.after(response=connection)
+ connection = await middleware.after(response=connection)
+ if connection is None:
+ logger.critical(
+ f'Make sure to return the `response` at the end of `{middleware.__class__.__name__}.after()`')
+ break
async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> None:
+ # Monitoring
+ monitoring = Monitoring()
+
request = Request(scope=scope, receive=receive, send=send)
- # Monitoring
- await self.monitoring.before(request=request)
+ await monitoring.before(request=request)
# Read Request Payload
await request.read_body()
@@ -175,7 +160,7 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N
# Find Endpoint
_endpoint, found_path = find_endpoint(path=request.path)
if _endpoint is None:
- return await self._raise(send, status_code=status.HTTP_404_NOT_FOUND)
+ return await self._raise(send, monitoring=monitoring, status_code=status.HTTP_404_NOT_FOUND)
# Check Endpoint Type
try:
@@ -184,15 +169,20 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N
else:
endpoint = check_class_type_endpoint(endpoint=_endpoint)
except TypeError:
- return await self._raise(send, status_code=status.HTTP_501_NOT_IMPLEMENTED)
+ return await self._raise(send, monitoring=monitoring, status_code=status.HTTP_501_NOT_IMPLEMENTED)
# Collect Path Variables
request.collect_path_variables(found_path=found_path)
- try: # They Both(middleware.before() & _endpoint()) Have The Same Exception (APIException)
- # Call 'Before' Middlewares
- for middleware in config['http_middlewares']:
+ middlewares = [middleware(**data) for middleware, data in config.HTTP_MIDDLEWARES]
+ try: # They Both(middleware.before() & _endpoint()) Have The Same Exception (APIError)
+ # Call Middlewares .before()
+ for middleware in middlewares:
request = await middleware.before(request=request)
+ if request is None:
+ logger.critical(
+ f'Make sure to return the `request` at the end of `{middleware.__class__.__name__}.before()`')
+ return await self._raise(send, monitoring=monitoring)
# Call Endpoint
response = await endpoint(request=request)
@@ -201,60 +191,27 @@ async def handle_http(self, scope: dict, receive: Callable, send: Callable) -> N
response = self._handle_exceptions(e)
except Exception as e: # noqa: BLE001
- # Every unhandled exception in Panther or code will catch here
+ # All unhandled exceptions are caught here
exception = clean_traceback_message(exception=e)
logger.critical(exception)
- return await self._raise(send, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
+ return await self._raise(send, monitoring=monitoring)
- # Call 'After' Middleware
- for middleware in config['reversed_http_middlewares']:
+ # Call Middlewares .after()
+ middlewares.reverse()
+ for middleware in middlewares:
try:
response = await middleware.after(response=response)
+ if response is None:
+ logger.critical(
+ f'Make sure to return the `response` at the end of `{middleware.__class__.__name__}.after()`')
+ return await self._raise(send, monitoring=monitoring)
except APIError as e: # noqa: PERF203
response = self._handle_exceptions(e)
- await http_response(
- send,
- status_code=response.status_code,
- monitoring=self.monitoring,
- headers=response.headers,
- body=response.body,
- )
-
- async def handle_ws_listener(self):
- """
- Cause of --preload in gunicorn we have to keep this function here,
- and we can't move it to __init__ of Panther
-
- * Each process should start this listener for itself,
- but they have same Manager()
- """
- # Start Websocket Listener (Redis/ Queue)
- if config['has_ws']:
- WebsocketListener().start()
-
- async def handle_startup(self):
- if startup := config['startup'] or self._startup:
- if is_function_async(startup):
- await startup()
- else:
- startup()
-
- def handle_shutdown(self):
- if shutdown := config['shutdown'] or self._shutdown:
- if is_function_async(shutdown):
- try:
- asyncio.run(shutdown())
- except ModuleNotFoundError:
- # Error: import of asyncio halted; None in sys.modules
- # And as I figured it out, it only happens when we are running with
- # gunicorn and Uvicorn workers (-k uvicorn.workers.UvicornWorker)
- pass
- else:
- shutdown()
+ await response.send(send, receive, monitoring=monitoring)
def __del__(self):
- self.handle_shutdown()
+ Event.run_shutdowns()
@classmethod
def _handle_exceptions(cls, e: APIError, /) -> Response:
@@ -263,11 +220,11 @@ def _handle_exceptions(cls, e: APIError, /) -> Response:
status_code=e.status_code,
)
- async def _raise(self, send, *, status_code: int):
- await http_response(
- send,
- headers={'content-type': 'application/json'},
- status_code=status_code,
- monitoring=self.monitoring,
- exception=True,
- )
+ @classmethod
+ async def _raise(cls, send, *, monitoring, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR):
+ headers = [[b'Content-Type', b'application/json']]
+ body = json.dumps({'detail': status.status_text[status_code]})
+ await monitoring.after(status_code)
+ await send({'type': 'http.response.start', 'status': status_code, 'headers': headers})
+ await send({'type': 'http.response.body', 'body': body, 'more_body': False})
+
diff --git a/panther/monitoring.py b/panther/monitoring.py
index 730a0a8..3af1089 100644
--- a/panther/monitoring.py
+++ b/panther/monitoring.py
@@ -3,7 +3,7 @@
from typing import Literal
from panther.base_request import BaseRequest
-
+from panther.configs import config
logger = logging.getLogger('monitoring')
@@ -13,12 +13,11 @@ class Monitoring:
Create Log Message Like Below:
date time | method | path | ip:port | response_time [ms, s] | status
"""
- def __init__(self, is_active: bool, is_ws: bool = False):
- self.is_active = is_active
+ def __init__(self, is_ws: bool = False):
self.is_ws = is_ws
async def before(self, request: BaseRequest):
- if self.is_active:
+ if config.MONITORING:
ip, port = request.client
if self.is_ws:
@@ -30,7 +29,7 @@ async def before(self, request: BaseRequest):
self.start_time = perf_counter()
async def after(self, status: int | Literal['Accepted', 'Rejected', 'Closed'], /):
- if self.is_active:
+ if config.MONITORING:
response_time = perf_counter() - self.start_time
time_unit = ' s'
@@ -38,4 +37,8 @@ async def after(self, status: int | Literal['Accepted', 'Rejected', 'Closed'], /
response_time = response_time * 1_000
time_unit = 'ms'
+ elif response_time >= 10:
+ response_time = response_time / 60
+ time_unit = ' m'
+
logger.info(f'{self.log} | {round(response_time, 4)} {time_unit} | {status}')
diff --git a/panther/pagination.py b/panther/pagination.py
new file mode 100644
index 0000000..6cb70fd
--- /dev/null
+++ b/panther/pagination.py
@@ -0,0 +1,48 @@
+from panther.db.cursor import Cursor
+from pantherdb import Cursor as PantherDBCursor
+
+
+class Pagination:
+ """
+ Request URL:
+ example.com/users?limit=10&skip=0
+ Response Data:
+ {
+ 'count': 10,
+ 'next': '?limit=10&skip=10',
+ 'previous': None,
+ results: [...]
+ }
+ """
+ DEFAULT_LIMIT = 20
+ DEFAULT_SKIP = 0
+
+ def __init__(self, query_params: dict, cursor: Cursor | PantherDBCursor):
+ self.limit = self.get_limit(query_params=query_params)
+ self.skip = self.get_skip(query_params=query_params)
+ self.cursor = cursor
+
+ def get_limit(self, query_params: dict) -> int:
+ return int(query_params.get('limit', self.DEFAULT_LIMIT))
+
+ def get_skip(self, query_params: dict) -> int:
+ return int(query_params.get('skip', self.DEFAULT_SKIP))
+
+ def build_next_params(self):
+ next_skip = self.skip + self.limit
+ return f'?limit={self.limit}&skip={next_skip}'
+
+ def build_previous_params(self):
+ previous_skip = max(self.skip - self.limit, 0)
+ return f'?limit={self.limit}&skip={previous_skip}'
+
+ async def paginate(self):
+ count = await self.cursor.cls.count(self.cursor.filter)
+ has_next = not bool(self.limit + self.skip >= count)
+
+ return {
+ 'count': count,
+ 'next': self.build_next_params() if has_next else None,
+ 'previous': self.build_previous_params() if self.skip else None,
+ 'results': self.cursor.skip(skip=self.skip).limit(limit=self.limit)
+ }
diff --git a/panther/panel/apis.py b/panther/panel/apis.py
index c30d3c4..f3f067a 100644
--- a/panther/panel/apis.py
+++ b/panther/panel/apis.py
@@ -20,12 +20,12 @@ async def models_api():
'name': model.__name__,
'module': model.__module__,
'index': i
- } for i, model in enumerate(config['models'])]
+ } for i, model in enumerate(config.MODELS)]
@API(methods=['GET', 'POST'])
async def documents_api(request: Request, index: int):
- model = config['models'][index]
+ model = config.MODELS[index]
if request.method == 'POST':
validated_data = API.validate_input(model=model, request=request)
@@ -45,7 +45,7 @@ async def documents_api(request: Request, index: int):
@API(methods=['PUT', 'DELETE', 'GET'])
async def single_document_api(request: Request, index: int, document_id: int | str):
- model = config['models'][index]
+ model = config.MODELS[index]
if document := model.find_one(id=document_id):
if request.method == 'PUT':
@@ -69,7 +69,7 @@ async def healthcheck_api():
checks = []
# Database
- if config['query_engine'].__name__ == 'BaseMongoDBQuery':
+ if config.QUERY_ENGINE.__name__ == 'BaseMongoDBQuery':
with pymongo.timeout(3):
try:
ping = db.session.command('ping').get('ok') == 1.0
@@ -78,6 +78,6 @@ async def healthcheck_api():
checks.append(False)
# Redis
if redis.is_connected:
- checks.append(redis.ping())
+ checks.append(await redis.ping())
return Response(all(checks))
diff --git a/panther/permissions.py b/panther/permissions.py
index decfa0f..26f9cbd 100644
--- a/panther/permissions.py
+++ b/panther/permissions.py
@@ -10,4 +10,4 @@ async def authorization(cls, request: Request) -> bool:
class AdminPermission(BasePermission):
@classmethod
async def authorization(cls, request: Request) -> bool:
- return request.user and hasattr(request.user, 'is_admin') and request.user.is_admin
+ return request.user and getattr(request.user, 'is_admin', False)
diff --git a/panther/request.py b/panther/request.py
index e980e42..921cabe 100644
--- a/panther/request.py
+++ b/panther/request.py
@@ -1,5 +1,5 @@
import logging
-from typing import Literal
+from typing import Literal, Callable
import orjson as json
@@ -10,6 +10,11 @@
class Request(BaseRequest):
+ def __init__(self, scope: dict, receive: Callable, send: Callable):
+ self._data = ...
+ self.validated_data = None # It's been set in API.validate_input()
+ super().__init__(scope=scope, receive=receive, send=send)
+
@property
def method(self) -> Literal['GET', 'POST', 'PUT', 'PATCH', 'DELETE']:
return self.scope['method']
@@ -29,17 +34,6 @@ def data(self) -> dict | bytes:
self._data = self.__body
return self._data
- def set_validated_data(self, validated_data) -> None:
- self._validated_data = validated_data
-
- @property
- def validated_data(self):
- """
- Return The Validated Data
- It has been set on API.validate_input() while request is happening.
- """
- return getattr(self, '_validated_data', None)
-
async def read_body(self) -> None:
"""Read the entire body from an incoming ASGI message."""
self.__body = b''
diff --git a/panther/response.py b/panther/response.py
index 69b7bbf..9943223 100644
--- a/panther/response.py
+++ b/panther/response.py
@@ -1,13 +1,20 @@
+import asyncio
from types import NoneType
+from typing import Generator, AsyncGenerator
import orjson as json
from pydantic import BaseModel as PydanticBaseModel
from pydantic._internal._model_construction import ModelMetaclass
+from panther import status
+from panther._utils import to_async_generator
from panther.db.cursor import Cursor
+from pantherdb import Cursor as PantherDBCursor
+from panther.monitoring import Monitoring
-ResponseDataTypes = list | tuple | set | dict | int | float | str | bool | bytes | NoneType | ModelMetaclass
-IterableDataTypes = list | tuple | set | Cursor
+ResponseDataTypes = list | tuple | set | Cursor | PantherDBCursor | dict | int | float | str | bool | bytes | NoneType | ModelMetaclass
+IterableDataTypes = list | tuple | set | Cursor | PantherDBCursor
+StreamingDataTypes = Generator | AsyncGenerator
class Response:
@@ -17,16 +24,16 @@ def __init__(
self,
data: ResponseDataTypes = None,
headers: dict | None = None,
- status_code: int = 200,
+ status_code: int = status.HTTP_200_OK,
):
"""
- :param data: should be int | float | dict | list | tuple | set | str | bool | bytes | NoneType
- or instance of Pydantic.BaseModel
+ :param data: should be an instance of ResponseDataTypes
+ :param headers: should be dict of headers
:param status_code: should be int
"""
- self.data = self._clean_data_type(data)
- self._check_status_code(status_code)
- self._headers = headers
+ self.headers = headers or {}
+ self.data = self.prepare_data(data=data)
+ self.status_code = self.check_status_code(status_code=status_code)
@property
def body(self) -> bytes:
@@ -40,54 +47,73 @@ def body(self) -> bytes:
@property
def headers(self) -> dict:
return {
- 'content-type': self.content_type,
- 'content-length': len(self.body),
- 'access-control-allow-origin': '*',
- } | (self._headers or {})
+ 'Content-Type': self.content_type,
+ 'Content-Length': len(self.body),
+ 'Access-Control-Allow-Origin': '*',
+ } | self._headers
- def _clean_data_type(self, data: any):
- """Make sure the response data is only ResponseDataTypes or Iterable of ResponseDataTypes"""
- if issubclass(type(data), PydanticBaseModel):
- return data.model_dump()
+ @property
+ def bytes_headers(self) -> list[list[bytes]]:
+ return [[k.encode(), str(v).encode()] for k, v in (self.headers or {}).items()]
- elif isinstance(data, IterableDataTypes):
- return [self._clean_data_type(d) for d in data]
+ @headers.setter
+ def headers(self, headers: dict):
+ self._headers = headers
+
+ def prepare_data(self, data: any):
+ """Make sure the response data is only ResponseDataTypes or Iterable of ResponseDataTypes"""
+ if isinstance(data, (int | float | str | bool | bytes | NoneType)):
+ return data
elif isinstance(data, dict):
- return {key: self._clean_data_type(value) for key, value in data.items()}
+ return {key: self.prepare_data(value) for key, value in data.items()}
- elif isinstance(data, (int | float | str | bool | bytes | NoneType)):
- return data
+ elif issubclass(type(data), PydanticBaseModel):
+ return data.model_dump()
+
+ elif isinstance(data, IterableDataTypes):
+ return [self.prepare_data(d) for d in data]
else:
msg = f'Invalid Response Type: {type(data)}'
raise TypeError(msg)
- def _check_status_code(self, status_code: any):
+ @classmethod
+ def check_status_code(cls, status_code: any):
if not isinstance(status_code, int):
- error = f'Response "status_code" Should Be "int". ("{status_code}" is {type(status_code)})'
+ error = f'Response `status_code` Should Be `int`. (`{status_code}` is {type(status_code)})'
raise TypeError(error)
-
- self.status_code = status_code
-
- def _clean_data_with_output_model(self, output_model: ModelMetaclass | None):
- if self.data and output_model:
- self.data = self._serialize_with_output_model(self.data, output_model=output_model)
+ return status_code
@classmethod
- def _serialize_with_output_model(cls, data: any, /, output_model: ModelMetaclass):
+ def apply_output_model(cls, data: any, /, output_model: ModelMetaclass):
+ """This method is called in API.__call__"""
# Dict
if isinstance(data, dict):
+ for field_name, field in output_model.model_fields.items():
+ if field.validation_alias and field_name in data:
+ data[field.validation_alias] = data.pop(field_name)
return output_model(**data).model_dump()
# Iterable
if isinstance(data, IterableDataTypes):
- return [cls._serialize_with_output_model(d, output_model=output_model) for d in data]
+ return [cls.apply_output_model(d, output_model=output_model) for d in data]
# Str | Bool | Bytes
msg = 'Type of Response data is not match with `output_model`.\n*hint: You may want to remove `output_model`'
raise TypeError(msg)
+ async def send_headers(self, send, /):
+ await send({'type': 'http.response.start', 'status': self.status_code, 'headers': self.bytes_headers})
+
+ async def send_body(self, send, receive, /):
+ await send({'type': 'http.response.body', 'body': self.body, 'more_body': False})
+
+ async def send(self, send, receive, /, monitoring: Monitoring):
+ await self.send_headers(send)
+ await self.send_body(send, receive)
+ await monitoring.after(self.status_code)
+
def __str__(self):
if len(data := str(self.data)) > 30:
data = f'{data:.27}...'
@@ -96,6 +122,57 @@ def __str__(self):
__repr__ = __str__
+class StreamingResponse(Response):
+ content_type = 'application/octet-stream'
+
+ def __init__(self, *args, **kwargs):
+ self.connection_closed = False
+ super().__init__(*args, **kwargs)
+
+ async def listen_to_disconnection(self, receive):
+ message = await receive()
+ if message['type'] == 'http.disconnect':
+ self.connection_closed = True
+
+ def prepare_data(self, data: any) -> AsyncGenerator:
+ if isinstance(data, AsyncGenerator):
+ return data
+ elif isinstance(data, Generator):
+ return to_async_generator(data)
+ msg = f'Invalid Response Type: {type(data)}'
+ raise TypeError(msg)
+
+ @property
+ def headers(self) -> dict:
+ return {
+ 'Content-Type': self.content_type,
+ 'Access-Control-Allow-Origin': '*',
+ } | self._headers
+
+ @headers.setter
+ def headers(self, headers: dict):
+ self._headers = headers
+
+ @property
+ async def body(self) -> AsyncGenerator:
+ async for chunk in self.data:
+ if isinstance(chunk, bytes):
+ yield chunk
+ elif chunk is None:
+ yield b''
+ else:
+ yield json.dumps(chunk)
+
+ async def send_body(self, send, receive, /):
+ asyncio.create_task(self.listen_to_disconnection(receive))
+ async for chunk in self.body:
+ if self.connection_closed:
+ break
+ await send({'type': 'http.response.body', 'body': chunk, 'more_body': True})
+ else:
+ await send({'type': 'http.response.body', 'body': b'', 'more_body': False})
+
+
class HTMLResponse(Response):
content_type = 'text/html; charset=utf-8'
diff --git a/panther/routings.py b/panther/routings.py
index 5471613..2543926 100644
--- a/panther/routings.py
+++ b/panther/routings.py
@@ -1,4 +1,3 @@
-import logging
import re
from collections import Counter
from collections.abc import Callable, Mapping, MutableMapping
@@ -6,9 +5,7 @@
from functools import partial, reduce
from panther.configs import config
-
-
-logger = logging.getLogger('panther')
+from panther.exceptions import PantherError
def flatten_urls(urls: dict) -> dict:
@@ -28,20 +25,17 @@ def _flattening_urls(data: dict | Callable, url: str = ''):
url = url.removeprefix('/')
# Collect it, if it doesn't have problem
- if _is_url_endpoint_valid(url=url, endpoint=data):
- yield url, data
+ _is_url_endpoint_valid(url=url, endpoint=data)
+ yield url, data
-def _is_url_endpoint_valid(url: str, endpoint: Callable) -> bool:
+def _is_url_endpoint_valid(url: str, endpoint: Callable):
if endpoint is ...:
- logger.error(f"URL Can't Point To Ellipsis. ('{url}' -> ...)")
+ raise PantherError(f"URL Can't Point To Ellipsis. ('{url}' -> ...)")
elif endpoint is None:
- logger.error(f"URL Can't Point To None. ('{url}' -> None)")
+ raise PantherError(f"URL Can't Point To None. ('{url}' -> None)")
elif url and not re.match(r'^[a-zA-Z<>0-9_/-]+$', url):
- logger.error(f"URL Is Not Valid. --> '{url}'")
- else:
- return True
- return False
+ raise PantherError(f"URL Is Not Valid. --> '{url}'")
def finalize_urls(urls: dict) -> dict:
@@ -60,7 +54,33 @@ def finalize_urls(urls: dict) -> dict:
else:
path = {single_path: path or endpoint}
urls_list.append(path)
- return _merge(*urls_list) if urls_list else {}
+ final_urls = _merge(*urls_list) if urls_list else {}
+ check_urls_path_variables(final_urls)
+ return final_urls
+
+
+def check_urls_path_variables(urls: dict, path: str = '', ) -> None:
+ middle_route_error = []
+ last_route_error = []
+ for key, value in urls.items():
+ new_path = f'{path}/{key}'
+
+ if isinstance(value, dict):
+ if key.startswith('<'):
+ middle_route_error.append(new_path)
+ check_urls_path_variables(value, path=new_path)
+ elif key.startswith('<'):
+ last_route_error.append(new_path)
+
+ if len(middle_route_error) > 1:
+ msg = '\n\t- ' + '\n\t- '.join(e for e in middle_route_error)
+ raise PantherError(
+ f"URLs can't have same-level path variables that point to a dict: {msg}")
+
+ if len(last_route_error) > 1:
+ msg = '\n\t- ' + '\n\t- '.join(e for e in last_route_error)
+ raise PantherError(
+ f"URLs can't have same-level path variables that point to an endpoint: {msg}")
def _merge(destination: MutableMapping, *sources) -> MutableMapping:
@@ -106,7 +126,7 @@ def _is_recursive_merge(a, b):
def find_endpoint(path: str) -> tuple[Callable | None, str]:
- urls = config['urls']
+ urls = config.URLS
# 'user/list/?name=ali' --> 'user/list/' --> 'user/list' --> ['user', 'list']
parts = path.split('?')[0].strip('/').split('/')
diff --git a/panther/serializer.py b/panther/serializer.py
index 2eea063..c2f538e 100644
--- a/panther/serializer.py
+++ b/panther/serializer.py
@@ -1,14 +1,16 @@
import typing
+from typing import TypeVar, Type
-from pydantic import create_model, BaseModel
-from pydantic.fields import FieldInfo
+from pydantic import create_model, BaseModel, ConfigDict
+from pydantic.fields import FieldInfo, Field
from pydantic_core._pydantic_core import PydanticUndefined
from panther.db import Model
+from panther.request import Request
class MetaModelSerializer:
- KNOWN_CONFIGS = ['model', 'fields', 'required_fields']
+ KNOWN_CONFIGS = ['model', 'fields', 'exclude', 'required_fields', 'optional_fields']
def __new__(
cls,
@@ -18,6 +20,9 @@ def __new__(
**kwargs
):
if cls_name == 'ModelSerializer':
+ # Put `model` and `request` to the main class with `create_model()`
+ namespace['__annotations__'].pop('model')
+ namespace['__annotations__'].pop('request')
cls.model_serializer = type(cls_name, (), namespace)
return super().__new__(cls)
@@ -38,7 +43,8 @@ def __new__(
__module__=namespace['__module__'],
__validators__=namespace,
__base__=(cls.model_serializer, BaseModel),
- model=(typing.ClassVar, config.model),
+ model=(typing.ClassVar[type[BaseModel]], config.model),
+ request=(Request, Field(None, exclude=True)),
**field_definitions
)
@@ -67,22 +73,70 @@ def check_config(cls, cls_name: str, namespace: dict) -> None:
raise AttributeError(msg) from None
# Check `fields`
- if (fields := getattr(config, 'fields', None)) is None:
+ if not hasattr(config, 'fields'):
msg = f'`{cls_name}.Config.fields` is required.'
raise AttributeError(msg) from None
- for field_name in fields:
- if field_name not in model.model_fields:
- msg = f'`{cls_name}.Config.fields.{field_name}` is not valid.'
- raise AttributeError(msg) from None
+ if config.fields != '*':
+ for field_name in config.fields:
+ if field_name == '*':
+ msg = f"`{cls_name}.Config.fields.{field_name}` is not valid. Did you mean `fields = '*'`"
+ raise AttributeError(msg) from None
+
+ if field_name not in model.model_fields:
+ msg = f'`{cls_name}.Config.fields.{field_name}` is not in `{model.__name__}.model_fields`'
+ raise AttributeError(msg) from None
# Check `required_fields`
if not hasattr(config, 'required_fields'):
config.required_fields = []
- for required in config.required_fields:
- if required not in config.fields:
- msg = f'`{cls_name}.Config.required_fields.{required}` should be in `Config.fields` too.'
+ if config.required_fields != '*':
+ for required in config.required_fields:
+ if required not in config.fields:
+ msg = f'`{cls_name}.Config.required_fields.{required}` should be in `Config.fields` too.'
+ raise AttributeError(msg) from None
+
+ # Check `optional_fields`
+ if not hasattr(config, 'optional_fields'):
+ config.optional_fields = []
+
+ if config.optional_fields != '*':
+ for optional in config.optional_fields:
+ if optional not in config.fields:
+ msg = f'`{cls_name}.Config.optional_fields.{optional}` should be in `Config.fields` too.'
+ raise AttributeError(msg) from None
+
+ # Check `required_fields` and `optional_fields` together
+ if (
+ (config.optional_fields == '*' and config.required_fields != []) or
+ (config.required_fields == '*' and config.optional_fields != [])
+ ):
+ msg = (
+ f"`{cls_name}.Config.optional_fields` and "
+ f"`{cls_name}.Config.required_fields` can't include same fields at the same time"
+ )
+ raise AttributeError(msg) from None
+ for optional in config.optional_fields:
+ for required in config.required_fields:
+ if optional == required:
+ msg = (
+ f"`{optional}` can't be in `{cls_name}.Config.optional_fields` and "
+ f"`{cls_name}.Config.required_fields` at the same time"
+ )
+ raise AttributeError(msg) from None
+
+ # Check `exclude`
+ if not hasattr(config, 'exclude'):
+ config.exclude = []
+
+ for field_name in config.exclude:
+ if field_name not in model.model_fields:
+ msg = f'`{cls_name}.Config.exclude.{field_name}` is not valid.'
+ raise AttributeError(msg) from None
+
+ if config.fields != '*' and field_name not in config.fields:
+ msg = f'`{cls_name}.Config.exclude.{field_name}` is not defined in `Config.fields`.'
raise AttributeError(msg) from None
@classmethod
@@ -90,15 +144,35 @@ def collect_fields(cls, config: typing.Callable, namespace: dict) -> dict:
field_definitions = {}
# Define `fields`
- for field_name in config.fields:
- field_definitions[field_name] = (
- config.model.model_fields[field_name].annotation,
- config.model.model_fields[field_name]
- )
+ if config.fields == '*':
+ for field_name, field in config.model.model_fields.items():
+ field_definitions[field_name] = (field.annotation, field)
+ else:
+ for field_name in config.fields:
+ field_definitions[field_name] = (
+ config.model.model_fields[field_name].annotation,
+ config.model.model_fields[field_name]
+ )
+
+ # Apply `exclude`
+ for field_name in config.exclude:
+ del field_definitions[field_name]
# Apply `required_fields`
- for required in config.required_fields:
- field_definitions[required][1].default = PydanticUndefined
+ if config.required_fields == '*':
+ for value in field_definitions.values():
+ value[1].default = PydanticUndefined
+ else:
+ for field_name in config.required_fields:
+ field_definitions[field_name][1].default = PydanticUndefined
+
+ # Apply `optional_fields`
+ if config.optional_fields == '*':
+ for value in field_definitions.values():
+ value[1].default = value[0]()
+ else:
+ for field_name in config.optional_fields:
+ field_definitions[field_name][1].default = field_definitions[field_name][0]()
# Collect and Override `Class Fields`
for key, value in namespace.pop('__annotations__', {}).items():
@@ -112,9 +186,43 @@ def collect_model_config(cls, config: typing.Callable, namespace: dict) -> dict:
return {
attr: getattr(config, attr) for attr in dir(config)
if not attr.startswith('__') and attr not in cls.KNOWN_CONFIGS
- } | namespace.pop('model_config', {})
+ } | namespace.pop('model_config', {}) | {'arbitrary_types_allowed': True}
class ModelSerializer(metaclass=MetaModelSerializer):
- def create(self) -> type[Model]:
- return self.model.insert_one(self.model_dump())
+ """
+ Doc:
+ https://pantherpy.github.io/serializer/#style-2-model-serializer
+ Example:
+ class PersonSerializer(ModelSerializer):
+ class Meta:
+ model = Person
+ fields = '*'
+ exclude = ['created_date'] # Optional
+ required_fields = ['first_name', 'last_name'] # Optional
+ optional_fields = ['age'] # Optional
+ """
+ model: type[BaseModel]
+ request: Request
+
+ async def create(self, validated_data: dict) -> Model:
+ """
+ validated_data = ModelSerializer.model_dump()
+ """
+ return await self.model.insert_one(validated_data)
+
+ async def update(self, instance: Model, validated_data: dict) -> Model:
+ """
+ instance = UpdateAPI.object()
+ validated_data = ModelSerializer.model_dump()
+ """
+ await instance.update(validated_data)
+ return instance
+
+ async def partial_update(self, instance: Model, validated_data: dict) -> Model:
+ """
+ instance = UpdateAPI.object()
+ validated_data = ModelSerializer.model_dump(exclude_none=True)
+ """
+ await instance.update(validated_data)
+ return instance
diff --git a/panther/test.py b/panther/test.py
index 554ae6b..7e45052 100644
--- a/panther/test.py
+++ b/panther/test.py
@@ -4,7 +4,7 @@
import orjson as json
-from panther.response import Response
+from panther.response import Response, HTMLResponse, PlainTextResponse, StreamingResponse
__all__ = ('APIClient', 'WebsocketClient')
@@ -12,12 +12,13 @@
class RequestClient:
def __init__(self, app: Callable):
self.app = app
+ self.response = b''
async def send(self, data: dict):
if data['type'] == 'http.response.start':
self.header = data
else:
- self.response = data
+ self.response += data['body']
async def receive(self):
return {
@@ -54,11 +55,22 @@ async def request(
receive=self.receive,
send=self.send,
)
- return Response(
- data=json.loads(self.response.get('body', b'null')),
- status_code=self.header['status'],
- headers=self.header['headers'],
- )
+ response_headers = {key.decode(): value.decode() for key, value in self.header['headers']}
+ if response_headers['Content-Type'] == 'text/html; charset=utf-8':
+ data = self.response.decode()
+ return HTMLResponse(data=data, status_code=self.header['status'], headers=response_headers)
+
+ elif response_headers['Content-Type'] == 'text/plain; charset=utf-8':
+ data = self.response.decode()
+ return PlainTextResponse(data=data, status_code=self.header['status'], headers=response_headers)
+
+ elif response_headers['Content-Type'] == 'application/octet-stream':
+ data = self.response.decode()
+ return PlainTextResponse(data=data, status_code=self.header['status'], headers=response_headers)
+
+ else:
+ data = json.loads(self.response or b'null')
+ return Response(data=data, status_code=self.header['status'], headers=response_headers)
class APIClient:
diff --git a/panther/utils.py b/panther/utils.py
index 0d19e82..337d8a5 100644
--- a/panther/utils.py
+++ b/panther/utils.py
@@ -7,6 +7,10 @@
from pathlib import Path
from typing import ClassVar
+import pytz
+
+from panther.configs import config
+
logger = logging.getLogger('panther')
URANDOM_SIZE = 16
@@ -25,8 +29,7 @@ def load_env(env_file: str | Path, /) -> dict[str, str]:
variables = {}
if env_file is None or not Path(env_file).is_file():
- logger.critical(f'"{env_file}" is not valid file for load_env()')
- return variables
+ raise ValueError(f'"{env_file}" is not a file.') from None
with open(env_file) as file:
for line in file.readlines():
@@ -80,26 +83,8 @@ def scrypt(password: str, salt: bytes, digest: bool = False) -> str | bytes:
dklen=dk_len
)
if digest:
-
return hashlib.md5(derived_key).hexdigest()
- else:
- return derived_key
-
-
-def encrypt_password(password: str) -> str:
- salt = os.urandom(URANDOM_SIZE)
- derived_key = scrypt(password=password, salt=salt, digest=True)
-
- return f'{salt.hex()}{derived_key}'
-
-
-def check_password(stored_password: str, new_password: str) -> bool:
- size = URANDOM_SIZE * 2
- salt = stored_password[:size]
- stored_hash = stored_password[size:]
- derived_key = scrypt(password=new_password, salt=bytes.fromhex(salt), digest=True)
-
- return derived_key == stored_hash
+ return derived_key
class ULID:
@@ -121,3 +106,8 @@ def _generate(cls, bits: str) -> str:
cls.crockford_base32_characters[int(bits[i: i + 5], base=2)]
for i in range(0, 130, 5)
)
+
+
+def timezone_now():
+ tz = pytz.timezone(config.TIMEZONE)
+ return datetime.now(tz=tz)
diff --git a/panther/websocket.py b/panther/websocket.py
index f32784a..37f0972 100644
--- a/panther/websocket.py
+++ b/panther/websocket.py
@@ -22,16 +22,16 @@ async def receive(self, data: str | bytes):
async def send(self, data: any = None):
"""
- We are using this method to send message to the client,
+ Send message to the client,
You may want to override it with your custom scenario. (not recommended)
"""
return await super().send(data=data)
async def send_message_to_websocket(connection_id: str, data: any):
- config.websocket_connections.publish(connection_id=connection_id, action='send', data=data)
+ await config.WEBSOCKET_CONNECTIONS.publish(connection_id=connection_id, action='send', data=data)
async def close_websocket_connection(connection_id: str, code: int = status.WS_1000_NORMAL_CLOSURE, reason: str = ''):
data = {'code': code, 'reason': reason}
- config.websocket_connections.publish(connection_id=connection_id, action='close', data=data)
+ await config.WEBSOCKET_CONNECTIONS.publish(connection_id=connection_id, action='close', data=data)
diff --git a/setup.py b/setup.py
index 003e262..8fce320 100644
--- a/setup.py
+++ b/setup.py
@@ -21,6 +21,7 @@ def panther_version() -> str:
'python-jose~=3.3',
'websockets~=12.0',
'cryptography~=42.0',
+ 'watchfiles~=0.21.0',
],
}
@@ -50,11 +51,11 @@ def panther_version() -> str:
},
install_requires=[
'httptools~=0.6',
- 'pantherdb==1.4.0',
+ 'pantherdb==2.0.0',
'pydantic~=2.6',
'rich~=13.7',
'uvicorn~=0.27',
- 'watchfiles~=0.21.0',
+ 'pytz~=2024.1',
],
extras_require=EXTRAS_REQUIRE,
)
diff --git a/tests/test_authentication.py b/tests/test_authentication.py
index 821fe6c..b223af3 100644
--- a/tests/test_authentication.py
+++ b/tests/test_authentication.py
@@ -69,15 +69,15 @@ async def test_user_without_auth(self):
assert res.data is None
async def test_user_auth_required_without_auth_class(self):
- auth_config = config['authentication']
- config['authentication'] = None
+ auth_config = config.AUTHENTICATION
+ config.AUTHENTICATION = None
with self.assertLogs(level='CRITICAL') as captured:
res = await self.client.get('auth-required')
assert len(captured.records) == 1
assert captured.records[0].getMessage() == '"AUTHENTICATION" has not been set in configs'
assert res.status_code == 500
assert res.data['detail'] == 'Internal Server Error'
- config['authentication'] = auth_config
+ config.AUTHENTICATION = auth_config
async def test_user_auth_required_without_token(self):
with self.assertLogs(level='ERROR') as captured:
diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py
index 7a8bf30..4530064 100644
--- a/tests/test_background_tasks.py
+++ b/tests/test_background_tasks.py
@@ -9,11 +9,11 @@
class TestBackgroundTasks(TestCase):
def setUp(self):
self.obj = BackgroundTasks()
- config['background_tasks'] = True
+ config.BACKGROUND_TASKS = True
def tearDown(self):
del Singleton._instances[BackgroundTasks]
- config['background_tasks'] = False
+ config.BACKGROUND_TASKS = False
def test_background_tasks_singleton(self):
new_obj = BackgroundTasks()
@@ -60,7 +60,7 @@ def func(_numbers): _numbers.append(1)
with self.assertLogs() as captured:
self.obj.add_task(task)
assert len(captured.records) == 1
- assert captured.records[0].getMessage() == 'Task will be ignored, `BACKGROUND_TASKS` is not True in `core/configs.py`'
+ assert captured.records[0].getMessage() == 'Task will be ignored, `BACKGROUND_TASKS` is not True in `configs`'
assert self.obj.tasks == []
def test_add_task_with_args(self):
diff --git a/tests/test_caching.py b/tests/test_caching.py
index efbea48..25efa68 100644
--- a/tests/test_caching.py
+++ b/tests/test_caching.py
@@ -1,4 +1,5 @@
import time
+import asyncio
from datetime import timedelta
from unittest import IsolatedAsyncioTestCase
@@ -10,19 +11,19 @@
@API()
async def without_cache_api():
- time.sleep(0.01)
+ await asyncio.sleep(0.01)
return {'detail': time.time()}
@API(cache=True)
async def with_cache_api():
- time.sleep(0.01)
+ await asyncio.sleep(0.01)
return {'detail': time.time()}
@API(cache=True, cache_exp_time=timedelta(seconds=5))
async def expired_cache_api():
- time.sleep(0.01)
+ await asyncio.sleep(0.01)
return {'detail': time.time()}
@@ -65,7 +66,7 @@ async def test_with_cache_5second_exp_time(self):
# Check Logs
assert len(captured.records) == 1
- assert captured.records[0].getMessage() == '"cache_exp_time" is not very accurate when redis is not connected.'
+ assert captured.records[0].getMessage() == '`cache_exp_time` is not very accurate when `redis` is not connected.'
# Second Request
res2 = await self.client.get('with-expired-cache')
@@ -74,7 +75,7 @@ async def test_with_cache_5second_exp_time(self):
# Response should be cached
assert check_two_dicts(res1.data, res2.data) is True
- time.sleep(5)
+ await asyncio.sleep(5)
# Third Request
res3 = await self.client.get('with-expired-cache')
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 955fd5f..2e7212a 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -23,7 +23,7 @@ def interactive_cli_1_mock_responses(index=None):
global interactive_cli_1_index
if index is None:
index = interactive_cli_1_index
- responses = ['project1', 'project1_dir', 'n', '0', 'y', 'y', 'y', 'y', 'y']
+ responses = ['project1', 'project1_dir', 'n', '0', 'y', 'n', 'y', 'y', 'y', 'y']
response = responses[index]
interactive_cli_1_index += 1
return response
@@ -33,7 +33,7 @@ def interactive_cli_2_mock_responses(index=None):
global interactive_cli_2_index
if index is None:
index = interactive_cli_2_index
- responses = ['project2', 'project2_dir', 'y', '0', 'y', 'y', 'y', 'y', 'y']
+ responses = ['project2', 'project2_dir', 'y', '0', 'y', 'n', 'y', 'y', 'y', 'y']
response = responses[index]
interactive_cli_2_index += 1
return response
diff --git a/tests/test_database.py b/tests/test_database.py
index eed5750..560a596 100644
--- a/tests/test_database.py
+++ b/tests/test_database.py
@@ -8,7 +8,8 @@
from panther import Panther
from panther.db import Model
from panther.db.connections import db
-from panther.db.cursor import Cursor
+from panther.db.cursor import Cursor as MongoCursor
+from pantherdb import Cursor as PantherDBCursor
f = faker.Faker()
@@ -33,10 +34,24 @@ async def test_insert_one(self):
assert book.name == name
assert book.pages_count == pages_count
- async def test_insert_many(self):
+ async def test_insert_many_with_insert_one(self):
insert_count = await self._insert_many()
assert insert_count > 1
+ async def test_insert_many(self):
+ insert_count = random.randint(2, 10)
+ data = [
+ {'name': f.name(), 'author': f.name(), 'pages_count': random.randint(0, 10)}
+ for _ in range(insert_count)
+ ]
+ books = await Book.insert_many(data)
+ inserted_books = [
+ {'_id': book._id, 'name': book.name, 'author': book.author, 'pages_count': book.pages_count}
+ for book in books
+ ]
+ assert len(books) == insert_count
+ assert data == inserted_books
+
# # # FindOne
async def test_find_one_not_found(self):
# Insert Many
@@ -157,9 +172,9 @@ async def test_find(self):
_len = sum(1 for _ in books)
if self.__class__.__name__ == 'TestMongoDB':
- assert isinstance(books, Cursor)
+ assert isinstance(books, MongoCursor)
else:
- assert isinstance(books, list)
+ assert isinstance(books, PantherDBCursor)
assert _len == insert_count
for book in books:
assert isinstance(book, Book)
@@ -174,9 +189,9 @@ async def test_find_not_found(self):
_len = sum(1 for _ in books)
if self.__class__.__name__ == 'TestMongoDB':
- assert isinstance(books, Cursor)
+ assert isinstance(books, MongoCursor)
else:
- assert isinstance(books, list)
+ assert isinstance(books, PantherDBCursor)
assert _len == 0
async def test_find_without_filter(self):
@@ -188,9 +203,9 @@ async def test_find_without_filter(self):
_len = sum(1 for _ in books)
if self.__class__.__name__ == 'TestMongoDB':
- assert isinstance(books, Cursor)
+ assert isinstance(books, MongoCursor)
else:
- assert isinstance(books, list)
+ assert isinstance(books, PantherDBCursor)
assert _len == insert_count
for book in books:
assert isinstance(book, Book)
@@ -204,9 +219,9 @@ async def test_all(self):
_len = sum(1 for _ in books)
if self.__class__.__name__ == 'TestMongoDB':
- assert isinstance(books, Cursor)
+ assert isinstance(books, MongoCursor)
else:
- assert isinstance(books, list)
+ assert isinstance(books, PantherDBCursor)
assert _len == insert_count
for book in books:
@@ -418,9 +433,9 @@ async def test_update_many(self):
_len = sum(1 for _ in books)
if self.__class__.__name__ == 'TestMongoDB':
- assert isinstance(books, Cursor)
+ assert isinstance(books, MongoCursor)
else:
- assert isinstance(books, list)
+ assert isinstance(books, PantherDBCursor)
assert _len == updated_count == insert_count
for book in books:
assert book.author == author
diff --git a/tests/test_events.py b/tests/test_events.py
new file mode 100644
index 0000000..5eff523
--- /dev/null
+++ b/tests/test_events.py
@@ -0,0 +1,123 @@
+import logging
+from unittest import IsolatedAsyncioTestCase
+
+from panther.configs import config
+from panther.events import Event
+
+logger = logging.getLogger('panther')
+
+
+class TestEvents(IsolatedAsyncioTestCase):
+ def tearDown(self):
+ config.refresh()
+
+ async def test_async_startup(self):
+ assert len(config.STARTUPS) == 0
+
+ async def startup_event():
+ logger.info('This Is Startup.')
+
+ Event.startup(startup_event)
+
+ assert len(config.STARTUPS) == 1
+ assert config.STARTUPS[0] == startup_event
+
+ with self.assertLogs(level='INFO') as capture:
+ await Event.run_startups()
+
+ assert len(capture.records) == 1
+ assert capture.records[0].getMessage() == 'This Is Startup.'
+
+ async def test_sync_startup(self):
+ assert len(config.STARTUPS) == 0
+
+ def startup_event():
+ logger.info('This Is Startup.')
+
+ Event.startup(startup_event)
+
+ assert len(config.STARTUPS) == 1
+ assert config.STARTUPS[0] == startup_event
+
+ with self.assertLogs(level='INFO') as capture:
+ await Event.run_startups()
+
+ assert len(capture.records) == 1
+ assert capture.records[0].getMessage() == 'This Is Startup.'
+
+ async def test_startup(self):
+ assert len(config.STARTUPS) == 0
+
+ def startup_event1():
+ logger.info('This Is Startup1.')
+
+ async def startup_event2():
+ logger.info('This Is Startup2.')
+
+ Event.startup(startup_event1)
+ Event.startup(startup_event2)
+
+ assert len(config.STARTUPS) == 2
+ assert config.STARTUPS[0] == startup_event1
+ assert config.STARTUPS[1] == startup_event2
+
+ with self.assertLogs(level='INFO') as capture:
+ await Event.run_startups()
+
+ assert len(capture.records) == 2
+ assert capture.records[0].getMessage() == 'This Is Startup1.'
+ assert capture.records[1].getMessage() == 'This Is Startup2.'
+
+ async def test_sync_shutdown(self):
+ assert len(config.SHUTDOWNS) == 0
+
+ def shutdown_event():
+ logger.info('This Is Shutdown.')
+
+ Event.shutdown(shutdown_event)
+
+ assert len(config.SHUTDOWNS) == 1
+ assert config.SHUTDOWNS[0] == shutdown_event
+
+ with self.assertLogs(level='INFO') as capture:
+ Event.run_shutdowns()
+
+ assert len(capture.records) == 1
+ assert capture.records[0].getMessage() == 'This Is Shutdown.'
+
+ async def shutdown_event(self):
+ logger.info('This Is Shutdown.')
+
+ def test_async_shutdown(self):
+ assert len(config.SHUTDOWNS) == 0
+
+ Event.shutdown(self.shutdown_event)
+
+ assert len(config.SHUTDOWNS) == 1
+ assert config.SHUTDOWNS[0] == self.shutdown_event
+
+ with self.assertLogs(level='INFO') as capture:
+ Event.run_shutdowns()
+
+ assert len(capture.records) == 1
+ assert capture.records[0].getMessage() == 'This Is Shutdown.'
+
+ def test_shutdown(self):
+ assert len(config.SHUTDOWNS) == 0
+
+ def shutdown_event_sync():
+ logger.info('This Is Sync Shutdown.')
+
+ Event.shutdown(self.shutdown_event)
+ Event.shutdown(shutdown_event_sync)
+
+ assert len(config.SHUTDOWNS) == 2
+ assert config.SHUTDOWNS[0] == self.shutdown_event
+ assert config.SHUTDOWNS[1] == shutdown_event_sync
+
+ with self.assertLogs(level='INFO') as capture:
+ Event.run_shutdowns()
+
+ assert len(capture.records) == 2
+ assert capture.records[0].getMessage() == 'This Is Shutdown.'
+ assert capture.records[1].getMessage() == 'This Is Sync Shutdown.'
diff --git a/tests/test_generics.py b/tests/test_generics.py
new file mode 100644
index 0000000..97a2a6a
--- /dev/null
+++ b/tests/test_generics.py
@@ -0,0 +1,226 @@
+from pathlib import Path
+from unittest import IsolatedAsyncioTestCase
+
+from panther import Panther
+from panther.db import Model
+from panther.generics import RetrieveAPI, ListAPI, UpdateAPI, DeleteAPI, CreateAPI
+from panther.pagination import Pagination
+from panther.request import Request
+from panther.serializer import ModelSerializer
+from panther.test import APIClient
+
+
+class User(Model):
+ name: str
+
+
+class Person(User):
+ age: int
+
+
+class RetrieveAPITest(RetrieveAPI):
+ async def object(self, request: Request, **kwargs) -> Model:
+ return await User.find_one(id=kwargs['id'])
+
+
+class ListAPITest(ListAPI):
+ async def objects(self, request: Request, **kwargs):
+ return await User.find()
+
+
+class FullListAPITest(ListAPI):
+ sort_fields = ['name', 'age']
+ search_fields = ['name']
+ filter_fields = ['name', 'age']
+ pagination = Pagination
+
+ async def objects(self, request: Request, **kwargs):
+ return await Person.find()
+
+
+class UserSerializer(ModelSerializer):
+ class Config:
+ model = User
+ fields = '*'
+
+
+class UpdateAPITest(UpdateAPI):
+ input_model = UserSerializer
+
+ async def object(self, request: Request, **kwargs) -> Model:
+ return await User.find_one(id=kwargs['id'])
+
+
+class CreateAPITest(CreateAPI):
+ input_model = UserSerializer
+
+
+class DeleteAPITest(DeleteAPI):
+ async def object(self, request: Request, **kwargs) -> Model:
+ return await User.find_one(id=kwargs['id'])
+
+
+urls = {
+ 'retrieve/': RetrieveAPITest,
+ 'list': ListAPITest,
+ 'full-list': FullListAPITest,
+ 'update/': UpdateAPITest,
+ 'create': CreateAPITest,
+ 'delete/': DeleteAPITest,
+}
+
+
+class TestGeneric(IsolatedAsyncioTestCase):
+ DB_PATH = 'test.pdb'
+
+ @classmethod
+ def setUpClass(cls) -> None:
+ global DATABASE
+ DATABASE = {
+ 'engine': {
+ 'class': 'panther.db.connections.PantherDBConnection',
+ 'path': cls.DB_PATH
+ },
+ }
+ app = Panther(__name__, configs=__name__, urls=urls)
+ cls.client = APIClient(app=app)
+
+ def tearDown(self) -> None:
+ Path(self.DB_PATH).unlink()
+
+ async def test_retrieve(self):
+ user = await User.insert_one(name='Ali')
+ res = await self.client.get(f'retrieve/{user.id}')
+ assert res.status_code == 200
+ assert res.data == {'id': user.id, 'name': user.name}
+
+ async def test_list(self):
+ users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}])
+ res = await self.client.get('list')
+ assert res.status_code == 200
+ assert res.data == [{'id': u.id, 'name': u.name} for u in users]
+
+ async def test_list_features(self):
+ await Person.insert_many([
+ {'name': 'Ali', 'age': 0},
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Saba', 'age': 0},
+ {'name': 'Saba', 'age': 1},
+ ])
+ res = await self.client.get('full-list')
+ assert res.status_code == 200
+ assert set(res.data.keys()) == {'results', 'count', 'previous', 'next'}
+
+ # Normal
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Ali', 'age': 0},
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Saba', 'age': 0},
+ {'name': 'Saba', 'age': 1},
+ ]
+
+ # Sort 1
+ res = await self.client.get('full-list', query_params={'sort': '-name'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Saba', 'age': 0},
+ {'name': 'Saba', 'age': 1},
+ {'name': 'Ali', 'age': 0},
+ {'name': 'Ali', 'age': 1},
+ ]
+
+ # Sort 2
+ res = await self.client.get('full-list', query_params={'sort': '-name,-age'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Saba', 'age': 1},
+ {'name': 'Saba', 'age': 0},
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Ali', 'age': 0},
+ ]
+
+ # Sort 3
+ res = await self.client.get('full-list', query_params={'sort': 'name,-age'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Ali', 'age': 0},
+ {'name': 'Saba', 'age': 1},
+ {'name': 'Saba', 'age': 0},
+ ]
+
+ # Filter 1
+ res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'name': 'Ali'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Ali', 'age': 0},
+ ]
+
+ # Filter 2
+ res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'name': 'Alex'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == []
+
+ # Search
+ res = await self.client.get('full-list', query_params={'sort': 'name,-age', 'search': 'Ali'})
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Ali', 'age': 1},
+ {'name': 'Ali', 'age': 0},
+ ]
+
+ # Pagination 1
+ res = await self.client.get('full-list', query_params={'sort': 'name,-age'})
+ assert res.data['previous'] is None
+ assert res.data['next'] is None
+ assert res.data['count'] == 4
+
+ # Pagination 2
+ res = await self.client.get('full-list', query_params={'limit': 2})
+ assert res.data['previous'] is None
+ assert res.data['next'] == '?limit=2&skip=2'
+ assert res.data['count'] == 4
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Ali', 'age': 0},
+ {'name': 'Ali', 'age': 1},
+ ]
+
+ res = await self.client.get('full-list', query_params={'limit': 2, 'skip': 2})
+ assert res.data['previous'] == '?limit=2&skip=0'
+ assert res.data['next'] is None
+ assert res.data['count'] == 4
+ response = [{'name': r['name'], 'age': r['age']} for r in res.data['results']]
+ assert response == [
+ {'name': 'Saba', 'age': 0},
+ {'name': 'Saba', 'age': 1},
+ ]
+
+ async def test_update(self):
+ users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}])
+ res = await self.client.put(f'update/{users[1].id}', payload={'name': 'NewName'})
+ assert res.status_code == 200
+ assert res.data['name'] == 'NewName'
+
+ new_users = await User.find()
+ users[1].name = 'NewName'
+ assert {(u.id, u.name) for u in new_users} == {(u.id, u.name) for u in users}
+
+ async def test_create(self):
+ res = await self.client.post('create', payload={'name': 'Sara'})
+ assert res.status_code == 201
+ assert res.data['name'] == 'Sara'
+
+ new_users = await User.find()
+ assert len([i for i in new_users])
+ assert new_users[0].name == 'Sara'
+
+ async def test_delete(self):
+ users = await User.insert_many([{'name': 'Ali'}, {'name': 'Hamed'}])
+ res = await self.client.delete(f'delete/{users[1].id}')
+ assert res.status_code == 204
+ new_users = await User.find()
+ assert len([u for u in new_users]) == 1
+ assert new_users[0].model_dump() == users[0].model_dump()
diff --git a/tests/test_response.py b/tests/test_response.py
new file mode 100644
index 0000000..f00cc54
--- /dev/null
+++ b/tests/test_response.py
@@ -0,0 +1,456 @@
+from unittest import IsolatedAsyncioTestCase
+
+from panther import Panther
+from panther.app import API, GenericAPI
+from panther.response import Response, HTMLResponse, PlainTextResponse, StreamingResponse
+from panther.test import APIClient
+
+
+@API()
+async def return_nothing():
+ pass
+
+
+class ReturnNothing(GenericAPI):
+ def get(self):
+ pass
+
+
+@API()
+async def return_none():
+ return None
+
+
+class ReturnNone(GenericAPI):
+ def get(self):
+ return None
+
+
+@API()
+async def return_string():
+ return 'Hello'
+
+
+class ReturnString(GenericAPI):
+ def get(self):
+ return 'Hello'
+
+
+@API()
+async def return_dict():
+ return {'detail': 'ok'}
+
+
+class ReturnDict(GenericAPI):
+ def get(self):
+ return {'detail': 'ok'}
+
+
+@API()
+async def return_list():
+ return [1, 2, 3]
+
+
+class ReturnList(GenericAPI):
+ def get(self):
+ return [1, 2, 3]
+
+
+@API()
+async def return_tuple():
+ return 1, 2, 3, 4
+
+
+class ReturnTuple(GenericAPI):
+ def get(self):
+ return 1, 2, 3, 4
+
+
+@API()
+async def return_response_none():
+ return Response()
+
+
+class ReturnResponseNone(GenericAPI):
+ def get(self):
+ return Response()
+
+
+@API()
+async def return_response_dict():
+ return Response(data={'detail': 'ok'})
+
+
+class ReturnResponseDict(GenericAPI):
+ def get(self):
+ return Response(data={'detail': 'ok'})
+
+
+@API()
+async def return_response_list():
+ return Response(data=['car', 'home', 'phone'])
+
+
+class ReturnResponseList(GenericAPI):
+ def get(self):
+ return Response(data=['car', 'home', 'phone'])
+
+
+@API()
+async def return_response_tuple():
+ return Response(data=('car', 'home', 'phone', 'book'))
+
+
+class ReturnResponseTuple(GenericAPI):
+ def get(self):
+ return Response(data=('car', 'home', 'phone', 'book'))
+
+
+@API()
+async def return_html_response():
+ return HTMLResponse('')
+
+
+class ReturnHTMLResponse(GenericAPI):
+ def get(self):
+ return HTMLResponse('')
+
+
+@API()
+async def return_plain_response():
+ return PlainTextResponse('Hello World')
+
+
+class ReturnPlainResponse(GenericAPI):
+ def get(self):
+ return PlainTextResponse('Hello World')
+
+
+class ReturnStreamingResponse(GenericAPI):
+ def get(self):
+ def f():
+ for i in range(5):
+ yield i
+ return StreamingResponse(f())
+
+
+class ReturnAsyncStreamingResponse(GenericAPI):
+ async def get(self):
+ async def f():
+ for i in range(6):
+ yield i
+ return StreamingResponse(f())
+
+
+class ReturnInvalidStatusCode(GenericAPI):
+ def get(self):
+ return Response(status_code='ali')
+
+
+urls = {
+ 'nothing': return_nothing,
+ 'none': return_none,
+ 'dict': return_dict,
+ 'str': return_string,
+ 'list': return_list,
+ 'tuple': return_tuple,
+ 'response-none': return_response_none,
+ 'response-dict': return_response_dict,
+ 'response-list': return_response_list,
+ 'response-tuple': return_response_tuple,
+ 'html': return_html_response,
+ 'plain': return_plain_response,
+
+ 'nothing-cls': ReturnNothing,
+ 'none-cls': ReturnNone,
+ 'dict-cls': ReturnDict,
+ 'str-cls': ReturnString,
+ 'list-cls': ReturnList,
+ 'tuple-cls': ReturnTuple,
+ 'response-none-cls': ReturnResponseNone,
+ 'response-dict-cls': ReturnResponseDict,
+ 'response-list-cls': ReturnResponseList,
+ 'response-tuple-cls': ReturnResponseTuple,
+ 'html-cls': ReturnHTMLResponse,
+ 'plain-cls': ReturnPlainResponse,
+
+ 'stream': ReturnStreamingResponse,
+ 'async-stream': ReturnAsyncStreamingResponse,
+ 'invalid-status-code': ReturnInvalidStatusCode,
+}
+
+
+class TestResponses(IsolatedAsyncioTestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ app = Panther(__name__, configs=__name__, urls=urls)
+ cls.client = APIClient(app=app)
+
+ async def test_nothing(self):
+ res = await self.client.get('nothing/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_nothing_cls(self):
+ res = await self.client.get('nothing-cls/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_none(self):
+ res = await self.client.get('none/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_none_cls(self):
+ res = await self.client.get('none-cls/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_dict(self):
+ res = await self.client.get('dict/')
+ assert res.status_code == 200
+ assert res.data == {'detail': 'ok'}
+ assert res.body == b'{"detail":"ok"}'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '15'
+
+ async def test_dict_cls(self):
+ res = await self.client.get('dict-cls/')
+ assert res.status_code == 200
+ assert res.data == {'detail': 'ok'}
+ assert res.body == b'{"detail":"ok"}'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '15'
+
+ async def test_string(self):
+ res = await self.client.get('str/')
+ assert res.status_code == 200
+ assert res.data == 'Hello'
+ assert res.body == b'"Hello"'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '7'
+
+ async def test_string_cls(self):
+ res = await self.client.get('str-cls/')
+ assert res.status_code == 200
+ assert res.data == 'Hello'
+ assert res.body == b'"Hello"'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '7'
+
+ async def test_list(self):
+ res = await self.client.get('list/')
+ assert res.status_code == 200
+ assert res.data == [1, 2, 3]
+ assert res.body == b'[1,2,3]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '7'
+
+ async def test_list_cls(self):
+ res = await self.client.get('list-cls/')
+ assert res.status_code == 200
+ assert res.data == [1, 2, 3]
+ assert res.body == b'[1,2,3]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '7'
+
+ async def test_tuple(self):
+ res = await self.client.get('tuple/')
+ assert res.status_code == 200
+ assert res.data == [1, 2, 3, 4]
+ assert res.body == b'[1,2,3,4]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '9'
+
+ async def test_tuple_cls(self):
+ res = await self.client.get('tuple-cls/')
+ assert res.status_code == 200
+ assert res.data == [1, 2, 3, 4]
+ assert res.body == b'[1,2,3,4]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '9'
+
+ async def test_response_none(self):
+ res = await self.client.get('response-none/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_response_none_cls(self):
+ res = await self.client.get('response-none-cls/')
+ assert res.status_code == 200
+ assert res.data is None
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '0'
+
+ async def test_response_dict(self):
+ res = await self.client.get('response-dict/')
+ assert res.status_code == 200
+ assert res.data == {'detail': 'ok'}
+ assert res.body == b'{"detail":"ok"}'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '15'
+
+ async def test_response_dict_cls(self):
+ res = await self.client.get('response-dict-cls/')
+ assert res.status_code == 200
+ assert res.data == {'detail': 'ok'}
+ assert res.body == b'{"detail":"ok"}'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '15'
+
+ async def test_response_list(self):
+ res = await self.client.get('response-list/')
+ assert res.status_code == 200
+ assert res.data == ['car', 'home', 'phone']
+ assert res.body == b'["car","home","phone"]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '22'
+
+ async def test_response_list_cls(self):
+ res = await self.client.get('response-list-cls/')
+ assert res.status_code == 200
+ assert res.data == ['car', 'home', 'phone']
+ assert res.body == b'["car","home","phone"]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '22'
+
+ async def test_response_tuple(self):
+ res = await self.client.get('response-tuple/')
+ assert res.status_code == 200
+ assert res.data == ['car', 'home', 'phone', 'book']
+ assert res.body == b'["car","home","phone","book"]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '29'
+
+ async def test_response_tuple_cls(self):
+ res = await self.client.get('response-tuple-cls/')
+ assert res.status_code == 200
+ assert res.data == ['car', 'home', 'phone', 'book']
+ assert res.body == b'["car","home","phone","book"]'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '29'
+
+ async def test_response_html(self):
+ res = await self.client.get('html/')
+ assert res.status_code == 200
+ assert res.data == ''
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'text/html; charset=utf-8'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '41'
+
+ async def test_response_html_cls(self):
+ res = await self.client.get('html-cls/')
+ assert res.status_code == 200
+ assert res.data == ''
+ assert res.body == b''
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'text/html; charset=utf-8'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '41'
+
+ async def test_response_plain(self):
+ res = await self.client.get('plain/')
+ assert res.status_code == 200
+ assert res.data == 'Hello World'
+ assert res.body == b'Hello World'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'text/plain; charset=utf-8'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '11'
+
+ async def test_response_plain_cls(self):
+ res = await self.client.get('plain-cls/')
+ assert res.status_code == 200
+ assert res.data == 'Hello World'
+ assert res.body == b'Hello World'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'text/plain; charset=utf-8'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == '11'
+
+ async def test_streaming_response(self):
+ res = await self.client.get('stream/')
+ assert res.status_code == 200
+ assert res.headers['Content-Type'] == 'application/octet-stream'
+ assert res.data == '01234'
+ assert res.body == b'01234'
+
+ async def test_async_streaming_response(self):
+ res = await self.client.get('async-stream/')
+ assert res.status_code == 200
+ assert res.headers['Content-Type'] == 'application/octet-stream'
+ assert res.data == '012345'
+ assert res.body == b'012345'
+
+ async def test_invalid_status_code(self):
+ with self.assertLogs(level='CRITICAL') as captured:
+ res = await self.client.get('invalid-status-code/')
+
+ assert len(captured.records) == 1
+ assert captured.records[0].getMessage().split('\n')[0] == "Response `status_code` Should Be `int`. (`ali` is )"
+
+ assert res.status_code == 500
+ assert res.data == {'detail': 'Internal Server Error'}
+ assert res.body == b'{"detail":"Internal Server Error"}'
+ assert set(res.headers.keys()) == {'Content-Type', 'Access-Control-Allow-Origin', 'Content-Length'}
+ assert res.headers['Content-Type'] == 'application/json'
+ assert res.headers['Access-Control-Allow-Origin'] == '*'
+ assert res.headers['Content-Length'] == 34
diff --git a/tests/test_routing.py b/tests/test_routing.py
index 6ef398e..91da3f7 100644
--- a/tests/test_routing.py
+++ b/tests/test_routing.py
@@ -1,8 +1,9 @@
import random
from unittest import TestCase
+from panther.base_request import BaseRequest
+from panther.exceptions import PantherError
from panther.routings import (
- collect_path_variables,
finalize_urls,
find_endpoint,
flatten_urls,
@@ -13,67 +14,88 @@ class TestRoutingFunctions(TestCase):
def tearDown(self) -> None:
from panther.configs import config
- config['urls'] = {}
+ config.URLS = {}
# Collecting
def test_collect_ellipsis_urls(self):
- urls = {
+ urls1 = {
'user/': {
'/': ...,
- 'profile/': ...,
+ },
+ }
+ urls2 = {
+ 'user/': {
'list/': ...,
},
}
- with self.assertLogs() as captured:
- collected_urls = flatten_urls(urls)
-
- assert len(captured.records) == 3
- assert captured.records[0].getMessage() == "URL Can't Point To Ellipsis. ('user//' -> ...)"
- assert captured.records[1].getMessage() == "URL Can't Point To Ellipsis. ('user/profile/' -> ...)"
- assert captured.records[2].getMessage() == "URL Can't Point To Ellipsis. ('user/list/' -> ...)"
+ try:
+ flatten_urls(urls1)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Can't Point To Ellipsis. ('user//' -> ...)"
+ else:
+ assert False
- assert collected_urls == {}
+ try:
+ flatten_urls(urls2)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Can't Point To Ellipsis. ('user/list/' -> ...)"
+ else:
+ assert False
- def test_collect_None_urls(self): # noqa: N802
- urls = {
+ def test_collect_None_urls(self):
+ urls1 = {
'user/': {
- '/': None,
- 'profile/': None,
'list/': None,
},
}
+ urls2 = {
+ 'user/': {
+ '/': None,
+ },
+ }
- with self.assertLogs() as captured:
- collected_urls = flatten_urls(urls)
-
- assert len(captured.records) == 3
- assert captured.records[0].getMessage() == "URL Can't Point To None. ('user//' -> None)"
- assert captured.records[1].getMessage() == "URL Can't Point To None. ('user/profile/' -> None)"
- assert captured.records[2].getMessage() == "URL Can't Point To None. ('user/list/' -> None)"
+ try:
+ flatten_urls(urls1)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Can't Point To None. ('user/list/' -> None)"
+ else:
+ assert False
- assert collected_urls == {}
+ try:
+ flatten_urls(urls2)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Can't Point To None. ('user//' -> None)"
+ else:
+ assert False
def test_collect_invalid_urls(self):
def temp_func(): pass
- urls = {
+ urls1 = {
'user/': {
'?': temp_func,
- '%^': temp_func,
+ },
+ }
+ urls2 = {
+ 'user/': {
'لیست': temp_func,
},
}
- with self.assertLogs() as captured:
- collected_urls = flatten_urls(urls)
-
- assert len(captured.records) == 3
- assert captured.records[0].getMessage() == "URL Is Not Valid. --> 'user/?/'"
- assert captured.records[1].getMessage() == "URL Is Not Valid. --> 'user/%^/'"
- assert captured.records[2].getMessage() == "URL Is Not Valid. --> 'user/لیست/'"
+ try:
+ flatten_urls(urls1)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Is Not Valid. --> 'user/?/'"
+ else:
+ assert False
- assert collected_urls == {}
+ try:
+ flatten_urls(urls2)
+ except PantherError as exc:
+ assert exc.args[0] == "URL Is Not Valid. --> 'user/لیست/'"
+ else:
+ assert False
def test_collect_empty_url(self):
def temp_func(): pass
@@ -504,13 +526,54 @@ def temp_func(): pass
}
assert finalized_urls == expected_result
+ def test_finalize_urls_with_same_level_path_variables(self):
+ def temp_func(): pass
+
+ urls1 = {
+ 'user': {
+ '/': temp_func,
+ '/': temp_func,
+ }
+ }
+ urls2 = {
+ 'user': {
+ '/': {'detail': temp_func},
+ '/': temp_func,
+ '/': {'detail': temp_func},
+ '/': {'detail': temp_func},
+ }
+ }
+
+ try:
+ finalize_urls(flatten_urls(urls1))
+ except PantherError as exc:
+ assert exc.args[0] == (
+ "URLs can't have same-level path variables that point to an endpoint: "
+ "\n\t- /user/"
+ "\n\t- /user/"
+ )
+ else:
+ assert False
+
+ try:
+ finalize_urls(flatten_urls(urls2))
+ except PantherError as exc:
+ assert exc.args[0] == (
+ "URLs can't have same-level path variables that point to a dict: "
+ "\n\t- /user/"
+ "\n\t- /user/"
+ "\n\t- /user/"
+ )
+ else:
+ assert False
+
# Find Endpoint
def test_find_endpoint_root_url(self):
def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'': temp_func,
}
_func, _ = find_endpoint('')
@@ -534,7 +597,7 @@ def admin_v2_users_detail_not_registered(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': {
'profile': {
@@ -598,7 +661,7 @@ def admin_v2_users_detail_not_registered(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': {
'profile': {
@@ -661,7 +724,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'list': temp_func,
},
@@ -687,7 +750,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'list': temp_func,
},
@@ -713,7 +776,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': temp_func,
},
@@ -739,7 +802,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': temp_func,
},
@@ -765,7 +828,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user/name': temp_func,
}
func, path = find_endpoint('user/name/troublemaker')
@@ -778,7 +841,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user/name': temp_func,
}
func, path = find_endpoint('user/')
@@ -795,7 +858,7 @@ def temp_3(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'': temp_1,
'': {
'': temp_2,
@@ -819,7 +882,7 @@ def temp_3(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'': temp_1,
'': {
'': temp_2,
@@ -843,7 +906,7 @@ def temp_3(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'': temp_1,
'hello': {
'': temp_2,
@@ -868,7 +931,7 @@ def temp_3(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'': temp_1,
'hello': {
'': temp_2,
@@ -888,7 +951,7 @@ def test_find_endpoint_with_params(self):
def user_id_profile_id(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': {
'profile': user_id_profile_id,
@@ -905,7 +968,7 @@ def temp_func(): pass
from panther.configs import config
- config['urls'] = {
+ config.URLS = {
'user': {
'': {
'profile': {
@@ -920,7 +983,9 @@ def temp_func(): pass
request_path = f'user/{_user_id}/profile/{_id}'
_, found_path = find_endpoint(request_path)
- path_variables = collect_path_variables(request_path=request_path, found_path=found_path)
+ request = BaseRequest(scope={'path': request_path}, receive=lambda x: x, send=lambda x: x)
+ request.collect_path_variables(found_path=found_path)
+ path_variables = request.path_variables
assert isinstance(path_variables, dict)
diff --git a/tests/test_run.py b/tests/test_run.py
index c05a9eb..52d742e 100644
--- a/tests/test_run.py
+++ b/tests/test_run.py
@@ -30,30 +30,28 @@ def test_load_configs(self):
Panther(__name__)
assert isinstance(config, Config)
- assert config['base_dir'] == base_dir
- assert config['monitoring'] is True
- assert config['log_queries'] is True
- assert config['default_cache_exp'] == timedelta(seconds=10)
- assert config['throttling'].rate == 10
- assert config['throttling'].duration == timedelta(seconds=10)
- assert config['secret_key'] == secret_key.encode()
+ assert config.BASE_DIR == base_dir
+ assert config.MONITORING is True
+ assert config.LOG_QUERIES is True
+ assert config.DEFAULT_CACHE_EXP == timedelta(seconds=10)
+ assert config.THROTTLING.rate == 10
+ assert config.THROTTLING.duration == timedelta(seconds=10)
+ assert config.SECRET_KEY == secret_key.encode()
- assert len(config['http_middlewares']) == 0
- assert len(config['reversed_http_middlewares']) == 0
- assert len(config['ws_middlewares']) == 0
- assert len(config['reversed_ws_middlewares']) == 0
+ assert len(config.HTTP_MIDDLEWARES) == 0
+ assert len(config.WS_MIDDLEWARES) == 0
- assert config['user_model'].__name__ == tests.sample_project.app.models.User.__name__
- assert config['user_model'].__module__.endswith('app.models')
- assert config['jwt_config'].algorithm == 'HS256'
- assert config['jwt_config'].life_time == timedelta(days=2).total_seconds()
- assert config['jwt_config'].key == secret_key
+ assert config.USER_MODEL.__name__ == tests.sample_project.app.models.User.__name__
+ assert config.USER_MODEL.__module__.endswith('app.models')
+ assert config.JWT_CONFIG.algorithm == 'HS256'
+ assert config.JWT_CONFIG.life_time == timedelta(days=2).total_seconds()
+ assert config.JWT_CONFIG.key == secret_key
- assert '' in config['urls']
- config['urls'].pop('')
+ assert '' in config.URLS
+ config.URLS.pop('')
- assert 'second' in config['urls']
- config['urls'].pop('second')
+ assert 'second' in config.URLS
+ config.URLS.pop('second')
urls = {
'_panel': {
@@ -65,5 +63,5 @@ def test_load_configs(self):
'health': healthcheck_api
},
}
- assert config['urls'] == urls
- assert config['query_engine'].__name__ == 'BasePantherDBQuery'
+ assert config.URLS == urls
+ assert config.QUERY_ENGINE.__name__ == 'BasePantherDBQuery'
diff --git a/tests/test_serializer.py b/tests/test_serializer.py
index 3a776b6..3232822 100644
--- a/tests/test_serializer.py
+++ b/tests/test_serializer.py
@@ -233,7 +233,7 @@ class Config:
fields = ['ok', 'no']
except Exception as e:
assert isinstance(e, AttributeError)
- assert e.args[0] == '`Serializer3.Config.fields.ok` is not valid.'
+ assert e.args[0] == '`Serializer3.Config.fields.ok` is not in `Book.model_fields`'
else:
assert False
@@ -262,6 +262,48 @@ class Config:
else:
assert False
+ async def test_define_class_with_invalid_exclude_1(self):
+ try:
+ class Serializer6(ModelSerializer):
+ class Config:
+ model = Book
+ fields = ['name', 'author', 'pages_count']
+ exclude = ['not_found']
+
+ except Exception as e:
+ assert isinstance(e, AttributeError)
+ assert e.args[0] == '`Serializer6.Config.exclude.not_found` is not valid.'
+ else:
+ assert False
+
+ async def test_define_class_with_invalid_exclude_2(self):
+ try:
+ class Serializer7(ModelSerializer):
+ class Config:
+ model = Book
+ fields = ['name', 'pages_count']
+ exclude = ['author']
+
+ except Exception as e:
+ assert isinstance(e, AttributeError)
+ assert e.args[0] == '`Serializer7.Config.exclude.author` is not defined in `Config.fields`.'
+ else:
+ assert False
+
+ async def test_with_star_fields_with_exclude3(self):
+ try:
+ class Serializer8(ModelSerializer):
+ class Config:
+ model = Book
+ fields = ['*']
+ exclude = ['author']
+
+ except Exception as e:
+ assert isinstance(e, AttributeError)
+ assert e.args[0] == "`Serializer8.Config.fields.*` is not valid. Did you mean `fields = '*'`"
+ else:
+ assert False
+
# # # Serializer Usage
async def test_with_simple_model_config(self):
class Serializer(ModelSerializer):
@@ -319,3 +361,39 @@ class Config:
serialized = Serializer2(name='book', author='AliRn', pages_count='12')
assert serialized.__doc__ is None
+
+ async def test_with_exclude(self):
+ class Serializer(ModelSerializer):
+ class Config:
+ model = Book
+ fields = ['name', 'author', 'pages_count']
+ exclude = ['author']
+
+ serialized = Serializer(name='book', author='AliRn', pages_count='12')
+ assert set(serialized.model_dump().keys()) == {'name', 'pages_count'}
+ assert serialized.name == 'book'
+ assert serialized.pages_count == 12
+
+ async def test_with_star_fields(self):
+ class Serializer(ModelSerializer):
+ class Config:
+ model = Book
+ fields = '*'
+
+ serialized = Serializer(name='book', author='AliRn', pages_count='12')
+ assert set(serialized.model_dump().keys()) == {'id', 'name', 'author', 'pages_count'}
+ assert serialized.name == 'book'
+ assert serialized.author == 'AliRn'
+ assert serialized.pages_count == 12
+
+ async def test_with_star_fields_with_exclude(self):
+ class Serializer(ModelSerializer):
+ class Config:
+ model = Book
+ fields = '*'
+ exclude = ['author']
+
+ serialized = Serializer(name='book', author='AliRn', pages_count='12')
+ assert set(serialized.model_dump().keys()) == {'id', 'name', 'pages_count'}
+ assert serialized.name == 'book'
+ assert serialized.pages_count == 12
diff --git a/tests/test_simple_responses.py b/tests/test_simple_responses.py
deleted file mode 100644
index 9499d2d..0000000
--- a/tests/test_simple_responses.py
+++ /dev/null
@@ -1,116 +0,0 @@
-from unittest import IsolatedAsyncioTestCase
-
-from panther import Panther
-from panther.app import API
-from panther.response import Response
-from panther.test import APIClient
-
-
-@API()
-async def return_nothing():
- pass
-
-
-@API()
-async def return_none():
- return None
-
-
-@API()
-async def return_dict():
- return {'detail': 'ok'}
-
-
-@API()
-async def return_list():
- return [1, 2, 3]
-
-
-@API()
-async def return_tuple():
- return 1, 2, 3, 4
-
-
-@API()
-async def return_response_none():
- return Response()
-
-
-@API()
-async def return_response_dict():
- return Response(data={'detail': 'ok'})
-
-
-@API()
-async def return_response_list():
- return Response(data=['car', 'home', 'phone'])
-
-
-@API()
-async def return_response_tuple():
- return Response(data=('car', 'home', 'phone', 'book'))
-
-
-urls = {
- 'nothing': return_nothing,
- 'none': return_none,
- 'dict': return_dict,
- 'list': return_list,
- 'tuple': return_tuple,
- 'response-none': return_response_none,
- 'response-dict': return_response_dict,
- 'response-list': return_response_list,
- 'response-tuple': return_response_tuple,
-}
-
-
-class TestSimpleResponses(IsolatedAsyncioTestCase):
- @classmethod
- def setUpClass(cls) -> None:
- app = Panther(__name__, configs=__name__, urls=urls)
- cls.client = APIClient(app=app)
-
- async def test_nothing(self):
- res = await self.client.get('nothing/')
- assert res.status_code == 200
- assert res.data is None
-
- async def test_none(self):
- res = await self.client.get('none/')
- assert res.status_code == 200
- assert res.data is None
-
- async def test_dict(self):
- res = await self.client.get('dict/')
- assert res.status_code == 200
- assert res.data == {'detail': 'ok'}
-
- async def test_list(self):
- res = await self.client.get('list/')
- assert res.status_code == 200
- assert res.data == [1, 2, 3]
-
- async def test_tuple(self):
- res = await self.client.get('tuple/')
- assert res.status_code == 200
- assert res.data == [1, 2, 3, 4]
-
- async def test_response_none(self):
- res = await self.client.get('response-none/')
- assert res.status_code == 200
- assert res.data is None
-
- async def test_response_dict(self):
- res = await self.client.get('response-dict/')
- assert res.status_code == 200
- assert res.data == {'detail': 'ok'}
-
- async def test_response_list(self):
- res = await self.client.get('response-list/')
- assert res.status_code == 200
- assert res.data == ['car', 'home', 'phone']
-
- async def test_response_tuple(self):
- res = await self.client.get('response-tuple/')
- assert res.status_code == 200
- assert res.data == ['car', 'home', 'phone', 'book']
diff --git a/tests/test_throttling.py b/tests/test_throttling.py
new file mode 100644
index 0000000..10034c5
--- /dev/null
+++ b/tests/test_throttling.py
@@ -0,0 +1,74 @@
+import asyncio
+from datetime import timedelta
+from unittest import IsolatedAsyncioTestCase
+
+from panther import Panther
+from panther.app import API
+from panther.test import APIClient
+from panther.throttling import Throttling
+
+
+@API()
+async def without_throttling_api():
+ return 'ok'
+
+
+@API(throttling=Throttling(rate=3, duration=timedelta(seconds=3)))
+async def with_throttling_api():
+ return 'ok'
+
+
+urls = {
+ 'without-throttling': without_throttling_api,
+ 'with-throttling': with_throttling_api,
+}
+
+
+class TestThrottling(IsolatedAsyncioTestCase):
+ @classmethod
+ def setUpClass(cls) -> None:
+ app = Panther(__name__, configs=__name__, urls=urls)
+ cls.client = APIClient(app=app)
+
+ async def test_without_throttling(self):
+ res1 = await self.client.get('without-throttling')
+ assert res1.status_code == 200
+
+ res2 = await self.client.get('without-throttling')
+ assert res2.status_code == 200
+
+ res3 = await self.client.get('without-throttling')
+ assert res3.status_code == 200
+
+ async def test_with_throttling(self):
+ res1 = await self.client.get('with-throttling')
+ assert res1.status_code == 200
+
+ res2 = await self.client.get('with-throttling')
+ assert res2.status_code == 200
+
+ res3 = await self.client.get('with-throttling')
+ assert res3.status_code == 200
+
+ res4 = await self.client.get('with-throttling')
+ assert res4.status_code == 429
+
+ res5 = await self.client.get('with-throttling')
+ assert res5.status_code == 429
+
+ await asyncio.sleep(3) # Sleep and try again
+
+ res6 = await self.client.get('with-throttling')
+ assert res6.status_code == 200
+
+ res7 = await self.client.get('with-throttling')
+ assert res7.status_code == 200
+
+ res8 = await self.client.get('with-throttling')
+ assert res8.status_code == 200
+
+ res9 = await self.client.get('with-throttling')
+ assert res9.status_code == 429
+
+ res10 = await self.client.get('with-throttling')
+ assert res10.status_code == 429
diff --git a/tests/test_utils.py b/tests/test_utils.py
index 36d5e6b..083e467 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,11 +3,10 @@
from pathlib import Path
from unittest import TestCase
-import panther.utils
from panther import Panther
from panther.configs import config
from panther.middlewares import BaseMiddleware
-from panther.utils import generate_hash_value_from_string, load_env, round_datetime, encrypt_password
+from panther.utils import generate_secret_key, generate_hash_value_from_string, load_env, round_datetime
class TestLoadEnvFile(TestCase):
@@ -25,11 +24,10 @@ def _create_env_file(self, file_data):
file.write(file_data)
def test_load_env_invalid_file(self):
- with self.assertLogs(level='ERROR') as captured:
+ try:
load_env('fake.file')
-
- assert len(captured.records) == 1
- assert captured.records[0].getMessage() == '"fake.file" is not valid file for load_env()'
+ except ValueError as e:
+ assert e.args[0] == '"fake.file" is not a file.'
def test_load_env_double_quote(self):
self._create_env_file(f"""
@@ -179,11 +177,6 @@ def test_generate_hash_value_from_string(self):
assert hashed_1 == hashed_2
assert text != hashed_1
- def test_encrypt_password(self):
- password = 'Password'
- encrypted = encrypt_password(password=password)
- assert password != encrypted
-
class TestLoadConfigs(TestCase):
def setUp(self):
@@ -202,7 +195,7 @@ def test_urls_not_found(self):
assert False
assert len(captured.records) == 1
- assert captured.records[0].getMessage() == "Invalid 'URLs': is required."
+ assert captured.records[0].getMessage() == "Invalid 'URLs': required."
def test_urls_cant_be_dict(self):
global URLs
@@ -373,7 +366,7 @@ def test_jwt_auth_without_secret_key(self):
def test_jwt_auth_with_secret_key(self):
global AUTHENTICATION, SECRET_KEY
AUTHENTICATION = 'panther.authentications.JWTAuthentication'
- SECRET_KEY = panther.utils.generate_secret_key()
+ SECRET_KEY = generate_secret_key()
with self.assertNoLogs(level='ERROR'):
try:
diff --git a/tests/test_websockets.py b/tests/test_websockets.py
index 35d1464..da8d310 100644
--- a/tests/test_websockets.py
+++ b/tests/test_websockets.py
@@ -131,7 +131,7 @@ async def connect(self):
class TestWebsocket(TestCase):
@classmethod
def setUpClass(cls) -> None:
- config['has_ws'] = True
+ config.HAS_WS = True
cls.app = Panther(__name__, configs=__name__, urls=urls)
def test_without_accept(self):
@@ -266,7 +266,7 @@ def test_with_auth_failed(self):
assert responses[0]['type'] == 'websocket.close'
assert responses[0]['code'] == 1000
- assert responses[0]['reason'] == 'Authentication Error'
+ assert responses[0]['reason'] == ''
def test_with_auth_not_defined(self):
ws = WebsocketClient(app=self.app)
@@ -274,11 +274,11 @@ def test_with_auth_not_defined(self):
responses = ws.connect('with-auth?authorization=Bearer token')
assert len(captured.records) == 1
- assert captured.records[0].getMessage() == '"WS_AUTHENTICATION" has not been set in configs'
+ assert captured.records[0].getMessage() == '`WS_AUTHENTICATION` has not been set in configs'
assert responses[0]['type'] == 'websocket.close'
assert responses[0]['code'] == 1000
- assert responses[0]['reason'] == 'Authentication Error'
+ assert responses[0]['reason'] == ''
def test_with_auth_success(self):
global WS_AUTHENTICATION, SECRET_KEY, DATABASE
@@ -304,7 +304,7 @@ def test_with_auth_success(self):
assert responses[0]['type'] == 'websocket.close'
assert responses[0]['code'] == 1000
- assert responses[0]['reason'] == 'Authentication Error'
+ assert responses[0]['reason'] == ''
def test_with_permission(self):
ws = WebsocketClient(app=self.app)
@@ -327,4 +327,4 @@ def test_without_permission(self):
assert responses[0]['type'] == 'websocket.close'
assert responses[0]['code'] == 1000
- assert responses[0]['reason'] == 'Permission Denied'
+ assert responses[0]['reason'] == ''