Skip to content

Commit

Permalink
refactor: improve error codes and code formatting
Browse files Browse the repository at this point in the history
This commit applies several code formatting improvements and replaces
the hardcoded error codes by the error codes in the FastAPI status
module.
  • Loading branch information
rickstaa committed Jul 16, 2024
1 parent 0d03040 commit cf59a6e
Show file tree
Hide file tree
Showing 18 changed files with 274 additions and 166 deletions.
2 changes: 1 addition & 1 deletion runner/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import Request
from app.pipelines.base import Pipeline
from fastapi import Request


def get_pipeline(request: Request) -> Pipeline:
Expand Down
11 changes: 6 additions & 5 deletions runner/app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from fastapi import FastAPI
from fastapi.routing import APIRoute
from contextlib import asynccontextmanager
import os
import logging
from app.routes import health
import os
from contextlib import asynccontextmanager

from app.routes import health
from fastapi import FastAPI
from fastapi.routing import APIRoute

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -46,6 +46,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case _:
raise EnvironmentError(
Expand Down
27 changes: 13 additions & 14 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import logging
import os
from enum import Enum
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.util import (
get_torch_device,
get_model_dir,
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from enum import Enum

from diffusers import (
AutoPipelineForImage2Image,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os

from PIL import ImageFile
from safetensors.torch import load_file

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand Down
15 changes: 7 additions & 8 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker

from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os
import time
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.util import SafetyChecker, get_model_dir, get_torch_device
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down
4 changes: 3 additions & 1 deletion runner/app/pipelines/optim/deepcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
For more information, see the DeepCache project on GitHub: https://github.com/horseee/DeepCache
"""
from DeepCache import DeepCacheSDHelper

import logging

from DeepCache import DeepCacheSDHelper

logger = logging.getLogger(__name__)


Expand Down
4 changes: 3 additions & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
For more information, see the DeepCache project on GitHub: https://github.com/chengzeyi/stable-fast
"""
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig

import logging

from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile

logger = logging.getLogger(__name__)


Expand Down
21 changes: 10 additions & 11 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import logging
import os
from typing import List, Tuple, Optional
from enum import Enum
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.util import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
StableDiffusion3Pipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
StableDiffusion3Pipeline,
)
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.util import (
get_model_dir,
get_torch_device,
SafetyChecker,
is_lightning_model,
is_turbo_model,
)

logger = logging.getLogger(__name__)


Expand Down
33 changes: 16 additions & 17 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from app.pipelines.base import Pipeline
from app.pipelines.util import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model

from diffusers import (
StableDiffusionUpscalePipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os
from typing import List, Optional, Tuple

from PIL import ImageFile
from PIL import Image
from io import BytesIO
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.util import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand Down Expand Up @@ -44,8 +43,8 @@ def __init__(self, model_id: str):
kwargs["variant"] = "fp16"

self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
model_id, **kwargs
).to(torch_device)

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(self, model_id: str):
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
Expand Down
15 changes: 8 additions & 7 deletions runner/app/pipelines/util.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
import logging
import os
import numpy as np
from torch import dtype as TorchDtype
import re
from pathlib import Path
from PIL import Image
from typing import Optional

import numpy as np
import torch
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torch import dtype as TorchDtype
from transformers import CLIPFeatureExtractor
from typing import Optional
import logging
import re

logger = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""This module contains several utility functions that are used across the pipelines module."""

from app.pipelines.utils.utils import (
SafetyChecker,
get_model_dir,
get_model_path,
get_temp_file,
get_torch_device,
validate_torch_device,
is_lightning_model,
is_turbo_model,
get_temp_file,
SafetyChecker,
validate_torch_device,
)
37 changes: 21 additions & 16 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
from PIL import Image
from typing import Annotated
import logging
import random
import os
import random
from typing import Annotated

from PIL import ImageFile
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}

RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-image", response_model=ImageResponse, responses=responses)
@router.post("/image-to-image", response_model=ImageResponse, responses=RESPONSES)
@router.post(
"/image-to-image/",
response_model=ImageResponse,
responses=responses,
responses=RESPONSES,
include_in_schema=False,
)
async def image_to_image(
Expand All @@ -48,14 +52,14 @@ async def image_to_image(
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=401,
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=400,
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
Expand Down Expand Up @@ -90,7 +94,8 @@ async def image_to_image(
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToImagePipeline error"),
)

# TODO: Return None once Go codegen tool supports optional properties
Expand Down
Loading

0 comments on commit cf59a6e

Please sign in to comment.