Skip to content

Commit

Permalink
Merge pull request #1 from eliteprox/rename-audio-to-text
Browse files Browse the repository at this point in the history
Rename pipeline speech-to-text to audio-to-text
  • Loading branch information
eliteprox authored Jul 5, 2024
2 parents ce7332b + f041465 commit 0908b51
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 216 deletions.
12 changes: 6 additions & 6 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
from app.pipelines.image_to_video import ImageToVideoPipeline

return ImageToVideoPipeline(model_id)
case "speech-to-text":
from app.pipelines.speech_to_text import SpeechToTextPipeline
case "audio-to-text":
from app.pipelines.audio_to_text import AudioToTextPipeline

return SpeechToTextPipeline(model_id)
return AudioToTextPipeline(model_id)
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
Expand All @@ -71,10 +71,10 @@ def load_route(pipeline: str) -> any:
from app.routes import image_to_video

return image_to_video.router
case "speech-to-text":
from app.routes import speech_to_text
case "audio-to-text":
from app.routes import audio_to_text

return speech_to_text.router
return audio_to_text.router
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@

from huggingface_hub import file_download
import torch
import PIL
from typing import List
import logging
import os

logger = logging.getLogger(__name__)


class SpeechToTextPipeline(Pipeline):
class AudioToTextPipeline(Pipeline):
def __init__(self, model_id: str):
# kwargs = {"cache_dir": get_model_dir()}
kwargs = {}
Expand All @@ -25,21 +24,19 @@ def __init__(self, model_id: str):
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load fp16 variant if fp16 safetensors files are found in cache
# Special case SDXL-Lightning because the safetensors files are fp16 but are not
# named properly right now
has_fp16_variant = any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("SpeechToTextPipeline loading fp16 variant for %s", model_id)
logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id)

kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"

if os.environ.get("BFLOAT16"):
logger.info("SpeechToTextPipeline using bfloat16 precision for %s", model_id)
logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id)
kwargs["torch_dtype"] = torch.bfloat16

self.model_id = model_id
Expand Down Expand Up @@ -81,4 +78,4 @@ def __call__(self, audio: str, **kwargs) -> List[File]:
return result

def __str__(self) -> str:
return f"SpeechToTextPipeline model_id={self.model_id}"
return f"AudioToTextPipeline model_id={self.model_id}"
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}

@router.post("/speech-to-text", response_model=TextResponse, responses=responses)
@router.post("/speech-to-text/", response_model=TextResponse, include_in_schema=False)
async def speech_to_text(
@router.post("/audio-to-text", response_model=TextResponse, responses=responses)
@router.post("/audio-to-text/", response_model=TextResponse, include_in_schema=False)
async def audio_to_text(
audio: Annotated[UploadFile, File()],
model_id: Annotated[str, Form()] = "",
seed: Annotated[int, Form()] = None,
Expand Down
2 changes: 1 addition & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function download_alpha_models() {
# Download upscale models
huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models

# Download speech-to-text models.
# Download audio-to-text models.
huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models

printf "\nDownloading token-gated models...\n"
Expand Down
6 changes: 4 additions & 2 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

import yaml
from app.main import app, use_route_names_as_operation_ids
from app.routes import health, image_to_image, image_to_video, text_to_image, upscale, speech_to_text
from app.routes import health, image_to_image, image_to_video, text_to_image, upscale
from fastapi.openapi.utils import get_openapi

from app.routes import audio_to_text

# Specify Endpoints for OpenAPI schema generation.
SERVERS = [
{
Expand Down Expand Up @@ -77,7 +79,7 @@ def write_openapi(fname, entrypoint="runner"):
app.include_router(image_to_image.router)
app.include_router(image_to_video.router)
app.include_router(upscale.router)
app.include_router(speech_to_text.router)
app.include_router(audio_to_text.router)

use_route_names_as_operation_ids(app)

Expand Down
54 changes: 27 additions & 27 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,15 @@
]
}
},
"/speech-to-text": {
"/audio-to-text": {
"post": {
"summary": "Speech To Text",
"operationId": "speech_to_text",
"summary": "Audio To Text",
"operationId": "audio_to_text",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"$ref": "#/components/schemas/Body_speech_to_text_speech_to_text_post"
"$ref": "#/components/schemas/Body_audio_to_text_audio_to_text_post"
}
}
},
Expand Down Expand Up @@ -361,6 +361,29 @@
],
"title": "APIError"
},
"Body_audio_to_text_audio_to_text_post": {
"properties": {
"audio": {
"type": "string",
"format": "binary",
"title": "Audio"
},
"model_id": {
"type": "string",
"title": "Model Id",
"default": ""
},
"seed": {
"type": "integer",
"title": "Seed"
}
},
"type": "object",
"required": [
"audio"
],
"title": "Body_audio_to_text_audio_to_text_post"
},
"Body_image_to_image_image_to_image_post": {
"properties": {
"prompt": {
Expand Down Expand Up @@ -472,29 +495,6 @@
],
"title": "Body_image_to_video_image_to_video_post"
},
"Body_speech_to_text_speech_to_text_post": {
"properties": {
"audio": {
"type": "string",
"format": "binary",
"title": "Audio"
},
"model_id": {
"type": "string",
"title": "Model Id",
"default": ""
},
"seed": {
"type": "integer",
"title": "Seed"
}
},
"type": "object",
"required": [
"audio"
],
"title": "Body_speech_to_text_speech_to_text_post"
},
"Body_upscale_upscale_post": {
"properties": {
"prompt": {
Expand Down
2 changes: 1 addition & 1 deletion worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ var containerHostPorts = map[string]string{
"image-to-image": "8001",
"image-to-video": "8002",
"upscale": "8003",
"speech-to-text": "8004",
"audio-to-text": "8004",
}

type DockerManager struct {
Expand Down
2 changes: 1 addition & 1 deletion worker/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m

return mw, nil
}
func NewSpeechToTextMultipartWriter(w io.Writer, req SpeechToTextMultipartRequestBody) (*multipart.Writer, error) {
func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestBody) (*multipart.Writer, error) {
mw := multipart.NewWriter(w)
writer, err := mw.CreateFormFile("audio", req.Audio.Filename())
if err != nil {
Expand Down
Loading

0 comments on commit 0908b51

Please sign in to comment.