Skip to content

Commit

Permalink
refactor(runner): add InferenceError to all pipelines
Browse files Browse the repository at this point in the history
This commit adds the inference error logic from the SAM2 pipeline to all
pipelines so users are given a warning when they supply wrong arguments.
  • Loading branch information
rickstaa committed Sep 4, 2024
1 parent 7a707b0 commit f08aced
Show file tree
Hide file tree
Showing 16 changed files with 90 additions and 49 deletions.
8 changes: 7 additions & 1 deletion runner/app/pipelines/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from app.utils.errors import InferenceError
from fastapi import File, UploadFile
from huggingface_hub import file_download
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
Expand Down Expand Up @@ -76,7 +77,12 @@ def __call__(self, audio: UploadFile, **kwargs) -> List[File]:
converted_bytes = audio_converter.convert(audio, "mp3")
audio_converter.write_bytes_to_file(converted_bytes, audio)

return self.tm(audio.file.read(), **kwargs)
try:
outputs = self.tm(audio.file.read(), **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs

def __str__(self) -> str:
return f"AudioToTextPipeline model_id={self.model_id}"
12 changes: 8 additions & 4 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
is_lightning_model,
is_turbo_model,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
Expand Down Expand Up @@ -222,14 +223,17 @@ def __call__(
# Default to 8step
kwargs["num_inference_steps"] = 8

output = self.ldm(prompt, image=image, **kwargs)
try:
outputs = self.ldm(prompt, image=image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"
8 changes: 7 additions & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile
Expand Down Expand Up @@ -134,7 +135,12 @@ def __call__(
else:
has_nsfw_concept = [None]

return self.ldm(image, **kwargs).frames, has_nsfw_concept
try:
outputs = self.ldm(image, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

return outputs.frames, has_nsfw_concept

def __str__(self) -> str:
return f"ImageToVideoPipeline model_id={self.model_id}"
2 changes: 1 addition & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def compile_model(pipe):
except ImportError:
logger.info("xformers not installed, skip")
try:
import triton # noqa: F401
import triton # noqa: F401

config.enable_triton = True
except ImportError:
Expand Down
4 changes: 2 additions & 2 deletions runner/app/pipelines/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import PIL
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir
from app.routes.util import InferenceError
from app.pipelines.utils import get_model_dir, get_torch_device
from app.utils.errors import InferenceError
from PIL import ImageFile
from sam2.sam2_image_predictor import SAM2ImagePredictor

Expand Down
12 changes: 8 additions & 4 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
is_turbo_model,
split_prompt,
)
from app.utils.errors import InferenceError
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
Expand Down Expand Up @@ -263,14 +264,17 @@ def __call__(
)
kwargs.update(neg_prompts)

output = self.ldm(prompt=prompt, **kwargs)
try:
outputs = self.ldm(prompt=prompt, **kwargs)
except Exception as e:
raise InferenceError(original_exception=e)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(outputs.images)
else:
has_nsfw_concept = [None] * len(output.images)
has_nsfw_concept = [None] * len(outputs.images)

return output.images, has_nsfw_concept
return outputs.images, has_nsfw_concept

def __str__(self) -> str:
return f"TextToImagePipeline model_id={self.model_id}"
6 changes: 5 additions & 1 deletion runner/app/routes/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.audio import AudioConversionError
from app.routes.util import HTTPError, TextResponse, file_exceeds_max_size, http_error
from app.routes.utils import HTTPError, TextResponse, file_exceeds_max_size, http_error
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -37,6 +38,9 @@ def handle_pipeline_error(e: Exception) -> JSONResponse:
) or isinstance(e, AudioConversionError):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
error_message = "Unsupported audio format or malformed file."
elif isinstance(e, InferenceError):
status_code = status.HTTP_400_BAD_REQUEST
error_message = str(e)
else:
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
error_message = "Internal server error during audio processing."
Expand Down
13 changes: 10 additions & 3 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -154,15 +155,21 @@ async def image_to_image(
num_images_per_prompt=1,
num_inference_steps=num_inference_steps,
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
if isinstance(e, InferenceError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToImagePipeline error"),
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)

# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
Expand Down
9 changes: 8 additions & 1 deletion runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url
from app.routes.utils import HTTPError, VideoResponse, http_error, image_to_data_url
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -140,6 +141,12 @@ async def image_to_video(
except Exception as e:
logger.error(f"ImageToVideoPipeline error: {e}")
logger.exception(e)
if isinstance(e, InferenceError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToVideoPipeline error"),
Expand Down
9 changes: 2 additions & 7 deletions runner/app/routes/segment_anything_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,8 @@
import numpy as np
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import (
HTTPError,
InferenceError,
MasksResponse,
http_error,
json_str_to_np_array,
)
from app.routes.utils import HTTPError, MasksResponse, http_error, json_str_to_np_array
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down
17 changes: 12 additions & 5 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url
from app.utils.errors import InferenceError
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down Expand Up @@ -142,19 +143,25 @@ async def text_to_image(
has_nsfw_concept = []
params.num_images_per_prompt = 1
for seed in seeds:
params.seed = seed
kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"}
try:
params.seed = seed
kwargs = {k: v for k, v in params.model_dump().items() if k != "model_id"}
imgs, nsfw_check = pipeline(**kwargs)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
if isinstance(e, InferenceError):
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)

return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("TextToImagePipeline error"),
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)

# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
Expand Down
2 changes: 1 addition & 1 deletion runner/app/routes/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from app.routes.utils import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
Expand Down
16 changes: 0 additions & 16 deletions runner/app/routes/util.py → runner/app/routes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ class HTTPError(BaseModel):
detail: APIError = Field(..., description="Detailed error information.")


class InferenceError(Exception):
"""Exception raised for errors during model inference."""

def __init__(self, message="Error during model execution", original_exception=None):
"""Initialize the exception.
Args:
message: The error message.
original_exception: The original exception that caused the error.
"""
if original_exception:
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception


def http_error(msg: str) -> HTTPError:
"""Create an HTTP error response with the specified message.
Expand Down
Empty file added runner/app/utils/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions runner/app/utils/errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Custom exceptions for the application."""


class InferenceError(Exception):
"""Exception raised for errors during model inference."""

def __init__(self, message="Error during model execution", original_exception=None):
"""Initialize the exception.
Args:
message: The error message.
original_exception: The original exception that caused the error.
"""
if original_exception:
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception
4 changes: 2 additions & 2 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import argparse
import copy
import json
import logging
import subprocess

import yaml
from app.main import app, use_route_names_as_operation_ids
Expand All @@ -14,8 +16,6 @@
upscale,
)
from fastapi.openapi.utils import get_openapi
import subprocess
import logging

logging.basicConfig(
level=logging.INFO,
Expand Down

0 comments on commit f08aced

Please sign in to comment.