Skip to content

Commit

Permalink
feat: configure celery to use rabbitmq broker (#58)
Browse files Browse the repository at this point in the history
  • Loading branch information
fspoettel authored Aug 17, 2023
1 parent 423018e commit 504975a
Show file tree
Hide file tree
Showing 12 changed files with 74 additions and 56 deletions.
4 changes: 3 additions & 1 deletion .env.dev
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,6 @@ TRAEFIK_DOMAIN="whisperbox-transcribe.localhost"
WHISPER_MODEL="tiny"
ENVIRONMENT="development"
DATABASE_URI="sqlite:///./whisperbox-transcribe.sqlite"
BROKER_URL="redis://redis:6379/0"

RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"
4 changes: 3 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ TRAEFIK_SSLEMAIL=""
# ---
# below settings match the default docker-compose configuration.

BROKER_URL="redis://redis:6379/0"
RABBITMQ_DEFAULT_USER="rabbitmq"
RABBITMQ_DEFAULT_PASS="rabbitmq_password"

DATABASE_URI="sqlite:////etc/whisperbox-transcribe/data/whisperbox-transcribe.sqlite"
ENVIRONMENT="production"
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ Builds and starts the docker containers.
```
# Bindings
http://localhost:5555 => Celery dashboard
http://localhost:15672 => RabbitMQ dashboard
http://whisperbox-transcribe.localhost => API
http://whisperbox-transcribe.localhost/docs => API docs
./whisperbox-transcribe.sqlite => Database
Expand Down
1 change: 1 addition & 0 deletions app/shared/celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ def get_celery_binding() -> Celery:
broker_connection_retry=False,
broker_connection_retry_on_startup=False,
)

return celery
5 changes: 5 additions & 0 deletions app/shared/db/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ class JobConfig(BaseModel):
class JobMeta(BaseModel):
"""(JSON) Metadata relating to a job's execution."""

attempts: int | None = Field(
default=None,
description="Number of processing attempts a job has taken.",
)

error: str | None = Field(
default=None,
description="Will contain a descriptive error message if processing failed.",
Expand Down
9 changes: 0 additions & 9 deletions app/web/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from contextlib import asynccontextmanager
from typing import Annotated, Callable, Generator
from uuid import UUID

Expand All @@ -8,7 +7,6 @@

import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.db.base import SessionLocal
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.web.task_queue import TaskQueue
Expand All @@ -21,17 +19,10 @@ def app_factory(

task_queue = TaskQueue()

@asynccontextmanager
async def lifespan(_: FastAPI):
with SessionLocal() as session:
task_queue.rehydrate(session)
yield

app = FastAPI(
description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
),
lifespan=lifespan,
title="whisperbox-transcribe",
)

Expand Down
26 changes: 0 additions & 26 deletions app/web/task_queue.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
from asyncio.log import logger

from celery import Celery
from sqlalchemy import or_
from sqlalchemy.orm import Session

import app.shared.db.models as models
from app.shared.celery import get_celery_binding
Expand All @@ -22,25 +18,3 @@ def queue_task(self, job: models.Job):
transcribe = self.celery.signature("app.worker.main.transcribe")
# TODO: catch delivery errors?
transcribe.delay(job.id)

def rehydrate(self, session: Session):
# TODO: we could use `acks_late` to handle this scenario within celery itself.
# the reason this does not work well in our case is that `visibility_timeout`
# needs to be very high since whisper workers can be long running.
# doing this app-side bears the risk of poison pilling the worker though,
# implement a workaround with an acceptable trade-off. (=> retry only once?)
jobs = (
session.query(models.Job)
.filter(
or_(
models.Job.status == models.JobStatus.processing,
models.Job.status == models.JobStatus.create,
)
)
.order_by(models.Job.created_at)
).all()

logger.info(f"Requeueing {len(jobs)} jobs.")

for job in jobs:
self.queue_task(job)
22 changes: 20 additions & 2 deletions app/worker/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
bind=True,
soft_time_limit=settings.TASK_SOFT_TIME_LIMIT,
time_limit=settings.TASK_HARD_TIME_LIMIT,
task_acks_late=True,
task_acks_on_failure_or_timeout=True,
task_reject_on_worker_lost=True,
)
def transcribe(self: Task, job_id: UUID) -> None:
try:
Expand All @@ -59,9 +62,20 @@ def transcribe(self: Task, job_id: UUID) -> None:

logger.debug(f"[{job.id}]: start processing {job.type} job.")

if job.meta:
attempts = 1 + (job.meta.get("attempts") or 0)
else:
attempts = 1

# SAFEGUARD: celery's retry policies do not handle lost workers, retry once.
# @see https://github.com/celery/celery/pull/6103
if attempts > 2:
raise Exception("Maximum number of retries exceeded for killed worker.")

# unit of work: set task status to processing.

job.meta = {"task_id": self.request.id}
job.meta = {"task_id": self.request.id, "attempts": attempts}

job.status = models.JobStatus.processing
db.commit()

Expand All @@ -83,7 +97,11 @@ def transcribe(self: Task, job_id: UUID) -> None:
if job and db:
if db.in_transaction():
db.rollback()
job.meta = {**job.meta, "error": str(e)} # type: ignore
if job.meta:
job.meta = {**job.meta, "error": str(e)} # type: ignore
else:
job.meta = {"error": str(e)}

job.status = models.JobStatus.error
db.commit()
raise
Expand Down
1 change: 1 addition & 0 deletions conf/rabbitmq.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vm_memory_high_watermark.absolute = 192MB
42 changes: 29 additions & 13 deletions docker-compose.base.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
x-broker-environment: &broker-environment
BROKER_URL: "amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672"

version: "3.8"
name: whisperbox-transcribe

Expand All @@ -12,46 +15,59 @@ services:
networks:
- traefik

redis:
image: redis:7-alpine
rabbitmq:
env_file: .env
image: rabbitmq:3-alpine
networks:
- app
deploy:
resources:
limits:
memory: 128M
memory: 256M
healthcheck:
test: rabbitmq-diagnostics check_port_connectivity
interval: 3s
timeout: 3s
retries: 10

volumes:
- ./conf/rabbitmq.conf:/etc/rabbitmq/rabbitmq.conf
- rabbitmq-data:/var/lib/rabbitmq/mnesia/

worker:
env_file: .env
environment:
<<: *broker-environment
build:
context: .
dockerfile: worker.Dockerfile
args:
WHISPER_MODEL: ${WHISPER_MODEL}
depends_on:
rabbitmq:
condition: service_healthy
networks:
- app
depends_on:
- redis
healthcheck:
test: ["CMD-SHELL", "celery -b ${BROKER_URL} inspect ping -d celery@$$HOSTNAME"]
interval: 5s
timeout: 5s
retries: 5

web:
env_file: .env
environment:
<<: *broker-environment
build:
context: .
dockerfile: web.Dockerfile
depends_on:
rabbitmq:
condition: service_healthy
networks:
- app
- traefik
depends_on:
worker:
condition: service_healthy

networks:
app:
driver: bridge
traefik:
driver: bridge

volumes:
rabbitmq-data:
13 changes: 10 additions & 3 deletions docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ services:

web:
command: bash -c "alembic upgrade head && uvicorn app.web:app --reload --host ${HOST:-0.0.0.0} --port ${PORT:-8000} --log-level info"
# NOTE: the docker on mac mount adapter (virtioFS) does not support flock.
# this can cause the sqlite database to corrupt when written from worker <> api simultaneously.
volumes:
- ./:/etc/whisperbox-transcribe/
labels:
Expand All @@ -26,13 +28,18 @@ services:
volumes:
- ./:/etc/whisperbox-transcribe/

rabbitmq:
image: rabbitmq:3-management-alpine
ports:
- 15672:15672

flower:
image: mher/flower
command: celery --broker redis://redis:6379/0 flower --port=5555
command: celery --broker amqp://${RABBITMQ_DEFAULT_USER}:${RABBITMQ_DEFAULT_PASS}@rabbitmq:5672 flower --port=5555
ports:
- 5555:5555
depends_on:
worker:
condition: service_healthy
- worker
- rabbitmq
networks:
- app
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ description = ""
version = "1.0.0"

dependencies=[
"celery[redis] ==5.3.1",
"celery ==5.3.1",
"sqlalchemy[mypy] ==2.0.20",
"pydantic ==2.1.1",
"pydantic-settings ==2.0.3"
Expand Down

0 comments on commit 504975a

Please sign in to comment.