From 6a15625efebce5dd9916579b8e2898739bda72e1 Mon Sep 17 00:00:00 2001 From: siesto1elemento Date: Sun, 13 Oct 2024 17:04:36 +0530 Subject: [PATCH 1/2] added_subscription support --- cvat/apps/engine/admin.py | 8 +- cvat/apps/engine/middleware.py | 168 ++++++++++++++++++ .../apps/engine/migrations/0082_subscriber.py | 39 ++++ ...3_remove_subscriber_subscribed_and_more.py | 26 +++ cvat/apps/engine/models.py | 19 ++ cvat/settings/base.py | 5 + 6 files changed, 264 insertions(+), 1 deletion(-) create mode 100644 cvat/apps/engine/migrations/0082_subscriber.py create mode 100644 cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py diff --git a/cvat/apps/engine/admin.py b/cvat/apps/engine/admin.py index 05e4b40a0f9b..a9cf0d6e83d1 100644 --- a/cvat/apps/engine/admin.py +++ b/cvat/apps/engine/admin.py @@ -5,7 +5,7 @@ from django.contrib import admin from .models import Task, Segment, Job, Label, AttributeSpec, Project, \ - CloudStorage, Storage, Data, AnnotationGuide, Asset + CloudStorage, Storage, Data, AnnotationGuide, Asset, Subscriber class JobInline(admin.TabularInline): model = Job @@ -187,6 +187,11 @@ def has_add_permission(self, request): AssetInline ] +class SubscriberAdmin(admin.ModelAdmin): + list_display = ('user', 'subscription_class') + list_filter = ('subscription_class',) + + admin.site.register(Task, TaskAdmin) admin.site.register(Segment, SegmentAdmin) admin.site.register(Label, LabelAdmin) @@ -195,3 +200,4 @@ def has_add_permission(self, request): admin.site.register(CloudStorage, CloudStorageAdmin) admin.site.register(Data, DataAdmin) admin.site.register(AnnotationGuide, AnnotationGuideAdmin) +admin.site.register(Subscriber, SubscriberAdmin) diff --git a/cvat/apps/engine/middleware.py b/cvat/apps/engine/middleware.py index f2b990a14b50..8c5dffe0d056 100644 --- a/cvat/apps/engine/middleware.py +++ b/cvat/apps/engine/middleware.py @@ -3,6 +3,9 @@ # SPDX-License-Identifier: MIT from uuid import uuid4 +from .models import Subscriber, Project,Task,Job +from django.core.exceptions import PermissionDenied +from rest_framework.views import APIView class RequestTrackingMiddleware: def __init__(self, get_response): @@ -18,3 +21,168 @@ def __call__(self, request): response.headers['X-Request-Id'] = request.uuid return response + + +class ProjectLimitCheckMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == 'POST': + project_creation_path = '/api/projects' + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if request.path == project_creation_path: + project_count = Project.objects.filter(owner=user).count() + + try: + subscription = Subscriber.objects.get(user=user) + + # Enforce project limits based on subscription class + if subscription.subscription_class == 'basic' and project_count >= 3: + raise PermissionDenied("You have reached your limit of 3 projects. Please subscribe for more.") + elif subscription.subscription_class == 'silver' and project_count >= 5: + raise PermissionDenied("You have reached your limit of 5 projects. Please upgrade to Gold for unlimited projects.") + # Gold users have unlimited projects, no limit check needed + except Subscriber.DoesNotExist: + # Default to basic limits if no subscriber record is found + if project_count >= 3: + raise PermissionDenied("You have reached your limit of 3 projects. Please subscribe for more.") + + response = self.get_response(request) + return response + + +class TaskLimitCheckMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == 'POST': + task_creation_path = '/api/tasks' + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if request.path == task_creation_path: + task_count = Task.objects.filter(owner=user).count() + + try: + subscription = Subscriber.objects.get(user=user) + + # Enforce task limits based on subscription class + if subscription.subscription_class == 'basic' and task_count >= 10: + raise PermissionDenied("You have reached your limit of 10 tasks. Please subscribe for more.") + elif subscription.subscription_class == 'silver' and task_count >= 20: + raise PermissionDenied("You have reached your limit of 20 tasks. Please upgrade to Gold for unlimited tasks.") + # Gold users have unlimited tasks, no limit check needed + except Subscriber.DoesNotExist: + # Default to basic limits if no subscriber record is found + if task_count >= 10: + raise PermissionDenied("You have reached your limit of 10 tasks. Please subscribe for more.") + + response = self.get_response(request) + return response + +class ExportJobAnnotationsMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == 'GET': + drf_request = APIView().initialize_request(request) + user = drf_request.user + + + if request.path.startswith('/api/jobs/') and request.path.endswith('/annotations') and 'format' in request.GET: + try: + + + subscription = Subscriber.objects.get(user=user) + + if subscription.subscription_class == 'basic': + raise PermissionDenied("Exporting annotations with images is not available for Basic subscription.") + elif subscription.subscription_class in ['silver', 'gold']: + + pass + + except Job.DoesNotExist: + raise PermissionDenied("Job not found or you do not have permission to access it.") + except Subscriber.DoesNotExist: + raise PermissionDenied("You need a valid subscription to export annotations with images.") + + + response = self.get_response(request) + return response + +class ExportTaskAnnotationsMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == 'GET': + drf_request = APIView().initialize_request(request) + user = drf_request.user + + + if request.path.startswith('/api/tasks/') and request.path.endswith('/annotations') and 'format' in request.GET: + try: + + + subscription = Subscriber.objects.get(user=user) + + if subscription.subscription_class == 'basic': + raise PermissionDenied("Exporting task annotations with audio is not available for Basic subscription.") + elif subscription.subscription_class in ['silver', 'gold']: + + pass + + except Task.DoesNotExist: + raise PermissionDenied("Task not found or you do not have permission to access it.") + except Subscriber.DoesNotExist: + raise PermissionDenied("You need a valid subscription to export task annotations with audio.") + + + response = self.get_response(request) + return response + +class ProjectTaskLimitMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == 'POST' and request.path == '/api/tasks': + drf_request = APIView().initialize_request(request) + user = drf_request.user + + + project_id = drf_request.data.get('project_id') + + if project_id: + try: + + subscription = Subscriber.objects.get(user=user) + + + project = Project.objects.get(id=project_id) + + + task_count = Task.objects.filter(project=project).count() + + + if subscription.subscription_class == 'basic' and task_count >= 5: + raise PermissionDenied("Basic subscription allows up to 5 tasks per project. Please upgrade to add more tasks.") + elif subscription.subscription_class == 'silver' and task_count >= 10: + raise PermissionDenied("Silver subscription allows up to 10 tasks per project. Please upgrade to add more tasks.") + elif subscription.subscription_class == 'gold': + + pass + + except Project.DoesNotExist: + raise PermissionDenied("Project not found.") + except Subscriber.DoesNotExist: + raise PermissionDenied("You need a valid subscription to create tasks in a project.") + + + response = self.get_response(request) + return response \ No newline at end of file diff --git a/cvat/apps/engine/migrations/0082_subscriber.py b/cvat/apps/engine/migrations/0082_subscriber.py new file mode 100644 index 000000000000..33f0274dbd1d --- /dev/null +++ b/cvat/apps/engine/migrations/0082_subscriber.py @@ -0,0 +1,39 @@ +# Generated by Django 4.2.13 on 2024-10-06 05:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("engine", "0081_task_extra_params"), + ] + + operations = [ + migrations.CreateModel( + name="Subscriber", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("subscribed", models.BooleanField(default=False)), + ( + "user", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="subscriber", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + ), + ] \ No newline at end of file diff --git a/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py b/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py new file mode 100644 index 000000000000..73f28e18016f --- /dev/null +++ b/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.13 on 2024-10-10 13:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("engine", "0082_subscriber"), + ] + + operations = [ + migrations.RemoveField( + model_name="subscriber", + name="subscribed", + ), + migrations.AddField( + model_name="subscriber", + name="subscription_class", + field=models.CharField( + choices=[("gold", "Gold"), ("silver", "Silver"), ("basic", "Basic")], + default="basic", + max_length=6, + ), + ), + ] \ No newline at end of file diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index d597432fa9f7..fe18ff8b21ed 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -1194,3 +1194,22 @@ def organization_id(self): def get_asset_dir(self): return os.path.join(settings.ASSETS_ROOT, str(self.uuid)) + + +class Subscriber(models.Model): + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="subscriber") + + SUBSCRIPTION_CHOICES = [ + ('gold', 'Gold'), + ('silver', 'Silver'), + ('basic', 'Basic'), # Added basic subscription class + ] + + subscription_class = models.CharField( + max_length=6, + choices=SUBSCRIPTION_CHOICES, + default='basic' # Default set to basic + ) + + def __str__(self): + return f"{self.user.username} - Subscription Class: {self.subscription_class}" diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 3e4d610915bd..9f91ddef10ce 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -197,6 +197,11 @@ def generate_secret_key(): 'dj_pagination.middleware.PaginationMiddleware', 'cvat.apps.iam.middleware.ContextMiddleware', 'allauth.account.middleware.AccountMiddleware', + 'cvat.apps.engine.middleware.ProjectLimitCheckMiddleware', + 'cvat.apps.engine.middleware.TaskLimitCheckMiddleware', + 'cvat.apps.engine.middleware.ExportJobAnnotationsMiddleware', + 'cvat.apps.engine.middleware.ExportTaskAnnotationsMiddleware', + 'cvat.apps.engine.middleware.ProjectTaskLimitMiddleware', ] UI_URL = os.getenv('UI_URL', '') From 8b69bd043218cf624dbac59055e8cd55d38b8e03 Mon Sep 17 00:00:00 2001 From: siesto1elemento Date: Mon, 14 Oct 2024 17:59:58 +0530 Subject: [PATCH 2/2] formatting changes --- cvat/apps/engine/middleware.py | 147 ++++++++++++++++++++------------- 1 file changed, 91 insertions(+), 56 deletions(-) diff --git a/cvat/apps/engine/middleware.py b/cvat/apps/engine/middleware.py index 8c5dffe0d056..5ab87950f370 100644 --- a/cvat/apps/engine/middleware.py +++ b/cvat/apps/engine/middleware.py @@ -3,10 +3,11 @@ # SPDX-License-Identifier: MIT from uuid import uuid4 -from .models import Subscriber, Project,Task,Job +from .models import Subscriber, Project, Task, Job from django.core.exceptions import PermissionDenied from rest_framework.views import APIView + class RequestTrackingMiddleware: def __init__(self, get_response): self.get_response = get_response @@ -18,7 +19,7 @@ def _generate_id(): def __call__(self, request): request.uuid = self._generate_id() response = self.get_response(request) - response.headers['X-Request-Id'] = request.uuid + response.headers["X-Request-Id"] = request.uuid return response @@ -28,8 +29,8 @@ def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - if request.method == 'POST': - project_creation_path = '/api/projects' + if request.method == "POST": + project_creation_path = "/api/projects" drf_request = APIView().initialize_request(request) user = drf_request.user @@ -39,16 +40,25 @@ def __call__(self, request): try: subscription = Subscriber.objects.get(user=user) - # Enforce project limits based on subscription class - if subscription.subscription_class == 'basic' and project_count >= 3: - raise PermissionDenied("You have reached your limit of 3 projects. Please subscribe for more.") - elif subscription.subscription_class == 'silver' and project_count >= 5: - raise PermissionDenied("You have reached your limit of 5 projects. Please upgrade to Gold for unlimited projects.") - # Gold users have unlimited projects, no limit check needed + if ( + subscription.subscription_class == "basic" + and project_count >= 3 + ): + raise PermissionDenied( + "You have reached your limit of 3 projects. Please subscribe for more." + ) + elif ( + subscription.subscription_class == "silver" + and project_count >= 5 + ): + raise PermissionDenied( + "You have reached your limit of 5 projects. Please upgrade to Gold for unlimited projects." + ) except Subscriber.DoesNotExist: - # Default to basic limits if no subscriber record is found if project_count >= 3: - raise PermissionDenied("You have reached your limit of 3 projects. Please subscribe for more.") + raise PermissionDenied( + "You have reached your limit of 3 projects. Please subscribe for more." + ) response = self.get_response(request) return response @@ -59,8 +69,8 @@ def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - if request.method == 'POST': - task_creation_path = '/api/tasks' + if request.method == "POST": + task_creation_path = "/api/tasks" drf_request = APIView().initialize_request(request) user = drf_request.user @@ -70,119 +80,144 @@ def __call__(self, request): try: subscription = Subscriber.objects.get(user=user) - # Enforce task limits based on subscription class - if subscription.subscription_class == 'basic' and task_count >= 10: - raise PermissionDenied("You have reached your limit of 10 tasks. Please subscribe for more.") - elif subscription.subscription_class == 'silver' and task_count >= 20: - raise PermissionDenied("You have reached your limit of 20 tasks. Please upgrade to Gold for unlimited tasks.") - # Gold users have unlimited tasks, no limit check needed + if subscription.subscription_class == "basic" and task_count >= 10: + raise PermissionDenied( + "You have reached your limit of 10 tasks. Please subscribe for more." + ) + elif ( + subscription.subscription_class == "silver" and task_count >= 20 + ): + raise PermissionDenied( + "You have reached your limit of 20 tasks. Please upgrade to Gold for unlimited tasks." + ) except Subscriber.DoesNotExist: - # Default to basic limits if no subscriber record is found if task_count >= 10: - raise PermissionDenied("You have reached your limit of 10 tasks. Please subscribe for more.") + raise PermissionDenied( + "You have reached your limit of 10 tasks. Please subscribe for more." + ) response = self.get_response(request) return response + class ExportJobAnnotationsMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - if request.method == 'GET': + if request.method == "GET": drf_request = APIView().initialize_request(request) user = drf_request.user - - if request.path.startswith('/api/jobs/') and request.path.endswith('/annotations') and 'format' in request.GET: + if ( + request.path.startswith("/api/jobs/") + and request.path.endswith("/annotations") + and "format" in request.GET + ): try: - subscription = Subscriber.objects.get(user=user) - if subscription.subscription_class == 'basic': - raise PermissionDenied("Exporting annotations with images is not available for Basic subscription.") - elif subscription.subscription_class in ['silver', 'gold']: + if subscription.subscription_class == "basic": + raise PermissionDenied( + "Exporting annotations with videos is not available for Basic subscription." + ) + elif subscription.subscription_class in ["silver", "gold"]: pass except Job.DoesNotExist: - raise PermissionDenied("Job not found or you do not have permission to access it.") + raise PermissionDenied( + "Job not found or you do not have permission to access it." + ) except Subscriber.DoesNotExist: - raise PermissionDenied("You need a valid subscription to export annotations with images.") - + raise PermissionDenied( + "You need a valid subscription to export annotations with videos." + ) response = self.get_response(request) return response + class ExportTaskAnnotationsMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - if request.method == 'GET': + if request.method == "GET": drf_request = APIView().initialize_request(request) user = drf_request.user - - if request.path.startswith('/api/tasks/') and request.path.endswith('/annotations') and 'format' in request.GET: + if ( + request.path.startswith("/api/tasks/") + and request.path.endswith("/annotations") + and "format" in request.GET + ): try: - subscription = Subscriber.objects.get(user=user) - if subscription.subscription_class == 'basic': - raise PermissionDenied("Exporting task annotations with audio is not available for Basic subscription.") - elif subscription.subscription_class in ['silver', 'gold']: + if subscription.subscription_class == "basic": + raise PermissionDenied( + "Exporting task annotations with audio is not available for Basic subscription." + ) + elif subscription.subscription_class in ["silver", "gold"]: pass except Task.DoesNotExist: - raise PermissionDenied("Task not found or you do not have permission to access it.") + raise PermissionDenied( + "Task not found or you do not have permission to access it." + ) except Subscriber.DoesNotExist: - raise PermissionDenied("You need a valid subscription to export task annotations with audio.") - + raise PermissionDenied( + "You need a valid subscription to export task annotations with audio." + ) response = self.get_response(request) return response + class ProjectTaskLimitMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): - if request.method == 'POST' and request.path == '/api/tasks': + if request.method == "POST" and request.path == "/api/tasks": drf_request = APIView().initialize_request(request) user = drf_request.user - - project_id = drf_request.data.get('project_id') + project_id = drf_request.data.get("project_id") if project_id: try: subscription = Subscriber.objects.get(user=user) - project = Project.objects.get(id=project_id) - task_count = Task.objects.filter(project=project).count() - - if subscription.subscription_class == 'basic' and task_count >= 5: - raise PermissionDenied("Basic subscription allows up to 5 tasks per project. Please upgrade to add more tasks.") - elif subscription.subscription_class == 'silver' and task_count >= 10: - raise PermissionDenied("Silver subscription allows up to 10 tasks per project. Please upgrade to add more tasks.") - elif subscription.subscription_class == 'gold': + if subscription.subscription_class == "basic" and task_count >= 5: + raise PermissionDenied( + "Basic subscription allows up to 5 tasks per project. Please upgrade to add more tasks." + ) + elif ( + subscription.subscription_class == "silver" and task_count >= 10 + ): + raise PermissionDenied( + "Silver subscription allows up to 10 tasks per project. Please upgrade to add more tasks." + ) + elif subscription.subscription_class == "gold": pass except Project.DoesNotExist: raise PermissionDenied("Project not found.") except Subscriber.DoesNotExist: - raise PermissionDenied("You need a valid subscription to create tasks in a project.") - + raise PermissionDenied( + "You need a valid subscription to create tasks in a project." + ) response = self.get_response(request) - return response \ No newline at end of file + return response