diff --git a/cvat/apps/engine/admin.py b/cvat/apps/engine/admin.py index 05e4b40a0f9b..a9cf0d6e83d1 100644 --- a/cvat/apps/engine/admin.py +++ b/cvat/apps/engine/admin.py @@ -5,7 +5,7 @@ from django.contrib import admin from .models import Task, Segment, Job, Label, AttributeSpec, Project, \ - CloudStorage, Storage, Data, AnnotationGuide, Asset + CloudStorage, Storage, Data, AnnotationGuide, Asset, Subscriber class JobInline(admin.TabularInline): model = Job @@ -187,6 +187,11 @@ def has_add_permission(self, request): AssetInline ] +class SubscriberAdmin(admin.ModelAdmin): + list_display = ('user', 'subscription_class') + list_filter = ('subscription_class',) + + admin.site.register(Task, TaskAdmin) admin.site.register(Segment, SegmentAdmin) admin.site.register(Label, LabelAdmin) @@ -195,3 +200,4 @@ def has_add_permission(self, request): admin.site.register(CloudStorage, CloudStorageAdmin) admin.site.register(Data, DataAdmin) admin.site.register(AnnotationGuide, AnnotationGuideAdmin) +admin.site.register(Subscriber, SubscriberAdmin) diff --git a/cvat/apps/engine/middleware.py b/cvat/apps/engine/middleware.py index f2b990a14b50..5ab87950f370 100644 --- a/cvat/apps/engine/middleware.py +++ b/cvat/apps/engine/middleware.py @@ -3,6 +3,10 @@ # SPDX-License-Identifier: MIT from uuid import uuid4 +from .models import Subscriber, Project, Task, Job +from django.core.exceptions import PermissionDenied +from rest_framework.views import APIView + class RequestTrackingMiddleware: def __init__(self, get_response): @@ -15,6 +19,205 @@ def _generate_id(): def __call__(self, request): request.uuid = self._generate_id() response = self.get_response(request) - response.headers['X-Request-Id'] = request.uuid + response.headers["X-Request-Id"] = request.uuid + + return response + + +class ProjectLimitCheckMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == "POST": + project_creation_path = "/api/projects" + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if request.path == project_creation_path: + project_count = Project.objects.filter(owner=user).count() + + try: + subscription = Subscriber.objects.get(user=user) + + if ( + subscription.subscription_class == "basic" + and project_count >= 3 + ): + raise PermissionDenied( + "You have reached your limit of 3 projects. Please subscribe for more." + ) + elif ( + subscription.subscription_class == "silver" + and project_count >= 5 + ): + raise PermissionDenied( + "You have reached your limit of 5 projects. Please upgrade to Gold for unlimited projects." + ) + except Subscriber.DoesNotExist: + if project_count >= 3: + raise PermissionDenied( + "You have reached your limit of 3 projects. Please subscribe for more." + ) + + response = self.get_response(request) + return response + + +class TaskLimitCheckMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == "POST": + task_creation_path = "/api/tasks" + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if request.path == task_creation_path: + task_count = Task.objects.filter(owner=user).count() + + try: + subscription = Subscriber.objects.get(user=user) + + if subscription.subscription_class == "basic" and task_count >= 10: + raise PermissionDenied( + "You have reached your limit of 10 tasks. Please subscribe for more." + ) + elif ( + subscription.subscription_class == "silver" and task_count >= 20 + ): + raise PermissionDenied( + "You have reached your limit of 20 tasks. Please upgrade to Gold for unlimited tasks." + ) + except Subscriber.DoesNotExist: + if task_count >= 10: + raise PermissionDenied( + "You have reached your limit of 10 tasks. Please subscribe for more." + ) + + response = self.get_response(request) + return response + + +class ExportJobAnnotationsMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == "GET": + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if ( + request.path.startswith("/api/jobs/") + and request.path.endswith("/annotations") + and "format" in request.GET + ): + try: + + subscription = Subscriber.objects.get(user=user) + + if subscription.subscription_class == "basic": + raise PermissionDenied( + "Exporting annotations with videos is not available for Basic subscription." + ) + elif subscription.subscription_class in ["silver", "gold"]: + + pass + + except Job.DoesNotExist: + raise PermissionDenied( + "Job not found or you do not have permission to access it." + ) + except Subscriber.DoesNotExist: + raise PermissionDenied( + "You need a valid subscription to export annotations with videos." + ) + + response = self.get_response(request) + return response + + +class ExportTaskAnnotationsMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == "GET": + drf_request = APIView().initialize_request(request) + user = drf_request.user + + if ( + request.path.startswith("/api/tasks/") + and request.path.endswith("/annotations") + and "format" in request.GET + ): + try: + + subscription = Subscriber.objects.get(user=user) + if subscription.subscription_class == "basic": + raise PermissionDenied( + "Exporting task annotations with audio is not available for Basic subscription." + ) + elif subscription.subscription_class in ["silver", "gold"]: + + pass + + except Task.DoesNotExist: + raise PermissionDenied( + "Task not found or you do not have permission to access it." + ) + except Subscriber.DoesNotExist: + raise PermissionDenied( + "You need a valid subscription to export task annotations with audio." + ) + + response = self.get_response(request) + return response + + +class ProjectTaskLimitMiddleware: + def __init__(self, get_response): + self.get_response = get_response + + def __call__(self, request): + if request.method == "POST" and request.path == "/api/tasks": + drf_request = APIView().initialize_request(request) + user = drf_request.user + + project_id = drf_request.data.get("project_id") + + if project_id: + try: + + subscription = Subscriber.objects.get(user=user) + + project = Project.objects.get(id=project_id) + + task_count = Task.objects.filter(project=project).count() + + if subscription.subscription_class == "basic" and task_count >= 5: + raise PermissionDenied( + "Basic subscription allows up to 5 tasks per project. Please upgrade to add more tasks." + ) + elif ( + subscription.subscription_class == "silver" and task_count >= 10 + ): + raise PermissionDenied( + "Silver subscription allows up to 10 tasks per project. Please upgrade to add more tasks." + ) + elif subscription.subscription_class == "gold": + + pass + + except Project.DoesNotExist: + raise PermissionDenied("Project not found.") + except Subscriber.DoesNotExist: + raise PermissionDenied( + "You need a valid subscription to create tasks in a project." + ) + + response = self.get_response(request) return response diff --git a/cvat/apps/engine/migrations/0082_subscriber.py b/cvat/apps/engine/migrations/0082_subscriber.py new file mode 100644 index 000000000000..33f0274dbd1d --- /dev/null +++ b/cvat/apps/engine/migrations/0082_subscriber.py @@ -0,0 +1,39 @@ +# Generated by Django 4.2.13 on 2024-10-06 05:21 + +from django.conf import settings +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ("engine", "0081_task_extra_params"), + ] + + operations = [ + migrations.CreateModel( + name="Subscriber", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("subscribed", models.BooleanField(default=False)), + ( + "user", + models.OneToOneField( + on_delete=django.db.models.deletion.CASCADE, + related_name="subscriber", + to=settings.AUTH_USER_MODEL, + ), + ), + ], + ), + ] \ No newline at end of file diff --git a/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py b/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py new file mode 100644 index 000000000000..73f28e18016f --- /dev/null +++ b/cvat/apps/engine/migrations/0083_remove_subscriber_subscribed_and_more.py @@ -0,0 +1,26 @@ +# Generated by Django 4.2.13 on 2024-10-10 13:22 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("engine", "0082_subscriber"), + ] + + operations = [ + migrations.RemoveField( + model_name="subscriber", + name="subscribed", + ), + migrations.AddField( + model_name="subscriber", + name="subscription_class", + field=models.CharField( + choices=[("gold", "Gold"), ("silver", "Silver"), ("basic", "Basic")], + default="basic", + max_length=6, + ), + ), + ] \ No newline at end of file diff --git a/cvat/apps/engine/models.py b/cvat/apps/engine/models.py index d597432fa9f7..fe18ff8b21ed 100644 --- a/cvat/apps/engine/models.py +++ b/cvat/apps/engine/models.py @@ -1194,3 +1194,22 @@ def organization_id(self): def get_asset_dir(self): return os.path.join(settings.ASSETS_ROOT, str(self.uuid)) + + +class Subscriber(models.Model): + user = models.OneToOneField(User, on_delete=models.CASCADE, related_name="subscriber") + + SUBSCRIPTION_CHOICES = [ + ('gold', 'Gold'), + ('silver', 'Silver'), + ('basic', 'Basic'), # Added basic subscription class + ] + + subscription_class = models.CharField( + max_length=6, + choices=SUBSCRIPTION_CHOICES, + default='basic' # Default set to basic + ) + + def __str__(self): + return f"{self.user.username} - Subscription Class: {self.subscription_class}" diff --git a/cvat/settings/base.py b/cvat/settings/base.py index 3e4d610915bd..9f91ddef10ce 100644 --- a/cvat/settings/base.py +++ b/cvat/settings/base.py @@ -197,6 +197,11 @@ def generate_secret_key(): 'dj_pagination.middleware.PaginationMiddleware', 'cvat.apps.iam.middleware.ContextMiddleware', 'allauth.account.middleware.AccountMiddleware', + 'cvat.apps.engine.middleware.ProjectLimitCheckMiddleware', + 'cvat.apps.engine.middleware.TaskLimitCheckMiddleware', + 'cvat.apps.engine.middleware.ExportJobAnnotationsMiddleware', + 'cvat.apps.engine.middleware.ExportTaskAnnotationsMiddleware', + 'cvat.apps.engine.middleware.ProjectTaskLimitMiddleware', ] UI_URL = os.getenv('UI_URL', '')