-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Proper api surface like openai for audio transcription
- Loading branch information
Showing
3 changed files
with
58 additions
and
47 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,50 @@ | ||
from fastapi import APIRouter, Body, File, Security, UploadFile | ||
from typing import List, Literal | ||
|
||
from fastapi import APIRouter, File, Form, Request, Security, UploadFile | ||
from fastapi.responses import PlainTextResponse | ||
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES | ||
|
||
from schemas.audio import Transcription, TranscriptionRequest | ||
from schemas.audio import AudioTranscription, AudioTranscriptionVerbose | ||
from utils.args import args | ||
from utils.exceptions import ModelNotFoundException | ||
from utils.lifespan import pipelines | ||
from utils.security import check_api_key | ||
|
||
|
||
router = APIRouter() | ||
|
||
SUPPORTED_LANGUAGES = set(list(LANGUAGES.keys()) + list(TO_LANGUAGE_CODE.keys())) | ||
|
||
|
||
@router.post("/audio/transcriptions") | ||
async def audio_transcriptions(file: UploadFile = File(...), request: TranscriptionRequest = Body(...), api_key=Security(check_api_key)): | ||
async def audio_transcriptions( | ||
request: Request, | ||
file: UploadFile = File(...), | ||
model: str = Form(args.model), | ||
language: str = Form("en"), | ||
prompt: str = Form(None), # TODO: implement | ||
response_format: Literal["text", "json"] = Form("json"), | ||
temperature: float = Form(0), | ||
timestamp_granularities: List[str] = Form(alias="timestamp_granularities[]", default=["segment"]), # TODO: implement | ||
api_key=Security(check_api_key) | ||
) -> AudioTranscription | AudioTranscriptionVerbose: | ||
""" | ||
Audio transcriptions API similar to OpenAI's API. | ||
See https://platform.openai.com/docs/api-reference/audio/create-transcription for the API specification. | ||
""" | ||
|
||
if model != args.model: | ||
raise ModelNotFoundException() | ||
|
||
if language not in SUPPORTED_LANGUAGES: | ||
raise ValueError(f"Language {language} not supported") | ||
|
||
file = await file.read() | ||
result = pipelines[request.model]( | ||
file, generate_kwargs={"language": request.language, "temperature": request.temperature}, return_timestamps=True | ||
result = pipelines[model]( | ||
file, generate_kwargs={"language": language, "temperature": temperature}, return_timestamps=True | ||
) | ||
|
||
if request.response_format == "text": | ||
if response_format == "text": | ||
return PlainTextResponse(result["text"]) | ||
|
||
return Transcription(**result) | ||
return AudioTranscription(**result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,49 +1,35 @@ | ||
from typing import Literal, Optional | ||
import json | ||
from typing import List | ||
|
||
from openai.types.audio import Transcription | ||
from pydantic import BaseModel, Field, field_validator, model_validator | ||
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE, LANGUAGES | ||
from pydantic import BaseModel | ||
|
||
from utils.args import args | ||
from utils.exceptions import ModelNotFoundException | ||
|
||
class AudioTranscription(Transcription): | ||
pass | ||
|
||
SUPPORTED_LANGUAGES = set(list(LANGUAGES.keys()) + list(TO_LANGUAGE_CODE.keys())) | ||
|
||
class Word(BaseModel): | ||
word: str | ||
start: float | ||
end: float | ||
|
||
class Transcription(Transcription): | ||
pass | ||
|
||
class Segment(BaseModel): | ||
id: int | ||
seek: int | ||
start: float | ||
end: float | ||
text: str | ||
tokens: List[int] | ||
temperature: float | ||
avg_logprob: float | ||
compression_ratio: float | ||
no_speech_prob: float | ||
|
||
|
||
class TranscriptionRequest(BaseModel): | ||
model: str = Field(default=args.model) | ||
language: str= Field(default="en") | ||
response_format: Literal["text", "json"] = Field(default="json") | ||
temperature: Optional[float] = Field(default=None, ge=0, le=1) | ||
|
||
|
||
@field_validator("language") | ||
@classmethod | ||
def validate_language(cls, value): | ||
if value not in SUPPORTED_LANGUAGES: | ||
raise ValueError(f"Language {value} not supported") | ||
return value | ||
|
||
@field_validator("model") | ||
@classmethod | ||
def validate_model(cls, model): | ||
if model != args.model: | ||
raise ModelNotFoundException() | ||
return model | ||
|
||
@model_validator(mode="before") | ||
@classmethod | ||
def validate_to_json(cls, value): | ||
if isinstance(value, str): | ||
return cls(**json.loads(value)) | ||
return value | ||
|
||
|
||
class Transcription(BaseModel): | ||
class AudioTranscriptionVerbose(AudioTranscription): | ||
language: str | ||
duration: float | ||
text: str | ||
words: List[Word] | ||
segments: List[Segment] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters