Skip to content

Commit

Permalink
refactor: improve error codes and code formatting
Browse files Browse the repository at this point in the history
This commit applies several code formatting improvements and replaces
the hardcoded error codes by the error codes in the FastAPI status
module.
  • Loading branch information
rickstaa committed Jul 16, 2024
1 parent 9fc476e commit a48c153
Show file tree
Hide file tree
Showing 17 changed files with 266 additions and 136 deletions.
2 changes: 1 addition & 1 deletion runner/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from fastapi import Request
from app.pipelines.base import Pipeline
from fastapi import Request


def get_pipeline(request: Request) -> Pipeline:
Expand Down
11 changes: 6 additions & 5 deletions runner/app/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from fastapi import FastAPI
from fastapi.routing import APIRoute
from contextlib import asynccontextmanager
import os
import logging
from app.routes import health
import os
from contextlib import asynccontextmanager

from app.routes import health
from fastapi import FastAPI
from fastapi.routing import APIRoute

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,6 +50,7 @@ def load_pipeline(pipeline: str, model_id: str) -> any:
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
from app.pipelines.upscale import UpscalePipeline

return UpscalePipeline(model_id)
case _:
raise EnvironmentError(
Expand Down
27 changes: 13 additions & 14 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,29 @@
import logging
import os
from enum import Enum
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
get_torch_device,
get_model_dir,
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from enum import Enum

from diffusers import (
AutoPipelineForImage2Image,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os

from PIL import ImageFile
from safetensors.torch import load_file

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand Down
15 changes: 7 additions & 8 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker

from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os
import time
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import SafetyChecker, get_model_dir, get_torch_device
from diffusers import StableVideoDiffusionPipeline
from huggingface_hub import file_download
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Expand Down
4 changes: 3 additions & 1 deletion runner/app/pipelines/optim/deepcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
For more information, see the DeepCache project on GitHub: https://github.com/horseee/DeepCache
"""
from DeepCache import DeepCacheSDHelper

import logging

from DeepCache import DeepCacheSDHelper

logger = logging.getLogger(__name__)


Expand Down
4 changes: 3 additions & 1 deletion runner/app/pipelines/optim/sfast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
For more information, see the DeepCache project on GitHub: https://github.com/chengzeyi/stable-fast
"""
from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig

import logging

from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig, compile

logger = logging.getLogger(__name__)


Expand Down
21 changes: 10 additions & 11 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
import logging
import os
from typing import List, Tuple, Optional
from enum import Enum
from typing import List, Optional, Tuple

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
StableDiffusion3Pipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
StableDiffusion3Pipeline,
)
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
get_model_dir,
get_torch_device,
SafetyChecker,
is_lightning_model,
is_turbo_model,
)

logger = logging.getLogger(__name__)


Expand Down
33 changes: 16 additions & 17 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_torch_device, get_model_dir, SafetyChecker, is_lightning_model, is_turbo_model

from diffusers import (
StableDiffusionUpscalePipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
import torch
import PIL
from typing import List, Tuple, Optional
import logging
import os
from typing import List, Optional, Tuple

from PIL import ImageFile
from PIL import Image
from io import BytesIO
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import StableDiffusionUpscalePipeline
from huggingface_hub import file_download
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

Expand Down Expand Up @@ -44,8 +43,8 @@ def __init__(self, model_id: str):
kwargs["variant"] = "fp16"

self.ldm = StableDiffusionUpscalePipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
model_id, **kwargs
).to(torch_device)

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(self, model_id: str):
self._safety_checker = SafetyChecker(device=safety_checker_device)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
Expand Down
2 changes: 0 additions & 2 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import logging
import os
import re
import tempfile
import uuid
from pathlib import Path
from typing import Optional

Expand Down
37 changes: 21 additions & 16 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,37 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
from PIL import Image
from typing import Annotated
import logging
import random
import os
import random
from typing import Annotated

from PIL import ImageFile
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}

RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-image", response_model=ImageResponse, responses=responses)
@router.post("/image-to-image", response_model=ImageResponse, responses=RESPONSES)
@router.post(
"/image-to-image/",
response_model=ImageResponse,
responses=responses,
responses=RESPONSES,
include_in_schema=False,
)
async def image_to_image(
Expand All @@ -48,14 +52,14 @@ async def image_to_image(
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=401,
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=400,
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
Expand Down Expand Up @@ -90,7 +94,8 @@ async def image_to_image(
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToImagePipeline error"),
)

# TODO: Return None once Go codegen tool supports optional properties
Expand Down
38 changes: 21 additions & 17 deletions runner/app/routes/image_to_video.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,36 @@
from fastapi import Depends, APIRouter, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from app.pipelines.base import Pipeline
from app.dependencies import get_pipeline
from app.routes.util import image_to_data_url, VideoResponse, HTTPError, http_error
from PIL import Image
from typing import Annotated
import logging
import random
import os
import random
from typing import Annotated

from PIL import ImageFile
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.routes.util import HTTPError, VideoResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

router = APIRouter()

logger = logging.getLogger(__name__)

responses = {400: {"model": HTTPError}, 500: {"model": HTTPError}}
RESPONSES = {
status.HTTP_400_BAD_REQUEST: {"model": HTTPError},
status.HTTP_401_UNAUTHORIZED: {"model": HTTPError},
status.HTTP_500_INTERNAL_SERVER_ERROR: {"model": HTTPError},
}


# TODO: Make model_id and other None properties optional once Go codegen tool supports
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
@router.post("/image-to-video", response_model=VideoResponse, responses=responses)
@router.post("/image-to-video", response_model=VideoResponse, responses=RESPONSES)
@router.post(
"/image-to-video/",
response_model=VideoResponse,
responses=responses,
responses=RESPONSES,
include_in_schema=False,
)
async def image_to_video(
Expand All @@ -47,14 +50,14 @@ async def image_to_video(
if auth_token:
if not token or token.credentials != auth_token:
return JSONResponse(
status_code=401,
status_code=status.HTTP_401_UNAUTHORIZED,
headers={"WWW-Authenticate": "Bearer"},
content=http_error("Invalid bearer token"),
)

if model_id != "" and model_id != pipeline.model_id:
return JSONResponse(
status_code=400,
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"pipeline configured with {pipeline.model_id} but called with "
f"{model_id}"
Expand All @@ -63,7 +66,7 @@ async def image_to_video(

if height % 8 != 0 or width % 8 != 0:
return JSONResponse(
status_code=400,
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(
f"`height` and `width` have to be divisible by 8 but are {height} and "
f"{width}."
Expand All @@ -88,7 +91,8 @@ async def image_to_video(
logger.error(f"ImageToVideoPipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToVideoPipeline error")
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=http_error("ImageToVideoPipeline error"),
)

output_frames = []
Expand Down
Loading

0 comments on commit a48c153

Please sign in to comment.