From f08aced13ae886ec91db44df48ee6298013a7297 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 4 Sep 2024 15:36:14 +0200 Subject: [PATCH] refactor(runner): add InferenceError to all pipelines This commit adds the inference error logic from the SAM2 pipeline to all pipelines so users are given a warning when they supply wrong arguments. --- runner/app/pipelines/audio_to_text.py | 8 +++++++- runner/app/pipelines/image_to_image.py | 12 ++++++++---- runner/app/pipelines/image_to_video.py | 8 +++++++- runner/app/pipelines/optim/sfast.py | 2 +- runner/app/pipelines/segment_anything_2.py | 4 ++-- runner/app/pipelines/text_to_image.py | 12 ++++++++---- runner/app/routes/audio_to_text.py | 6 +++++- runner/app/routes/image_to_image.py | 13 ++++++++++--- runner/app/routes/image_to_video.py | 9 ++++++++- runner/app/routes/segment_anything_2.py | 9 ++------- runner/app/routes/text_to_image.py | 17 ++++++++++++----- runner/app/routes/upscale.py | 2 +- runner/app/routes/{util.py => utils.py} | 16 ---------------- runner/app/utils/__init__.py | 0 runner/app/utils/errors.py | 17 +++++++++++++++++ runner/gen_openapi.py | 4 ++-- 16 files changed, 90 insertions(+), 49 deletions(-) rename runner/app/routes/{util.py => utils.py} (89%) create mode 100644 runner/app/utils/__init__.py create mode 100644 runner/app/utils/errors.py diff --git a/runner/app/pipelines/audio_to_text.py b/runner/app/pipelines/audio_to_text.py index 4c75ddfb..888f70ed 100644 --- a/runner/app/pipelines/audio_to_text.py +++ b/runner/app/pipelines/audio_to_text.py @@ -6,6 +6,7 @@ from app.pipelines.base import Pipeline from app.pipelines.utils import get_model_dir, get_torch_device from app.pipelines.utils.audio import AudioConverter +from app.utils.errors import InferenceError from fastapi import File, UploadFile from huggingface_hub import file_download from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline @@ -76,7 +77,12 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]: converted_bytes = audio_converter.convert(audio, "mp3") audio_converter.write_bytes_to_file(converted_bytes, audio) - return self.tm(audio.file.read(), **kwargs) + try: + outputs = self.tm(audio.file.read(), **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) + + return outputs def __str__(self) -> str: return f"AudioToTextPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index 9e20ff03..9aeae4c9 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -13,6 +13,7 @@ is_lightning_model, is_turbo_model, ) +from app.utils.errors import InferenceError from diffusers import ( AutoPipelineForImage2Image, EulerAncestralDiscreteScheduler, @@ -222,14 +223,17 @@ def __call__( # Default to 8step kwargs["num_inference_steps"] = 8 - output = self.ldm(prompt, image=image, **kwargs) + try: + outputs = self.ldm(prompt, image=image, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) if safety_check: - _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) else: - has_nsfw_concept = [None] * len(output.images) + has_nsfw_concept = [None] * len(outputs.images) - return output.images, has_nsfw_concept + return outputs.images, has_nsfw_concept def __str__(self) -> str: return f"ImageToImagePipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index f605cb2f..680800a5 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -7,6 +7,7 @@ import torch from app.pipelines.base import Pipeline from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device +from app.utils.errors import InferenceError from diffusers import StableVideoDiffusionPipeline from huggingface_hub import file_download from PIL import ImageFile @@ -134,7 +135,12 @@ def __call__( else: has_nsfw_concept = [None] - return self.ldm(image, **kwargs).frames, has_nsfw_concept + try: + outputs = self.ldm(image, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) + + return outputs.frames, has_nsfw_concept def __str__(self) -> str: return f"ImageToVideoPipeline model_id={self.model_id}" diff --git a/runner/app/pipelines/optim/sfast.py b/runner/app/pipelines/optim/sfast.py index 166e014e..15598562 100644 --- a/runner/app/pipelines/optim/sfast.py +++ b/runner/app/pipelines/optim/sfast.py @@ -31,7 +31,7 @@ def compile_model(pipe): except ImportError: logger.info("xformers not installed, skip") try: - import triton # noqa: F401 + import triton # noqa: F401 config.enable_triton = True except ImportError: diff --git a/runner/app/pipelines/segment_anything_2.py b/runner/app/pipelines/segment_anything_2.py index 64c4080d..cd5c852c 100644 --- a/runner/app/pipelines/segment_anything_2.py +++ b/runner/app/pipelines/segment_anything_2.py @@ -3,8 +3,8 @@ import PIL from app.pipelines.base import Pipeline -from app.pipelines.utils import get_torch_device, get_model_dir -from app.routes.util import InferenceError +from app.pipelines.utils import get_model_dir, get_torch_device +from app.utils.errors import InferenceError from PIL import ImageFile from sam2.sam2_image_predictor import SAM2ImagePredictor diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 10e4f485..a760aade 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -14,6 +14,7 @@ is_turbo_model, split_prompt, ) +from app.utils.errors import InferenceError from diffusers import ( AutoPipelineForText2Image, EulerDiscreteScheduler, @@ -263,14 +264,17 @@ def __call__( ) kwargs.update(neg_prompts) - output = self.ldm(prompt=prompt, **kwargs) + try: + outputs = self.ldm(prompt=prompt, **kwargs) + except Exception as e: + raise InferenceError(original_exception=e) if safety_check: - _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) + _, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images) else: - has_nsfw_concept = [None] * len(output.images) + has_nsfw_concept = [None] * len(outputs.images) - return output.images, has_nsfw_concept + return outputs.images, has_nsfw_concept def __str__(self) -> str: return f"TextToImagePipeline model_id={self.model_id}" diff --git a/runner/app/routes/audio_to_text.py b/runner/app/routes/audio_to_text.py index 7396e8b0..c92109d3 100644 --- a/runner/app/routes/audio_to_text.py +++ b/runner/app/routes/audio_to_text.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline from app.pipelines.utils.audio import AudioConversionError -from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error +from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -37,6 +38,9 @@ def handle_pipeline_error(e: Exception) -> JSONResponse: ) or isinstance(e, AudioConversionError): status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE error_message = "Unsupported audio format or malformed file." + elif isinstance(e, InferenceError): + status_code = status.HTTP_400_BAD_REQUEST + error_message = str(e) else: status_code = status.HTTP_500_INTERNAL_SERVER_ERROR error_message = "Internal server error during audio processing." diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 63fc3b0f..ab730181 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -154,15 +155,21 @@ async def image_to_image( num_images_per_prompt=1, num_inference_steps=num_inference_steps, ) - images.extend(imgs) - has_nsfw_concept.extend(nsfw_checks) except Exception as e: logger.error(f"ImageToImagePipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("ImageToImagePipeline error"), ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) # TODO: Return None once Go codegen tool supports optional properties # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 diff --git a/runner/app/routes/image_to_video.py b/runner/app/routes/image_to_video.py index a7c9350d..9fb7b7b0 100644 --- a/runner/app/routes/image_to_video.py +++ b/runner/app/routes/image_to_video.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -140,6 +141,12 @@ async def image_to_video( except Exception as e: logger.error(f"ImageToVideoPipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("ImageToVideoPipeline error"), diff --git a/runner/app/routes/segment_anything_2.py b/runner/app/routes/segment_anything_2.py index 70436432..d3ef6bec 100644 --- a/runner/app/routes/segment_anything_2.py +++ b/runner/app/routes/segment_anything_2.py @@ -5,13 +5,8 @@ import numpy as np from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import ( - HTTPError, - InferenceError, - MasksResponse, - http_error, - json_str_to_np_array, -) +from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index c72dae53..00f57a54 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -5,7 +5,8 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url +from app.utils.errors import InferenceError from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer @@ -142,19 +143,25 @@ async def text_to_image( has_nsfw_concept = [] params.num_images_per_prompt = 1 for seed in seeds: + params.seed = seed + kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"} try: - params.seed = seed - kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"} imgs, nsfw_check = pipeline(**kwargs) - images.extend(imgs) - has_nsfw_concept.extend(nsfw_check) except Exception as e: logger.error(f"TextToImagePipeline error: {e}") logger.exception(e) + if isinstance(e, InferenceError): + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content=http_error(str(e)), + ) + return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content=http_error("TextToImagePipeline error"), ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_check) # TODO: Return None once Go codegen tool supports optional properties # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 diff --git a/runner/app/routes/upscale.py b/runner/app/routes/upscale.py index 5aca1073..0e4788a7 100644 --- a/runner/app/routes/upscale.py +++ b/runner/app/routes/upscale.py @@ -5,7 +5,7 @@ from app.dependencies import get_pipeline from app.pipelines.base import Pipeline -from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url +from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url from fastapi import APIRouter, Depends, File, Form, UploadFile, status from fastapi.responses import JSONResponse from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer diff --git a/runner/app/routes/util.py b/runner/app/routes/utils.py similarity index 89% rename from runner/app/routes/util.py rename to runner/app/routes/utils.py index 8a319e84..6b223db7 100644 --- a/runner/app/routes/util.py +++ b/runner/app/routes/utils.py @@ -70,22 +70,6 @@ class HTTPError(BaseModel): detail: APIError = Field(..., description="Detailed error information.") -class InferenceError(Exception): - """Exception raised for errors during model inference.""" - - def __init__(self, message="Error during model execution", original_exception=None): - """Initialize the exception. - - Args: - message: The error message. - original_exception: The original exception that caused the error. - """ - if original_exception: - message = f"{message}: {original_exception}" - super().__init__(message) - self.original_exception = original_exception - - def http_error(msg: str) -> HTTPError: """Create an HTTP error response with the specified message. diff --git a/runner/app/utils/__init__.py b/runner/app/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/runner/app/utils/errors.py b/runner/app/utils/errors.py new file mode 100644 index 00000000..c02c769f --- /dev/null +++ b/runner/app/utils/errors.py @@ -0,0 +1,17 @@ +"""Custom exceptions for the application.""" + + +class InferenceError(Exception): + """Exception raised for errors during model inference.""" + + def __init__(self, message="Error during model execution", original_exception=None): + """Initialize the exception. + + Args: + message: The error message. + original_exception: The original exception that caused the error. + """ + if original_exception: + message = f"{message}: {original_exception}" + super().__init__(message) + self.original_exception = original_exception diff --git a/runner/gen_openapi.py b/runner/gen_openapi.py index e0c86e56..3b4f0d25 100644 --- a/runner/gen_openapi.py +++ b/runner/gen_openapi.py @@ -1,6 +1,8 @@ import argparse import copy import json +import logging +import subprocess import yaml from app.main import app, use_route_names_as_operation_ids @@ -14,8 +16,6 @@ upscale, ) from fastapi.openapi.utils import get_openapi -import subprocess -import logging logging.basicConfig( level=logging.INFO,