Skip to content

Commit

Permalink
Proper api surface like openai for audio transcription
Browse files Browse the repository at this point in the history
  • Loading branch information
Ledoux committed Oct 30, 2024
1 parent 6435004 commit f925654
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 47 deletions.
38 changes: 31 additions & 7 deletions app/endpoints/audio.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,50 @@
from fastapi import APIRouter, Body, File, Security, UploadFile
from typing import List, Literal

from fastapi import APIRouter, File, Form, Request, Security, UploadFile
from fastapi.responses import PlainTextResponse
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES

from schemas.audio import Transcription, TranscriptionRequest
from schemas.audio import AudioTranscription, AudioTranscriptionVerbose
from utils.args import args
from utils.exceptions import ModelNotFoundException
from utils.lifespan import pipelines
from utils.security import check_api_key


router = APIRouter()

SUPPORTED_LANGUAGES = set(list(LANGUAGES.keys()) + list(TO_LANGUAGE_CODE.keys()))


@router.post("/audio/transcriptions")
async def audio_transcriptions(file: UploadFile = File(...), request: TranscriptionRequest = Body(...), api_key=Security(check_api_key)):
async def audio_transcriptions(
request: Request,
file: UploadFile = File(...),
model: str = Form(args.model),
language: str = Form("en"),
prompt: str = Form(None), # TODO: implement
response_format: Literal["text", "json"] = Form("json"),
temperature: float = Form(0),
timestamp_granularities: List[str] = Form(alias="timestamp_granularities[]", default=["segment"]), # TODO: implement
api_key=Security(check_api_key)
) -> AudioTranscription | AudioTranscriptionVerbose:
"""
Audio transcriptions API similar to OpenAI's API.
See https://platform.openai.com/docs/api-reference/audio/create-transcription for the API specification.
"""

if model != args.model:
raise ModelNotFoundException()

if language not in SUPPORTED_LANGUAGES:
raise ValueError(f"Language {language} not supported")

file = await file.read()
result = pipelines[request.model](
file, generate_kwargs={"language": request.language, "temperature": request.temperature}, return_timestamps=True
result = pipelines[model](
file, generate_kwargs={"language": language, "temperature": temperature}, return_timestamps=True
)

if request.response_format == "text":
if response_format == "text":
return PlainTextResponse(result["text"])

return Transcription(**result)
return AudioTranscription(**result)
64 changes: 25 additions & 39 deletions app/schemas/audio.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,35 @@
from typing import Literal, Optional
import json
from typing import List

from openai.types.audio import Transcription
from pydantic import BaseModel, Field, field_validator, model_validator
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES
from pydantic import BaseModel

from utils.args import args
from utils.exceptions import ModelNotFoundException

class AudioTranscription(Transcription):
pass

SUPPORTED_LANGUAGES = set(list(LANGUAGES.keys()) + list(TO_LANGUAGE_CODE.keys()))

class Word(BaseModel):
word: str
start: float
end: float

class Transcription(Transcription):
pass

class Segment(BaseModel):
id: int
seek: int
start: float
end: float
text: str
tokens: List[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float


class TranscriptionRequest(BaseModel):
model: str = Field(default=args.model)
language: str= Field(default="en")
response_format: Literal["text", "json"] = Field(default="json")
temperature: Optional[float] = Field(default=None, ge=0, le=1)


@field_validator("language")
@classmethod
def validate_language(cls, value):
if value not in SUPPORTED_LANGUAGES:
raise ValueError(f"Language {value} not supported")
return value

@field_validator("model")
@classmethod
def validate_model(cls, model):
if model != args.model:
raise ModelNotFoundException()
return model

@model_validator(mode="before")
@classmethod
def validate_to_json(cls, value):
if isinstance(value, str):
return cls(**json.loads(value))
return value


class Transcription(BaseModel):
class AudioTranscriptionVerbose(AudioTranscription):
language: str
duration: float
text: str
words: List[Word]
segments: List[Segment]
3 changes: 2 additions & 1 deletion app/utils/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from utils.config import API_KEY
from utils.exceptions import InvalidAuthenticationSchemeException, InvalidAPIKeyException


auth_scheme = HTTPBearer(scheme_name="API key")

if not API_KEY:
Expand All @@ -22,4 +23,4 @@ def check_api_key(api_key: Annotated[HTTPAuthorizationCredentials, Depends(auth_
if api_key.credentials != API_KEY:
raise InvalidAPIKeyException()

return api_key.credentials
return api_key.credentials

0 comments on commit f925654

Please sign in to comment.