From f04146510a50059267252ba7739b2334092979d3 Mon Sep 17 00:00:00 2001 From: Elite Encoder Date: Thu, 4 Jul 2024 19:41:54 -0400 Subject: [PATCH] Rename pipeline speech-to-text to audio-to-text --- runner/app/main.py | 12 +- .../{speech_to_text.py => audio_to_text.py} | 11 +- .../{speech_to_text.py => audio_to_text.py} | 6 +- runner/dl_checkpoints.sh | 2 +- runner/gen_openapi.py | 6 +- runner/openapi.json | 54 +-- worker/docker.go | 2 +- worker/multipart.go | 2 +- worker/runner.gen.go | 312 +++++++++--------- worker/worker.go | 24 +- 10 files changed, 215 insertions(+), 216 deletions(-) rename runner/app/pipelines/{speech_to_text.py => audio_to_text.py} (85%) rename runner/app/routes/{speech_to_text.py => audio_to_text.py} (92%) diff --git a/runner/app/main.py b/runner/app/main.py index 13fc8a39..604808b1 100644 --- a/runner/app/main.py +++ b/runner/app/main.py @@ -42,10 +42,10 @@ def load_pipeline(pipeline: str, model_id: str) -> any: from app.pipelines.image_to_video import ImageToVideoPipeline return ImageToVideoPipeline(model_id) - case "speech-to-text": - from app.pipelines.speech_to_text import SpeechToTextPipeline + case "audio-to-text": + from app.pipelines.audio_to_text import AudioToTextPipeline - return SpeechToTextPipeline(model_id) + return AudioToTextPipeline(model_id) case "frame-interpolation": raise NotImplementedError("frame-interpolation pipeline not implemented") case "upscale": @@ -71,10 +71,10 @@ def load_route(pipeline: str) -> any: from app.routes import image_to_video return image_to_video.router - case "speech-to-text": - from app.routes import speech_to_text + case "audio-to-text": + from app.routes import audio_to_text - return speech_to_text.router + return audio_to_text.router case "frame-interpolation": raise NotImplementedError("frame-interpolation pipeline not implemented") case "upscale": diff --git a/runner/app/pipelines/speech_to_text.py b/runner/app/pipelines/audio_to_text.py similarity index 85% rename from runner/app/pipelines/speech_to_text.py rename to runner/app/pipelines/audio_to_text.py index a9cb1584..9474333b 100644 --- a/runner/app/pipelines/speech_to_text.py +++ b/runner/app/pipelines/audio_to_text.py @@ -6,7 +6,6 @@ from huggingface_hub import file_download import torch -import PIL from typing import List import logging import os @@ -14,7 +13,7 @@ logger = logging.getLogger(__name__) -class SpeechToTextPipeline(Pipeline): +class AudioToTextPipeline(Pipeline): def __init__(self, model_id: str): # kwargs = {"cache_dir": get_model_dir()} kwargs = {} @@ -25,21 +24,19 @@ def __init__(self, model_id: str): ) folder_path = os.path.join(get_model_dir(), folder_name) # Load fp16 variant if fp16 safetensors files are found in cache - # Special case SDXL-Lightning because the safetensors files are fp16 but are not - # named properly right now has_fp16_variant = any( ".fp16.safetensors" in fname for _, _, files in os.walk(folder_path) for fname in files ) if torch_device != "cpu" and has_fp16_variant: - logger.info("SpeechToTextPipeline loading fp16 variant for %s", model_id) + logger.info("AudioToTextPipeline loading fp16 variant for %s", model_id) kwargs["torch_dtype"] = torch.float16 kwargs["variant"] = "fp16" if os.environ.get("BFLOAT16"): - logger.info("SpeechToTextPipeline using bfloat16 precision for %s", model_id) + logger.info("AudioToTextPipeline using bfloat16 precision for %s", model_id) kwargs["torch_dtype"] = torch.bfloat16 self.model_id = model_id @@ -81,4 +78,4 @@ def __call__(self, audio: str, **kwargs) -> List[File]: return result def __str__(self) -> str: - return f"SpeechToTextPipeline model_id={self.model_id}" + return f"AudioToTextPipeline model_id={self.model_id}" diff --git a/runner/app/routes/speech_to_text.py b/runner/app/routes/audio_to_text.py similarity index 92% rename from runner/app/routes/speech_to_text.py rename to runner/app/routes/audio_to_text.py index d81f6807..8e975ed9 100644 --- a/runner/app/routes/speech_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -17,9 +17,9 @@ responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}} -@router.post("/speech-to-text", response_model=TextResponse, responses=responses) -@router.post("/speech-to-text/", response_model=TextResponse, include_in_schema=False) -async def speech_to_text( +@router.post("/audio-to-text", response_model=TextResponse, responses=responses) +@router.post("/audio-to-text/", response_model=TextResponse, include_in_schema=False) +async def audio_to_text( audio: Annotated[UploadFile, File()], model_id: Annotated[str, Form()] = "", seed: Annotated[int, Form()] = None, diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index c528bd21..f1822291 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -30,7 +30,7 @@ function download_alpha_models() { # Download upscale models huggingface-cli download stabilityai/stable-diffusion-x4-upscaler --include "*.fp16.safetensors" --cache-dir models - # Download speech-to-text models. + # Download audio-to-text models. huggingface-cli download openai/whisper-large-v3 --include "*.safetensors" "*.json" --cache-dir models printf "\nDownloading token-gated models...\n" diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index 9d2cbcaa..b2efbe84 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -5,9 +5,11 @@ import yaml from app.main import app, use_route_names_as_operation_ids -from app.routes import health, image_to_image, image_to_video, text_to_image, upscale, speech_to_text +from app.routes import health, image_to_image, image_to_video, text_to_image, upscale from fastapi.openapi.utils import get_openapi +from app.routes import audio_to_text + # Specify Endpoints for OpenAPI schema generation. SERVERS = [ { @@ -77,7 +79,7 @@ def write_openapi(fname, entrypoint="runner"): app.include_router(image_to_image.router) app.include_router(image_to_video.router) app.include_router(upscale.router) - app.include_router(speech_to_text.router) + app.include_router(audio_to_text.router) use_route_names_as_operation_ids(app) diff --git a/runner/openapi.json b/runner/openapi.json index 9c212e0e..aa3630a0 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -282,15 +282,15 @@ ] } }, - "/speech-to-text": { + "/audio-to-text": { "post": { - "summary": "Speech To Text", - "operationId": "speech_to_text", + "summary": "Audio To Text", + "operationId": "audio_to_text", "requestBody": { "content": { "multipart/form-data": { "schema": { - "$ref": "#/components/schemas/Body_speech_to_text_speech_to_text_post" + "$ref": "#/components/schemas/Body_audio_to_text_audio_to_text_post" } } }, @@ -361,6 +361,29 @@ ], "title": "APIError" }, + "Body_audio_to_text_audio_to_text_post": { + "properties": { + "audio": { + "type": "string", + "format": "binary", + "title": "Audio" + }, + "model_id": { + "type": "string", + "title": "Model Id", + "default": "" + }, + "seed": { + "type": "integer", + "title": "Seed" + } + }, + "type": "object", + "required": [ + "audio" + ], + "title": "Body_audio_to_text_audio_to_text_post" + }, "Body_image_to_image_image_to_image_post": { "properties": { "prompt": { @@ -472,29 +495,6 @@ ], "title": "Body_image_to_video_image_to_video_post" }, - "Body_speech_to_text_speech_to_text_post": { - "properties": { - "audio": { - "type": "string", - "format": "binary", - "title": "Audio" - }, - "model_id": { - "type": "string", - "title": "Model Id", - "default": "" - }, - "seed": { - "type": "integer", - "title": "Seed" - } - }, - "type": "object", - "required": [ - "audio" - ], - "title": "Body_speech_to_text_speech_to_text_post" - }, "Body_upscale_upscale_post": { "properties": { "prompt": { diff --git a/worker/docker.go b/worker/docker.go index 9cbef677..8d7f97e0 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -34,7 +34,7 @@ var containerHostPorts = map[string]string{ "image-to-image": "8001", "image-to-video": "8002", "upscale": "8003", - "speech-to-text": "8004", + "audio-to-text": "8004", } type DockerManager struct { diff --git a/worker/multipart.go b/worker/multipart.go index 53492798..55101e2c 100644 --- a/worker/multipart.go +++ b/worker/multipart.go @@ -194,7 +194,7 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m return mw, nil } -func NewSpeechToTextMultipartWriter(w io.Writer, req SpeechToTextMultipartRequestBody) (*multipart.Writer, error) { +func NewAudioToTextMultipartWriter(w io.Writer, req AudioToTextMultipartRequestBody) (*multipart.Writer, error) { mw := multipart.NewWriter(w) writer, err := mw.CreateFormFile("audio", req.Audio.Filename()) if err != nil { diff --git a/worker/runner.gen.go b/worker/runner.gen.go index fb834673..89a8509c 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -31,6 +31,13 @@ type APIError struct { Msg string `json:"msg"` } +// BodyAudioToTextAudioToTextPost defines model for Body_audio_to_text_audio_to_text_post. +type BodyAudioToTextAudioToTextPost struct { + Audio openapi_types.File `json:"audio"` + ModelId *string `json:"model_id,omitempty"` + Seed *int `json:"seed,omitempty"` +} + // BodyImageToImageImageToImagePost defines model for Body_image_to_image_image_to_image_post. type BodyImageToImageImageToImagePost struct { GuidanceScale *float32 `json:"guidance_scale,omitempty"` @@ -58,13 +65,6 @@ type BodyImageToVideoImageToVideoPost struct { Width *int `json:"width,omitempty"` } -// BodySpeechToTextSpeechToTextPost defines model for Body_speech_to_text_speech_to_text_post. -type BodySpeechToTextSpeechToTextPost struct { - Audio openapi_types.File `json:"audio"` - ModelId *string `json:"model_id,omitempty"` - Seed *int `json:"seed,omitempty"` -} - // BodyUpscaleUpscalePost defines model for Body_upscale_upscale_post. type BodyUpscaleUpscalePost struct { Image openapi_types.File `json:"image"` @@ -150,15 +150,15 @@ type Chunk struct { Timestamp []interface{} `json:"timestamp"` } +// AudioToTextMultipartRequestBody defines body for AudioToText for multipart/form-data ContentType. +type AudioToTextMultipartRequestBody = BodyAudioToTextAudioToTextPost + // ImageToImageMultipartRequestBody defines body for ImageToImage for multipart/form-data ContentType. type ImageToImageMultipartRequestBody = BodyImageToImageImageToImagePost // ImageToVideoMultipartRequestBody defines body for ImageToVideo for multipart/form-data ContentType. type ImageToVideoMultipartRequestBody = BodyImageToVideoImageToVideoPost -// SpeechToTextMultipartRequestBody defines body for SpeechToText for multipart/form-data ContentType. -type SpeechToTextMultipartRequestBody = BodySpeechToTextSpeechToTextPost - // TextToImageJSONRequestBody defines body for TextToImage for application/json ContentType. type TextToImageJSONRequestBody = TextToImageParams @@ -300,6 +300,9 @@ func WithRequestEditorFn(fn RequestEditorFn) ClientOption { // The interface specification for the client above. type ClientInterface interface { + // AudioToTextWithBody request with any body + AudioToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) + // Health request Health(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -309,9 +312,6 @@ type ClientInterface interface { // ImageToVideoWithBody request with any body ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) - // SpeechToTextWithBody request with any body - SpeechToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) - // TextToImageWithBody request with any body TextToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) @@ -321,8 +321,8 @@ type ClientInterface interface { UpscaleWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) } -func (c *Client) Health(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewHealthRequest(c.Server) +func (c *Client) AudioToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewAudioToTextRequestWithBody(c.Server, contentType, body) if err != nil { return nil, err } @@ -333,8 +333,8 @@ func (c *Client) Health(ctx context.Context, reqEditors ...RequestEditorFn) (*ht return c.Client.Do(req) } -func (c *Client) ImageToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewImageToImageRequestWithBody(c.Server, contentType, body) +func (c *Client) Health(ctx context.Context, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewHealthRequest(c.Server) if err != nil { return nil, err } @@ -345,8 +345,8 @@ func (c *Client) ImageToImageWithBody(ctx context.Context, contentType string, b return c.Client.Do(req) } -func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewImageToVideoRequestWithBody(c.Server, contentType, body) +func (c *Client) ImageToImageWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewImageToImageRequestWithBody(c.Server, contentType, body) if err != nil { return nil, err } @@ -357,8 +357,8 @@ func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, b return c.Client.Do(req) } -func (c *Client) SpeechToTextWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { - req, err := NewSpeechToTextRequestWithBody(c.Server, contentType, body) +func (c *Client) ImageToVideoWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) { + req, err := NewImageToVideoRequestWithBody(c.Server, contentType, body) if err != nil { return nil, err } @@ -405,8 +405,8 @@ func (c *Client) UpscaleWithBody(ctx context.Context, contentType string, body i return c.Client.Do(req) } -// NewHealthRequest generates requests for Health -func NewHealthRequest(server string) (*http.Request, error) { +// NewAudioToTextRequestWithBody generates requests for AudioToText with any type of body +func NewAudioToTextRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error serverURL, err := url.Parse(server) @@ -414,7 +414,7 @@ func NewHealthRequest(server string) (*http.Request, error) { return nil, err } - operationPath := fmt.Sprintf("/health") + operationPath := fmt.Sprintf("/audio-to-text") if operationPath[0] == '/' { operationPath = "." + operationPath } @@ -424,16 +424,18 @@ func NewHealthRequest(server string) (*http.Request, error) { return nil, err } - req, err := http.NewRequest("GET", queryURL.String(), nil) + req, err := http.NewRequest("POST", queryURL.String(), body) if err != nil { return nil, err } + req.Header.Add("Content-Type", contentType) + return req, nil } -// NewImageToImageRequestWithBody generates requests for ImageToImage with any type of body -func NewImageToImageRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { +// NewHealthRequest generates requests for Health +func NewHealthRequest(server string) (*http.Request, error) { var err error serverURL, err := url.Parse(server) @@ -441,7 +443,7 @@ func NewImageToImageRequestWithBody(server string, contentType string, body io.R return nil, err } - operationPath := fmt.Sprintf("/image-to-image") + operationPath := fmt.Sprintf("/health") if operationPath[0] == '/' { operationPath = "." + operationPath } @@ -451,18 +453,16 @@ func NewImageToImageRequestWithBody(server string, contentType string, body io.R return nil, err } - req, err := http.NewRequest("POST", queryURL.String(), body) + req, err := http.NewRequest("GET", queryURL.String(), nil) if err != nil { return nil, err } - req.Header.Add("Content-Type", contentType) - return req, nil } -// NewImageToVideoRequestWithBody generates requests for ImageToVideo with any type of body -func NewImageToVideoRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { +// NewImageToImageRequestWithBody generates requests for ImageToImage with any type of body +func NewImageToImageRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error serverURL, err := url.Parse(server) @@ -470,7 +470,7 @@ func NewImageToVideoRequestWithBody(server string, contentType string, body io.R return nil, err } - operationPath := fmt.Sprintf("/image-to-video") + operationPath := fmt.Sprintf("/image-to-image") if operationPath[0] == '/' { operationPath = "." + operationPath } @@ -490,8 +490,8 @@ func NewImageToVideoRequestWithBody(server string, contentType string, body io.R return req, nil } -// NewSpeechToTextRequestWithBody generates requests for SpeechToText with any type of body -func NewSpeechToTextRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { +// NewImageToVideoRequestWithBody generates requests for ImageToVideo with any type of body +func NewImageToVideoRequestWithBody(server string, contentType string, body io.Reader) (*http.Request, error) { var err error serverURL, err := url.Parse(server) @@ -499,7 +499,7 @@ func NewSpeechToTextRequestWithBody(server string, contentType string, body io.R return nil, err } - operationPath := fmt.Sprintf("/speech-to-text") + operationPath := fmt.Sprintf("/image-to-video") if operationPath[0] == '/' { operationPath = "." + operationPath } @@ -631,6 +631,9 @@ func WithBaseURL(baseURL string) ClientOption { // ClientWithResponsesInterface is the interface specification for the client with responses above. type ClientWithResponsesInterface interface { + // AudioToTextWithBodyWithResponse request with any body + AudioToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*AudioToTextResponse, error) + // HealthWithResponse request HealthWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*HealthResponse, error) @@ -640,9 +643,6 @@ type ClientWithResponsesInterface interface { // ImageToVideoWithBodyWithResponse request with any body ImageToVideoWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImageToVideoResponse, error) - // SpeechToTextWithBodyWithResponse request with any body - SpeechToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*SpeechToTextResponse, error) - // TextToImageWithBodyWithResponse request with any body TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) @@ -652,14 +652,17 @@ type ClientWithResponsesInterface interface { UpscaleWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*UpscaleResponse, error) } -type HealthResponse struct { +type AudioToTextResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *HealthCheck + JSON200 *TextResponse + JSON400 *HTTPError + JSON422 *HTTPValidationError + JSON500 *HTTPError } // Status returns HTTPResponse.Status -func (r HealthResponse) Status() string { +func (r AudioToTextResponse) Status() string { if r.HTTPResponse != nil { return r.HTTPResponse.Status } @@ -667,24 +670,21 @@ func (r HealthResponse) Status() string { } // StatusCode returns HTTPResponse.StatusCode -func (r HealthResponse) StatusCode() int { +func (r AudioToTextResponse) StatusCode() int { if r.HTTPResponse != nil { return r.HTTPResponse.StatusCode } return 0 } -type ImageToImageResponse struct { +type HealthResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *ImageResponse - JSON400 *HTTPError - JSON422 *HTTPValidationError - JSON500 *HTTPError + JSON200 *HealthCheck } // Status returns HTTPResponse.Status -func (r ImageToImageResponse) Status() string { +func (r HealthResponse) Status() string { if r.HTTPResponse != nil { return r.HTTPResponse.Status } @@ -692,24 +692,24 @@ func (r ImageToImageResponse) Status() string { } // StatusCode returns HTTPResponse.StatusCode -func (r ImageToImageResponse) StatusCode() int { +func (r HealthResponse) StatusCode() int { if r.HTTPResponse != nil { return r.HTTPResponse.StatusCode } return 0 } -type ImageToVideoResponse struct { +type ImageToImageResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *VideoResponse + JSON200 *ImageResponse JSON400 *HTTPError JSON422 *HTTPValidationError JSON500 *HTTPError } // Status returns HTTPResponse.Status -func (r ImageToVideoResponse) Status() string { +func (r ImageToImageResponse) Status() string { if r.HTTPResponse != nil { return r.HTTPResponse.Status } @@ -717,24 +717,24 @@ func (r ImageToVideoResponse) Status() string { } // StatusCode returns HTTPResponse.StatusCode -func (r ImageToVideoResponse) StatusCode() int { +func (r ImageToImageResponse) StatusCode() int { if r.HTTPResponse != nil { return r.HTTPResponse.StatusCode } return 0 } -type SpeechToTextResponse struct { +type ImageToVideoResponse struct { Body []byte HTTPResponse *http.Response - JSON200 *TextResponse + JSON200 *VideoResponse JSON400 *HTTPError JSON422 *HTTPValidationError JSON500 *HTTPError } // Status returns HTTPResponse.Status -func (r SpeechToTextResponse) Status() string { +func (r ImageToVideoResponse) Status() string { if r.HTTPResponse != nil { return r.HTTPResponse.Status } @@ -742,7 +742,7 @@ func (r SpeechToTextResponse) Status() string { } // StatusCode returns HTTPResponse.StatusCode -func (r SpeechToTextResponse) StatusCode() int { +func (r ImageToVideoResponse) StatusCode() int { if r.HTTPResponse != nil { return r.HTTPResponse.StatusCode } @@ -799,6 +799,15 @@ func (r UpscaleResponse) StatusCode() int { return 0 } +// AudioToTextWithBodyWithResponse request with arbitrary body returning *AudioToTextResponse +func (c *ClientWithResponses) AudioToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*AudioToTextResponse, error) { + rsp, err := c.AudioToTextWithBody(ctx, contentType, body, reqEditors...) + if err != nil { + return nil, err + } + return ParseAudioToTextResponse(rsp) +} + // HealthWithResponse request returning *HealthResponse func (c *ClientWithResponses) HealthWithResponse(ctx context.Context, reqEditors ...RequestEditorFn) (*HealthResponse, error) { rsp, err := c.Health(ctx, reqEditors...) @@ -826,15 +835,6 @@ func (c *ClientWithResponses) ImageToVideoWithBodyWithResponse(ctx context.Conte return ParseImageToVideoResponse(rsp) } -// SpeechToTextWithBodyWithResponse request with arbitrary body returning *SpeechToTextResponse -func (c *ClientWithResponses) SpeechToTextWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*SpeechToTextResponse, error) { - rsp, err := c.SpeechToTextWithBody(ctx, contentType, body, reqEditors...) - if err != nil { - return nil, err - } - return ParseSpeechToTextResponse(rsp) -} - // TextToImageWithBodyWithResponse request with arbitrary body returning *TextToImageResponse func (c *ClientWithResponses) TextToImageWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*TextToImageResponse, error) { rsp, err := c.TextToImageWithBody(ctx, contentType, body, reqEditors...) @@ -861,95 +861,95 @@ func (c *ClientWithResponses) UpscaleWithBodyWithResponse(ctx context.Context, c return ParseUpscaleResponse(rsp) } -// ParseHealthResponse parses an HTTP response from a HealthWithResponse call -func ParseHealthResponse(rsp *http.Response) (*HealthResponse, error) { +// ParseAudioToTextResponse parses an HTTP response from a AudioToTextWithResponse call +func ParseAudioToTextResponse(rsp *http.Response) (*AudioToTextResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) defer func() { _ = rsp.Body.Close() }() if err != nil { return nil, err } - response := &HealthResponse{ + response := &AudioToTextResponse{ Body: bodyBytes, HTTPResponse: rsp, } switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest HealthCheck + var dest TextResponse if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } response.JSON200 = &dest + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: + var dest HTTPValidationError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON422 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + } return response, nil } -// ParseImageToImageResponse parses an HTTP response from a ImageToImageWithResponse call -func ParseImageToImageResponse(rsp *http.Response) (*ImageToImageResponse, error) { +// ParseHealthResponse parses an HTTP response from a HealthWithResponse call +func ParseHealthResponse(rsp *http.Response) (*HealthResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) defer func() { _ = rsp.Body.Close() }() if err != nil { return nil, err } - response := &ImageToImageResponse{ + response := &HealthResponse{ Body: bodyBytes, HTTPResponse: rsp, } switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest ImageResponse + var dest HealthCheck if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } response.JSON200 = &dest - case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: - var dest HTTPError - if err := json.Unmarshal(bodyBytes, &dest); err != nil { - return nil, err - } - response.JSON400 = &dest - - case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: - var dest HTTPValidationError - if err := json.Unmarshal(bodyBytes, &dest); err != nil { - return nil, err - } - response.JSON422 = &dest - - case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: - var dest HTTPError - if err := json.Unmarshal(bodyBytes, &dest); err != nil { - return nil, err - } - response.JSON500 = &dest - } return response, nil } -// ParseImageToVideoResponse parses an HTTP response from a ImageToVideoWithResponse call -func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error) { +// ParseImageToImageResponse parses an HTTP response from a ImageToImageWithResponse call +func ParseImageToImageResponse(rsp *http.Response) (*ImageToImageResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) defer func() { _ = rsp.Body.Close() }() if err != nil { return nil, err } - response := &ImageToVideoResponse{ + response := &ImageToImageResponse{ Body: bodyBytes, HTTPResponse: rsp, } switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest VideoResponse + var dest ImageResponse if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } @@ -981,22 +981,22 @@ func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error return response, nil } -// ParseSpeechToTextResponse parses an HTTP response from a SpeechToTextWithResponse call -func ParseSpeechToTextResponse(rsp *http.Response) (*SpeechToTextResponse, error) { +// ParseImageToVideoResponse parses an HTTP response from a ImageToVideoWithResponse call +func ParseImageToVideoResponse(rsp *http.Response) (*ImageToVideoResponse, error) { bodyBytes, err := io.ReadAll(rsp.Body) defer func() { _ = rsp.Body.Close() }() if err != nil { return nil, err } - response := &SpeechToTextResponse{ + response := &ImageToVideoResponse{ Body: bodyBytes, HTTPResponse: rsp, } switch { case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200: - var dest TextResponse + var dest VideoResponse if err := json.Unmarshal(bodyBytes, &dest); err != nil { return nil, err } @@ -1124,6 +1124,9 @@ func ParseUpscaleResponse(rsp *http.Response) (*UpscaleResponse, error) { // ServerInterface represents all server handlers. type ServerInterface interface { + // Audio To Text + // (POST /audio-to-text) + AudioToText(w http.ResponseWriter, r *http.Request) // Health // (GET /health) Health(w http.ResponseWriter, r *http.Request) @@ -1133,9 +1136,6 @@ type ServerInterface interface { // Image To Video // (POST /image-to-video) ImageToVideo(w http.ResponseWriter, r *http.Request) - // Speech To Text - // (POST /speech-to-text) - SpeechToText(w http.ResponseWriter, r *http.Request) // Text To Image // (POST /text-to-image) TextToImage(w http.ResponseWriter, r *http.Request) @@ -1148,6 +1148,12 @@ type ServerInterface interface { type Unimplemented struct{} +// Audio To Text +// (POST /audio-to-text) +func (_ Unimplemented) AudioToText(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotImplemented) +} + // Health // (GET /health) func (_ Unimplemented) Health(w http.ResponseWriter, r *http.Request) { @@ -1166,12 +1172,6 @@ func (_ Unimplemented) ImageToVideo(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNotImplemented) } -// Speech To Text -// (POST /speech-to-text) -func (_ Unimplemented) SpeechToText(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusNotImplemented) -} - // Text To Image // (POST /text-to-image) func (_ Unimplemented) TextToImage(w http.ResponseWriter, r *http.Request) { @@ -1193,12 +1193,14 @@ type ServerInterfaceWrapper struct { type MiddlewareFunc func(http.Handler) http.Handler -// Health operation middleware -func (siw *ServerInterfaceWrapper) Health(w http.ResponseWriter, r *http.Request) { +// AudioToText operation middleware +func (siw *ServerInterfaceWrapper) AudioToText(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) + handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siw.Handler.Health(w, r) + siw.Handler.AudioToText(w, r) })) for _, middleware := range siw.HandlerMiddlewares { @@ -1208,14 +1210,12 @@ func (siw *ServerInterfaceWrapper) Health(w http.ResponseWriter, r *http.Request handler.ServeHTTP(w, r.WithContext(ctx)) } -// ImageToImage operation middleware -func (siw *ServerInterfaceWrapper) ImageToImage(w http.ResponseWriter, r *http.Request) { +// Health operation middleware +func (siw *ServerInterfaceWrapper) Health(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) - handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siw.Handler.ImageToImage(w, r) + siw.Handler.Health(w, r) })) for _, middleware := range siw.HandlerMiddlewares { @@ -1225,14 +1225,14 @@ func (siw *ServerInterfaceWrapper) ImageToImage(w http.ResponseWriter, r *http.R handler.ServeHTTP(w, r.WithContext(ctx)) } -// ImageToVideo operation middleware -func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.Request) { +// ImageToImage operation middleware +func (siw *ServerInterfaceWrapper) ImageToImage(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siw.Handler.ImageToVideo(w, r) + siw.Handler.ImageToImage(w, r) })) for _, middleware := range siw.HandlerMiddlewares { @@ -1242,14 +1242,14 @@ func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.R handler.ServeHTTP(w, r.WithContext(ctx)) } -// SpeechToText operation middleware -func (siw *ServerInterfaceWrapper) SpeechToText(w http.ResponseWriter, r *http.Request) { +// ImageToVideo operation middleware +func (siw *ServerInterfaceWrapper) ImageToVideo(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctx = context.WithValue(ctx, HTTPBearerScopes, []string{}) handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - siw.Handler.SpeechToText(w, r) + siw.Handler.ImageToVideo(w, r) })) for _, middleware := range siw.HandlerMiddlewares { @@ -1406,6 +1406,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl ErrorHandlerFunc: options.ErrorHandlerFunc, } + r.Group(func(r chi.Router) { + r.Post(options.BaseURL+"/audio-to-text", wrapper.AudioToText) + }) r.Group(func(r chi.Router) { r.Get(options.BaseURL+"/health", wrapper.Health) }) @@ -1415,9 +1418,6 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/image-to-video", wrapper.ImageToVideo) }) - r.Group(func(r chi.Router) { - r.Post(options.BaseURL+"/speech-to-text", wrapper.SpeechToText) - }) r.Group(func(r chi.Router) { r.Post(options.BaseURL+"/text-to-image", wrapper.TextToImage) }) @@ -1431,31 +1431,31 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xZW3PbNhP9Kxx836NsyW7cdPRmu23iaZN4IiV9yHg0MLmSkJAAiosTjUf/vYMFL+At", - "ZMaXtBk/WSQXu2cX5wAL+JbEIpOCAzeazG+JjreQUfx5ennxm1JCud9SCQnKMMAvmd64P4aZFMicvNIb", - "MiFmJ92DNorxDdnvJ0TB35YpSMj8Aw65mpRDSt/lOHH9EWJD9hNyJpLdimV0Aysj8h+NRym0acPaWJZQ", - "HsNKx9RFuSUJrKlNDZk/Pzypgr/I7aIF2pUQuM2uQTkIGMU5WAuVUUPm5JpxqnakcnKBJq2087Grr2A5", - "CrGgm2gYUSYSSFcsqXkiAZ5XziC6SLogcdhQw25gJZXIpOn18Tq3iy69XZcrm/k50CsJqsvhUeDPZhEm", - "qKNLUC2vjBvY+PQqP8XYfgiarsHsVvEW4k+1yEZZqIIv0Cw6R7PSzbUQKVCOfgCSMOLCPXeB00YB35ht", - "Ldjs8JcgVmHRmrmGDGSRlWdYoIixrB8UzA1LQDQfuwWzlrqW088VnN+l7qzFFthmW5/wk+fBuJf+e9fQ", - "u4jqTvTPhGGCr65t/AlM08nR8fPQi7OMztCy5i3IgwumYUXtZtVDjNlxIAFnHJ3aTdTPke9A6c8sacA+", - "mh0/qyL9hd/bIxt0HmBxPxX7WKwlQLx1xga+mOZjN4upTZj4Oq9O0eS+eTWy2o2aebjNmo1IvK9mVuIe", - "U/7trtJ3U9+/eGn/tsW5s84dk/JyubzsaZwSMJSl7tf/FazJnPxvWrVf07z3mpbNURNgPjwAVsXqAfKe", - "piyhblUbhMQMZHoIW9PfvsLyq/dUAqFK0R3mEKJtOujCDTQ12/OCBHW82lBj69sWefMHCfdiNOhqSKtN", - "qgrQER918Ba0FFxDj5L06Iq9goTRsE6+I+qqU2tp1eFc12F14PaRWni5Xn8OxfDaPd9p97AqDe3eqXSw", - "/7doo71HRBRk5oF3ZLSEL6Z/IuKt5Z/GTwSahxNx7sc3J2JC3JobJuhgDGZovFEOKsiulkRPkkuBs3tJ", - "FfWJPNTJpurfRnRsP/ihA93yNSjA0hpotMMns4bXwjZaoO1/7iBTdn3f2OblSTU4XedsB7EH955UxDX1", - "Ur57sybzD7etWt22IF4FQv5TxBimQ8rNqwrQuqfJ8S8qU8QcLd3bIem7PHyo3DKo1Ij97r3ri/uXubWi", - "WWO/+caNp7m8FWc873hgI8rDhynV8HYk5FfaViLjllUXJwNtaCbDVAPcy/L7AHQTGrpgQRIeYws8yim2", - "ipndwtXRI3eNyxlQBaq8I0MN+lelk60xkuydD8bXwktax4pJJOecnPKISpkyz9bIiEhZHp1eRJJJSBn3", - "k1GQmt2ABFDu+1vLOQa6AaW9r9nh0eHMVUtI4FQyMic/4asJkdRsEfZ0i40O7iCApXfTgcEvkrIPIq5o", - "fjJx1PFshrur4AY4jgpATz9qF764KBziYNhpYWHqBVnYOAat1zaNSj7hFNgsc2eTEqJ7OcUt4MCIg/Is", - "Uxx16mnhspSvTsQzArRxPXwjr8ymhkmqzNQdig4Sauj41MZe2OzrrHRr+/4BK17vEsfWfEKe3eesl6eS", - "jvhnNIne+inBuMfH9xq3dUBpI6hMovIQc/JY6V9wA4rTNFqAugEVVSe9Yt3BDTBccT5c7a9CTfg746Xw", - "bU5DG3jJMqgNXMIfSxv910CPrI36xvWkjR9ZG57hqA1/mebEUXQg3dpYoN1S5H3JQ2pjxHXfI2ujdkp9", - "ksYPKA1Pb6cNJDhKw9FtREcVHPe+Koy70a9+oHzqm54EcL8CcBxrtE35XX4/89/lBg+7HXT+a+FJAE8C", - "uF8BFGTe+1HOjcZB9UjlzcN5KmwSnYsss5yZXfSCGvhMdyT/DwDed+j5dJoooNnBxn89TPPhh7EbTvZX", - "+38CAAD//4PFS/VlJAAA", + "H4sIAAAAAAAC/+xZW2/bOBb+KwR3H53YyTbbhd+S7EwbzLQNarfzUAQGIx3bbCWSw0taI/B/H/BQlqhb", + "5SCXzhR5iiWdy3cuH3nI3NJE5koKENbQ6S01yRpyhj9PLy9+0Vpq/1tpqUBbDvglNyv/x3KbAZ3SN2ZF", + "R9RulH8wVnOxotvtiGr403ENKZ1+QpWrUalS2i715PVnSCzdjuiZTDcL5lIuF1YuLHyzjScljW2DQhn/", + "Yyl1ziyd0msumN7QyCuKtKCOaC5TyBY89eopLJnLvH6k+cYLkIu0S9kApHE6Zv65lOPCwgp0KyEBbpSS", + "/cLuyxfP2Qq8aPjReOzO2MrxlIkEFiZhHkIU+8vDkwrZq0KOzFCuhCBcfu0jG1H08v3cX6BIR/oCwu9g", + "OYqxoBkyjOheFRWwYpbfwEJpmSvba+NtIUcug1yXKZeHGpiFAt1l8Ciy53KCARpyCbpltWylEa3s7HT7", + "IRi2BLtZJGtIvtQ8W+2gcj5DMXKOYqWZaykzYOIOfT7yrkGs7LrmbHL4v8jXTqJVuQZL1C6q0GFNuuzR", + "9YOEueEpyOZjN2GWytRi+m8F51dlOnOxBr5a1wt+8jLSex2+d6neh1T3av9cWi7F4tolX8A2jRwdv4yt", + "eElyhpI1a1EcQnIDC+ZWi57GmBxHFPDC5NStSH+P/ICW/srTBuyjyfGLytMf+H1w0R/o4v5W7Otip3C9", + "LP929+0P66S/8TJ1t4WmM88dRXk9n1/2DE0pWMYz/+vfGpZ0Sv81rkavcTF3jcvBqAmwUI+AVb56gHxk", + "GU+ZZ+ggJG4hN0PYmva2FZb/B0slEKY122AMMdqmgS7cwDK7Pt81QR2vscy6+hJM3/1G430FBbqG0WrB", + "rRx0+EcevAejpDDQwySzd8beQMpZnKewu3flqbVMmLjWdVgduIOnFl5hll9jMrz1z/daCZ3OYrkPOhuc", + "/R3KmGAREUWRBeAdEc3hm+0vRLJ24sv+hUDxuBDnQb9ZiBH1o3YcoIcxGKENQgWoKLpaED1BziVW95Jp", + "FgJ5rCm9mkX2mD5+8gEazYolaMDUWmiMdieThtWdLJmh7D9uKC8nmDuOLEVQjZ6u92xHYw/uPZlMauxl", + "YvNuSaefblu5um1BvIqI/LtM0E0HlZvXFGBMz5ATXlSiiJnM/dsh6vs4gqtCMsrUHvvdRz/j9S9zS83y", + "xn5zx42nubztzivB8MBGVLiPQ6rh7QgorLStQPZbVr2fHIxluYpDjXDPy+8D0G0s6J1FQQSMLfBIp8Rp", + "bjczn8eA3A8uZ8A06PJ+DDkYXpVG1tYquvU2uFjKQGmTaK6wOaf0VBCmVMZDtxIriXaCnF4QxRVkXIRi", + "7Jqa34AC0P77eycEOroBbYKtyeHR4cRnSyoQTHE6pf/BVyOqmF0j7DFeHh1YebBL/e5s4MuCIC7S3Z3Y", + "XBb18BkEY/3Mi7usFBYEauUus1wxbcf+EHGQMsuq+8Khdtzvbmtbr6FfCfFFaDaM6ngyaeCKkjr+bHx6", + "9gVV25vRd71iM5ckYMzSZaQSG9EXDwihGuE7/J+xlLwP9UC/x8cP6rc1zbcRVCKknPhPnir8C2FBC5aR", + "Gegb0KQ6Fu1IirtFTM9PV9urETUuz/3htuhtMpcEu9urjtc4/uNcBR1sCKcD+ohdF58/9m26bRxUARGj", + "wcHIc7w84XeTHDfrYs9+ZJbvcSX3xDyvn52eif4TEj38V2Auw/Df4AZeow1yAwebp+JG/0XfE3OjPs49", + "c+Nn5kbocOSGn7n22Daik953mXG/Gax+lnzeHJ4J8LAE8D3W2BuKa/z+zv9QCDzuftD5X4VnAjwT4GEJ", + "sGvmbdDyZgwq1T2Vlw7nmXQpOZd57gS3G/KKWfjKNrS4/MerDjMdj1MNLD9Yha+HWaF+mHh1ur3a/hUA", + "AP//VKmMfVwkAAA=", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index e991d987..b0761c5d 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -250,20 +250,20 @@ func (w *Worker) Upscale(ctx context.Context, req UpscaleMultipartRequestBody) ( return resp.JSON200, nil } -func (w *Worker) SpeechToText(ctx context.Context, req SpeechToTextMultipartRequestBody) (*TextResponse, error) { - c, err := w.borrowContainer(ctx, "speech-to-text", *req.ModelId) +func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartRequestBody) (*TextResponse, error) { + c, err := w.borrowContainer(ctx, "audio-to-text", *req.ModelId) if err != nil { return nil, err } defer w.returnContainer(c) var buf bytes.Buffer - mw, err := NewSpeechToTextMultipartWriter(&buf, req) + mw, err := NewAudioToTextMultipartWriter(&buf, req) if err != nil { return nil, err } - resp, err := c.Client.SpeechToTextWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) + resp, err := c.Client.AudioToTextWithBodyWithResponse(ctx, mw.FormDataContentType(), &buf) if err != nil { return nil, err } @@ -273,8 +273,8 @@ func (w *Worker) SpeechToText(ctx context.Context, req SpeechToTextMultipartRequ if err != nil { return nil, err } - slog.Error("speech-to-text container returned 422", slog.String("err", string(val))) - return nil, errors.New("speech-to-text container returned 422") + slog.Error("audio-to-text container returned 422", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 422") } if resp.JSON400 != nil { @@ -282,13 +282,13 @@ func (w *Worker) SpeechToText(ctx context.Context, req SpeechToTextMultipartRequ if err != nil { return nil, err } - slog.Error("speech-to-text container returned 400", slog.String("err", string(val))) - return nil, errors.New("speech-to-text container returned 400") + slog.Error("audio-to-text container returned 400", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 400") } if resp.StatusCode() == 413 { - msg := "speech-to-text container returned 413 file too large; max file size is 50MB" - slog.Error("speech-to-text container returned 400", slog.String("err", string(msg))) + msg := "audio-to-text container returned 413 file too large; max file size is 50MB" + slog.Error("audio-to-text container returned 400", slog.String("err", string(msg))) return nil, errors.New(msg) } @@ -297,8 +297,8 @@ func (w *Worker) SpeechToText(ctx context.Context, req SpeechToTextMultipartRequ if err != nil { return nil, err } - slog.Error("speech-to-text container returned 500", slog.String("err", string(val))) - return nil, errors.New("speech-to-text container returned 500") + slog.Error("audio-to-text container returned 500", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 500") } return resp.JSON200, nil