diff --git a/test_src/test_proj/migrations/0043_manufacturer_store_product_option_manufacturer_store_and_more.py b/test_src/test_proj/migrations/0043_manufacturer_store_product_option_manufacturer_store_and_more.py new file mode 100644 index 00000000..b8907db8 --- /dev/null +++ b/test_src/test_proj/migrations/0043_manufacturer_store_product_option_manufacturer_store_and_more.py @@ -0,0 +1,74 @@ +# Generated by Django 4.2.8 on 2023-12-20 06:45 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('test_proj', '0042_testexternalcustommodel'), + ] + + operations = [ + migrations.CreateModel( + name='Manufacturer', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ], + options={ + 'default_related_name': 'manufacturers', + }, + ), + migrations.CreateModel( + name='Store', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ], + options={ + 'default_related_name': 'stores', + }, + ), + migrations.CreateModel( + name='Product', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ('price', models.DecimalField(decimal_places=2, max_digits=10)), + ('manufacturer', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_proj.manufacturer')), + ('store', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_proj.store')), + ], + options={ + 'default_related_name': 'products', + }, + ), + migrations.CreateModel( + name='Option', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ('product', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_proj.product')), + ], + options={ + 'default_related_name': 'options', + }, + ), + migrations.AddField( + model_name='manufacturer', + name='store', + field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_proj.store'), + ), + migrations.CreateModel( + name='Attribute', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=255)), + ('product', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='test_proj.product')), + ], + options={ + 'default_related_name': 'attributes', + }, + ), + ] diff --git a/test_src/test_proj/models/__init__.py b/test_src/test_proj/models/__init__.py index 96fba894..54fa6086 100644 --- a/test_src/test_proj/models/__init__.py +++ b/test_src/test_proj/models/__init__.py @@ -14,3 +14,4 @@ from .fields_testing import Post, ExtraPost, Author, ModelWithChangedFk, ModelWithCrontabField, ModelWithUuidFK, ModelWithUuidPk from .cacheable import CachableModel, CachableProxyModel from .deep import Group, ModelWithNestedModels, GroupWithFK, AnotherDeepNested, ProtectedBySignal +from .nested_models import Option, Attribute, Store, Product, Manufacturer diff --git a/test_src/test_proj/models/deep.py b/test_src/test_proj/models/deep.py index 9cf39c2f..176e51b2 100644 --- a/test_src/test_proj/models/deep.py +++ b/test_src/test_proj/models/deep.py @@ -66,5 +66,5 @@ class Meta: 'protected': { 'allow_append': True, 'model': ProtectedBySignal, - } + }, } diff --git a/test_src/test_proj/models/nested_models.py b/test_src/test_proj/models/nested_models.py new file mode 100644 index 00000000..c49b1550 --- /dev/null +++ b/test_src/test_proj/models/nested_models.py @@ -0,0 +1,84 @@ +from django.db import models +from django.dispatch import receiver +from django.db.models.signals import pre_delete +from django.core.validators import ValidationError +from rest_framework.permissions import BasePermission +from rest_framework.request import Request +from rest_framework.views import APIView + +from vstutils.models import BaseModel + + +class DisallowStaffPermission(BasePermission): + def has_permission(self, request, view): + if not request.user.is_superuser and request.user.is_staff: + return False + return super().has_permission(request, view) + + +class Option(BaseModel): + name = models.CharField(max_length=255) + product = models.ForeignKey('Product', on_delete=models.CASCADE) + + class Meta: + default_related_name = 'options' + + +class Attribute(BaseModel): + name = models.CharField(max_length=255) + product = models.ForeignKey('Product', on_delete=models.CASCADE) + + class Meta: + default_related_name = 'attributes' + _permission_classes = [DisallowStaffPermission] + + +class Product(BaseModel): + name = models.CharField(max_length=255) + price = models.DecimalField(max_digits=10, decimal_places=2) + store = models.ForeignKey('Store', on_delete=models.CASCADE) + manufacturer = models.ForeignKey('Manufacturer', on_delete=models.CASCADE) + + class Meta: + default_related_name = 'products' + _nested = { + 'options': { + 'allow_append': True, + 'model': Option, + }, + 'attributes': { + 'allow_append': True, + 'model': Attribute, + } + } + + +class Manufacturer(BaseModel): + name = models.CharField(max_length=255) + store = models.ForeignKey('Store', on_delete=models.CASCADE) + + class Meta: + default_related_name = 'manufacturers' + _nested = { + 'products': { + 'allow_append': False, + 'model': Product, + } + } + + +class Store(BaseModel): + name = models.CharField(max_length=255) + + class Meta: + default_related_name = 'stores' + _nested = { + 'products': { + 'allow_append': True, + 'model': Product, + }, + 'manufacturers': { + 'allow_append': False, + 'model': Manufacturer, + } + } diff --git a/test_src/test_proj/settings.py b/test_src/test_proj/settings.py index 2c10a048..71c3187a 100644 --- a/test_src/test_proj/settings.py +++ b/test_src/test_proj/settings.py @@ -68,6 +68,9 @@ API[VST_API_VERSION][r'modelwithnested'] = dict( model='test_proj.models.ModelWithNestedModels' ) +API[VST_API_VERSION][r'stores'] = dict( + model='test_proj.models.Store' +) API[VST_API_VERSION][r'modelwithcrontab'] = dict( model='test_proj.models.ModelWithCrontabField' ) diff --git a/test_src/test_proj/tests.py b/test_src/test_proj/tests.py index 7e81eea2..bddb9ab7 100644 --- a/test_src/test_proj/tests.py +++ b/test_src/test_proj/tests.py @@ -86,6 +86,11 @@ ModelWithNestedModels, ProtectedBySignal, ModelWithUuidPk, + Store, + Manufacturer, + Option, + Attribute, + Product ) from rest_framework.exceptions import ValidationError from base64 import b64encode @@ -2321,6 +2326,23 @@ def has_deep_parent_filter(params): schema = self.endpoint_schema() self.assertTrue(schema['definitions']['User']['properties']['is_staff']['readOnly']) + # check that nested endponit's permissions took into account + user = self._create_user(is_super_user=False, is_staff=True) + with self.user_as(self, user): + schema = self.endpoint_schema() + schemas_differance = set(api['paths'].keys()) - set(schema['paths'].keys()) + expected_differance = { + '/stores/{id}/products/{products_id}/attributes/', + '/stores/{id}/products/{products_id}/attributes/{attributes_id}/', + '/stores/{id}/manufacturers/{manufacturers_id}/products/{products_id}/attributes/{attributes_id}/', + '/stores/{id}/manufacturers/{manufacturers_id}/products/{products_id}/attributes/', + } + # Check that only expected endpoints were banned. + self.assertEqual( + schemas_differance, + expected_differance + ) + def test_search_fields(self): self.assertEqual( self.get_model_class('test_proj.Variable').generated_view.search_fields, @@ -5043,6 +5065,57 @@ def serializer_test(serializer): generated_serializer = ModelWithBinaryFiles.generated_view.serializer_class() serializer_test(generated_serializer) + def test_nested_views_permissions(self): + # Test nested model viewsets permissions. + store = Store.objects.create( + name='test' + ) + manufacturer = Manufacturer.objects.create( + name='test man', + store=store + ) + product = Product.objects.create( + name='test prod', + store=store, + price = 100, + manufacturer=manufacturer + ) + attr = Attribute.objects.create( + name='test attr', + product=product + ) + option = Option.objects.create( + name='test option', + product=product, + ) + + endpoints_to_test = [ + {'method': 'get', 'path': f'/stores/{store.id}/products/{product.id}/attributes/'}, + {'method': 'get', 'path': f'/stores/{store.id}/products/{product.id}/attributes/{attr.id}/'}, + {'method': 'get', 'path': f'/stores/{store.id}/manufacturers/{manufacturer.id}/products/{product.id}/attributes/{attr.id}/'}, + {'method': 'get', 'path': f'/stores/{store.id}/manufacturers/{manufacturer.id}/products/{product.id}/attributes/'}, + ] + + always_available = [ + {'method': 'get', 'path': f'/stores/{store.id}/products/{product.id}/options/{option.id}/'}, + {'method': 'get', 'path': f'/stores/{store.id}/products/{product.id}/options/'}, + ] + + results = self.bulk(endpoints_to_test + always_available) + + for result in results: + self.assertEqual(result['status'], 200, result['path']) + + user = self._create_user(is_super_user=False, is_staff=True) + with self.user_as(self, user): + results = self.bulk(endpoints_to_test) + for result in results: + self.assertEqual(result['status'], 403, result['path']) + + results = self.bulk(always_available) + for result in results: + self.assertEqual(result['status'], 200, result['path']) + def test_deep_nested(self): results = self.bulk([ # [0-2] Create 3 nested objects diff --git a/vstutils/__init__.py b/vstutils/__init__.py index 4826670b..14ada44f 100644 --- a/vstutils/__init__.py +++ b/vstutils/__init__.py @@ -1,2 +1,2 @@ # pylint: disable=django-not-available -__version__: str = '5.8.18' +__version__: str = '5.8.19' diff --git a/vstutils/api/permissions.py b/vstutils/api/permissions.py index 4301f919..c638db3a 100644 --- a/vstutils/api/permissions.py +++ b/vstutils/api/permissions.py @@ -5,17 +5,17 @@ from ..utils import raise_context -class IsAuthenticatedOpenApiRequest(permissions.IsAuthenticated): +def is_openapi_request(request): + return ( + request.path.startswith(f'/{settings.API_URL}/openapi/') or + request.path.startswith(f'/{settings.API_URL}/endpoint/') or + request.path == f'/{settings.API_URL}/{request.version}/_openapi/' + ) - def is_openapi(self, request): - return ( - request.path.startswith(f'/{settings.API_URL}/openapi/') or - request.path.startswith(f'/{settings.API_URL}/endpoint/') or - request.path == f'/{settings.API_URL}/{request.version}/_openapi/' - ) +class IsAuthenticatedOpenApiRequest(permissions.IsAuthenticated): def has_permission(self, request, view): - return self.is_openapi(request) or super().has_permission(request, view) + return is_openapi_request(request) or super().has_permission(request, view) class SuperUserPermission(IsAuthenticatedOpenApiRequest): @@ -29,7 +29,7 @@ def has_permission(self, request, view): issubclass(view.get_queryset().model, AbstractUser) and str(view.kwargs['pk']) in (str(request.user.pk), 'profile') ) - return self.is_openapi(request) + return is_openapi_request(request) def has_object_permission(self, request, view, obj): if request.user.is_superuser: diff --git a/vstutils/api/schema/generators.py b/vstutils/api/schema/generators.py index 66f1171f..47e99934 100644 --- a/vstutils/api/schema/generators.py +++ b/vstutils/api/schema/generators.py @@ -10,6 +10,8 @@ from drf_yasg.inspectors import field as field_insp from vstutils.utils import raise_context_decorator_with_default +from .schema import get_nested_view_obj, _get_nested_view_and_subaction + def get_centrifugo_public_address(request: drf_request.Request): address = settings.CENTRIFUGO_PUBLIC_HOST @@ -103,6 +105,12 @@ def get_path_parameters(self, path, view_cls): continue # nocv return parameters + def should_include_endpoint(self, path, method, view, public): + nested_view, sub_action = _get_nested_view_and_subaction(view) + if nested_view and sub_action: + view = get_nested_view_obj(view, nested_view, sub_action, method) + return super().should_include_endpoint(path, method, view, public) + def get_operation_keys(self, subpath, method, view): keys = super().get_operation_keys(subpath, method, view) subpath_keys = list(filter(bool, subpath.split('/'))) diff --git a/vstutils/api/schema/schema.py b/vstutils/api/schema/schema.py index f5e36c37..ec6aefc8 100644 --- a/vstutils/api/schema/schema.py +++ b/vstutils/api/schema/schema.py @@ -19,6 +19,79 @@ View = _t.Type[views.APIView] +def _get_nested_view_and_subaction(view, default=None): + sub_action = getattr(view, view.action, None) + return getattr(sub_action, '_nested_view', default), sub_action + + +def _get_nested_view_class(nested_view, view_action_func): + # pylint: disable=protected-access + if not hasattr(view_action_func, '_nested_name'): + return nested_view + + nested_action_name = '_'.join(view_action_func._nested_name.split('_')[1:]) + + if nested_view is None: + return nested_view # nocv + + if hasattr(view_action_func, '_nested_view'): + nested_view_class = view_action_func._nested_view + view_action_func = getattr(nested_view_class, nested_action_name, None) + else: # nocv + nested_view_class = None + view_action_func = None + + if view_action_func is None: + return nested_view + + return _get_nested_view_class(nested_view_class, view_action_func) + + +def get_nested_view_obj(view, nested_view: 'View', view_action_func, method): + # pylint: disable=protected-access + # Get nested view recursively + nested_view: 'View' = utils.get_if_lazy(_get_nested_view_class(nested_view, view_action_func)) + # Get action suffix + replace_pattern = view_action_func._nested_subname + '_' + replace_index = view.action.index(replace_pattern) + len(replace_pattern) + action_suffix = view.action[replace_index:] + # Check detail or list action + is_detail = action_suffix.endswith('detail') + is_list = action_suffix.endswith('list') + # Create view object + method = method.lower() + nested_view_obj = nested_view() + nested_view_obj.request = view.request + nested_view_obj.kwargs = view.kwargs + nested_view_obj.lookup_field = view.lookup_field + nested_view_obj.lookup_url_kwarg = view.lookup_url_kwarg + nested_view_obj.format_kwarg = None + nested_view_obj.format_kwarg = None + nested_view_obj._nested_wrapped_view = getattr(view_action_func, '_nested_wrapped_view', None) + # Check operation action + if method == 'post' and is_list: + nested_view_obj.action = 'create' + elif method == 'get' and is_list: + nested_view_obj.action = 'list' + elif method == 'get' and is_detail: + nested_view_obj.action = 'retrieve' + elif method == 'put' and is_detail: + nested_view_obj.action = 'update' + elif method == 'patch' and is_detail: + nested_view_obj.action = 'partial_update' + elif method == 'delete' and is_detail: + nested_view_obj.action = 'destroy' + else: + nested_view_obj.action = action_suffix + new_view = getattr(nested_view_obj, action_suffix, None) + if new_view is not None: + serializer_class = new_view.kwargs.get('serializer_class', None) + if serializer_class: + nested_view_obj.serializer_class = serializer_class + + return nested_view_obj + + class ExtendedSwaggerAutoSchema(SwaggerAutoSchema): def get_query_parameters(self): result = super().get_query_parameters() @@ -107,72 +180,6 @@ def __init__(self, *args, **kwargs): self._sch.view = args[0] self.request._schema = self - def _get_nested_view_class(self, nested_view: 'View', view_action_func): - # pylint: disable=protected-access - if not hasattr(view_action_func, '_nested_name'): - return nested_view - - nested_action_name = '_'.join(view_action_func._nested_name.split('_')[1:]) - - if nested_view is None: - return nested_view # nocv - - if hasattr(view_action_func, '_nested_view'): - nested_view_class = view_action_func._nested_view - view_action_func = getattr(nested_view_class, nested_action_name, None) - else: # nocv - nested_view_class = None - view_action_func = None - - if view_action_func is None: - return nested_view - - return self._get_nested_view_class(nested_view_class, view_action_func) - - def __get_nested_view_obj(self, nested_view: 'View', view_action_func): - # pylint: disable=protected-access - # Get nested view recursively - nested_view: 'View' = utils.get_if_lazy(self._get_nested_view_class(nested_view, view_action_func)) - # Get action suffix - replace_pattern = view_action_func._nested_subname + '_' - replace_index = self.view.action.index(replace_pattern) + len(replace_pattern) - action_suffix = self.view.action[replace_index:] - # Check detail or list action - is_detail = action_suffix.endswith('detail') - is_list = action_suffix.endswith('list') - # Create view object - method = self.method.lower() - nested_view_obj = nested_view() - nested_view_obj.request = self.view.request - nested_view_obj.kwargs = self.view.kwargs - nested_view_obj.lookup_field = self.view.lookup_field - nested_view_obj.lookup_url_kwarg = self.view.lookup_url_kwarg - nested_view_obj.format_kwarg = None - nested_view_obj.format_kwarg = None - nested_view_obj._nested_wrapped_view = getattr(view_action_func, '_nested_wrapped_view', None) - # Check operation action - if method == 'post' and is_list: - nested_view_obj.action = 'create' - elif method == 'get' and is_list: - nested_view_obj.action = 'list' - elif method == 'get' and is_detail: - nested_view_obj.action = 'retrieve' - elif method == 'put' and is_detail: - nested_view_obj.action = 'update' - elif method == 'patch' and is_detail: - nested_view_obj.action = 'partial_update' - elif method == 'delete' and is_detail: - nested_view_obj.action = 'destroy' - else: - nested_view_obj.action = action_suffix - view = getattr(nested_view_obj, action_suffix, None) - if view is not None: - serializer_class = view.kwargs.get('serializer_class', None) - if serializer_class: - nested_view_obj.serializer_class = serializer_class - - return nested_view_obj - def get_operation_id(self, operation_keys=None): new_operation_keys: _t.List[str] = [] append_new_operation_keys = new_operation_keys.append @@ -189,17 +196,13 @@ def get_response_schemas(self, response_serializers): response.description = self.default_status_messages.get(response_code, 'Action accepted.') return responses - def __get_nested_view_and_subaction(self, default=None): - sub_action = getattr(self.view, self.view.action, None) - return getattr(sub_action, '_nested_view', default), sub_action - def __perform_with_nested(self, func_name, *args, **kwargs): # pylint: disable=protected-access - nested_view, sub_action = self.__get_nested_view_and_subaction() + nested_view, sub_action = _get_nested_view_and_subaction(self.view) if nested_view and sub_action: schema = copy(self) try: - schema.view = self.__get_nested_view_obj(nested_view, sub_action) + schema.view = get_nested_view_obj(self.view, nested_view, sub_action, self.method) result = getattr(schema, func_name)(*args, **kwargs) if result: return result @@ -260,7 +263,7 @@ def get_operation(self, operation_keys=None): params_to_override = ('x-title', 'x-icons') if self.method.lower() == 'get': - subscribe_view = self.__get_nested_view_and_subaction(self.view)[0] + subscribe_view = _get_nested_view_and_subaction(self.view, self.view)[0] queryset = getattr(subscribe_view, 'queryset', None) if queryset is not None: # pylint: disable=protected-access diff --git a/vstutils/management/commands/rpc_worker.py b/vstutils/management/commands/rpc_worker.py index 09eecfd3..3ea5b828 100644 --- a/vstutils/management/commands/rpc_worker.py +++ b/vstutils/management/commands/rpc_worker.py @@ -96,9 +96,9 @@ def handle(self, *args, **options): # nocv proc.kill() proc.wait() raise exc - except KeyboardInterrupt: + except KeyboardInterrupt: # nocv self._print('Exit by user...', 'WARNING') - except BaseException as err: + except BaseException as err: # nocv self._print(traceback.format_exc()) self._print(str(err), 'ERROR') sys.exit(1)