Skip to content

Commit

Permalink
refactor: improve use dependency injection
Browse files Browse the repository at this point in the history
- remove settings global
- remove engine / session globals
  • Loading branch information
fspoettel committed Aug 17, 2023
1 parent 504975a commit 90884c6
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 110 deletions.
10 changes: 3 additions & 7 deletions app/shared/celery.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
from celery import Celery

from app.shared.settings import settings


def get_celery_binding() -> Celery:
celery = Celery(
broker_url=settings.BROKER_URL,
def get_celery_binding(broker_url: str) -> Celery:
return Celery(
broker_url=broker_url,
broker_connection_retry=False,
broker_connection_retry_on_startup=False,
)

return celery
4 changes: 3 additions & 1 deletion app/shared/db/alembic/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from sqlalchemy import engine_from_config, pool

from app.shared.db.models import Base
from app.shared.settings import settings
from app.shared.settings import Settings

settings = Settings() # type: ignore

# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
Expand Down
33 changes: 14 additions & 19 deletions app/shared/db/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,21 @@
from typing import Any, Generator
from typing import Any

from sqlalchemy import create_engine, event
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import Engine, create_engine, event
from sqlalchemy.orm import sessionmaker

from app.shared.settings import settings

engine = create_engine(settings.DATABASE_URI, connect_args={"check_same_thread": False})
def get_engine(database_url: str):
engine = create_engine(database_url, connect_args={"check_same_thread": False})

@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()

@event.listens_for(engine, "connect")
def set_sqlite_pragma(conn: Any, _: Any) -> None:
cursor = conn.cursor()
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
return engine


SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)


def get_session() -> Generator[Session, None, None]:
session: Session = SessionLocal()
try:
yield session
finally:
session.close()
def get_session_local(engine: Engine):
session_local = sessionmaker(autocommit=False, autoflush=False, bind=engine)
return session_local
48 changes: 26 additions & 22 deletions app/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,55 @@
from sqlalchemy_utils import create_database, database_exists, drop_database

import app.shared.db.models as models
from app.shared.db.base import SessionLocal, engine
from app.shared.settings import settings
from app.shared.db.base import get_engine, get_session_local
from app.shared.settings import Settings
from app.web.injections.db import get_session
from app.web.main import app_factory


def pytest_configure() -> None:
if not database_exists(engine.url):
create_database(engine.url)


def pytest_unconfigure() -> None:
if database_exists(engine.url):
drop_database(engine.url)
@pytest.fixture()
def settings():
return Settings(_env_file=".env.test") # type: ignore # noqa: E501


@pytest.fixture()
def auth_headers() -> dict[str, str]:
def auth_headers(settings) -> dict[str, str]:
return {"Authorization": f"Bearer {settings.API_SECRET}"}


@pytest.fixture()
def test_db():
def test_db(settings):
engine = get_engine(settings.DATABASE_URI)

if not database_exists(engine.url):
create_database(engine.url)

models.Base.metadata.create_all(engine)

connection = engine.connect()
yield connection
connection.close()

models.Base.metadata.drop_all(bind=engine)
drop_database(engine.url)


@pytest.fixture()
def db_session(test_db):
with SessionLocal(bind=test_db) as session:
session_local = get_session_local(test_db)
with session_local() as session:
yield session


@pytest.fixture()
def client(db_session):
app = app_factory(lambda: db_session)
def app(settings, db_session):
app = app_factory(settings)
app.dependency_overrides[get_session] = lambda: db_session
return app


@pytest.fixture()
def client(app):
client = TestClient(app)
return client

Expand All @@ -66,10 +77,3 @@ def mock_artifact(db_session, mock_job):
db_session.add(artifact)
db_session.commit()
return artifact


@pytest.fixture()
def sharing_enabled():
settings.ENABLE_SHARING = True
yield
settings.ENABLE_SHARING = False
20 changes: 10 additions & 10 deletions app/tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fastapi.testclient import TestClient

import app.shared.db.models as models
from app.web.main import app_factory
from app.shared.settings import Settings
from app.web.injections.settings import get_settings


# POST /api/v1/jobs
Expand Down Expand Up @@ -61,24 +60,25 @@ def test_get_job_not_found(client, auth_headers: dict[str, str], mock_job):
assert res.status_code == 404


def test_get_job_sharing_disabled(client, mock_job):
def test_get_job_sharing_enabled(client, app, mock_job):
app.dependency_overrides[get_settings] = lambda: Settings(
_env_file=".env.test", ENABLE_SHARING=True # type: ignore
)

res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers={},
)
assert res.status_code == 401

assert res.status_code == 200

def test_get_job_sharing_enabled(db_session, mock_job, sharing_enabled):
# HACK: delay construction until settings are patched.
client = TestClient(app_factory(lambda: db_session))

def test_get_job_sharing_disabled(client, mock_job):
res = client.get(
f"/api/v1/jobs/{mock_job.id}",
headers={},
)

assert res.status_code == 200
assert res.status_code == 401


# GET /api/v1/jobs/:id/artifacts
Expand Down
6 changes: 4 additions & 2 deletions app/web/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from app.shared.db.base import get_session
from app.shared.settings import Settings
from app.web.main import app_factory

app = app_factory(get_session)

def app():
return app_factory(Settings()) # type: ignore
Empty file added app/web/injections/__init__.py
Empty file.
12 changes: 12 additions & 0 deletions app/web/injections/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Generator

from fastapi import Request
from sqlalchemy.orm import Session


def get_session(request: Request) -> Generator[Session, None, None]:
session: Session = request.app.state.session_local()
try:
yield session
finally:
session.close()
39 changes: 39 additions & 0 deletions app/web/injections/security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from hmac import compare_digest
from typing import Annotated

from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer

from app.shared.settings import Settings
from app.web.injections.settings import get_settings


def validate_credentials(credentials: HTTPAuthorizationCredentials, secret: str):
# use compare_digest to counter timing attacks.
if (
not credentials
or not secret
or not compare_digest(secret, credentials.credentials)
):
raise HTTPException(status_code=401)


def authenticate_api_key(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
validate_credentials(credentials, settings.API_SECRET)


def authenticate_sharing(
credentials: Annotated[
HTTPAuthorizationCredentials, Depends(HTTPBearer(auto_error=False))
],
settings: Annotated[Settings, Depends(get_settings)],
):
if settings.ENABLE_SHARING:
pass
else:
validate_credentials(credentials, settings.API_SECRET)
5 changes: 5 additions & 0 deletions app/web/injections/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fastapi import Request


def get_settings(request: Request):
return request.app.state.settings
5 changes: 5 additions & 0 deletions app/web/injections/task_queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from fastapi import Request


def get_task_queue(request: Request):
return request.app.state.task_queue
32 changes: 19 additions & 13 deletions app/web/main.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,40 @@
from typing import Annotated, Callable, Generator
from typing import Annotated
from uuid import UUID

from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path
from pydantic import AnyHttpUrl, BaseModel, Field
from sqlalchemy import create_engine
from sqlalchemy.orm import Session

import app.shared.db.models as models
import app.web.dtos as dtos
from app.shared.settings import settings
from app.web.security import authenticate_api_key
from app.shared.db.base import get_session_local
from app.web.injections.db import get_session
from app.web.injections.security import authenticate_api_key, authenticate_sharing
from app.web.injections.task_queue import get_task_queue
from app.web.task_queue import TaskQueue

DatabaseSession = Annotated[Session, Depends(get_session)]

def app_factory(
session_getter: Callable[[], Generator[Session, None, None]]
) -> FastAPI:
DatabaseSession = Annotated[Session, Depends(session_getter)]

task_queue = TaskQueue()

def app_factory(settings):
app = FastAPI(
description=(
"whisperbox-transcribe is an async HTTP wrapper for openai/whisper."
),
title="whisperbox-transcribe",
)

engine = create_engine(settings.DATABASE_URI)

app.state.settings = settings # type: ignore
app.state.session_local = get_session_local(engine)
app.state.task_queue = TaskQueue(settings.BROKER_URL)

api_router = APIRouter(prefix="/api/v1")

@api_router.get("/", response_model=None, status_code=204)
def api_root() -> None:
@api_router.get("/", status_code=204)
def api_root():
return None

@api_router.get(
Expand All @@ -52,7 +57,7 @@ def get_jobs(

@api_router.get(
"/jobs/{id}",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
dependencies=[Depends(authenticate_sharing)],
response_model=dtos.Job,
summary="Get metadata for one job",
)
Expand All @@ -72,7 +77,7 @@ def get_job(

@api_router.get(
"/jobs/{id}/artifacts",
dependencies=[] if settings.ENABLE_SHARING else [Depends(authenticate_api_key)],
dependencies=[Depends(authenticate_sharing)],
response_model=list[dtos.Artifact],
summary="Get all artifacts for one job",
)
Expand Down Expand Up @@ -138,6 +143,7 @@ class PostJobPayload(BaseModel):
def create_job(
payload: PostJobPayload,
session: DatabaseSession,
task_queue: Annotated[TaskQueue, Depends(get_task_queue)],
) -> models.Job:
"""
Enqueue a new whisper job for processing.
Expand Down
16 changes: 0 additions & 16 deletions app/web/security.py

This file was deleted.

4 changes: 2 additions & 2 deletions app/web/task_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
class TaskQueue:
celery: Celery

def __init__(self) -> None:
self.celery = get_celery_binding()
def __init__(self, broker_url: str) -> None:
self.celery = get_celery_binding(broker_url=broker_url)

def queue_task(self, job: models.Job):
"""
Expand Down
Loading

0 comments on commit 90884c6

Please sign in to comment.