diff --git a/.github/workflows/pull_request.yaml b/.github/workflows/pull_request.yaml index c317a282..739f2688 100644 --- a/.github/workflows/pull_request.yaml +++ b/.github/workflows/pull_request.yaml @@ -23,8 +23,24 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt + - name: Install pymongo + run: | + pip uninstall -y bson + pip install pymongo + + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.10.0 + with: + mongodb-version: '4.4' + + - name: Remove Build Directory + run: rm -rf build/ + - name: Run Tests - run: pytest tests + run: coverage run -m pytest tests + + - name: Show Coverage + run: coverage report tests-windows: name: Test Windows @@ -46,7 +62,7 @@ jobs: pip install -r requirements.txt - name: Run Tests - run: pytest tests + run: pytest -m "not mongodb" tests tests-macos: name: Test MacOS @@ -68,4 +84,4 @@ jobs: pip install -r requirements.txt - name: Run Tests - run: pytest tests + run: pytest -m "not mongodb" tests diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml index 967d43ef..752887fb 100644 --- a/.github/workflows/release.yaml +++ b/.github/workflows/release.yaml @@ -30,7 +30,7 @@ jobs: pip install -r requirements.txt - name: Run Tests - run: pytest tests + run: pytest -m "not mongodb" tests tests-windows: name: Test Windows Python${{ matrix.python-version }} @@ -55,7 +55,7 @@ jobs: pip install -r requirements.txt - name: Run Tests - run: pytest tests + run: pytest -m "not mongodb" tests tests-macos: name: Test MacOS Python${{ matrix.python-version }} @@ -80,12 +80,46 @@ jobs: pip install -r requirements.txt - name: Run Tests - run: pytest tests + run: pytest -m "not mongodb" tests + + tests-mongodb: + name: Test MongoDB ${{ matrix.python-version }} + runs-on: ubuntu-latest + strategy: + matrix: + mongodb-version: ['4.2', '4.4', '5.0', '6.0'] + + steps: + - name: Checkout source + uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.11' + + - name: Install Requirements + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + - name: Install pymongo + run: | + pip uninstall -y bson + pip install pymongo + + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.10.0 + with: + mongodb-version: ${{ matrix.mongodb-version }} + + - name: Run Tests + run: pytest -m "mongodb" tests coverage: name: Coverage runs-on: ubuntu-latest - needs: [tests-linux, tests-macos, tests-windows] + needs: [tests-linux, tests-mongodb, tests-macos, tests-windows] steps: - name: Checkout source @@ -101,8 +135,24 @@ jobs: python -m pip install --upgrade pip pip install -r requirements.txt + - name: Install pymongo + run: | + pip uninstall -y bson + pip install pymongo + + - name: Start MongoDB + uses: supercharge/mongodb-github-action@1.10.0 + with: + mongodb-version: '4.4' + + - name: Remove Build Directory + run: rm -rf build/ + - name: Run Coverage Tests - run: coverage run -m unittest tests/test_*.py + run: coverage run -m pytest tests + + - name: Show Coverage + run: coverage report - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/panther/_load_configs.py b/panther/_load_configs.py index b0e3bf64..86aea90d 100644 --- a/panther/_load_configs.py +++ b/panther/_load_configs.py @@ -132,6 +132,8 @@ def load_jwt_config(configs: dict, /) -> JWTConfig | None: if getattr(config['authentication'], '__name__', None) == 'JWTAuthentication': user_config = configs.get('JWTConfig', {}) if 'key' not in user_config: + if config['secret_key'] is None: + raise PantherException('"SECRET_KEY" is required when using "JWTAuthentication"') user_config['key'] = config['secret_key'].decode() return JWTConfig(**user_config) diff --git a/panther/_utils.py b/panther/_utils.py index 6538671c..bb3f5dee 100644 --- a/panther/_utils.py +++ b/panther/_utils.py @@ -121,28 +121,6 @@ def is_function_async(func: Callable) -> bool: return bool(func.__code__.co_flags & (1 << 7)) -def run_sync_async_function(func: Callable): - # Async - if is_function_async(func): - # Get Event Loop - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - # Add Coroutine To Event Loop - if loop and loop.is_running(): - loop.create_task(func()) - - # Start New Event Loop - else: - asyncio.run(func()) - - # Sync - else: - func() - - def clean_traceback_message(exception: Exception) -> str: """We are ignoring packages traceback message""" tb = TracebackException(type(exception), exception, exception.__traceback__) diff --git a/panther/app.py b/panther/app.py index 9745cc32..73c7ed33 100644 --- a/panther/app.py +++ b/panther/app.py @@ -119,7 +119,8 @@ def handle_authentications(self) -> None: auth_class = config['authentication'] if self.auth: if not auth_class: - raise TypeError('"AUTHENTICATION" has not been set in core/configs') + logger.critical('"AUTHENTICATION" has not been set in configs') + raise APIException user = auth_class.authentication(self.request) self.request.set_user(user=user) diff --git a/panther/authentications.py b/panther/authentications.py index 80e8401b..e88627f1 100644 --- a/panther/authentications.py +++ b/panther/authentications.py @@ -46,7 +46,11 @@ def get_authorization_header(cls, request: Request) -> bytes: raise cls.exception(msg) from None if isinstance(auth, str): - auth = auth.encode(JWTAuthentication.HTTP_HEADER_ENCODING) + try: + auth = auth.encode(JWTAuthentication.HTTP_HEADER_ENCODING) + except UnicodeEncodeError as e: + raise cls.exception(e) from None + return auth @classmethod @@ -126,6 +130,6 @@ def login(cls, user_id: IDType) -> str: return cls.encode_jwt(user_id=user_id) @staticmethod - def exception(message: str | JWTError, /) -> type[AuthenticationException]: + def exception(message: str | JWTError | UnicodeEncodeError, /) -> type[AuthenticationException]: logger.error(f'JWT Authentication Error: "{message}"') return AuthenticationException diff --git a/panther/configs.py b/panther/configs.py index 19372764..bf61f008 100644 --- a/panther/configs.py +++ b/panther/configs.py @@ -63,8 +63,10 @@ class Config(TypedDict): 'default_cache_exp': None, 'throttling': None, 'secret_key': None, - 'middlewares': [], - 'reversed_middlewares': [], + 'http_middlewares': [], + 'ws_middlewares': [], + 'reversed_http_middlewares': [], + 'reversed_ws_middlewares': [], 'user_model': None, 'authentication': None, 'jwt_config': None, diff --git a/panther/db/connection.py b/panther/db/connection.py index 67eb9f5e..7e80609a 100644 --- a/panther/db/connection.py +++ b/panther/db/connection.py @@ -45,8 +45,11 @@ def name(self) -> str: return self._db_name def _create_mongodb_session(self, db_url: str) -> None: - from pymongo import MongoClient - + try: + from pymongo import MongoClient + except ModuleNotFoundError: + msg = "No module named 'pymongo'. Hint: `pip install pymongo`" + raise ValueError(msg) self._client: MongoClient = MongoClient(db_url) self._session: Database = self._client.get_database() diff --git a/panther/db/queries/mongodb_queries.py b/panther/db/queries/mongodb_queries.py index 7df28530..6e8732aa 100644 --- a/panther/db/queries/mongodb_queries.py +++ b/panther/db/queries/mongodb_queries.py @@ -86,4 +86,4 @@ def update_many(cls, _filter: dict, _data: dict | None = None, /, **kwargs) -> i update_fields = {'$set': cls._merge(_data, kwargs)} result = db.session[cls.__name__].update_many(_filter, update_fields) - return result.updated_count + return result.modified_count diff --git a/panther/db/utils.py b/panther/db/utils.py index 34b6f184..eb03aa48 100644 --- a/panther/db/utils.py +++ b/panther/db/utils.py @@ -13,11 +13,6 @@ def log_query(func): def log(*args, **kwargs): - # Check Database Connection - if config['db_engine'] == '': - msg = "You don't have active database connection, Check your middlewares" - raise NotImplementedError(msg) - if config['log_queries'] is False: return func(*args, **kwargs) start = perf_counter() diff --git a/panther/main.py b/panther/main.py index 91c1e7b8..545b45f0 100644 --- a/panther/main.py +++ b/panther/main.py @@ -113,7 +113,7 @@ def _create_ws_connections_instance(self): if config['has_ws']: config['websocket_connections'] = WebsocketConnections() # Websocket Redis Connection - for middleware in config['middlewares']: + for middleware in config['http_middlewares']: if middleware.__class__.__name__ == 'RedisMiddleware': self.ws_redis_connection = middleware.redis_connection_for_ws() break diff --git a/panther/response.py b/panther/response.py index b554ed7e..3fbe2bfd 100644 --- a/panther/response.py +++ b/panther/response.py @@ -84,6 +84,13 @@ def _serialize_with_output_model(cls, data: any, /, output_model: ModelMetaclass msg = 'Type of Response data is not match with `output_model`.\n*hint: You may want to remove `output_model`' raise TypeError(msg) + def __str__(self): + if len(data := str(self.data)) > 30: + data = f'{data:.27}...' + return f'Response(status_code={self.status_code}, data={data})' + + __repr__ = __str__ + class HTMLResponse(Response): content_type = 'text/html; charset=utf-8' diff --git a/pyproject.toml b/pyproject.toml index aa2a8b05..7a41fc38 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,4 +3,9 @@ requires = [ "setuptools>=42", "wheel" ] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + +[tool.pytest.ini_options] +markers = [ + "mongodb: marks mongodb tests" +] diff --git a/requirements.txt b/requirements.txt index b9e42924..478b61b5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,4 @@ python-jose faker coverage pytest -cryptography~=41.0 \ No newline at end of file +cryptography~=41.0 diff --git a/tests/test_authentication.py b/tests/test_authentication.py new file mode 100644 index 00000000..a53be6b9 --- /dev/null +++ b/tests/test_authentication.py @@ -0,0 +1,186 @@ +import asyncio +from pydantic import BaseModel as PydanticBaseModel, Field, field_validator +from pathlib import Path +from unittest import TestCase + +from panther import Panther +from panther.app import API +from panther.configs import config +from panther.db.queries import Query +from panther.db.queries.pantherdb_queries import BasePantherDBQuery +from panther.request import Request +from panther.test import APIClient + + +@API() +async def without_auth(request: Request): + return request.user + + +@API(auth=True) +async def auth_required(request: Request): + return request.user + + +urls = { + 'without': without_auth, + 'auth-required': auth_required, +} + + +class Model(PydanticBaseModel): + id: int | None = Field(None, validation_alias='_id') + + @field_validator('id', mode='before') + def validate_id(cls, value): + return value + + @property + def _id(self): + return self.id + + +class User(Model, BasePantherDBQuery, Query): + username: str + password: str + + +AUTHENTICATION = 'panther.authentications.JWTAuthentication' +SECRET_KEY = 'hvdhRspoTPh1cJVBHcuingQeOKNc1uRhIP2k7suLe2g=' +DB_PATH = 'test.pdb' +MIDDLEWARES = [ + ('panther.middlewares.db.DatabaseMiddleware', {'url': f'pantherdb://{DB_PATH}'}), +] +USER_MODEL = 'tests.test_authentication.User' + + +class TestAuthentication(TestCase): + SHORT_TOKEN = {'Authorization': 'TOKEN'} + NOT_ENOUGH_SEGMENT_TOKEN = {'Authorization': 'Bearer XXX'} + JUST_BEARER_TOKEN = {'Authorization': 'Bearer'} + BAD_UNICODE_TOKEN = {'Authorization': 'Bearer علی'} + BAD_SIGNATURE_TOKEN = {'Authorization': 'Bearer eyJhbGciOiJIUzI1NiJ9.eyJpZCI6MX0.JAWUkAU2mWhxcd6MS8r9pd44yBIfkEBmpr3WLeqIccM'} + TOKEN_WITHOUT_USER_ID = {'Authorization': 'Bearer eyJhbGciOiJIUzI1NiJ9.eyJpZCI6MX0.PpyXW0PgmGSPaaNirm_Ei4Y2fw9nb4TN26RN1u9RHSo'} + TOKEN = {'Authorization': 'Bearer eyJhbGciOiJIUzI1NiJ9.eyJ1c2VyX2lkIjoxfQ.AF3nsj8IQ6t0ncqIx4quoyPfYaZ-pqUOW4z_euUztPM'} + + @classmethod + def setUpClass(cls) -> None: + app = Panther(__name__, configs=__name__, urls=urls) + cls.client = APIClient(app=app) + + def setUp(self) -> None: + for middleware in config['http_middlewares']: + asyncio.run(middleware.before(request=None)) + + def tearDown(self) -> None: + for middleware in config['reversed_http_middlewares']: + asyncio.run(middleware.after(response=None)) + Path(DB_PATH).unlink() + + def test_user_without_auth(self): + res = self.client.get('without') + assert res.status_code == 200 + assert res.data is None + + res = self.client.get('without', headers={'Authorization': 'Token'}) + assert res.status_code == 200 + assert res.data is None + + def test_user_auth_required_without_auth_class(self): + auth_config = config['authentication'] + config['authentication'] = None + with self.assertLogs(level='CRITICAL') as captured: + res = self.client.get('auth-required') + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == '"AUTHENTICATION" has not been set in configs' + assert res.status_code == 500 + assert res.data['detail'] == 'Internal Server Error' + config['authentication'] = auth_config + + def test_user_auth_required_without_token(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required') + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Authorization is required"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_bad_token_1(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.SHORT_TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Authorization keyword is not valid"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_bad_token2(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.JUST_BEARER_TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Authorization should have 2 part"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_bad_token3(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.BAD_UNICODE_TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == ( + 'JWT Authentication Error: "\'latin-1\' codec can\'t encode characters in position 7-9: ' + 'ordinal not in range(256)"' + ) + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_bad_token4(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.NOT_ENOUGH_SEGMENT_TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Not enough segments"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_invalid_token_signature(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.BAD_SIGNATURE_TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Signature verification failed."' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_token_without_user_id(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.TOKEN_WITHOUT_USER_ID) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "Payload does not have user_id"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_token_user_not_found(self): + with self.assertLogs(level='ERROR') as captured: + res = self.client.get('auth-required', headers=self.TOKEN) + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == 'JWT Authentication Error: "User not found"' + assert res.status_code == 401 + assert res.data['detail'] == 'Authentication Error' + + def test_user_auth_required_with_token(self): + User.insert_one(username='Username', password='Password') + + with self.assertNoLogs(level='ERROR'): + res = self.client.get('auth-required', headers=self.TOKEN) + + assert res.status_code == 200 + assert [*res.data.keys()] == ['id', 'username', 'password'] + assert res.data['id'] == 1 + assert res.data['username'] == 'Username' + assert res.data['password'] == 'Password' + diff --git a/tests/test_background_tasks.py b/tests/test_background_tasks.py index 838b7586..b59aaf0a 100644 --- a/tests/test_background_tasks.py +++ b/tests/test_background_tasks.py @@ -155,5 +155,3 @@ def func(_numbers): time.sleep(3) self.assertEqual(len(numbers), 2) - -# TODO: Run tests for every_minutes(), every_hours(), every_days(), every_weeks(), at() diff --git a/tests/test_pantherdb.py b/tests/test_database.py similarity index 75% rename from tests/test_pantherdb.py rename to tests/test_database.py index abfb37ed..297862e4 100644 --- a/tests/test_pantherdb.py +++ b/tests/test_database.py @@ -3,19 +3,23 @@ from pathlib import Path from unittest import TestCase +import bson import faker +import pytest from pydantic import BaseModel as PydanticBaseModel from pydantic import Field, field_validator from panther import Panther from panther.configs import config +from panther.db.connection import db from panther.db.queries import Query from panther.db.queries.pantherdb_queries import BasePantherDBQuery +from panther.exceptions import DBException f = faker.Faker() -class Model(PydanticBaseModel): +class PantherDBModel(PydanticBaseModel): id: int | None = Field(None, validation_alias='_id') @field_validator('id', mode='before') @@ -27,31 +31,35 @@ def _id(self): return self.id -class Book(Model, BasePantherDBQuery, Query): - name: str - author: str - pages_count: int +class MongoDBModel(PydanticBaseModel): + id: str | None = Field(None, validation_alias='_id') + @field_validator('id', mode='before') + def validate_id(cls, value): + if isinstance(value, str): + try: + bson.ObjectId(value) + except bson.objectid.InvalidId as e: + msg = 'Invalid ObjectId' + raise ValueError(msg) from e + elif not isinstance(value, bson.ObjectId): + msg = 'ObjectId required' + raise ValueError(msg) from None + value = str(value) + return value -DB_PATH = 'test.pdb' -MIDDLEWARES = [ - ('panther.middlewares.db.DatabaseMiddleware', {'url': f'pantherdb://{DB_PATH}'}), -] + @property + def _id(self): + return bson.ObjectId(self.id) if self.id else None -class TestPantherDB(TestCase): - @classmethod - def setUpClass(cls) -> None: - Panther(__name__, configs=__name__, urls={}) +class BaseBook: + name: str + author: str + pages_count: int - def setUp(self) -> None: - for middleware in config['middlewares']: - asyncio.run(middleware.before(request=None)) - def tearDown(self) -> None: - for middleware in config['reversed_middlewares']: - asyncio.run(middleware.after(response=None)) - Path(DB_PATH).unlink() +class _BaseDatabaseTestCase: # # # Insert def test_insert_one(self): @@ -61,7 +69,7 @@ def test_insert_one(self): book = Book.insert_one(name=name, author=author, pages_count=pages_count) assert isinstance(book, Book) - assert book.id == 1 + assert book.id assert book.name == name assert book.pages_count == pages_count @@ -81,7 +89,7 @@ def test_find_one_not_found(self): def test_find_one_in_many_when_its_last(self): # Insert Many - insert_count = self._insert_many() + self._insert_many() # Insert One name = f.name() @@ -93,8 +101,8 @@ def test_find_one_in_many_when_its_last(self): book = Book.find_one(name=name, author=author, pages_count=pages_count) assert isinstance(book, Book) - assert book.id == insert_count + 1 - assert book._id == book.id + assert book.id + assert str(book._id) == str(book.id) assert book.name == name assert book.pages_count == pages_count assert created_book == book @@ -116,7 +124,7 @@ def test_find_one_in_many_when_its_middle(self): book = Book.find_one(name=name, author=author, pages_count=pages_count) assert isinstance(book, Book) - assert book.id == insert_count + 1 + assert book.id assert book.name == name assert book.pages_count == pages_count assert created_book == book @@ -135,7 +143,7 @@ def test_first(self): book = Book.first(name=name, author=author, pages_count=pages_count) assert isinstance(book, Book) - assert book.id == insert_count + 1 + assert book.id assert book.name == name assert book.pages_count == pages_count @@ -285,6 +293,21 @@ def test_delete_one(self): # Count Them After Deletion assert Book.count(name=name) == insert_count - 1 + def test_delete_self(self): + # Insert Many + self._insert_many() + + # Insert With Specific Name + name = f.name() + insert_count = self._insert_many_with_specific_params(name=name) + + # Delete One + book = Book.find_one(name=name) + book.delete() + + # Count Them After Deletion + assert Book.count(name=name) == insert_count - 1 + def test_delete_one_not_found(self): # Insert Many insert_count = self._insert_many() @@ -359,6 +382,33 @@ def test_update_one(self): assert Book.count(name=name) == 0 assert Book.count() == insert_count + 1 + def test_update_self(self): + # Insert Many + insert_count = self._insert_many() + + # Insert With Specific Name + name = f.name() + author = f.name() + pages_count = random.randint(0, 10) + Book.insert_one(name=name, author=author, pages_count=pages_count) + + # Update One + book = Book.find_one(name=name) + new_name = 'New Name' + book.update(name=new_name) + + assert book.name == new_name + assert book.author == author + + book = Book.find_one(name=new_name) + assert isinstance(book, Book) + assert book.author == author + assert book.pages_count == pages_count + + # Count Them After Update + assert Book.count(name=name) == 0 + assert Book.count() == insert_count + 1 + def test_update_one_not_found(self): # Insert Many insert_count = self._insert_many() @@ -443,3 +493,64 @@ def _insert_many_with_specific_params( Book.insert_one(name=name, author=author, pages_count=pages_count) return insert_count + + +class TestPantherDB(_BaseDatabaseTestCase, TestCase): + DB_PATH = 'test.pdb' + + @classmethod + def setUpClass(cls) -> None: + global MIDDLEWARES, Book + MIDDLEWARES = [ + ('panther.middlewares.db.DatabaseMiddleware', {'url': f'pantherdb://{cls.DB_PATH}'}), + ] + Book = type('Book', (BaseBook, PantherDBModel, BasePantherDBQuery, Query), {}) + Panther(__name__, configs=__name__, urls={}) + + def setUp(self) -> None: + for middleware in config['http_middlewares']: + asyncio.run(middleware.before(request=None)) + + def tearDown(self) -> None: + Path(self.DB_PATH).unlink() + for middleware in config['reversed_http_middlewares']: + asyncio.run(middleware.after(response=None)) + + +@pytest.mark.mongodb +class TestMongoDB(_BaseDatabaseTestCase, TestCase): + DB_NAME = 'test.pdb' + + @classmethod + def setUpClass(cls) -> None: + global MIDDLEWARES, Book + MIDDLEWARES = [ + ('panther.middlewares.db.DatabaseMiddleware', {'url': f'mongodb://127.0.0.1:27017/{cls.DB_NAME}'}), + ] + Book = type('Book', (BaseBook, MongoDBModel, Query), {}) + Panther(__name__, configs=__name__, urls={}) + + def setUp(self) -> None: + for middleware in config['http_middlewares']: + asyncio.run(middleware.before(request=None)) + + def tearDown(self) -> None: + db.session.drop_collection('Book') + for middleware in config['reversed_http_middlewares']: + asyncio.run(middleware.after(response=None)) + + def test_last(self): + try: + super().test_last() + except DBException as exc: + assert exc.args[0] == 'last() is not supported in MongoDB yet.' + else: + assert False + + def test_last_not_found(self): + try: + super().test_last_not_found() + except DBException as exc: + assert exc.args[0] == 'last() is not supported in MongoDB yet.' + else: + assert False diff --git a/tests/test_request.py b/tests/test_request.py index 8c1caeae..3312d5eb 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -1,7 +1,9 @@ from unittest import TestCase +import orjson as json + from panther import Panther -from panther.app import API +from panther.app import API, GenericAPI from panther.request import Request from panther.response import Response from panther.test import APIClient @@ -27,6 +29,11 @@ async def request_data(request: Request): return request.data +@API() +async def request_path_variables(name: str, age: int, is_alive: bool): + return {'name': name, 'age': age, 'is_alive': is_alive} + + @API() async def request_header(request: Request): return request.headers.__dict__ @@ -43,56 +50,119 @@ async def request_header_by_item(request: Request): # # # Methods +class AllMethods(GenericAPI): + def get(self, *args, **kwargs): + return Response() + + def post(self, *args, **kwargs): + return Response() + + def put(self, *args, **kwargs): + return Response() + + def patch(self, *args, **kwargs): + return Response() + + def delete(self, *args, **kwargs): + return Response() + + @API() -async def request_all(): +async def all_methods(): return Response() +class GetMethod(GenericAPI): + def get(self, *args, **kwargs): + return Response() + + @API(methods=['GET']) -async def request_get(): +async def get_method(): return Response() +class PostMethod(GenericAPI): + def post(self, *args, **kwargs): + return Response() + + @API(methods=['POST']) -async def request_post(): +async def post_method(): return Response() +class PutMethod(GenericAPI): + def put(self, *args, **kwargs): + return Response() + + @API(methods=['PUT']) -async def request_put(): +async def put_method(): return Response() +class PatchMethod(GenericAPI): + def patch(self, *args, **kwargs): + return Response() + + @API(methods=['PATCH']) -async def request_patch(): +async def patch_method(): return Response() +class DeleteMethod(GenericAPI): + def delete(self, *args, **kwargs): + return Response() + + @API(methods=['DELETE']) -async def request_delete(): +async def delete_method(): return Response() +class GetPostPatchMethods(GenericAPI): + def get(self, *args, **kwargs): + return Response() + + def post(self, *args, **kwargs): + return Response() + + def patch(self, *args, **kwargs): + return Response() + + @API(methods=['GET', 'POST', 'PATCH']) -async def request_get_post_patch(): +async def get_post_patch_methods(): return Response() urls = { - 'request-header': request_header, - 'request-header-attr': request_header_by_attr, - 'request-header-item': request_header_by_item, - 'request-path': request_path, - 'request-client': request_client, - 'request-query_params': request_query_params, - 'request-data': request_data, - 'all': request_all, - 'get': request_get, - 'post': request_post, - 'put': request_put, - 'patch': request_patch, - 'delete': request_delete, - 'get-post-patch': request_get_post_patch, + 'path': request_path, + 'client': request_client, + 'query-params': request_query_params, + 'data': request_data, + 'path//variable///': request_path_variables, + + 'header': request_header, + 'header-attr': request_header_by_attr, + 'header-item': request_header_by_item, + + 'all-func': all_methods, + 'all-class': AllMethods, + 'get-func': get_method, + 'get-class': GetMethod, + 'post-func': post_method, + 'post-class': PostMethod, + 'put-func': put_method, + 'put-class': PutMethod, + 'patch-func': patch_method, + 'patch-class': PatchMethod, + 'delete-func': delete_method, + 'delete-class': DeleteMethod, + 'get-post-patch-func': get_post_patch_methods, + 'get-post-patch-class': GetPostPatchMethods, } @@ -103,31 +173,43 @@ def setUpClass(cls) -> None: cls.client = APIClient(app=app) def test_path(self): - res = self.client.get('request-path/') + res = self.client.get('path/') assert res.status_code == 200 - assert res.data == '/request-path/' + assert res.data == '/path/' def test_client(self): - res = self.client.get('request-client/') + res = self.client.get('client/') assert res.status_code == 200 assert res.data == ['127.0.0.1', 8585] def test_query_params(self): res = self.client.get( - 'request-query_params/', + 'query-params/', query_params={'my': 'name', 'is': 'ali', 'how': 'are'}, ) assert res.status_code == 200 assert res.data == {'my': 'name', 'is': 'ali', 'how': 'are'} + def test_data(self): + payload = {'detail': 'ok'} + res = self.client.post('data/', payload=json.dumps(payload)) + assert res.status_code == 200 + assert res.data == payload + + def test_path_variables(self): + res = self.client.post('path/Ali/variable/27/true/') + expected_response = {'name': 'Ali', 'age': 27, 'is_alive': True} + assert res.status_code == 200 + assert res.data == expected_response + # # # Headers def test_headers_none(self): - res = self.client.get('request-header') + res = self.client.get('header') expected_headers = {} assert res.data == expected_headers def test_headers_content_type(self): - res = self.client.post('request-header') + res = self.client.post('header') expected_headers = {'content-type': 'application/json'} assert res.data == expected_headers @@ -141,7 +223,7 @@ def test_headers_full_items(self): 'Connection': 'keep-alive', 'Content-Length': 546, } - res = self.client.post('request-header', headers=headers) + res = self.client.post('header', headers=headers) expected_headers = { 'content-type': 'application/json', 'User-Agent': 'PostmanRuntime/7.36.0', @@ -159,7 +241,7 @@ def test_headers_unknown_items(self): 'Header1': 'PostmanRuntime/7.36.0', 'Header2': '*/*', } - res = self.client.post('request-header', headers=headers) + res = self.client.post('header', headers=headers) expected_headers = { 'content-type': 'application/json', 'Header1': 'PostmanRuntime/7.36.0', @@ -171,125 +253,195 @@ def test_headers_authorization_by_getattr(self): headers = { 'Authorization': 'Token xxx', } - res = self.client.post('request-header-attr', headers=headers) + res = self.client.post('header-attr', headers=headers) assert res.data == 'Token xxx' def test_headers_authorization_by_getitem(self): headers = { 'Authorization': 'Token xxx', } - res = self.client.post('request-header-item', headers=headers) + res = self.client.post('header-item', headers=headers) assert res.data == 'Token xxx' # # # Methods def test_method_all(self): - res = self.client.get('all/') - assert res.status_code == 200 - - res = self.client.post('all/') - assert res.status_code == 200 - - res = self.client.put('all/') - assert res.status_code == 200 - - res = self.client.patch('all/') - assert res.status_code == 200 - - res = self.client.delete('all/') - assert res.status_code == 200 + res_func = self.client.get('all-func/') + res_class = self.client.get('all-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.post('all-func/') + res_class = self.client.post('all-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.put('all-func/') + res_class = self.client.put('all-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.patch('all-func/') + res_class = self.client.patch('all-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.delete('all-func/') + res_class = self.client.delete('all-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 def test_method_get(self): - res = self.client.get('get/') - assert res.status_code == 200 - - res = self.client.post('get/') - assert res.status_code == 405 - - res = self.client.put('get/') - assert res.status_code == 405 - - res = self.client.patch('get/') - assert res.status_code == 405 - - res = self.client.delete('get/') - assert res.status_code == 405 + res_func = self.client.get('get-func/') + res_class = self.client.get('get-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.post('get-func/') + res_class = self.client.post('get-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.put('get-func/') + res_class = self.client.put('get-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.patch('get-func/') + res_class = self.client.patch('get-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.delete('get-func/') + res_class = self.client.delete('get-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 def test_method_post(self): - res = self.client.get('post/') - assert res.status_code == 405 - - res = self.client.post('post/') - assert res.status_code == 200 - - res = self.client.put('post/') - assert res.status_code == 405 - - res = self.client.patch('post/') - assert res.status_code == 405 - - res = self.client.delete('post/') - assert res.status_code == 405 + res_func = self.client.get('post-func/') + res_class = self.client.get('post-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.post('post-func/') + res_class = self.client.post('post-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.put('post-func/') + res_class = self.client.put('post-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.patch('post-func/') + res_class = self.client.patch('post-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.delete('post-func/') + res_class = self.client.delete('post-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 def test_method_put(self): - res = self.client.get('put/') - assert res.status_code == 405 - - res = self.client.post('put/') - assert res.status_code == 405 - - res = self.client.put('put/') - assert res.status_code == 200 - - res = self.client.patch('put/') - assert res.status_code == 405 - - res = self.client.delete('put/') - assert res.status_code == 405 + res_func = self.client.get('put-func/') + res_class = self.client.get('put-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.post('put-func/') + res_class = self.client.post('put-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.put('put-func/') + res_class = self.client.put('put-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.patch('put-func/') + res_class = self.client.patch('put-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.delete('put-func/') + res_class = self.client.delete('put-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 def test_method_patch(self): - res = self.client.get('patch/') - assert res.status_code == 405 - - res = self.client.post('patch/') - assert res.status_code == 405 - - res = self.client.put('patch/') - assert res.status_code == 405 - - res = self.client.patch('patch/') - assert res.status_code == 200 - - res = self.client.delete('patch/') - assert res.status_code == 405 + res_func = self.client.get('patch-func/') + res_class = self.client.get('patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.post('patch-func/') + res_class = self.client.post('patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.put('patch-func/') + res_class = self.client.put('patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.patch('patch-func/') + res_class = self.client.patch('patch-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.delete('patch-func/') + res_class = self.client.delete('patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 def test_method_delete(self): - res = self.client.get('delete/') - assert res.status_code == 405 - - res = self.client.post('delete/') - assert res.status_code == 405 - - res = self.client.put('delete/') - assert res.status_code == 405 - - res = self.client.patch('delete/') - assert res.status_code == 405 - - res = self.client.delete('delete/') - assert res.status_code == 200 + res_func = self.client.get('delete-func/') + res_class = self.client.get('delete-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.post('delete-func/') + res_class = self.client.post('delete-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.put('delete-func/') + res_class = self.client.put('delete-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.patch('delete-func/') + res_class = self.client.patch('delete-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.delete('delete-func/') + res_class = self.client.delete('delete-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 def test_method_get_post_patch(self): - res = self.client.get('get-post-patch/') - assert res.status_code == 200 - - res = self.client.post('get-post-patch/') - assert res.status_code == 200 - - res = self.client.put('get-post-patch/') - assert res.status_code == 405 - - res = self.client.patch('get-post-patch/') - assert res.status_code == 200 - - res = self.client.delete('get-post-patch/') - assert res.status_code == 405 + res_func = self.client.get('get-post-patch-func/') + res_class = self.client.get('get-post-patch-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.post('get-post-patch-func/') + res_class = self.client.post('get-post-patch-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.put('get-post-patch-func/') + res_class = self.client.put('get-post-patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 + + res_func = self.client.patch('get-post-patch-func/') + res_class = self.client.patch('get-post-patch-class/') + assert res_func.status_code == 200 + assert res_class.status_code == 200 + + res_func = self.client.delete('get-post-patch-func/') + res_class = self.client.delete('get-post-patch-class/') + assert res_func.status_code == 405 + assert res_class.status_code == 405 diff --git a/tests/test_utils.py b/tests/test_utils.py index 440d3f22..db94400c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ from pathlib import Path from unittest import TestCase +import panther.utils from panther import Panther from panther.middlewares import BaseMiddleware from panther.utils import generate_hash_value_from_string, load_env, round_datetime, encrypt_password @@ -185,9 +186,8 @@ def test_encrypt_password(self): class TestLoadConfigs(TestCase): def test_urls_not_found(self): - global URLs, MIDDLEWARES + global URLs URLs = None - MIDDLEWARES = [] with self.assertLogs(level='ERROR') as captured: try: @@ -201,9 +201,8 @@ def test_urls_not_found(self): assert captured.records[0].getMessage() == "Invalid 'URLs': is required." def test_urls_cant_be_dict(self): - global URLs, MIDDLEWARES + global URLs URLs = {} - MIDDLEWARES = None with self.assertLogs(level='ERROR') as captured: try: @@ -221,9 +220,8 @@ def test_urls_cant_be_dict(self): assert captured.records[0].getMessage() == msg def test_urls_not_string(self): - global URLs, MIDDLEWARES + global URLs URLs = True - MIDDLEWARES = None with self.assertLogs(level='ERROR') as captured: try: @@ -237,9 +235,8 @@ def test_urls_not_string(self): assert captured.records[0].getMessage() == "Invalid 'URLs': should be dotted string." def test_urls_invalid_target(self): - global URLs, MIDDLEWARES + global URLs URLs = 'tests.test_utils.TestLoadConfigs' - MIDDLEWARES = None with self.assertLogs(level='ERROR') as captured: try: @@ -253,9 +250,8 @@ def test_urls_invalid_target(self): assert captured.records[0].getMessage() == "Invalid 'URLs': should point to a dict." def test_urls_invalid_module_path(self): - global URLs, MIDDLEWARES + global URLs URLs = 'fake.module' - MIDDLEWARES = None with self.assertLogs(level='ERROR') as captured: try: @@ -273,6 +269,7 @@ def test_middlewares_invalid_path(self): MIDDLEWARES = [ ('fake.module', {}) ] + with self.assertLogs(level='ERROR') as captured: try: Panther(name=__name__, configs=__name__, urls={}) @@ -280,6 +277,8 @@ def test_middlewares_invalid_path(self): assert True else: assert False + finally: + MIDDLEWARES = [] assert len(captured.records) == 1 assert captured.records[0].getMessage() == "Invalid 'MIDDLEWARES': fake.module is not a valid middleware path" @@ -287,6 +286,7 @@ def test_middlewares_invalid_path(self): def test_middlewares_invalid_structure(self): global MIDDLEWARES MIDDLEWARES = ['fake.module'] + with self.assertLogs(level='ERROR') as captured: try: Panther(name=__name__, configs=__name__, urls={}) @@ -294,6 +294,8 @@ def test_middlewares_invalid_structure(self): assert True else: assert False + finally: + MIDDLEWARES = [] assert len(captured.records) == 1 assert captured.records[0].getMessage() == "Invalid 'MIDDLEWARES': fake.module should have 2 part: (path, kwargs)" @@ -303,6 +305,7 @@ def test_middlewares_too_many_args(self): MIDDLEWARES = [ ('fake.module', 1, 2) ] + with self.assertLogs(level='ERROR') as captured: try: Panther(name=__name__, configs=__name__, urls={}) @@ -310,6 +313,8 @@ def test_middlewares_too_many_args(self): assert True else: assert False + finally: + MIDDLEWARES = [] assert len(captured.records) == 1 assert captured.records[0].getMessage() == "Invalid 'MIDDLEWARES': ('fake.module', 1, 2) too many arguments" @@ -319,14 +324,18 @@ def test_middlewares_without_args(self): MIDDLEWARES = [ ('tests.test_utils.CorrectTestMiddleware', ) ] + with self.assertNoLogs(level='ERROR'): Panther(name=__name__, configs=__name__, urls={}) + MIDDLEWARES = [] + def test_middlewares_invalid_middleware_parent(self): global MIDDLEWARES MIDDLEWARES = [ ('tests.test_utils.TestMiddleware', ) ] + with self.assertLogs(level='ERROR') as captured: try: Panther(name=__name__, configs=__name__, urls={}) @@ -334,10 +343,45 @@ def test_middlewares_invalid_middleware_parent(self): assert True else: assert False + finally: + MIDDLEWARES = [] assert len(captured.records) == 1 assert captured.records[0].getMessage() == "Invalid 'MIDDLEWARES': is not a sub class of BaseMiddleware" + def test_jwt_auth_without_secret_key(self): + global AUTHENTICATION + AUTHENTICATION = 'panther.authentications.JWTAuthentication' + + with self.assertLogs(level='ERROR') as captured: + try: + Panther(name=__name__, configs=__name__, urls={}) + except SystemExit: + assert True + else: + assert False + finally: + AUTHENTICATION = None + + assert len(captured.records) == 1 + assert captured.records[0].getMessage() == '"SECRET_KEY" is required when using "JWTAuthentication"' + + def test_jwt_auth_with_secret_key(self): + global AUTHENTICATION, SECRET_KEY + AUTHENTICATION = 'panther.authentications.JWTAuthentication' + SECRET_KEY = panther.utils.generate_secret_key() + + with self.assertNoLogs(level='ERROR'): + try: + Panther(name=__name__, configs=__name__, urls={}) + except SystemExit: + assert False + else: + assert True + finally: + AUTHENTICATION = None + SECRET_KEY = None + class CorrectTestMiddleware(BaseMiddleware): pass