Skip to content

Commit

Permalink
runner: Separate load_pipeline + load_route
Browse files Browse the repository at this point in the history
  • Loading branch information
yondonfu committed Feb 13, 2024
1 parent d659bd8 commit df80da8
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 24 deletions.
49 changes: 27 additions & 22 deletions runner/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from contextlib import asynccontextmanager
import os
import logging
from typing import Any
from pydantic import BaseModel
from app.routes import health


Expand All @@ -20,9 +18,8 @@ async def lifespan(app: FastAPI):
pipeline = os.environ["PIPELINE"]
model_id = os.environ["MODEL_ID"]

config = load_pipeline(pipeline, model_id)
app.pipeline = config["pipeline"]
app.include_router(config["route"])
app.pipeline = load_pipeline(pipeline, model_id)
app.include_router(load_route(pipeline))

use_route_names_as_operation_ids(app)

Expand All @@ -31,32 +28,20 @@ async def lifespan(app: FastAPI):
logger.info("Shutting down")


class PipelineConfig(BaseModel):
pipeline: Any
route: Any


def load_pipeline(pipeline: str, model_id: str) -> PipelineConfig:
config = {}
def load_pipeline(pipeline: str, model_id: str) -> any:
match pipeline:
case "text-to-image":
from app.pipelines import TextToImagePipeline
from app.routes import text_to_image

config["pipeline"] = TextToImagePipeline(model_id)
config["route"] = text_to_image.router
return TextToImagePipeline(model_id)
case "image-to-image":
from app.pipelines import ImageToImagePipeline
from app.routes import image_to_image

config["pipeline"] = ImageToImagePipeline(model_id)
config["route"] = image_to_image.router
return ImageToImagePipeline(model_id)
case "image-to-video":
from app.pipelines import ImageToVideoPipeline
from app.routes import image_to_video

config["pipeline"] = ImageToVideoPipeline(model_id)
config["route"] = image_to_video.router
return ImageToVideoPipeline(model_id)
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
Expand All @@ -66,7 +51,27 @@ def load_pipeline(pipeline: str, model_id: str) -> PipelineConfig:
f"{pipeline} is not a valid pipeline for model {model_id}"
)

return config

def load_route(pipeline: str) -> any:
match pipeline:
case "text-to-image":
from app.routes import text_to_image

return text_to_image.router
case "image-to-image":
from app.routes import image_to_image

return image_to_image.router
case "image-to-video":
from app.routes import image_to_video

return image_to_video.router
case "frame-interpolation":
raise NotImplementedError("frame-interpolation pipeline not implemented")
case "upscale":
raise NotImplementedError("upscale pipeline not implemented")
case _:
raise EnvironmentError(f"{pipeline} is not a valid pipeline")


def config_logging():
Expand Down
3 changes: 1 addition & 2 deletions runner/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ def bench_pipeline(pipeline: Pipeline, batch_size=1, runs=1) -> BenchMetrics:
print(f"{args.pipeline=} {args.model_id=} {args.runs=} {args.batch_size=}")

start = time()
config = load_pipeline(args.pipeline, args.model_id)
pipeline = config["pipeline"]
pipeline = load_pipeline(args.pipeline, args.model_id)

# Collect pipeline load metrics
load_time = time() - start
Expand Down

0 comments on commit df80da8

Please sign in to comment.