diff --git a/.env.dev b/.env.dev index 7505d9d..7c8dca0 100644 --- a/.env.dev +++ b/.env.dev @@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost" WHISPER_MODEL="tiny" ENVIRONMENT="development" DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite" -BROKER_URL="redis://redis:6379/0" + +RABBITMQ_DEFAULT_USER="rabbitmq" +RABBITMQ_DEFAULT_PASS="rabbitmq_password" diff --git a/.env.example b/.env.example index d0f0e0f..ebef8d3 100644 --- a/.env.example +++ b/.env.example @@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL="" # --- # below settings match the default docker-compose configuration. -BROKER_URL="redis://redis:6379/0" +RABBITMQ_DEFAULT_USER="rabbitmq" +RABBITMQ_DEFAULT_PASS="rabbitmq_password" + DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite" ENVIRONMENT="production" diff --git a/README.md b/README.md index cf3b60c..08baa55 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ Builds and starts the docker containers. ``` # Bindings http://localhost:5555 => Celery dashboard +http://localhost:15672 => RabbitMQ dashboard http://whisperbox-transcribe.localhost => API http://whisperbox-transcribe.localhost/docs => API docs ./whisperbox-transcribe.sqlite => Database diff --git a/app/shared/celery.py b/app/shared/celery.py index 8bd334a..71c1342 100644 --- a/app/shared/celery.py +++ b/app/shared/celery.py @@ -9,4 +9,5 @@ def get_celery_binding() -> Celery: broker_connection_retry=False, broker_connection_retry_on_startup=False, ) + return celery diff --git a/app/shared/db/models.py b/app/shared/db/models.py index e7c3375..b63bcb8 100644 --- a/app/shared/db/models.py +++ b/app/shared/db/models.py @@ -52,6 +52,11 @@ class JobConfig(BaseModel): class JobMeta(BaseModel): """(JSON) Metadata relating to a job's execution.""" + attempts: int | None = Field( + default=None, + description="Number of processing attempts a job has taken.", + ) + error: str | None = Field( default=None, description="Will contain a descriptive error message if processing failed.", diff --git a/app/web/main.py b/app/web/main.py index baa7842..5738b9b 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,4 +1,3 @@ -from contextlib import asynccontextmanager from typing import Annotated, Callable, Generator from uuid import UUID @@ -8,7 +7,6 @@ import app.shared.db.models as models import app.web.dtos as dtos -from app.shared.db.base import SessionLocal from app.shared.settings import settings from app.web.security import authenticate_api_key from app.web.task_queue import TaskQueue @@ -21,17 +19,10 @@ def app_factory( task_queue = TaskQueue() - @asynccontextmanager - async def lifespan(_: FastAPI): - with SessionLocal() as session: - task_queue.rehydrate(session) - yield - app = FastAPI( description=( "whisperbox-transcribe is an async HTTP wrapper for openai/whisper." ), - lifespan=lifespan, title="whisperbox-transcribe", ) diff --git a/app/web/task_queue.py b/app/web/task_queue.py index 589d465..1d630ab 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -1,8 +1,4 @@ -from asyncio.log import logger - from celery import Celery -from sqlalchemy import or_ -from sqlalchemy.orm import Session import app.shared.db.models as models from app.shared.celery import get_celery_binding @@ -22,25 +18,3 @@ def queue_task(self, job: models.Job): transcribe = self.celery.signature("app.worker.main.transcribe") # TODO: catch delivery errors? transcribe.delay(job.id) - - def rehydrate(self, session: Session): - # TODO: we could use `acks_late` to handle this scenario within celery itself. - # the reason this does not work well in our case is that `visibility_timeout` - # needs to be very high since whisper workers can be long running. - # doing this app-side bears the risk of poison pilling the worker though, - # implement a workaround with an acceptable trade-off. (=> retry only once?) - jobs = ( - session.query(models.Job) - .filter( - or_( - models.Job.status == models.JobStatus.processing, - models.Job.status == models.JobStatus.create, - ) - ) - .order_by(models.Job.created_at) - ).all() - - logger.info(f"Requeueing {len(jobs)} jobs.") - - for job in jobs: - self.queue_task(job) diff --git a/app/worker/main.py b/app/worker/main.py index 7957971..0e98169 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -39,6 +39,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: bind=True, soft_time_limit=settings.TASK_SOFT_TIME_LIMIT, time_limit=settings.TASK_HARD_TIME_LIMIT, + task_acks_late=True, + task_acks_on_failure_or_timeout=True, + task_reject_on_worker_lost=True, ) def transcribe(self: Task, job_id: UUID) -> None: try: @@ -59,9 +62,20 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: start processing {job.type} job.") + if job.meta: + attempts = 1 + (job.meta.get("attempts") or 0) + else: + attempts = 1 + + # SAFEGUARD: celery's retry policies do not handle lost workers, retry once. + # @see https://github.com/celery/celery/pull/6103 + if attempts > 2: + raise Exception("Maximum number of retries exceeded for killed worker.") + # unit of work: set task status to processing. - job.meta = {"task_id": self.request.id} + job.meta = {"task_id": self.request.id, "attempts": attempts} + job.status = models.JobStatus.processing db.commit() @@ -83,7 +97,11 @@ def transcribe(self: Task, job_id: UUID) -> None: if job and db: if db.in_transaction(): db.rollback() - job.meta = {**job.meta, "error": str(e)} # type: ignore + if job.meta: + job.meta = {**job.meta, "error": str(e)} # type: ignore + else: + job.meta = {"error": str(e)} + job.status = models.JobStatus.error db.commit() raise diff --git a/conf/rabbitmq.conf b/conf/rabbitmq.conf new file mode 100644 index 0000000..fad073a --- /dev/null +++ b/conf/rabbitmq.conf @@ -0,0 +1 @@ +vm_memory_high_watermark.absolute = 192MB diff --git a/docker-compose.base.yml b/docker-compose.base.yml index d4dfe81..2148512 100644 --- a/docker-compose.base.yml +++ b/docker-compose.base.yml @@ -1,3 +1,6 @@ +x-broker-environment: &broker-environment + BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672" + version: "3.8" name: whisperbox-transcribe @@ -12,46 +15,59 @@ services: networks: - traefik - redis: - image: redis:7-alpine + rabbitmq: + env_file: .env + image: rabbitmq:3-alpine networks: - app deploy: resources: limits: - memory: 128M + memory: 256M + healthcheck: + test: rabbitmq-diagnostics check_port_connectivity + interval: 3s + timeout: 3s + retries: 10 + + volumes: + - ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf + - rabbitmq-data:/var/lib/rabbitmq/mnesia/ worker: env_file: .env + environment: + <<: *broker-environment build: context: . dockerfile: worker.Dockerfile args: WHISPER_MODEL: ${WHISPER_MODEL} + depends_on: + rabbitmq: + condition: service_healthy networks: - app - depends_on: - - redis - healthcheck: - test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"] - interval: 5s - timeout: 5s - retries: 5 web: env_file: .env + environment: + <<: *broker-environment build: context: . dockerfile: web.Dockerfile + depends_on: + rabbitmq: + condition: service_healthy networks: - app - traefik - depends_on: - worker: - condition: service_healthy networks: app: driver: bridge traefik: driver: bridge + +volumes: + rabbitmq-data: diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml index e6bc559..41d6c28 100644 --- a/docker-compose.dev.yml +++ b/docker-compose.dev.yml @@ -13,6 +13,8 @@ services: web: command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info" + # NOTE: the docker on mac mount adapter (virtioFS) does not support flock. + # this can cause the sqlite database to corrupt when written from worker <> api simultaneously. volumes: - ./:/etc/whisperbox-transcribe/ labels: @@ -26,13 +28,18 @@ services: volumes: - ./:/etc/whisperbox-transcribe/ + rabbitmq: + image: rabbitmq:3-management-alpine + ports: + - 15672:15672 + flower: image: mher/flower - command: celery --broker redis://redis:6379/0 flower --port=5555 + command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555 ports: - 5555:5555 depends_on: - worker: - condition: service_healthy + - worker + - rabbitmq networks: - app diff --git a/pyproject.toml b/pyproject.toml index 3d82800..a2c4233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ description = "" version = "1.0.0" dependencies=[ - "celery[redis] ==5.3.1", + "celery ==5.3.1", "sqlalchemy[mypy] ==2.0.20", "pydantic ==2.1.1", "pydantic-settings ==2.0.3"