diff --git a/api/src/deploy/routes.py b/api/src/deploy/routes.py index a008c3f..a96579c 100644 --- a/api/src/deploy/routes.py +++ b/api/src/deploy/routes.py @@ -11,12 +11,13 @@ from sqlalchemy import select from typing_extensions import Annotated +from src import settings from src.build import build_and_publish_image_to_ecr, unzip_file, write_to_zipfile from src.constants import API_VERSION from src.core.models import DeployConfig, ServiceConfig from src.db import get_db from src.deploy import create_ecr_repository, deploy_python_lambda_function_from_ecr -from src.middleware import get_user +from src.middleware import get_total_services_deployed_for_user, get_user from src.models import Deployment, Service, User if TYPE_CHECKING: @@ -81,8 +82,15 @@ async def deploy( file: Annotated[UploadFile, File()], json_data: Annotated[str, Form()], user: User = Depends(get_user), + service_count: int = Depends(get_total_services_deployed_for_user), db: AsyncSession = Depends(get_db), ): + if service_count >= settings.MAX_SERVICES_PER_USER: + raise HTTPException( + status_code=403, + detail="User has reached the maximum number of services.", + ) + try: deploy_config = DeployConfig(**json.loads(json_data)) # TODO: more careful handling here diff --git a/api/src/middleware.py b/api/src/middleware.py index 9ed2978..5b7ab48 100644 --- a/api/src/middleware.py +++ b/api/src/middleware.py @@ -4,12 +4,12 @@ from fastapi import Depends from fastapi.responses import JSONResponse -from sqlalchemy import select +from sqlalchemy import func, select from sqlalchemy.exc import MultipleResultsFound, NoResultFound from src import settings from src.db import get_db -from src.models import User +from src.models import Deployment, Service, User if TYPE_CHECKING: from fastapi import FastAPI, Request @@ -33,6 +33,22 @@ async def get_user(request: Request, db: AsyncSession = Depends(get_db)) -> User return user +async def get_total_services_deployed_for_user( + user: User, db: AsyncSession = Depends(get_db) +) -> int: + async with db as session: + query = ( + func.count(Service.id) + .select_from(Service) + .join(Service.deployment) + .join(Deployment.user) + .where(User.id == user.id) + ) + + result = await session.execute(query) + return result.scalar_one() + + async def get_deploy_version(request: Request) -> str | None: try: return request.state.deploy_version diff --git a/api/src/settings.py b/api/src/settings.py index 53900fc..98581ed 100644 --- a/api/src/settings.py +++ b/api/src/settings.py @@ -39,3 +39,6 @@ "PARE_ATOMIC_DEPLOYMENT_HEADER", "X-Pare-Atomic-Deployment" ) PARE_API_KEY_HEADER: str = env.str("PARE_API_KEY_HEADER", "X-Pare-API-Key") + + +MAX_SERVICES_PER_USER: int = env.int("MAX_SERVICES_PER_USER", 50)