From a46989207f4f5e8877822947309a081878ca34f1 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 4 Sep 2024 15:36:14 +0200 Subject: [PATCH 01/14] refactor(runner): add InferenceError to all pipelines This commit adds the inference error logic from the SAM2 pipeline to all pipelines so users are given a warning when they supply wrong arguments. --- runner/app/pipelines/audio_to_text.py | 8 +++++++- runner/app/pipelines/image_to_image.py | 12 ++++++++---- runner/app/pipelines/image_to_video.py | 8 +++++++- runner/app/pipelines/optim/sfast.py | 2 +- runner/app/pipelines/segment_anything_2.py | 4 ++-- runner/app/pipelines/text_to_image.py | 12 ++++++++---- runner/app/routes/audio_to_text.py | 6 +++++- runner/app/routes/image_to_image.py | 13 ++++++++++--- runner/app/routes/image_to_video.py | 9 ++++++++- runner/app/routes/segment_anything_2.py | 9 ++------- runner/app/routes/text_to_image.py | 17 ++++++++++++----- runner/app/routes/upscale.py | 2 +- runner/app/routes/{util.py => utils.py} | 16 ---------------- runner/app/utils/__init__.py | 0 runner/app/utils/errors.py | 17 +++++++++++++++++ runner/gen_openapi.py | 4 ++-- 16 files changed, 90 insertions(+), 49 deletions(-) rename runner/app/routes/{util.py => utils.py} (89%) create mode 100644 runner/app/utils/__init__.py create mode 100644 runner/app/utils/errors.py diff --git a/runner/app/pipelines/audio_to_text.py b/runner/app/pipelines/audio_to_text.py index 4c75ddfb..888f70ed 100644 --- a/runner/app/pipelines/audio_to_text.py +++ b/runner/app/pipelines/audio_to_text.py @@ -6,6 +6,7 @@ from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device from app.pipelines.utils.audio import AudioConverter +from app.utils.errors import InferenceError from fastapi import File, UploadFile from huggingface_hub import file_download from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline @@ -76,7 +77,12 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]: converted_bytes = audio_converter.convert(audio, "mp3") audio_converter.write_bytes_to_file(converted_bytes, audio) - return self.tm(audio.file.read(), **kwargs) + try: + outputs = self.tm(audio.file.read(), **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) + + return outputs def __str__(self) -> str: return f"AudioToTextPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 9e20ff03..9aeae4c9 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -13,6 +13,7 @@ is_lightning_model, is_turbo_model, ) +from app.utils.errors import InferenceError from diffusers import ( AutoPipelineForImage2Image, EulerAncestralDiscreteScheduler, @@ -222,14 +223,17 @@ def __call__( # Default to 8step kwargs["num_inference_steps"] = 8 - output = self.ldm(prompt, image=image, **kwargs) + try: + outputs = self.ldm(prompt, image=image, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) if safety_check: - _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) else: - has_nsfw_concept = [None] * len(output.images) + has_nsfw_concept = [None] * len(outputs.images) - return output.images, has_nsfw_concept + return outputs.images, has_nsfw_concept def __str__(self) -> str: return f"ImageToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index f605cb2f..680800a5 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -7,6 +7,7 @@ import torch from app.pipelines.base import Pipeline from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device +from app.utils.errors import InferenceError from diffusers import StableVideoDiffusionPipeline from huggingface_hub import file_download from PIL import ImageFile @@ -134,7 +135,12 @@ def __call__( else: has_nsfw_concept = [None] - return self.ldm(image, **kwargs).frames, has_nsfw_concept + try: + outputs = self.ldm(image, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) + + return outputs.frames, has_nsfw_concept def __str__(self) -> str: return f"ImageToVideoPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/optim/sfast.py b/runner/app/pipelines/optim/sfast.py index 166e014e..15598562 100644 --- a/runner/app/pipelines/optim/sfast.py +++ b/runner/app/pipelines/optim/sfast.py @@ -31,7 +31,7 @@ def compile_model(pipe): except ImportError: logger.info("xformers not installed, skip") try: - import triton # noqa: F401 + import triton # noqa: F401 config.enable_triton = True except ImportError: diff --git a/runner/app/pipelines/segment_anything_2.py b/runner/app/pipelines/segment_anything_2.py index 64c4080d..cd5c852c 100644 --- a/runner/app/pipelines/segment_anything_2.py +++ b/runner/app/pipelines/segment_anything_2.py @@ -3,8 +3,8 @@ import PIL from app.pipelines.base import Pipeline -from app.pipelines.utils import get_torch_device, get_model_dir -from app.routes.util import InferenceError +from app.pipelines.utils import get_model_dir, get_torch_device +from app.utils.errors import InferenceError from PIL import ImageFile from sam2.sam2_image_predictor import SAM2ImagePredictor diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 10e4f485..a760aade 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -14,6 +14,7 @@ is_turbo_model, split_prompt, ) +from app.utils.errors import InferenceError from diffusers import ( AutoPipelineForText2Image, EulerDiscreteScheduler, @@ -263,14 +264,17 @@ def __call__( ) kwargs.update(neg_prompts) - output = self.ldm(prompt=prompt, **kwargs) + try: + outputs = self.ldm(prompt=prompt, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) if safety_check: - _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) else: - has_nsfw_concept = [None] * len(output.images) + has_nsfw_concept = [None] * len(outputs.images) - return output.images, has_nsfw_concept + return outputs.images, has_nsfw_concept def __str__(self) -> str: return f"TextToImagePipeline model_id={self.model_id}" diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 7396e8b0..c92109d3 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.pipelines.utils.audio import AudioConversionError -from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error +from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -37,6 +38,9 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) or isinstance(e, AudioConversionError): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE error_message = "Unsupported audio format or malformed file." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) else: status_code = status.HTTP_500_INTERNAL_SERVER_ERROR error_message = "Internal server error during audio processing." diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 63fc3b0f..ab730181 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -154,15 +155,21 @@ async def image_to_image( num_images_per_prompt=1, num_inference_steps=num_inference_steps, ) - images.extend(imgs) - has_nsfw_concept.extend(nsfw_checks) except Exception as e: logger.error(f"ImageToImagePipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("ImageToImagePipeline error"), ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) # TODO: Return None once Go codegen tool supports optional properties # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index a7c9350d..9fb7b7b0 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -140,6 +141,12 @@ async def image_to_video( except Exception as e: logger.error(f"ImageToVideoPipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("ImageToVideoPipeline error"), diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index 70436432..d3ef6bec 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -5,13 +5,8 @@ import numpy as np from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import ( - HTTPError, - InferenceError, - MasksResponse, - http_error, - json_str_to_np_array, -) +from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index c72dae53..00f57a54 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -142,19 +143,25 @@ async def text_to_image( has_nsfw_concept = [] params.num_images_per_prompt = 1 for seed in seeds: + params.seed = seed + kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"} try: - params.seed = seed - kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"} imgs, nsfw_check = pipeline(**kwargs) - images.extend(imgs) - has_nsfw_concept.extend(nsfw_check) except Exception as e: logger.error(f"TextToImagePipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("TextToImagePipeline error"), ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_check) # TODO: Return None once Go codegen tool supports optional properties # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 5aca1073..0e4788a7 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -5,7 +5,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/util.py b/runner/app/routes/utils.py similarity index 89% rename from runner/app/routes/util.py rename to runner/app/routes/utils.py index 8a319e84..6b223db7 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/utils.py @@ -70,22 +70,6 @@ class HTTPError(BaseModel): detail: APIError = Field(..., description="Detailed error information.") -class InferenceError(Exception): - """Exception raised for errors during model inference.""" - - def __init__(self, message="Error during model execution", original_exception=None): - """Initialize the exception. - - Args: - message: The error message. - original_exception: The original exception that caused the error. - """ - if original_exception: - message = f"{message}: {original_exception}" - super().__init__(message) - self.original_exception = original_exception - - def http_error(msg: str) -> HTTPError: """Create an HTTP error response with the specified message. diff --git a/runner/app/utils/__init__.py b/runner/app/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py new file mode 100644 index 00000000..1dfab466 --- /dev/null +++ b/runner/app/utils/errors.py @@ -0,0 +1,17 @@ +"""Custom exceptions used throughout the whole application.""" + + +class InferenceError(Exception): + """Exception raised for errors during model inference.""" + + def __init__(self, message="Error during model execution", original_exception=None): + """Initialize the exception. + + Args: + message: The error message. + original_exception: The original exception that caused the error. + """ + if original_exception: + message = f"{message}: {original_exception}" + super().__init__(message) + self.original_exception = original_exception diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index e0c86e56..3b4f0d25 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -1,6 +1,8 @@ import argparse import copy import json +import logging +import subprocess import yaml from app.main import app, use_route_names_as_operation_ids @@ -14,8 +16,6 @@ upscale, ) from fastapi.openapi.utils import get_openapi -import subprocess -import logging logging.basicConfig( level=logging.INFO, From 6ecc02434d1dfbd756385f68052f9c1f9a0b1011 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 08:26:17 +0200 Subject: [PATCH 02/14] fixup! refactor(runner): add InferenceError to all pipelines --- runner/app/pipelines/upscale.py | 12 ++++++++---- runner/app/routes/upscale.py | 7 +++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index e36e4606..6a80ab53 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -4,6 +4,7 @@ import PIL import torch +from app.utils.errors import InferenceError from app.pipelines.base import Pipeline from app.pipelines.utils import ( SafetyChecker, @@ -113,14 +114,17 @@ def __call__( if num_inference_steps is None or num_inference_steps < 1: del kwargs["num_inference_steps"] - output = self.ldm(prompt, image=image, **kwargs) + try: + outputs = self.ldm(prompt, image=image, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) if safety_check: - _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) else: - has_nsfw_concept = [None] * len(output.images) + has_nsfw_concept = [None] * len(outputs.images) - return output.images, has_nsfw_concept + return outputs.images, has_nsfw_concept def __str__(self) -> str: return f"UpscalePipeline model_id={self.model_id}" diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 0e4788a7..46c473a9 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -3,6 +3,7 @@ import random from typing import Annotated +from app.utils.errors import InferenceError from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url @@ -107,6 +108,12 @@ async def upscale( except Exception as e: logger.error(f"UpscalePipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("UpscalePipeline error"), From d2475fa0df19e8cbc939e27922f4ba5c5feb4cd5 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 09:06:31 +0200 Subject: [PATCH 03/14] refactor(runner): handle OOM error This commit ensures that users get a descriptive error message when the GPU runs out of memory. --- runner/app/pipelines/utils/audio.py | 2 +- runner/app/routes/audio_to_text.py | 17 +++++---- runner/app/routes/image_to_image.py | 47 +++++++++++++++++-------- runner/app/routes/image_to_video.py | 47 +++++++++++++++++-------- runner/app/routes/segment_anything_2.py | 45 +++++++++++++++-------- runner/app/routes/text_to_image.py | 45 +++++++++++++++-------- runner/app/routes/upscale.py | 45 +++++++++++++++-------- runner/app/utils/errors.py | 7 ++++ 8 files changed, 172 insertions(+), 83 deletions(-) diff --git a/runner/app/pipelines/utils/audio.py b/runner/app/pipelines/utils/audio.py index ccc15f04..a8e91bfb 100644 --- a/runner/app/pipelines/utils/audio.py +++ b/runner/app/pipelines/utils/audio.py @@ -11,7 +11,7 @@ class AudioConversionError(Exception): """Raised when an audio file cannot be converted.""" - def __init__(self, message="Audio conversion failed."): + def __init__(self, message="Audio conversion failed"): self.message = message super().__init__(self.message) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index c92109d3..99e49bee 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -6,7 +6,7 @@ from app.pipelines.base import Pipeline from app.pipelines.utils.audio import AudioConversionError from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -24,7 +24,7 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during audio processing. + """Handles exceptions raised during audio pipeline processing. Args: e: The exception raised during audio processing. @@ -32,18 +32,21 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"Audio processing error: {str(e)}") # Log the detailed error + logger.error(f"AudioToText pipeline error: {str(e)}") # Log the detailed error if "Soundfile is either not in the correct format or is malformed" in str( e ) or isinstance(e, AudioConversionError): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE error_message = "Unsupported audio format or malformed file." + elif "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = "Out of memory error." elif isinstance(e, InferenceError): status_code = status.HTTP_400_BAD_REQUEST error_message = str(e) else: status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Internal server error during audio processing." + error_message = "Audio-to-text pipeline error." return JSONResponse( status_code=status_code, @@ -80,7 +83,7 @@ async def audio_to_text( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if model_id != "" and model_id != pipeline.model_id: @@ -88,14 +91,14 @@ async def audio_to_text( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"{model_id}." ), ) if file_exceeds_max_size(audio, 50 * 1024 * 1024): return JSONResponse( status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE, - content=http_error("File size exceeds limit"), + content=http_error("File size exceeds limit."), ) try: diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index ab730181..6d6dd06f 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -6,7 +6,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -26,6 +26,34 @@ } +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during image-to-image pipeline processing. + + Args: + e: The exception raised during image-to-image processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error( + f"ImageToImagePipeline pipeline error: {str(e)}" + ) # Log the detailed error + if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = "Out of memory error. Try reducing input image resolution." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Image-to-image pipeline error." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( @@ -120,7 +148,7 @@ async def image_to_image( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if model_id != "" and model_id != pipeline.model_id: @@ -128,7 +156,7 @@ async def image_to_image( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"{model_id}." ), ) @@ -156,18 +184,7 @@ async def image_to_image( num_inference_steps=num_inference_steps, ) except Exception as e: - logger.error(f"ImageToImagePipeline error: {e}") - logger.exception(e) - if isinstance(e, InferenceError): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error(str(e)), - ) - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("ImageToImagePipeline error"), - ) + return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_checks) diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 9fb7b7b0..8355edf6 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -6,7 +6,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -25,6 +25,34 @@ } +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during image-to-video pipeline processing. + + Args: + e: The exception raised during image-to-video processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error(f"ImageToVideo pipeline error: {str(e)}") # Log the detailed error + if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = ( + "Out of memory error. Try reducing input or output video resolution." + ) + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Image-to-video pipeline error." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( @@ -102,7 +130,7 @@ async def image_to_video( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if model_id != "" and model_id != pipeline.model_id: @@ -110,7 +138,7 @@ async def image_to_video( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"{model_id}." ), ) @@ -139,18 +167,7 @@ async def image_to_video( seed=seed, ) except Exception as e: - logger.error(f"ImageToVideoPipeline error: {e}") - logger.exception(e) - if isinstance(e, InferenceError): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error(str(e)), - ) - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("ImageToVideoPipeline error"), - ) + return handle_pipeline_error(e) output_frames = [] for frames in batch_frames: diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index d3ef6bec..7cc71a46 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -6,7 +6,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -25,6 +25,32 @@ } +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during segment-anything-2 pipeline processing. + + Args: + e: The exception raised during segment-anything-2 processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error(f"SegmentAnything2 pipeline error: {str(e)}") # Log the detailed error + if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = "Out of memory error. Try reducing input image resolution." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Segment-anything-2 pipeline error." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373. @router.post( @@ -116,7 +142,7 @@ async def segment_anything_2( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if model_id != "" and model_id != pipeline.model_id: @@ -124,7 +150,7 @@ async def segment_anything_2( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"{model_id}." ), ) @@ -152,18 +178,7 @@ async def segment_anything_2( normalize_coords=normalize_coords, ) except Exception as e: - logger.error(f"Segment Anything 2 error: {e}") - logger.exception(e) - if isinstance(e, InferenceError): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error(str(e)), - ) - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("Segment Anything 2 error"), - ) + return handle_pipeline_error(e) # Return masks sorted by descending score as string. sorted_ind = np.argsort(scores)[::-1] diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 00f57a54..acac61a3 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -6,7 +6,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -17,6 +17,32 @@ logger = logging.getLogger(__name__) +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during text-to-image pipeline processing. + + Args: + e: The exception raised during text-to-image processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error + if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = "Out of memory error. Try reducing output image resolution." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Text-to-image pipeline error." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + class TextToImageParams(BaseModel): # TODO: Make model_id and other None properties optional once Go codegen tool # supports OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @@ -122,7 +148,7 @@ async def text_to_image( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if params.model_id != "" and params.model_id != pipeline.model_id: @@ -130,7 +156,7 @@ async def text_to_image( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{params.model_id}" + f"{params.model_id}." ), ) @@ -148,18 +174,7 @@ async def text_to_image( try: imgs, nsfw_check = pipeline(**kwargs) except Exception as e: - logger.error(f"TextToImagePipeline error: {e}") - logger.exception(e) - if isinstance(e, InferenceError): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error(str(e)), - ) - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("TextToImagePipeline error"), - ) + return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 46c473a9..ff31fdcd 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -3,7 +3,7 @@ import random from typing import Annotated -from app.utils.errors import InferenceError +from app.utils.errors import InferenceError, OutOfMemoryError from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url @@ -26,6 +26,32 @@ } +def handle_pipeline_error(e: Exception) -> JSONResponse: + """Handles exceptions raised during upscale pipeline processing. + + Args: + e: The exception raised during upscale processing. + + Returns: + A JSONResponse with the appropriate error message and status code. + """ + logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error + if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = "Out of memory error. Try reducing input image resolution." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) + else: + status_code = status.HTTP_500_INTERNAL_SERVER_ERROR + error_message = "Upscale pipeline error." + + return JSONResponse( + status_code=status_code, + content=http_error(error_message), + ) + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( @@ -81,7 +107,7 @@ async def upscale( return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, headers={"WWW-Authenticate": "Bearer"}, - content=http_error("Invalid bearer token"), + content=http_error("Invalid bearer token."), ) if model_id != "" and model_id != pipeline.model_id: @@ -89,7 +115,7 @@ async def upscale( status_code=status.HTTP_400_BAD_REQUEST, content=http_error( f"pipeline configured with {pipeline.model_id} but called with " - f"{model_id}" + f"{model_id}." ), ) @@ -106,18 +132,7 @@ async def upscale( seed=seed, ) except Exception as e: - logger.error(f"UpscalePipeline error: {e}") - logger.exception(e) - if isinstance(e, InferenceError): - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content=http_error(str(e)), - ) - - return JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content=http_error("UpscalePipeline error"), - ) + return handle_pipeline_error(e) seeds = [seed] diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 1dfab466..04a70021 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -15,3 +15,10 @@ def __init__(self, message="Error during model execution", original_exception=No message = f"{message}: {original_exception}" super().__init__(message) self.original_exception = original_exception + +class OutOfMemoryError(Exception): + """Raised when the system runs out of memory.""" + + def __init__(self, message="GPU ran out of memory."): + self.message = message + super().__init__(self.message) From 8566106d0d2790f116323a2673756f4f766277ff Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 09:08:25 +0200 Subject: [PATCH 04/14] chore: apply black formatter --- runner/app/utils/errors.py | 1 + 1 file changed, 1 insertion(+) diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 04a70021..2c97ef77 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -16,6 +16,7 @@ def __init__(self, message="Error during model execution", original_exception=No super().__init__(message) self.original_exception = original_exception + class OutOfMemoryError(Exception): """Raised when the system runs out of memory.""" From c5a5560bfab44f4b4bf99ec4e91f7a2d5c11dd2e Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 09:26:20 +0200 Subject: [PATCH 05/14] refactor(runner): improve response errors This commit ensures that all response errors are known by FastAPI and therefore shown in the docs. --- runner/app/routes/audio_to_text.py | 16 +++++++++------- runner/app/routes/image_to_image.py | 14 +++++++------- runner/app/routes/image_to_video.py | 13 +++++++------ runner/app/routes/segment_anything_2.py | 13 +++++++------ runner/app/routes/upscale.py | 14 +++++++------- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 99e49bee..0b72d378 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -15,13 +15,6 @@ logger = logging.getLogger(__name__) -RESPONSES = { - status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, - status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, - status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, -} - def handle_pipeline_error(e: Exception) -> JSONResponse: """Handles exceptions raised during audio pipeline processing. @@ -54,6 +47,15 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_413_REQUEST_ENTITY_TOO_LARGE: {"model": HTTPError}, + status.HTTP_415_UNSUPPORTED_MEDIA_TYPE: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + @router.post( "/audio-to-text", response_model=TextResponse, diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 6d6dd06f..ac3b75d9 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -19,13 +19,6 @@ logger = logging.getLogger(__name__) -RESPONSES = { - status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, - status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, -} - - def handle_pipeline_error(e: Exception) -> JSONResponse: """Handles exceptions raised during image-to-image pipeline processing. @@ -54,6 +47,13 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 8355edf6..b07b36c7 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -18,12 +18,6 @@ logger = logging.getLogger(__name__) -RESPONSES = { - status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, - status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, -} - def handle_pipeline_error(e: Exception) -> JSONResponse: """Handles exceptions raised during image-to-video pipeline processing. @@ -53,6 +47,13 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index 7cc71a46..70b89908 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -18,12 +18,6 @@ logger = logging.getLogger(__name__) -RESPONSES = { - status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, - status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, -} - def handle_pipeline_error(e: Exception) -> JSONResponse: """Handles exceptions raised during segment-anything-2 pipeline processing. @@ -51,6 +45,13 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373. @router.post( diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index ff31fdcd..597e57f9 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -19,13 +19,6 @@ logger = logging.getLogger(__name__) -RESPONSES = { - status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, - status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, - status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, -} - - def handle_pipeline_error(e: Exception) -> JSONResponse: """Handles exceptions raised during upscale pipeline processing. @@ -52,6 +45,13 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) +RESPONSES = { + status.HTTP_400_BAD_REQUEST: {"model": HTTPError}, + status.HTTP_401_UNAUTHORIZED: {"model": HTTPError}, + status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError}, +} + + # TODO: Make model_id and other None properties optional once Go codegen tool supports # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 @router.post( From b81a3ce61617f565ed10b776ca350ce4c4ca894f Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 5 Sep 2024 09:45:08 +0200 Subject: [PATCH 06/14] refactor(worker): add missing error handling This commit adds some missing error handling to the pipeline worker functions. --- runner/gateway.openapi.yaml | 8 +- runner/openapi.yaml | 8 +- worker/runner.gen.go | 111 +++++++++++++------------ worker/worker.go | 159 +++++++++++++++++++++++++----------- 4 files changed, 185 insertions(+), 101 deletions(-) diff --git a/runner/gateway.openapi.yaml b/runner/gateway.openapi.yaml index 5a3a1834..72220e47 100644 --- a/runner/gateway.openapi.yaml +++ b/runner/gateway.openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.1.3 + version: v0.2.0 servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -221,6 +221,12 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPError' + '415': + description: Unsupported Media Type + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' '500': description: Internal Server Error content: diff --git a/runner/openapi.yaml b/runner/openapi.yaml index dfd87a18..773781b6 100644 --- a/runner/openapi.yaml +++ b/runner/openapi.yaml @@ -2,7 +2,7 @@ openapi: 3.1.0 info: title: Livepeer AI Runner description: An application to run AI pipelines - version: v0.1.3 + version: v0.2.0 servers: - url: https://dream-gateway.livepeer.cloud description: Livepeer Cloud Community Gateway @@ -232,6 +232,12 @@ paths: application/json: schema: $ref: '#/components/schemas/HTTPError' + '415': + description: Unsupported Media Type + content: + application/json: + schema: + $ref: '#/components/schemas/HTTPError' '500': description: Internal Server Error content: diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 3ae7e692..b918f931 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -841,6 +841,7 @@ type AudioToTextResponse struct { JSON400 *HTTPError JSON401 *HTTPError JSON413 *HTTPError + JSON415 *HTTPError JSON422 *HTTPValidationError JSON500 *HTTPError } @@ -1126,6 +1127,13 @@ func ParseAudioToTextResponse(rsp *http.Response) (*AudioToTextResponse, error) } response.JSON413 = &dest + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 415: + var dest HTTPError + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON415 = &dest + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 422: var dest HTTPValidationError if err := json.Unmarshal(bodyBytes, &dest); err != nil { @@ -1779,57 +1787,58 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xbe2/bxrL/KgveCzQBZFt26ubCQP9w0jY2bpIasdO0yDWEFTkityF32X3YVnP93Q9m", - "dknxack5jovT6i+L5OzMb2Z2Hvvw5yhWRakkSGuio8+RiTMoOP08Pjv9UWul8XcCJtaitELJ6Ai/MMBP", - "TIMplTTACpVAvhtNolKrErQVQDwKk/aHX2QQhhdgDE8Bx1lhc4iOojcmxadliQ/GaiHT6PZ2Emn4wwkN", - "SXT0kbherobUQOtxav47xDa6nUQvVLKccZcINbNqZuHGdp5KZSxCbOMmmj7y92WueAIJo+9sIXJgVrE5", - "MKu5RMo5JKjNQumC2+gomgvJ9bKh3zFx7mk4iciCM5F4qQvuchwfTToQTlyaCpmyn3gcrM5Of2DOQMIW", - "Stc4iLxlV0+arDWuV71h3s1MOGZ7UfAUkNT/6DwOWz91IuEyhpmJOUJoGOT57mHXIj/KWDnNUzDBHlax", - "FCRoboGRGMPiXBnIlywX8hMkSGEzYIielVoVpWVPMpFmoNkVzx1y4kumIXFxYMH+cDwXdvm0adNXASc7", - "J5y1CaQr5qDRBDT2jnnkeVuFyMViya6FzQhaKUrIhYS7J9MpsR+YTN66d9hxv2/HHyDVQGCuMxF7GJUd", - "K6TCsNKZjEx4zXViiEpIYQXPPc1uFx9bb6aHmfseYoC88fSfRBJSbsUVzPxUWAPiYjVpnpinNNmcSIBd", - "Z9ziE9zEuUuALbQq+pDYaSqVRnsuWNs97P/cdPosZvtN2G8DNHbmoQ2hd4UPJjMrQQ/psN9V4S0ZnqlF", - "FR7NiClBB/VaQFzBTj3xGegeHCEtpN6XhEcuQAOpZqE0bTTT6TieBKQSBn1MA3fZG6XB/2bOOJ5jDAOn", - "CA4BGwKzUmXuLDO5ugbNahTIJnE5zeP5khmrQaY26+lX0bNzQj2kXdO8m8yKu+bkuE8NX4BdzuIM4k8t", - "41ntoGu9M9CYIRhnfhijYTQVjRUFZcFFN5INi5XLEyxdarEAaXCSKc0yrouFy5swzz3XlwSmBjtXKgcu", - "CS1A0rfIOYSw1FwmqmA+2kdMgcSD9q581bLCdPd/RpKXWvgK6FOmUJLxsszFKuVrqHzsPfNkil/2W2n9", - "vJLZy1SdUllWDvRpvlszNyh9a6vmlUhAdR+Hq+aiE2jfTQb6roXmBRgKcgOxkgmZrJXpSUbTHj+NxEIG", - "Is3aqebw+aBUT8mEZKW4gdxsIPTEMx+Su3FRrXMa9/wpJ39hRX2YEuVh3L9EFQqpZ3MXfwLbRbF/8LwL", - "430lEF0s8CWCQpPzQjlp0QGep+8as3aRIp/59IqfQujizwLzcRh5LfIcE4iQ9Knnwjee7AWBbinWLBdK", - "GJhxl85GQn160FXuuFaBBjOeJKsAbynsGxJ20mrtQlunwUAxz6kxGR3LuEyYkLEGbiq9W2WDABy7lI0n", - "jfUl8eDwP7gibmtVZYlrkXRm7/704NuhfEiU90qHH4h3X2qnIq0pROPVZKwQGUgLkHbG5dJmQqazg6FX", - "wwVprm4G9g5YTlONfcu41nzJUnEFknHDOJurm2o5FiKSMugELfXrb7/+xnzebtrlhboZXf/0hZ9WlSHo", - "8KW1gJtPMyFLZwf1U9c7GozKHaU/JGZE3FHKLksRU/zSUoGzUsOVUM7gj0TENFrYMAMnq7UhRdD+zcnN", - "B/bk5PsP3x8cfkeT9/z4TauPeYOSTwnmVytoX7rmKlyO8W4+zZSztSHvyByn2Nk5mKws6OuPBus0FiBs", - "/5ChIVy8mIvUoTG96f20MhOmFhYkPiYuRr3mYC3oMNJmXGJuEjLNoeGGllYVcvazRz6UZiROqlz8CbNY", - "KZ2Y+6lXKiEto5FCcgumLrU131VDy2UK7ON0sn8ZpgiNDnIZ3JQQW08+B0+gweBLfOXdl4gCs6qSpl3b", - "giz20uswpGhTWD8Y3t4chChXi6BVcEQnFq4z0MCAxwE+E+g49uTXyW9PV3mytXgisi6y1QTzwHI+h3wA", - "2Gt6X/c+LWgVmn0mZCJisj9HUki1cjIJ1NgZTFskcx5/apL04XqxQ3D9NJ7lKhX2HrPFDzPMyR2MAJOp", - "HHshmp6eFxPSWOwP1AIhUo6j701073wQvfbS+37eqMrco1SMVRtX0lZI/Xe4rvxVm2oPkzGDbsmXb1et", - "6SafH/6D9lc2suZ2o2Vd83q/jY3BMB2I6ZOLi7ORIyT8tOEZUgKWi5xOZfL850V09PFz9N8aFtFR9F97", - "q9OrvXB0tVcfB91e9neJkBUkQbKQ9T7Rbs8GQWxD95U6I7r+wnORELta6zFVhIWCXt2lSZff7QqL12QF", - "hKos6dBE22UwhBt4brOXVQC08RrLrWtnl+jn/21tlRHB0JnSaiNnJWBAPmXbd2EK9OfJu9bkGG05BwqE", - "GT547IYnjt7IGW8gEbzpAr8dPuSCXq00zWnU1njAJNi3m3uZxI+tVjUjVmn2Fl2raH49YU422stV82vY", - "Ez/0ad0vUbfcTCvdzqG9Vlrrih4/MsFgwo6VHnMt2eMbTLhyIRIqNJ6ccFNz2RbZSoye8dqj5wDMVOTB", - "qpcd7Hf6l2bSwMqxwA+VM2MlLRd+K0o29rPnCleSbfPhuL7DpVlc98V8yMBWG3te4DU3bJHzNIUE1+Fv", - "z3/60CrdyGbzcoSewC++42nuwtYSN9pNcTofZv7+3evQoa9UiLnECsvjGIzxR/CVgPc6X+tVRzTGQyGz", - "Nf1J7hrwI/Yk9wpTOm++K3HFmZPro4XYeNKNsxeRN7PXSy+qm70mEXLfBEHTxmiJtUa2nijoeNkefVe8", - "4PcLRWnzjGvulf273hh4yEOV3nn8HYcq2yP4f84R/OE/+gSenUPJyc60CVnSTp/flKKdgm/+/xucGsaV", - "pdIBcL1VtV1W/mVnIr1stuGZSJgwnYLTLigDVWftYi5XcWslx+UyrE678+FzD+LlbbN3jknMQC0O1ydX", - "nQhdlxyacf7FipQwswt8u64uox5eVKBsWGqDBeQvIgF1rzZo6BC8c5WBbims60KqM32kbTVC91zPdRug", - "6tqDB7FmfRegNm3WMsiAxXwvNtD/0wea+JjLKBlxZkUBxvKi7JtpvFUjBiGCiOv6bg2/B0kjPKvPPcaV", - "vRvGu6h5rbGfbRIisIYlvaF6FqSUFTst7PIcnemNcXJxcfYCuAZd32OmPOdf1Uwya8voFnngqmrAC+HK", - "kI9JzMLaSXZ8Wm8em+ayV1xBCaDx+zsnJQm6Am08r6vp7v7uMzStKkHyUkRH0bPd/d0pepLbjHDv0WXa", - "Hat2KndWu94dF9Q3jBu3j/05SujGcWoQ6tOkumh8oYKz0eRg7AuVLGmhoaQFSVJ8EeTa7mEV2km45auL", - "4OuCaLO7wbdtp2MRpBc+RMgKB9NpB1fDC3u/GzTBpqBa6wmS3alsjpaJC5ezFdkk+vYBIay2Cgfkv+AJ", - "e+f94eXuP47c95I7mykt/oSEBO8/exzBQVn2o7TYJ14oxV5znXqrHxw8KIjenmkfzoqE1fuqh4/l/FNp", - "QUues3PQV6ArBI2cRi1EM5t9vLy9nETGFQXXyyqy2YViFNs4dC+jTVZaCQOhb+cCvwcbfcWYa+7ybhpy", - "t02lAkTShjo9zIj1Sd9wSjwuy3xZHfe1rn1SXuTY12OT0Ogd23ahFjB0gl85SW5wGfSR02R7H3qbJ8fz", - "5DZF3TdF+YtWF8pvWHSimvr28ah+NXRrd/Ngpvb7sYJ5/ELdIwdze9GxDeZtMH+FYPahRcEcjhp3qps9", - "OwfjAX3uacPBFl3s4nIsigPxceB78JUj+R6Xlh45otvHiNuI3kb0w0V0FZFVlLEDH9W4ZN+g737VOSuj", - "At04GjP9sG5sud4Z0f/eqr+9qbvtp7cB+zcJWDrcarfT4fLdeJS+9wR1rWXzZfWvTHSpxBq2+jeFfsSG", - "4V+5/g5eJdwG7jZw/yaBW0XRrR+FbAwN6vw/QnWQ8DJXLmEvVVE4KeySveIWrvkyCveh6PjCHO3tJRp4", - "sZP6r7t5GL4b43A6cRzhf25p83CMbc3IEN0eL8XeHCzfq07dotvL238FAAD//x9j1/Y0RgAA", + "H4sIAAAAAAAC/+xb+2/bRvL/Vxb8foEmgGzLat0cDPQHJ21j45LUiJ2mRc4QVuSI2obcZfdhW83pfz/M", + "7JLi05Jzjotr9ZNFcnbmMzM7j334UxSrvFASpDXR8afIxAvIOf08OT/7QWul8XcCJtaisELJ6Bi/MMBP", + "TIMplDTAcpVAth+NokKrArQVQDxyk3aHXy4gDM/BGJ4CjrPCZhAdR69Nik/LAh+M1UKm0Wo1ijT87oSG", + "JDr+QFyv1kMqoNU4NfsNYhutRtFzlSyn3CVCTa2aWri1radCGYsQm7iJpov8XZEpnkDC6DubiwyYVWwG", + "zGoukXIGCWozVzrnNjqOZkJyvazpd0KcOxqOIrLgVCRe6py7DMdHoxaEU5emQqbsRx4Hq7Oz75kzkLC5", + "0hUOIm/Y1ZMmG43rVa+ZdzsTDtle5DwFJPU/Wo/91k+dSLiMYWpijhBqBnm2f9S2yA8yVk7zFEywh1Us", + "BQmaW2AkxrA4UwayJcuE/AgJUtgFMETPCq3ywrInC5EuQLNrnjnkxJdMQ+LiwIL97ngm7PJp3aYvA052", + "QTgrE0iXz0CjCWjsHfPI87YKkYv5kt0IuyBohSggExLunkxnxL5nMnnr3mHHw64dv4dUA4G5WYjYwyjt", + "WCIVhhXOLMiEN1wnhqiEFFbwzNPst/GxzWZ6mLnvIQbIW0//USQh5VZcw9RPhQ0gLteT5ol5SpPNiQTY", + "zYJbfILbOHMJsLlWeRcSO0ul0mjPOWu6h/3Ljcdfx+ywDvtNgMbOPbQ+9C73wWSmBeg+HQ7bKrwhwzM1", + "L8OjHjEF6KBeA4jL2ZknPgfdgSOkhdT7kvDIOWgg1SwUpolmPB7Gk4BUwqCPaeA+e600+N/MGcczjGHg", + "FMEhYENglqrMnGUmUzegWYUC2SQuo3k8WzJjNcjULjr6lfTsglD3aVc37zaz4q45OexTw+dgl9N4AfHH", + "hvGsdtC23jlozBCMMz+M0TCaisaKnLLgvB3JhsXKZQmWLjWfgzQ4yZRmC67zucvqMC881xcEpgI7UyoD", + "LgktQNK1yAWEsNRcJipnPtoHTIHEvfYufdWwwnj/HwPJS819BfQpUyjJeFFkYp3yNZQ+9p55MsYvh420", + "flHK7GSqVqksSgf6NN+umVuUvo1V81okoNqP/VVz3gq0b0c9fddc8xwMBbmBWMmETNbI9CSjbo8fB2Jh", + "ASJdNFPN0bNeqZ6SCckKcQuZ2ULoqWfeJ3frolrlNO75U07+zIr6MCXKw7h/icoVUk9nLv4Ito3icPKs", + "DeNdKRBdLPAlgkKT81w5adEBnqfvGhfNIkU+8+kVP4XQxZ855uMw8kZkGSYQIelTx4WvPdlzAt1QrF4u", + "lDAw5S6dDoT6eNJW7qRSgQYzniTrAG8o7BsSdtpo7UJbp8FAPsuoMRkcy7hMmJCxBm5KvRtlgwCcuJQN", + "J43NJXFy9D9cEXe1qrTEjUhas/dwPPmmLx8S5b3S4Xvi3ZXaqkgbCtFwNRkqRAbSHKSdcrm0CyHT6aTv", + "VX9Bmqnbnr0DltFUY98wrjVfslRcg2TcMM5m6rZcjoWIpAw6Qkv98usvvzKft+t2ea5uB9c/XeFnZWUI", + "OnxuLeDm41TIwtle/dTNngajMkfpD4kZEbeUsstCxBS/tFTgrNBwLZQz+CMRMY0WNszA0XptSBF0eHt6", + "+549Of3u/XeTo29p8l6cvG70Ma9R8hnB/GIF7XPXXLnLMN7Nx6lytjLkHZnjDDs7B6O1BX390WCdxgKE", + "7R8yNISL5zOROjSmN72fVmbE1NyCxMfExajXDKwFHUbaBZeYm4RMM6i5oaFViZz95JH3pRmJkyoTf8A0", + "Vkon5n7qFUpIy2ikkNyCqUptxXfd0HKZAvswHh1ehSlCo4NcBrcFxNaTz8ATaDD4El959yUix6yqpGnW", + "tiCLvfA69ClaF9YNhje3kxDlah60Co5oxcLNAjQw4HGAzwQ6jj35ZfTr03WebCyeiKyNbD3BPLCMzyDr", + "AfaK3le9TwNaieaQCZmImOzPkRRSrZxMAjV2BuMGyYzHH+skXbhebB9cP42nmUqFvcds8cMMc3IPI8As", + "VIa9EE1Pz4sJaSz2B2qOECnH0fc6urc+iF556V0/b1Vl7lEqhqqNK2grpPrbX1f+rE21h8mYQbfk87er", + "NnSTz47+RvsrW1lzt9GyqXm938ZGb5j2xPTp5eX5wBESftryDCkBy0VGpzJZ9tM8Ov7wKfp/DfPoOPq/", + "g/Xp1UE4ujqojoNWV91dImQFSZAsZLVPtN+xQRBb032tzoCuP/NMJMSu0npIFWEhp1d3adLmt1pj8Zqs", + "gVCVJR3qaNsM+nADz+ziRRkATbzGcuua2SX66Z+NrTIi6DtTWm/krAX0yKds+zZMge48eduYHIMtZ0+B", + "MP0Hj+3wxNFbOeM1JILXXeC3w/tc0KmVpj6Nmhr3mAT7dnMvk/ix5apmwCr13qJtFc1vRszJWnu5bn4N", + "e+KHPq36JeqW62ml3Tk010obXdHhRyboTdix0kOuJXt8hQlXzkVChcaTE25qLpsiG4nRM9549ByAmZI8", + "WPWqhf1O/9JM6lk55vihdGaspOXCb0XJ2n72TOFKsmk+HNd1uDTzm66Y9wuw5caeF3jDDZtnPE0hwXX4", + "m4sf3zdKN7LZvhyhJ/CL73jqu7CVxK12U5zO+pm/e/sqdOhrFWIuscLyOAZj/BF8KeCdzjZ61RGN8VDI", + "bHV/krt6/Ig9yb3ClM6b70pc8cLJzdFCbDzp1tmLyOvZ64UX1c5eowi5b4OgbmO0xEYjW08UdLxqjr4r", + "XvD7paK0ec4198r+VW8MPOShSuc8/o5Dld0R/N/nCP7ob30Czy6g4GRn2oQsaKfPb0rRTsFX//4Kp4Zx", + "RaF0AFxtVe2WlX/amUgnm215JhImTKvgNAtKT9XZuJjLVNxYyXG5DKvT9nz41IF4tar3zjGJ6anF4frk", + "uhOh65J9M86/WJMSZnaJbzfVZdTDiwqUNUttsYD8WSSg7tUG9R2Ct64y0C2FTV1IeaaPtI1G6J7ruXYD", + "VF578CA2rO8C1LrNGgbpsZjvxXr6f/pAEx9zGSUjzqzIwVieF10zDbdqxCBEEHHd3K3h9yBpgGf5ucO4", + "tHfNeJcVrw32s3VCBFazpDdUx4KUsmKnhV1eoDO9MU4vL8+fA9egq3vMlOf8q4rJwtoiWiEPXFX1eCFc", + "GfIxiVlYO8lOzqrNY1Nf9oprKAA0fn/rpCRB16CN53U93p/sj9G0qgDJCxEdR1/vH+6P0ZPcLgj3AV2m", + "3bNqr3RnuevdckF1w7h2+9ifo4RuHKcGoT5LyovGlyo4G00Oxj5XyZIWGkpakCTFF0Gu7QFWob2EW76+", + "CL4piLa7G7xqOh2LIL3wIUJWmIzHLVw1Lxz8ZtAE24JqrCdIdquyOVomzl3G1mSj6JsHhLDeKuyR/5wn", + "7K33h5d7+Dhy30nu7EJp8QckJPjw68cRHJRlP0iLfeKlUuwV16m3+uHRY2m/7ugo9/vqiBAmkweF0Nm2", + "7YJZk7Bqa/fosebfmbSgJc/YBehr0CWCWlqlLqaeUD9cra5GkXF5zvWyTC7sUjFKLzj0YEH7vLQYB0Lf", + "TEd+Gzj6gmFf32jeNupXdaUCRNKGmk1MytVhY39WPimKbFmeODZunlJq5ri0wD6l1r427UJdaGhGv3Ce", + "3uI+6iNn6uZW+C5VD6fqXYq6b4ryd70uld8zaUU1LR2Go/pl38Xh7YOZVgCPFczDd/oeOZib655dMO+C", + "+QsEsw8tCuZw2rlXXi7amwwH9IWnDWdrdLeMy6EoDsQnge/kC0fyPe5NPXJEN08ydxG9i+iHi+gyIsso", + "YxMf1RZu7RZ998vWcR0V6NrpnOmGdW3X986I/u82Hpr7yrt+ehewf5GApfO1Zjsd7v8NR+k7T1DVWjZb", + "lv9NRfdarGHr/5ToRmwY/oXrb+9txl3g7gL3LxK4ZRSt/ChkY2hQ618iyrOMF5lyCXuh8txJYZfsJbdw", + "w5dRuJJFJyjm+OAg0cDzvdR/3c/C8P0Yh9Oh5wD/C0ubh0NsK0aG6A54IQ5mYPlBefAXra5W/wkAAP//", + "QoG6BLdGAAA=", } // GetSwagger returns the content of the embedded swagger specification file diff --git a/worker/worker.go b/worker/worker.go index 06b0b946..b30cc852 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -70,22 +70,31 @@ func (w *Worker) TextToImage(ctx context.Context, req TextToImageJSONRequestBody return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("text-to-image container returned 422", slog.String("err", string(val))) - return nil, errors.New("text-to-image container returned 422") + slog.Error("text-to-image container returned 400", slog.String("err", string(val))) + return nil, errors.New("text-to-image container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("text-to-image container returned 400", slog.String("err", string(val))) - return nil, errors.New("text-to-image container returned 400") + slog.Error("text-to-image container returned 401", slog.String("err", string(val))) + return nil, errors.New("text-to-image container returned 401") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("text-to-image container returned 422", slog.String("err", string(val))) + return nil, errors.New("text-to-image container returned 422") } if resp.JSON500 != nil { @@ -118,22 +127,31 @@ func (w *Worker) ImageToImage(ctx context.Context, req ImageToImageMultipartRequ return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("image-to-image container returned 422", slog.String("err", string(val))) - return nil, errors.New("image-to-image container returned 422") + slog.Error("image-to-image container returned 400", slog.String("err", string(val))) + return nil, errors.New("image-to-image container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("image-to-image container returned 400", slog.String("err", string(val))) - return nil, errors.New("image-to-image container returned 400") + slog.Error("image-to-image container returned 401", slog.String("err", string(val))) + return nil, errors.New("image-to-image container returned 401") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("image-to-image container returned 422", slog.String("err", string(val))) + return nil, errors.New("image-to-image container returned 422") } if resp.JSON500 != nil { @@ -166,22 +184,31 @@ func (w *Worker) ImageToVideo(ctx context.Context, req ImageToVideoMultipartRequ return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("image-to-video container returned 422", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 422") + slog.Error("image-to-video container returned 400", slog.String("err", string(val))) + return nil, errors.New("image-to-video container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("image-to-video container returned 400", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 400") + slog.Error("image-to-video container returned 401", slog.String("err", string(val))) + return nil, errors.New("image-to-video container returned 401") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("image-to-video container returned 422", slog.String("err", string(val))) + return nil, errors.New("image-to-video container returned 422") } if resp.JSON500 != nil { @@ -219,22 +246,31 @@ func (w *Worker) Upscale(ctx context.Context, req UpscaleMultipartRequestBody) ( return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("upscale container returned 422", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 422") + slog.Error("upscale container returned 400", slog.String("err", string(val))) + return nil, errors.New("upscale container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("upscale container returned 400", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 400") + slog.Error("upscale container returned 401", slog.String("err", string(val))) + return nil, errors.New("upscale container returned 401") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("upscale container returned 422", slog.String("err", string(val))) + return nil, errors.New("upscale container returned 422") } if resp.JSON500 != nil { @@ -267,22 +303,22 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("audio-to-text container returned 422", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 422") + slog.Error("audio-to-text container returned 400", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("audio-to-text container returned 400", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 400") + slog.Error("audio-to-text container returned 401", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 401") } if resp.JSON413 != nil { @@ -291,6 +327,24 @@ func (w *Worker) AudioToText(ctx context.Context, req AudioToTextMultipartReques return nil, errors.New(msg) } + if resp.JSON415 != nil { + val, err := json.Marshal(resp.JSON415) + if err != nil { + return nil, err + } + slog.Error("audio-to-text container returned 415", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 415") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("audio-to-text container returned 422", slog.String("err", string(val))) + return nil, errors.New("audio-to-text container returned 422") + } + if resp.JSON500 != nil { val, err := json.Marshal(resp.JSON500) if err != nil { @@ -321,22 +375,31 @@ func (w *Worker) SegmentAnything2(ctx context.Context, req SegmentAnything2Multi return nil, err } - if resp.JSON422 != nil { - val, err := json.Marshal(resp.JSON422) + if resp.JSON400 != nil { + val, err := json.Marshal(resp.JSON400) if err != nil { return nil, err } - slog.Error("segment anything 2 container returned 422", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 422") + slog.Error("segment anything 2 container returned 400", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 400") } - if resp.JSON400 != nil { - val, err := json.Marshal(resp.JSON400) + if resp.JSON401 != nil { + val, err := json.Marshal(resp.JSON401) if err != nil { return nil, err } - slog.Error("segment anything 2 container returned 400", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 400") + slog.Error("segment anything 2 container returned 401", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 401") + } + + if resp.JSON422 != nil { + val, err := json.Marshal(resp.JSON422) + if err != nil { + return nil, err + } + slog.Error("segment anything 2 container returned 422", slog.String("err", string(val))) + return nil, errors.New("segment anything 2 container returned 422") } if resp.JSON500 != nil { From 53d76ebf885cf5f675ed675d61c78ab6aa614d97 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 10:05:42 +0200 Subject: [PATCH 07/14] refactor: improve out of memory error handling This commit improves the out of memory error handling by using the native torch error. --- runner/app/pipelines/segment_anything_2.py | 2 +- runner/app/pipelines/upscale.py | 2 ++ runner/app/routes/audio_to_text.py | 6 ++++-- runner/app/routes/image_to_image.py | 7 ++----- runner/app/routes/image_to_video.py | 8 +++++--- runner/app/routes/segment_anything_2.py | 6 ++++-- runner/app/routes/text_to_image.py | 6 +++--- runner/app/routes/upscale.py | 8 +++++--- 8 files changed, 26 insertions(+), 19 deletions(-) diff --git a/runner/app/pipelines/segment_anything_2.py b/runner/app/pipelines/segment_anything_2.py index cd5c852c..8278fc28 100644 --- a/runner/app/pipelines/segment_anything_2.py +++ b/runner/app/pipelines/segment_anything_2.py @@ -3,7 +3,7 @@ import PIL from app.pipelines.base import Pipeline -from app.pipelines.utils import get_model_dir, get_torch_device +from app.pipelines.utils import get_torch_device, get_model_dir from app.utils.errors import InferenceError from PIL import ImageFile from sam2.sam2_image_predictor import SAM2ImagePredictor diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index e78d92eb..c82e5508 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -117,6 +117,8 @@ def __call__( try: outputs = self.ldm(prompt, image=image, **kwargs) + except torch.cuda.OutOfMemoryError as e: + raise e except Exception as e: raise InferenceError(original_exception=e) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 021cda93..1e5347a1 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -2,6 +2,7 @@ import os from typing import Annotated +import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.pipelines.utils.audio import AudioConversionError @@ -42,15 +43,15 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"AudioToText pipeline error: {str(e)}") # Log the detailed error if "Soundfile is either not in the correct format or is malformed" in str( e ) or isinstance(e, AudioConversionError): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE error_message = "Unsupported audio format or malformed file." - elif "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + elif isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = "Out of memory error." + torch.cuda.empty_cache() elif isinstance(e, InferenceError): status_code = status.HTTP_400_BAD_REQUEST error_message = str(e) @@ -118,4 +119,5 @@ async def audio_to_text( try: return pipeline(audio=audio) except Exception as e: + logger.error(f"AudioToText pipeline error: {str(e)}") return handle_pipeline_error(e) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 49b62446..43f91f9c 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -30,11 +30,7 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error( - f"ImageToImagePipeline pipeline error: {str(e)}" - ) # Log the detailed error - logger.exception(e) # TODO: Check if needed. - if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError) or isinstance(torch.cuda.OutOfMemoryError): # TODO: simplify condition. + if isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = "Out of memory error. Try reducing input image resolution." torch.cuda.empty_cache() @@ -215,6 +211,7 @@ async def image_to_image( num_inference_steps=num_inference_steps, ) except Exception as e: + logger.error(f"ImageToImagePipeline pipeline error: {str(e)}") return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_checks) diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 22b57fcd..a9185616 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -4,9 +4,10 @@ from typing import Annotated from app.dependencies import get_pipeline +import torch from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError, OutOfMemoryError +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -28,12 +29,12 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"ImageToVideo pipeline error: {str(e)}") # Log the detailed error - if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + if isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = ( "Out of memory error. Try reducing input or output video resolution." ) + torch.cuda.empty_cache() elif isinstance(e, InferenceError): status_code = status.HTTP_400_BAD_REQUEST error_message = str(e) @@ -181,6 +182,7 @@ async def image_to_video( seed=seed, ) except Exception as e: + logger.error(f"ImageToVideo pipeline error: {str(e)}") return handle_pipeline_error(e) output_frames = [] diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index ec386fbf..d2f5e9e5 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -3,6 +3,7 @@ from typing import Annotated import numpy as np +import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array @@ -28,10 +29,10 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"SegmentAnything2 pipeline error: {str(e)}") # Log the detailed error - if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + if isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = "Out of memory error. Try reducing input image resolution." + torch.cuda.empty_cache() elif isinstance(e, InferenceError): status_code = status.HTTP_400_BAD_REQUEST error_message = str(e) @@ -192,6 +193,7 @@ async def segment_anything_2( normalize_coords=normalize_coords, ) except Exception as e: + logger.error(f"SegmentAnything2 pipeline error: {str(e)}") return handle_pipeline_error(e) # Return masks sorted by descending score as string. diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index cf852492..23a09b77 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -7,7 +7,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError, OutOfMemoryError +from app.utils.errors import InferenceError from app.pipelines.utils.utils import LoraLoadingError from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse @@ -28,8 +28,7 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error - if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError) or isinstance(e, torch.cuda.OutOfMemoryError): # TODO: Simplify. + if isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = "Out of memory error. Try reducing output image resolution." torch.cuda.empty_cache() @@ -204,6 +203,7 @@ async def text_to_image( try: imgs, nsfw_check = pipeline(**kwargs) except Exception as e: + logger.error(f"TextToImage pipeline error: {str(e)}") return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 30d84343..412bcce9 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -3,7 +3,8 @@ import random from typing import Annotated -from app.utils.errors import InferenceError, OutOfMemoryError +import torch +from app.utils.errors import InferenceError from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url @@ -28,10 +29,10 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: Returns: A JSONResponse with the appropriate error message and status code. """ - logger.error(f"TextToImage pipeline error: {str(e)}") # Log the detailed error - if "CUDA out of memory" in str(e) or isinstance(e, OutOfMemoryError): + if isinstance(e, torch.cuda.OutOfMemoryError): status_code = status.HTTP_400_BAD_REQUEST error_message = "Out of memory error. Try reducing input image resolution." + torch.cuda.empty_cache() elif isinstance(e, InferenceError): status_code = status.HTTP_400_BAD_REQUEST error_message = str(e) @@ -145,6 +146,7 @@ async def upscale( seed=seed, ) except Exception as e: + logger.error(f"TextToImage pipeline error: {str(e)}") return handle_pipeline_error(e) seeds = [seed] From 25a7700dac620a28cadf3df80f3a1808475435b1 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 11:18:19 +0200 Subject: [PATCH 08/14] feat: forward runner errors upstream This commit ensures that errors thrown by the runner are forwarded to the orchestrator. It applies the logic used by the SAM2 and audio-to-text pipelines to the other pipelines. --- runner/app/routes/audio_to_text.py | 2 +- runner/app/routes/image_to_image.py | 2 +- runner/app/routes/segment_anything_2.py | 2 +- runner/app/utils/errors.py | 8 ---- worker/worker.go | 54 ++++++++++++------------- 5 files changed, 30 insertions(+), 38 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 1e5347a1..dc1133d9 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -7,7 +7,7 @@ from app.pipelines.base import Pipeline from app.pipelines.utils.audio import AudioConversionError from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error -from app.utils.errors import InferenceError, OutOfMemoryError +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 43f91f9c..ba8fbc9c 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -7,7 +7,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError, OutOfMemoryError +from app.utils.errors import InferenceError from app.pipelines.utils.utils import LoraLoadingError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index d2f5e9e5..9640de6b 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -7,7 +7,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array -from app.utils.errors import InferenceError, OutOfMemoryError +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py index 2c97ef77..1dfab466 100644 --- a/runner/app/utils/errors.py +++ b/runner/app/utils/errors.py @@ -15,11 +15,3 @@ def __init__(self, message="Error during model execution", original_exception=No message = f"{message}: {original_exception}" super().__init__(message) self.original_exception = original_exception - - -class OutOfMemoryError(Exception): - """Raised when the system runs out of memory.""" - - def __init__(self, message="GPU ran out of memory."): - self.message = message - super().__init__(self.message) diff --git a/worker/worker.go b/worker/worker.go index 33199be4..8993a641 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -89,7 +89,7 @@ func (w *Worker) TextToImage(ctx context.Context, req GenTextToImageJSONRequestB return nil, err } slog.Error("text-to-image container returned 401", slog.String("err", string(val))) - return nil, errors.New("text-to-image container returned 401") + return nil, errors.New("text-to-image container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON422 != nil { @@ -98,7 +98,7 @@ func (w *Worker) TextToImage(ctx context.Context, req GenTextToImageJSONRequestB return nil, err } slog.Error("text-to-image container returned 422", slog.String("err", string(val))) - return nil, errors.New("text-to-image container returned 422") + return nil, errors.New("text-to-image container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -107,7 +107,7 @@ func (w *Worker) TextToImage(ctx context.Context, req GenTextToImageJSONRequestB return nil, err } slog.Error("text-to-image container returned 500", slog.String("err", string(val))) - return nil, errors.New("text-to-image container returned 500") + return nil, errors.New("text-to-image container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil @@ -146,7 +146,7 @@ func (w *Worker) ImageToImage(ctx context.Context, req GenImageToImageMultipartR return nil, err } slog.Error("image-to-image container returned 401", slog.String("err", string(val))) - return nil, errors.New("image-to-image container returned 401") + return nil, errors.New("image-to-image container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON422 != nil { @@ -155,7 +155,7 @@ func (w *Worker) ImageToImage(ctx context.Context, req GenImageToImageMultipartR return nil, err } slog.Error("image-to-image container returned 422", slog.String("err", string(val))) - return nil, errors.New("image-to-image container returned 422") + return nil, errors.New("image-to-image container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -164,7 +164,7 @@ func (w *Worker) ImageToImage(ctx context.Context, req GenImageToImageMultipartR return nil, err } slog.Error("image-to-image container returned 500", slog.String("err", string(val))) - return nil, errors.New("image-to-image container returned 500") + return nil, errors.New("image-to-image container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil @@ -194,7 +194,7 @@ func (w *Worker) ImageToVideo(ctx context.Context, req GenImageToVideoMultipartR return nil, err } slog.Error("image-to-video container returned 400", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 400") + return nil, errors.New("image-to-video container returned 400: " + resp.JSON400.Detail.Msg) } if resp.JSON401 != nil { @@ -203,7 +203,7 @@ func (w *Worker) ImageToVideo(ctx context.Context, req GenImageToVideoMultipartR return nil, err } slog.Error("image-to-video container returned 401", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 401") + return nil, errors.New("image-to-video container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON422 != nil { @@ -212,7 +212,7 @@ func (w *Worker) ImageToVideo(ctx context.Context, req GenImageToVideoMultipartR return nil, err } slog.Error("image-to-video container returned 422", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 422") + return nil, errors.New("image-to-video container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -221,7 +221,7 @@ func (w *Worker) ImageToVideo(ctx context.Context, req GenImageToVideoMultipartR return nil, err } slog.Error("image-to-video container returned 500", slog.String("err", string(val))) - return nil, errors.New("image-to-video container returned 500") + return nil, errors.New("image-to-video container returned 500: " + resp.JSON500.Detail.Msg) } if resp.JSON200 == nil { @@ -256,7 +256,7 @@ func (w *Worker) Upscale(ctx context.Context, req GenUpscaleMultipartRequestBody return nil, err } slog.Error("upscale container returned 400", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 400") + return nil, errors.New("upscale container returned 400: " + resp.JSON400.Detail.Msg) } if resp.JSON401 != nil { @@ -265,7 +265,7 @@ func (w *Worker) Upscale(ctx context.Context, req GenUpscaleMultipartRequestBody return nil, err } slog.Error("upscale container returned 401", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 401") + return nil, errors.New("upscale container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON422 != nil { @@ -274,7 +274,7 @@ func (w *Worker) Upscale(ctx context.Context, req GenUpscaleMultipartRequestBody return nil, err } slog.Error("upscale container returned 422", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 422") + return nil, errors.New("upscale container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -283,7 +283,7 @@ func (w *Worker) Upscale(ctx context.Context, req GenUpscaleMultipartRequestBody return nil, err } slog.Error("upscale container returned 500", slog.String("err", string(val))) - return nil, errors.New("upscale container returned 500") + return nil, errors.New("upscale container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil @@ -313,7 +313,7 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return nil, err } slog.Error("audio-to-text container returned 400", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 400") + return nil, errors.New("audio-to-text container returned 400: " + resp.JSON400.Detail.Msg) } if resp.JSON401 != nil { @@ -322,11 +322,11 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return nil, err } slog.Error("audio-to-text container returned 401", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 401") + return nil, errors.New("audio-to-text container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON413 != nil { - msg := "audio-to-text container returned 413 file too large; max file size is 50MB" + msg := "audio-to-text container returned 413: file too large; max file size is 50MB" slog.Error("audio-to-text container returned 413", slog.String("err", string(msg))) return nil, errors.New(msg) } @@ -337,7 +337,7 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return nil, err } slog.Error("audio-to-text container returned 415", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 415") + return nil, errors.New("audio-to-text container returned 415: " + resp.JSON415.Detail.Msg) } if resp.JSON422 != nil { @@ -346,7 +346,7 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return nil, err } slog.Error("audio-to-text container returned 422", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 422") + return nil, errors.New("audio-to-text container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -355,7 +355,7 @@ func (w *Worker) AudioToText(ctx context.Context, req GenAudioToTextMultipartReq return nil, err } slog.Error("audio-to-text container returned 500", slog.String("err", string(val))) - return nil, errors.New("audio-to-text container returned 500") + return nil, errors.New("audio-to-text container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil @@ -420,7 +420,7 @@ func (w *Worker) SegmentAnything2(ctx context.Context, req GenSegmentAnything2Mu return nil, err } slog.Error("segment anything 2 container returned 400", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 400") + return nil, errors.New("segment anything 2 container returned 400: " + resp.JSON400.Detail.Msg) } if resp.JSON401 != nil { @@ -429,7 +429,7 @@ func (w *Worker) SegmentAnything2(ctx context.Context, req GenSegmentAnything2Mu return nil, err } slog.Error("segment anything 2 container returned 401", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 401") + return nil, errors.New("segment anything 2 container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON422 != nil { @@ -438,7 +438,7 @@ func (w *Worker) SegmentAnything2(ctx context.Context, req GenSegmentAnything2Mu return nil, err } slog.Error("segment anything 2 container returned 422", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 422") + return nil, errors.New("segment anything 2 container returned 422: " + string(val)) } if resp.JSON500 != nil { @@ -447,7 +447,7 @@ func (w *Worker) SegmentAnything2(ctx context.Context, req GenSegmentAnything2Mu return nil, err } slog.Error("segment anything 2 container returned 500", slog.String("err", string(val))) - return nil, errors.New("segment anything 2 container returned 500") + return nil, errors.New("segment anything 2 container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil @@ -544,7 +544,7 @@ func (w *Worker) handleNonStreamingResponse(c *RunnerContainer, resp *GenLLMResp return nil, err } slog.Error("LLM container returned 400", slog.String("err", string(val))) - return nil, errors.New("LLM container returned 400") + return nil, errors.New("LLM container returned 400: " + resp.JSON400.Detail.Msg) } if resp.JSON401 != nil { @@ -553,7 +553,7 @@ func (w *Worker) handleNonStreamingResponse(c *RunnerContainer, resp *GenLLMResp return nil, err } slog.Error("LLM container returned 401", slog.String("err", string(val))) - return nil, errors.New("LLM container returned 401") + return nil, errors.New("LLM container returned 401: " + resp.JSON401.Detail.Msg) } if resp.JSON500 != nil { @@ -562,7 +562,7 @@ func (w *Worker) handleNonStreamingResponse(c *RunnerContainer, resp *GenLLMResp return nil, err } slog.Error("LLM container returned 500", slog.String("err", string(val))) - return nil, errors.New("LLM container returned 500") + return nil, errors.New("LLM container returned 500: " + resp.JSON500.Detail.Msg) } return resp.JSON200, nil From f6e715ef795bef5939591496eb5f8abb4d28a9e2 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 12:02:30 +0200 Subject: [PATCH 09/14] refactor: apply black formatter This commit applies the black formatter to the PR files. --- runner/app/pipelines/segment_anything_2.py | 2 +- runner/app/pipelines/upscale.py | 2 +- runner/app/routes/image_to_image.py | 2 +- runner/app/routes/image_to_video.py | 2 +- runner/app/routes/text_to_image.py | 2 +- runner/app/routes/upscale.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/runner/app/pipelines/segment_anything_2.py b/runner/app/pipelines/segment_anything_2.py index 8278fc28..cd5c852c 100644 --- a/runner/app/pipelines/segment_anything_2.py +++ b/runner/app/pipelines/segment_anything_2.py @@ -3,7 +3,7 @@ import PIL from app.pipelines.base import Pipeline -from app.pipelines.utils import get_torch_device, get_model_dir +from app.pipelines.utils import get_model_dir, get_torch_device from app.utils.errors import InferenceError from PIL import ImageFile from sam2.sam2_image_predictor import SAM2ImagePredictor diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index c82e5508..97f888ff 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -4,7 +4,6 @@ import PIL import torch -from app.utils.errors import InferenceError from app.pipelines.base import Pipeline from app.pipelines.utils import ( SafetyChecker, @@ -13,6 +12,7 @@ is_lightning_model, is_turbo_model, ) +from app.utils.errors import InferenceError from diffusers import StableDiffusionUpscalePipeline from huggingface_hub import file_download from PIL import ImageFile diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index ba8fbc9c..c05e1a2e 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -6,9 +6,9 @@ import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline +from app.pipelines.utils.utils import LoraLoadingError from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url from app.utils.errors import InferenceError -from app.pipelines.utils.utils import LoraLoadingError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index a9185616..8b2b71d7 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -3,8 +3,8 @@ import random from typing import Annotated -from app.dependencies import get_pipeline import torch +from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url from app.utils.errors import InferenceError diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 23a09b77..7eebee96 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -6,9 +6,9 @@ import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline +from app.pipelines.utils.utils import LoraLoadingError from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url from app.utils.errors import InferenceError -from app.pipelines.utils.utils import LoraLoadingError from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 412bcce9..efb4ccaf 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -4,10 +4,10 @@ from typing import Annotated import torch -from app.utils.errors import InferenceError from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer From b9d025cb9dec0b796f86943f3197c6116b9e82a2 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 12:15:45 +0200 Subject: [PATCH 10/14] refactor: improve error string handling This commit removes the redundant str call. --- runner/app/routes/audio_to_text.py | 2 +- runner/app/routes/image_to_image.py | 2 +- runner/app/routes/image_to_video.py | 2 +- runner/app/routes/segment_anything_2.py | 2 +- runner/app/routes/text_to_image.py | 2 +- runner/app/routes/upscale.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index dc1133d9..bfa55a2e 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -119,5 +119,5 @@ async def audio_to_text( try: return pipeline(audio=audio) except Exception as e: - logger.error(f"AudioToText pipeline error: {str(e)}") + logger.error(f"AudioToText pipeline error: {e}") return handle_pipeline_error(e) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index c05e1a2e..e3b390b6 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -211,7 +211,7 @@ async def image_to_image( num_inference_steps=num_inference_steps, ) except Exception as e: - logger.error(f"ImageToImagePipeline pipeline error: {str(e)}") + logger.error(f"ImageToImagePipeline pipeline error: {e}") return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_checks) diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 8b2b71d7..14a7f9d4 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -182,7 +182,7 @@ async def image_to_video( seed=seed, ) except Exception as e: - logger.error(f"ImageToVideo pipeline error: {str(e)}") + logger.error(f"ImageToVideo pipeline error: {e}") return handle_pipeline_error(e) output_frames = [] diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index 9640de6b..01e18e20 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -193,7 +193,7 @@ async def segment_anything_2( normalize_coords=normalize_coords, ) except Exception as e: - logger.error(f"SegmentAnything2 pipeline error: {str(e)}") + logger.error(f"SegmentAnything2 pipeline error: {e}") return handle_pipeline_error(e) # Return masks sorted by descending score as string. diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 7eebee96..e29ce16b 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -203,7 +203,7 @@ async def text_to_image( try: imgs, nsfw_check = pipeline(**kwargs) except Exception as e: - logger.error(f"TextToImage pipeline error: {str(e)}") + logger.error(f"TextToImage pipeline error: {e}") return handle_pipeline_error(e) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index efb4ccaf..2da1ec58 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -146,7 +146,7 @@ async def upscale( seed=seed, ) except Exception as e: - logger.error(f"TextToImage pipeline error: {str(e)}") + logger.error(f"TextToImage pipeline error: {e}") return handle_pipeline_error(e) seeds = [seed] From ae8df749e107cf2e13ab8fb86f5a884675370e1c Mon Sep 17 00:00:00 2001 From: gioelecerati Date: Mon, 14 Oct 2024 16:07:51 +0200 Subject: [PATCH 11/14] feat(runner): add global pipeline error handling logic This commit introduces a global error handling configuration and function to streamline error management across different pipelines. The new `handle_pipeline_exception` function centralizes error handling logic, allowing pipelines to override it if necessary. This change reduces code duplication and improves maintainability. Co-authored-by: rickstaa --- runner/app/routes/audio_to_text.py | 65 ++++++++++------------ runner/app/routes/image_to_image.py | 54 ++++++++----------- runner/app/routes/image_to_video.py | 52 ++++++++---------- runner/app/routes/segment_anything_2.py | 49 ++++++++--------- runner/app/routes/text_to_image.py | 53 ++++++++---------- runner/app/routes/upscale.py | 51 ++++++++---------- runner/app/routes/utils.py | 71 ++++++++++++++++++++++++- 7 files changed, 208 insertions(+), 187 deletions(-) diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index bfa55a2e..c1aa23c3 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -1,13 +1,17 @@ import logging import os -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.pipelines.utils.audio import AudioConversionError -from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + TextResponse, + file_exceeds_max_size, + http_error, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -16,6 +20,20 @@ logger = logging.getLogger(__name__) +# Pipeline specific error handling configuration. +AUDIO_FORMAT_ERROR_MESSAGE = "Unsupported audio format or malformed file." +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "AudioConversionError": ( + AUDIO_FORMAT_ERROR_MESSAGE, + status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + ), + "Soundfile is either not in the correct format or is malformed": ( + AUDIO_FORMAT_ERROR_MESSAGE, + status.HTTP_415_UNSUPPORTED_MEDIA_TYPE, + ), +} + RESPONSES = { status.HTTP_200_OK: { "content": { @@ -34,37 +52,6 @@ } -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during audio pipeline processing. - - Args: - e: The exception raised during audio processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if "Soundfile is either not in the correct format or is malformed" in str( - e - ) or isinstance(e, AudioConversionError): - status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE - error_message = "Unsupported audio format or malformed file." - elif isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = "Out of memory error." - torch.cuda.empty_cache() - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Audio-to-text pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), - ) - - @router.post( "/audio-to-text", response_model=TextResponse, @@ -119,5 +106,11 @@ async def audio_to_text( try: return pipeline(audio=audio) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"AudioToText pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Audio-to-text pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index e3b390b6..76a29b5e 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -1,14 +1,18 @@ import logging import os import random -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.pipelines.utils.utils import LoraLoadingError -from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + ImageResponse, + http_error, + image_to_data_url, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -21,34 +25,14 @@ logger = logging.getLogger(__name__) -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during image-to-image pipeline processing. - - Args: - e: The exception raised during image-to-image processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = "Out of memory error. Try reducing input image resolution." - torch.cuda.empty_cache() - elif isinstance(e, LoraLoadingError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Image-to-image pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing input image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, ) - +} RESPONSES = { status.HTTP_200_OK: { @@ -211,8 +195,14 @@ async def image_to_image( num_inference_steps=num_inference_steps, ) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"ImageToImagePipeline pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Image-to-image pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) images.extend(imgs) has_nsfw_concept.extend(nsfw_checks) diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index 14a7f9d4..88e2cb28 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -1,13 +1,18 @@ import logging import os import random -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + VideoResponse, + http_error, + image_to_data_url, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -19,33 +24,14 @@ logger = logging.getLogger(__name__) - -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during image-to-video pipeline processing. - - Args: - e: The exception raised during image-to-video processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = ( - "Out of memory error. Try reducing input or output video resolution." - ) - torch.cuda.empty_cache() - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Image-to-video pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing input or output video resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, ) +} RESPONSES = { @@ -182,8 +168,14 @@ async def image_to_video( seed=seed, ) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"ImageToVideo pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Image-to-video pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) output_frames = [] for frames in batch_frames: diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index 01e18e20..787af5b9 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -1,13 +1,18 @@ import logging import os -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import numpy as np import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + MasksResponse, + http_error, + json_str_to_np_array, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -20,30 +25,14 @@ logger = logging.getLogger(__name__) -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during segment-anything-2 pipeline processing. - - Args: - e: The exception raised during segment-anything-2 processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = "Out of memory error. Try reducing input image resolution." - torch.cuda.empty_cache() - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Segment-anything-2 pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing input image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, ) +} RESPONSES = { @@ -193,8 +182,14 @@ async def segment_anything_2( normalize_coords=normalize_coords, ) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"SegmentAnything2 pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Segment-anything-2 pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) # Return masks sorted by descending score as string. sorted_ind = np.argsort(scores)[::-1] diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index e29ce16b..01058197 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -1,14 +1,18 @@ import logging import os import random -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.pipelines.utils.utils import LoraLoadingError -from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + ImageResponse, + http_error, + image_to_data_url, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -19,33 +23,14 @@ logger = logging.getLogger(__name__) -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during text-to-image pipeline processing. - - Args: - e: The exception raised during text-to-image processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = "Out of memory error. Try reducing output image resolution." - torch.cuda.empty_cache() - elif isinstance(e, LoraLoadingError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Text-to-image pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing output image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, ) +} class TextToImageParams(BaseModel): @@ -203,8 +188,14 @@ async def text_to_image( try: imgs, nsfw_check = pipeline(**kwargs) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"TextToImage pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Text-to-image pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 2da1ec58..706dce4c 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -1,13 +1,18 @@ import logging import os import random -from typing import Annotated +from typing import Annotated, Dict, Tuple, Union import torch from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url -from app.utils.errors import InferenceError +from app.routes.utils import ( + HTTPError, + ImageResponse, + http_error, + image_to_data_url, + handle_pipeline_exception, +) from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -19,32 +24,14 @@ logger = logging.getLogger(__name__) - -def handle_pipeline_error(e: Exception) -> JSONResponse: - """Handles exceptions raised during upscale pipeline processing. - - Args: - e: The exception raised during upscale processing. - - Returns: - A JSONResponse with the appropriate error message and status code. - """ - if isinstance(e, torch.cuda.OutOfMemoryError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = "Out of memory error. Try reducing input image resolution." - torch.cuda.empty_cache() - elif isinstance(e, InferenceError): - status_code = status.HTTP_400_BAD_REQUEST - error_message = str(e) - else: - status_code = status.HTTP_500_INTERNAL_SERVER_ERROR - error_message = "Upscale pipeline error." - - return JSONResponse( - status_code=status_code, - content=http_error(error_message), +# Pipeline specific error handling configuration. +PIPELINE_ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "OutOfMemoryError": ( + "Out of memory error. Try reducing input image resolution.", + status.HTTP_500_INTERNAL_SERVER_ERROR, ) - +} RESPONSES = { status.HTTP_200_OK: { @@ -146,8 +133,14 @@ async def upscale( seed=seed, ) except Exception as e: + if isinstance(e, torch.cuda.OutOfMemoryError): + torch.cuda.empty_cache() logger.error(f"TextToImage pipeline error: {e}") - return handle_pipeline_error(e) + return handle_pipeline_exception( + e, + default_error_message="Upscale pipeline error.", + custom_error_config=PIPELINE_ERROR_CONFIG, + ) seeds = [seed] diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 85c5cd12..a240ff85 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -2,10 +2,11 @@ import io import json import os -from typing import List, Optional +from typing import Dict, List, Optional, Tuple, Union import numpy as np -from fastapi import UploadFile +from fastapi import UploadFile, status +from fastapi.responses import JSONResponse from PIL import Image from pydantic import BaseModel, Field @@ -165,3 +166,69 @@ def json_str_to_np_array( error_message += f": {e}" raise ValueError(error_message) return None + + +# Global error handling configuration. +ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { + # Specific error types. + "LoraLoadingError": (None, status.HTTP_400_BAD_REQUEST), + "InferenceError": (None, status.HTTP_400_BAD_REQUEST), + "ValueError": ("Pipeline error.", status.HTTP_400_BAD_REQUEST), + "OutOfMemoryError": ("GPU out of memory.", status.HTTP_500_INTERNAL_SERVER_ERROR), + # General error patterns. + "out of memory": ("Out of memory.", status.HTTP_500_INTERNAL_SERVER_ERROR), + "CUDA out of memory": ("GPU out of memory.", status.HTTP_500_INTERNAL_SERVER_ERROR), +} + + +def handle_pipeline_exception( + e: object, + default_error_message: Union[str, Dict[str, object]] = "Pipeline error.", + default_status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, + custom_error_config: Dict[str, Tuple[str, int]] = None, +) -> JSONResponse: + """Handles pipeline exceptions by returning a JSON response with the appropriate + error message and status code. + + Args: + e (object): The exception to handle. Can be any type of object. + default_error_message (Union[str, Dict[str, Any]]): The default error message + or content dictionary. Default will be used if no specific error type is + matched. + default_status_code (int): The default status code to use if no specific error + type is matched. Defaults to HTTP_500_INTERNAL_SERVER_ERROR. + custom_error_config (Dict[str, Tuple[str, int]]): Custom error configuration + to override the application error configuration. + + Returns: + JSONResponse: The JSON response with appropriate status code and error message. + """ + error_config = ERROR_CONFIG.copy() + + # Update error_config with custom_error_config if provided. + if custom_error_config: + error_config.update(custom_error_config) + + error_message = default_error_message + status_code = default_status_code + + error_type = type(e).__name__ + if error_type in error_config: + message, status_code = error_config[error_type] + error_message = str(e) if message is None or message == "" else message + else: + for error_pattern, (message, code) in error_config.items(): + if error_pattern.lower() in str(e).lower(): + error_message = str(e) if message is None or message == "" else message + status_code = code + break + + if isinstance(error_message, str): + content = http_error(error_message) + else: + content = error_message + + return JSONResponse( + status_code=status_code, + content=content, + ) From 870c2f079af2a9f03f56aa2c01e7b87b9ea143b7 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 16:25:04 +0200 Subject: [PATCH 12/14] refactor(runner): change default error message behavoir This commit ensures that pipelines can overwrite the default error message when the Global error configuration contains a empty string. --- runner/app/routes/utils.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index a240ff85..2eeaef25 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -169,11 +169,12 @@ def json_str_to_np_array( # Global error handling configuration. +# NOTE: "" for default message, None for exception message. ERROR_CONFIG: Dict[str, Tuple[Union[str, None], int]] = { # Specific error types. "LoraLoadingError": (None, status.HTTP_400_BAD_REQUEST), "InferenceError": (None, status.HTTP_400_BAD_REQUEST), - "ValueError": ("Pipeline error.", status.HTTP_400_BAD_REQUEST), + "ValueError": ("", status.HTTP_400_BAD_REQUEST), "OutOfMemoryError": ("GPU out of memory.", status.HTTP_500_INTERNAL_SERVER_ERROR), # General error patterns. "out of memory": ("Out of memory.", status.HTTP_500_INTERNAL_SERVER_ERROR), @@ -215,14 +216,17 @@ def handle_pipeline_exception( error_type = type(e).__name__ if error_type in error_config: message, status_code = error_config[error_type] - error_message = str(e) if message is None or message == "" else message + error_message = str(e) if message is None else message else: for error_pattern, (message, code) in error_config.items(): if error_pattern.lower() in str(e).lower(): - error_message = str(e) if message is None or message == "" else message + error_message = str(e) if message is None else message status_code = code break + if error_message == "": + error_message = default_error_message + if isinstance(error_message, str): content = http_error(error_message) else: From 6702956074a009230769f76b5163b349553c4adc Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 21:17:09 +0200 Subject: [PATCH 13/14] test(runner): add handle_pipeline_exception test This commit adds a test for the 'handle_pipeline_exception' route utility function. It also fixes some errors into that function. --- runner/app/pipelines/utils/__init__.py | 2 + runner/app/routes/utils.py | 37 +++++----- runner/tests/__init__.py | 0 runner/tests/routes/__init__.py | 0 runner/tests/routes/test_utils.py | 99 ++++++++++++++++++++++++++ 5 files changed, 118 insertions(+), 20 deletions(-) create mode 100644 runner/tests/__init__.py create mode 100644 runner/tests/routes/__init__.py create mode 100644 runner/tests/routes/test_utils.py diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index 872fd313..99e06686 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -4,12 +4,14 @@ from app.pipelines.utils.utils import ( LoraLoader, + LoraLoadingError, SafetyChecker, get_model_dir, get_model_path, get_torch_device, is_lightning_model, is_turbo_model, + is_numeric, split_prompt, validate_torch_device, ) diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 2eeaef25..32b55062 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -192,45 +192,42 @@ def handle_pipeline_exception( error message and status code. Args: - e (object): The exception to handle. Can be any type of object. - default_error_message (Union[str, Dict[str, Any]]): The default error message - or content dictionary. Default will be used if no specific error type is - matched. - default_status_code (int): The default status code to use if no specific error - type is matched. Defaults to HTTP_500_INTERNAL_SERVER_ERROR. - custom_error_config (Dict[str, Tuple[str, int]]): Custom error configuration - to override the application error configuration. + e(int): The exception to handle. Can be any type of object. + default_error_message: The default error message or content dictionary. Default + will be used if no specific error type ismatched. + default_status_code: The default status code to use if no specific error type is + matched. Defaults to HTTP_500_INTERNAL_SERVER_ERROR. + custom_error_config: Custom error configuration to override the application + error configuration. Returns: - JSONResponse: The JSON response with appropriate status code and error message. + The JSON response with appropriate status code and error message. """ error_config = ERROR_CONFIG.copy() - - # Update error_config with custom_error_config if provided. if custom_error_config: error_config.update(custom_error_config) - error_message = default_error_message status_code = default_status_code + error_message = default_error_message error_type = type(e).__name__ if error_type in error_config: - message, status_code = error_config[error_type] - error_message = str(e) if message is None else message + error_message, status_code = error_config[error_type] else: for error_pattern, (message, code) in error_config.items(): if error_pattern.lower() in str(e).lower(): - error_message = str(e) if message is None else message status_code = code + error_message = message break - if error_message == "": + if error_message is None: + error_message = f"{e}." + elif error_message == "": error_message = default_error_message - if isinstance(error_message, str): - content = http_error(error_message) - else: - content = error_message + content = ( + http_error(error_message) if isinstance(error_message, str) else error_message + ) return JSONResponse( status_code=status_code, diff --git a/runner/tests/__init__.py b/runner/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runner/tests/routes/__init__.py b/runner/tests/routes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runner/tests/routes/test_utils.py b/runner/tests/routes/test_utils.py new file mode 100644 index 00000000..51bd052d --- /dev/null +++ b/runner/tests/routes/test_utils.py @@ -0,0 +1,99 @@ +import pytest +from app.routes.utils import handle_pipeline_exception +from app.pipelines.utils import LoraLoadingError +import torch +from fastapi import status +from fastapi.responses import JSONResponse +import json + + +class TestHandlePipelineException: + """Tests for the handle_pipeline_exception function.""" + + @staticmethod + def parse_response(response: JSONResponse): + """Parses the JSON response body from a FastAPI JSONResponse object.""" + return json.loads(response.body) + + @pytest.mark.parametrize( + "exception, expected_status, expected_message, description", + [ + ( + Exception("Unknown error"), + status.HTTP_500_INTERNAL_SERVER_ERROR, + "Pipeline error.", + "Returns default message and status code for unknown error.", + ), + ( + torch.cuda.OutOfMemoryError("Some Message"), + status.HTTP_500_INTERNAL_SERVER_ERROR, + "GPU out of memory.", + "Returns global message and status code for type match.", + ), + ( + Exception("CUDA out of memory"), + status.HTTP_500_INTERNAL_SERVER_ERROR, + "Out of memory.", + "Returns global message and status code for pattern match.", + ), + ( + LoraLoadingError("A custom error message"), + status.HTTP_400_BAD_REQUEST, + "A custom error message.", + "Forwards exception message if configured with None.", + ), + ( + ValueError("A custom error message"), + status.HTTP_400_BAD_REQUEST, + "Pipeline error.", + "Returns default message if configured with ''.", + ), + ], + ) + def test_handle_pipeline_exception( + self, exception, expected_status, expected_message, description + ): + """Test that the handle_pipeline_exception function returns the correct status + code and error message for different types of exceptions. + """ + response = handle_pipeline_exception(exception) + response_body = self.parse_response(response) + assert response.status_code == expected_status, f"Failed: {description}" + assert ( + response_body["detail"]["msg"] == expected_message + ), f"Failed: {description}" + + def test_handle_pipeline_exception_custom_default_message(self): + """Test that a custom default error message is used when provided.""" + exception = ValueError("Some value error") + response = handle_pipeline_exception( + exception, default_error_message="A custom error message." + ) + response_body = self.parse_response(response) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response_body["detail"]["msg"] == "A custom error message." + + def test_handle_pipeline_exception_custom_status_code(self): + """Test that a custom default status code is used when provided.""" + exception = Exception("Some value error") + response = handle_pipeline_exception( + exception, default_status_code=status.HTTP_404_NOT_FOUND + ) + response_body = self.parse_response(response) + assert response.status_code == status.HTTP_404_NOT_FOUND + assert response_body["detail"]["msg"] == "Pipeline error." + + def test_handle_pipeline_exception_custom_error_config(self): + """Test that custom error configuration overrides the global error + configuration, which prints the exception message. + """ + exception = LoraLoadingError("Some error message.") + response = handle_pipeline_exception( + exception, + custom_error_config={ + "LoraLoadingError": ("Custom message.", status.HTTP_400_BAD_REQUEST) + }, + ) + response_body = self.parse_response(response) + assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response_body["detail"]["msg"] == "Custom message." From e3798408dcb184a69f7cfc375b043160cdd0dd86 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 14 Oct 2024 21:20:42 +0200 Subject: [PATCH 14/14] fixup! test(runner): add handle_pipeline_exception test --- runner/app/routes/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runner/app/routes/utils.py b/runner/app/routes/utils.py index 32b55062..c7717783 100644 --- a/runner/app/routes/utils.py +++ b/runner/app/routes/utils.py @@ -207,8 +207,8 @@ def handle_pipeline_exception( if custom_error_config: error_config.update(custom_error_config) - status_code = default_status_code error_message = default_error_message + status_code = default_status_code error_type = type(e).__name__ if error_type in error_config: