From 90884c6bab7d05871b3500ef7c9d6ff72f2b9608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felix=20Sp=C3=B6ttel?= <1682504+fspoettel@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:37:30 +0200 Subject: [PATCH] refactor: improve use dependency injection - remove settings global - remove engine / session globals --- app/shared/celery.py | 10 ++----- app/shared/db/alembic/env.py | 4 ++- app/shared/db/base.py | 33 +++++++++------------ app/tests/conftest.py | 48 ++++++++++++++++-------------- app/tests/test_api.py | 20 ++++++------- app/web/__init__.py | 6 ++-- app/web/injections/__init__.py | 0 app/web/injections/db.py | 12 ++++++++ app/web/injections/security.py | 39 ++++++++++++++++++++++++ app/web/injections/settings.py | 5 ++++ app/web/injections/task_queue.py | 5 ++++ app/web/main.py | 32 ++++++++++++-------- app/web/security.py | 16 ---------- app/web/task_queue.py | 4 +-- app/worker/main.py | 51 +++++++++++++++++++++----------- mypy.ini | 1 + 16 files changed, 176 insertions(+), 110 deletions(-) create mode 100644 app/web/injections/__init__.py create mode 100644 app/web/injections/db.py create mode 100644 app/web/injections/security.py create mode 100644 app/web/injections/settings.py create mode 100644 app/web/injections/task_queue.py delete mode 100644 app/web/security.py diff --git a/app/shared/celery.py b/app/shared/celery.py index 71c1342..7c216bd 100644 --- a/app/shared/celery.py +++ b/app/shared/celery.py @@ -1,13 +1,9 @@ from celery import Celery -from app.shared.settings import settings - -def get_celery_binding() -> Celery: - celery = Celery( - broker_url=settings.BROKER_URL, +def get_celery_binding(broker_url: str) -> Celery: + return Celery( + broker_url=broker_url, broker_connection_retry=False, broker_connection_retry_on_startup=False, ) - - return celery diff --git a/app/shared/db/alembic/env.py b/app/shared/db/alembic/env.py index 53ffd1f..97bce90 100644 --- a/app/shared/db/alembic/env.py +++ b/app/shared/db/alembic/env.py @@ -4,7 +4,9 @@ from sqlalchemy import engine_from_config, pool from app.shared.db.models import Base -from app.shared.settings import settings +from app.shared.settings import Settings + +settings = Settings() # type: ignore # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/app/shared/db/base.py b/app/shared/db/base.py index a717f39..52fa033 100644 --- a/app/shared/db/base.py +++ b/app/shared/db/base.py @@ -1,26 +1,21 @@ -from typing import Any, Generator +from typing import Any -from sqlalchemy import create_engine, event -from sqlalchemy.orm import Session, sessionmaker +from sqlalchemy import Engine, create_engine, event +from sqlalchemy.orm import sessionmaker -from app.shared.settings import settings -engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False}) +def get_engine(database_url: str): + engine = create_engine(database_url, connect_args={"check_same_thread": False}) + @event.listens_for(engine, "connect") + def set_sqlite_pragma(conn: Any, _: Any) -> None: + cursor = conn.cursor() + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() -@event.listens_for(engine, "connect") -def set_sqlite_pragma(conn: Any, _: Any) -> None: - cursor = conn.cursor() - cursor.execute("PRAGMA journal_mode=WAL") - cursor.close() + return engine -SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - -def get_session() -> Generator[Session, None, None]: - session: Session = SessionLocal() - try: - yield session - finally: - session.close() +def get_session_local(engine: Engine): + session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine) + return session_local diff --git a/app/tests/conftest.py b/app/tests/conftest.py index b2b08d9..3c227ed 100644 --- a/app/tests/conftest.py +++ b/app/tests/conftest.py @@ -3,44 +3,55 @@ from sqlalchemy_utils import create_database, database_exists, drop_database import app.shared.db.models as models -from app.shared.db.base import SessionLocal, engine -from app.shared.settings import settings +from app.shared.db.base import get_engine, get_session_local +from app.shared.settings import Settings +from app.web.injections.db import get_session from app.web.main import app_factory -def pytest_configure() -> None: - if not database_exists(engine.url): - create_database(engine.url) - - -def pytest_unconfigure() -> None: - if database_exists(engine.url): - drop_database(engine.url) +@pytest.fixture() +def settings(): + return Settings(_env_file=".env.test") # type: ignore # noqa: E501 @pytest.fixture() -def auth_headers() -> dict[str, str]: +def auth_headers(settings) -> dict[str, str]: return {"Authorization": f"Bearer {settings.API_SECRET}"} @pytest.fixture() -def test_db(): +def test_db(settings): + engine = get_engine(settings.DATABASE_URI) + + if not database_exists(engine.url): + create_database(engine.url) + models.Base.metadata.create_all(engine) + connection = engine.connect() yield connection connection.close() + models.Base.metadata.drop_all(bind=engine) + drop_database(engine.url) @pytest.fixture() def db_session(test_db): - with SessionLocal(bind=test_db) as session: + session_local = get_session_local(test_db) + with session_local() as session: yield session @pytest.fixture() -def client(db_session): - app = app_factory(lambda: db_session) +def app(settings, db_session): + app = app_factory(settings) + app.dependency_overrides[get_session] = lambda: db_session + return app + + +@pytest.fixture() +def client(app): client = TestClient(app) return client @@ -66,10 +77,3 @@ def mock_artifact(db_session, mock_job): db_session.add(artifact) db_session.commit() return artifact - - -@pytest.fixture() -def sharing_enabled(): - settings.ENABLE_SHARING = True - yield - settings.ENABLE_SHARING = False diff --git a/app/tests/test_api.py b/app/tests/test_api.py index 8d86677..db58c28 100644 --- a/app/tests/test_api.py +++ b/app/tests/test_api.py @@ -1,7 +1,6 @@ -from fastapi.testclient import TestClient - import app.shared.db.models as models -from app.web.main import app_factory +from app.shared.settings import Settings +from app.web.injections.settings import get_settings # POST /api/v1/jobs @@ -61,24 +60,25 @@ def test_get_job_not_found(client, auth_headers: dict[str, str], mock_job): assert res.status_code == 404 -def test_get_job_sharing_disabled(client, mock_job): +def test_get_job_sharing_enabled(client, app, mock_job): + app.dependency_overrides[get_settings] = lambda: Settings( + _env_file=".env.test", ENABLE_SHARING=True # type: ignore + ) + res = client.get( f"/api/v1/jobs/{mock_job.id}", headers={}, ) - assert res.status_code == 401 + assert res.status_code == 200 -def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled): - # HACK: delay construction until settings are patched. - client = TestClient(app_factory(lambda: db_session)) +def test_get_job_sharing_disabled(client, mock_job): res = client.get( f"/api/v1/jobs/{mock_job.id}", headers={}, ) - - assert res.status_code == 200 + assert res.status_code == 401 # GET /api/v1/jobs/:id/artifacts diff --git a/app/web/__init__.py b/app/web/__init__.py index 61dd17a..b761491 100644 --- a/app/web/__init__.py +++ b/app/web/__init__.py @@ -1,4 +1,6 @@ -from app.shared.db.base import get_session +from app.shared.settings import Settings from app.web.main import app_factory -app = app_factory(get_session) + +def app(): + return app_factory(Settings()) # type: ignore diff --git a/app/web/injections/__init__.py b/app/web/injections/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/web/injections/db.py b/app/web/injections/db.py new file mode 100644 index 0000000..0645c8b --- /dev/null +++ b/app/web/injections/db.py @@ -0,0 +1,12 @@ +from typing import Generator + +from fastapi import Request +from sqlalchemy.orm import Session + + +def get_session(request: Request) -> Generator[Session, None, None]: + session: Session = request.app.state.session_local() + try: + yield session + finally: + session.close() diff --git a/app/web/injections/security.py b/app/web/injections/security.py new file mode 100644 index 0000000..220fc9e --- /dev/null +++ b/app/web/injections/security.py @@ -0,0 +1,39 @@ +from hmac import compare_digest +from typing import Annotated + +from fastapi import Depends, HTTPException +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer + +from app.shared.settings import Settings +from app.web.injections.settings import get_settings + + +def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str): + # use compare_digest to counter timing attacks. + if ( + not credentials + or not secret + or not compare_digest(secret, credentials.credentials) + ): + raise HTTPException(status_code=401) + + +def authenticate_api_key( + credentials: Annotated[ + HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False)) + ], + settings: Annotated[Settings, Depends(get_settings)], +): + validate_credentials(credentials, settings.API_SECRET) + + +def authenticate_sharing( + credentials: Annotated[ + HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False)) + ], + settings: Annotated[Settings, Depends(get_settings)], +): + if settings.ENABLE_SHARING: + pass + else: + validate_credentials(credentials, settings.API_SECRET) diff --git a/app/web/injections/settings.py b/app/web/injections/settings.py new file mode 100644 index 0000000..f92c908 --- /dev/null +++ b/app/web/injections/settings.py @@ -0,0 +1,5 @@ +from fastapi import Request + + +def get_settings(request: Request): + return request.app.state.settings diff --git a/app/web/injections/task_queue.py b/app/web/injections/task_queue.py new file mode 100644 index 0000000..ff765ec --- /dev/null +++ b/app/web/injections/task_queue.py @@ -0,0 +1,5 @@ +from fastapi import Request + + +def get_task_queue(request: Request): + return request.app.state.task_queue diff --git a/app/web/main.py b/app/web/main.py index 5738b9b..a78705e 100644 --- a/app/web/main.py +++ b/app/web/main.py @@ -1,24 +1,23 @@ -from typing import Annotated, Callable, Generator +from typing import Annotated from uuid import UUID from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path from pydantic import AnyHttpUrl, BaseModel, Field +from sqlalchemy import create_engine from sqlalchemy.orm import Session import app.shared.db.models as models import app.web.dtos as dtos -from app.shared.settings import settings -from app.web.security import authenticate_api_key +from app.shared.db.base import get_session_local +from app.web.injections.db import get_session +from app.web.injections.security import authenticate_api_key, authenticate_sharing +from app.web.injections.task_queue import get_task_queue from app.web.task_queue import TaskQueue +DatabaseSession = Annotated[Session, Depends(get_session)] -def app_factory( - session_getter: Callable[[], Generator[Session, None, None]] -) -> FastAPI: - DatabaseSession = Annotated[Session, Depends(session_getter)] - - task_queue = TaskQueue() +def app_factory(settings): app = FastAPI( description=( "whisperbox-transcribe is an async HTTP wrapper for openai/whisper." @@ -26,10 +25,16 @@ def app_factory( title="whisperbox-transcribe", ) + engine = create_engine(settings.DATABASE_URI) + + app.state.settings = settings # type: ignore + app.state.session_local = get_session_local(engine) + app.state.task_queue = TaskQueue(settings.BROKER_URL) + api_router = APIRouter(prefix="/api/v1") - @api_router.get("/", response_model=None, status_code=204) - def api_root() -> None: + @api_router.get("/", status_code=204) + def api_root(): return None @api_router.get( @@ -52,7 +57,7 @@ def get_jobs( @api_router.get( "/jobs/{id}", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + dependencies=[Depends(authenticate_sharing)], response_model=dtos.Job, summary="Get metadata for one job", ) @@ -72,7 +77,7 @@ def get_job( @api_router.get( "/jobs/{id}/artifacts", - dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)], + dependencies=[Depends(authenticate_sharing)], response_model=list[dtos.Artifact], summary="Get all artifacts for one job", ) @@ -138,6 +143,7 @@ class PostJobPayload(BaseModel): def create_job( payload: PostJobPayload, session: DatabaseSession, + task_queue: Annotated[TaskQueue, Depends(get_task_queue)], ) -> models.Job: """ Enqueue a new whisper job for processing. diff --git a/app/web/security.py b/app/web/security.py deleted file mode 100644 index 6d66ef1..0000000 --- a/app/web/security.py +++ /dev/null @@ -1,16 +0,0 @@ -from hmac import compare_digest - -from fastapi import Depends, HTTPException -from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer - -from app.shared.settings import settings - - -def authenticate_api_key( - credentials: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)), -) -> None: - # use compare_digest to counter timing attacks. - if not credentials or not compare_digest( - settings.API_SECRET, credentials.credentials - ): - raise HTTPException(status_code=401) diff --git a/app/web/task_queue.py b/app/web/task_queue.py index 1d630ab..77d1353 100644 --- a/app/web/task_queue.py +++ b/app/web/task_queue.py @@ -7,8 +7,8 @@ class TaskQueue: celery: Celery - def __init__(self) -> None: - self.celery = get_celery_binding() + def __init__(self, broker_url: str) -> None: + self.celery = get_celery_binding(broker_url=broker_url) def queue_task(self, job: models.Job): """ diff --git a/app/worker/main.py b/app/worker/main.py index 0e98169..20acb6a 100644 --- a/app/worker/main.py +++ b/app/worker/main.py @@ -7,11 +7,14 @@ import app.shared.db.models as models from app.shared.celery import get_celery_binding -from app.shared.db.base import SessionLocal -from app.shared.settings import settings +from app.shared.db.base import get_engine, get_session_local +from app.shared.settings import Settings from app.worker.strategies.local import LocalStrategy -celery = get_celery_binding() +settings = Settings() # type: ignore +celery = get_celery_binding(settings.BROKER_URL) +engine = get_engine(settings.DATABASE_URI) +SessionLocal = get_session_local(engine) class TranscribeTask(Task): @@ -43,14 +46,24 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: task_acks_on_failure_or_timeout=True, task_reject_on_worker_lost=True, ) -def transcribe(self: Task, job_id: UUID) -> None: +def transcribe(self: TranscribeTask, job_id: UUID) -> None: + session: Session | None = None + job: models.Job | None = None + try: + if not self.strategy: + raise Exception("expected a transcription strategy to be defined.") + # runs in a separate thread => requires sqlite's WAL mode to be enabled. - db: Session = SessionLocal() + session = SessionLocal() + + # work around mypy not inferring the sum type correctly. + if not session: + raise Exception("failed to acquire a session.") # check if passed job should be processed. - job = db.query(models.Job).filter(models.Job.id == job_id).one_or_none() + job = session.query(models.Job).filter(models.Job.id == job_id).one_or_none() if job is None: logger.warn("[{job.id}]: Received unknown job, abort.") @@ -62,7 +75,7 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: start processing {job.type} job.") - if job.meta: + if job.meta is not None: attempts = 1 + (job.meta.get("attempts") or 0) else: attempts = 1 @@ -77,7 +90,7 @@ def transcribe(self: Task, job_id: UUID) -> None: job.meta = {"task_id": self.request.id, "attempts": attempts} job.status = models.JobStatus.processing - db.commit() + session.commit() logger.debug(f"[{job.id}]: finished setting task to {job.status}.") @@ -86,25 +99,27 @@ def transcribe(self: Task, job_id: UUID) -> None: logger.debug(f"[{job.id}]: successfully processed audio.") artifact = models.Artifact(job_id=str(job.id), data=result, type=result_type) - db.add(artifact) + session.add(artifact) job.status = models.JobStatus.success - db.commit() + session.commit() logger.debug(f"[{job.id}]: successfully stored artifact.") except Exception as e: - if job and db: - if db.in_transaction(): - db.rollback() - if job.meta: - job.meta = {**job.meta, "error": str(e)} # type: ignore + if job and session: + if session.in_transaction(): + session.rollback() + if job.meta is not None: + job.meta = {**job.meta, "error": str(e)} else: job.meta = {"error": str(e)} job.status = models.JobStatus.error - db.commit() + session.commit() raise finally: - self.strategy.cleanup(job_id) - db.close() + if self.strategy: + self.strategy.cleanup(job_id) + if session: + session.close() diff --git a/mypy.ini b/mypy.ini index fa12c48..aa36381 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,3 +2,4 @@ plugins = sqlalchemy.ext.mypy.plugin ignore_missing_imports = True disallow_untyped_defs = False +check_untyped_defs = True