Skip to content

Commit

Permalink
feat: add LoRa support to the txt2img and img2img pipelines (#119)
Browse files Browse the repository at this point in the history
Add LoRa support to `txt2img` and `img2img` routes, optimize loading and error handling, and introduce memory cleanup mechanism.

Co-authored-by: Elite Encoder <[email protected]>
Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
3 people authored Sep 22, 2024
1 parent e38db92 commit bcd929d
Show file tree
Hide file tree
Showing 10 changed files with 346 additions and 53 deletions.
10 changes: 10 additions & 0 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_torch_device,
Expand Down Expand Up @@ -172,11 +173,14 @@ def __init__(self, model_id: str):
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

self._lora_loader = LoraLoader(self.ldm)

def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
loras_json = kwargs.pop("loras", "")

if seed is not None:
if isinstance(seed, int):
Expand All @@ -188,6 +192,12 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

# Dynamically (un)load LoRas.
if not loras_json:
self._lora_loader.disable_loras()
else:
self._lora_loader.load_loras(loras_json)

if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
Expand Down
10 changes: 10 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_torch_device,
Expand Down Expand Up @@ -202,11 +203,14 @@ def __init__(self, model_id: str):
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)

self._lora_loader = LoraLoader(self.ldm)

def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
loras_json = kwargs.pop("loras", "")

if seed is not None:
if isinstance(seed, int):
Expand All @@ -218,6 +222,12 @@ def __call__(
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]

# Dynamically (un)load LoRas.
if not loras_json:
self._lora_loader.disable_loras()
else:
self._lora_loader.load_loras(loras_json)

if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
Expand Down
1 change: 1 addition & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

from app.pipelines.utils.utils import (
LoraLoader,
SafetyChecker,
get_model_dir,
get_model_path,
Expand Down
191 changes: 190 additions & 1 deletion runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,26 @@
"""This module contains several utility functions."""

import json
import logging
import os
import re
from pathlib import Path
from typing import Dict, Optional
from typing import Any, Dict, List, Optional

import numpy as np
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from PIL import Image
from torch import dtype as TorchDtype
from transformers import CLIPImageProcessor

logger = logging.getLogger(__name__)

LORA_LIMIT = 4 # Max number of LoRas that can be requested at once.
LORA_MAX_LOADED = 12 # Number of LoRas to keep in memory.
LORA_FREE_VRAM_THRESHOLD = 2.0 # VRAM threshold (GB) to start evicting LoRas.


def get_model_dir() -> Path:
return Path(os.environ["MODEL_DIR"])
Expand Down Expand Up @@ -175,3 +181,186 @@ def check_nsfw_images(
clip_input=safety_checker_input.pixel_values.to(self._dtype),
)
return images, has_nsfw_concept


def is_numeric(val: Any) -> bool:
"""Check if the given value is numeric.
Args:
s: Value to check.
Returns:
True if the value is numeric, False otherwise.
"""
try:
float(val)
return True
except (ValueError, TypeError):
return False


class LoraLoadingError(Exception):
"""Exception raised for errors during LoRa loading."""

def __init__(self, message="Error loading LoRas", original_exception=None):
"""Initialize the exception.
Args:
message: The error message.
original_exception: The original exception that caused the error.
"""
if original_exception:
message = f"{message}: {original_exception}"
super().__init__(message)
self.original_exception = original_exception


class LoraLoader:
"""Utility class to load LoRas and set their weights into a given pipeline.
Attributes:
pipeline: Diffusion pipeline on which the LoRas are loaded.
loras_enabled: Flag to enable or disable LoRas.
"""

def __init__(self, pipeline: DiffusionPipeline):
"""Initializes the LoraLoader.
Args:
pipeline: Diffusion pipeline to load LoRas into.
"""
self.pipeline = pipeline
self.loras_enabled = False

def _get_loaded_loras(self) -> List[str]:
"""Returns the names of the loaded LoRas.
Returns:
List of loaded LoRa names.
"""
loaded_loras_dict = self.pipeline.get_list_adapters()
seen = set()
return [
lora
for loras in loaded_loras_dict.values()
for lora in loras
if lora not in seen and not seen.add(lora)
]

def _evict_loras_if_needed(self, request_loras: dict) -> None:
"""Evict the oldest unused LoRa until free memory is above the threshold or the
number of loaded LoRas is below the maximum allowed.
Args:
request_loras: list of requested LoRas.
"""
while True:
free_memory_gb = (
torch.cuda.mem_get_info(device=self.pipeline.device)[0] / 1024**3
)
loaded_loras = self._get_loaded_loras()
memory_limit_reached = free_memory_gb < LORA_FREE_VRAM_THRESHOLD

# Break if memory is sufficient, LoRas within limit, or no LoRas to evict.
if (
not memory_limit_reached
and len(loaded_loras) < LORA_MAX_LOADED
or not any(lora not in request_loras for lora in loaded_loras)
):
break

# Evict the oldest unused LoRa.
for lora in loaded_loras:
if lora not in request_loras:
self.pipeline.delete_adapters(lora)
break
if memory_limit_reached:
torch.cuda.empty_cache()

def load_loras(self, loras_json: str) -> None:
"""Loads LoRas and sets their weights into the pipeline managed by this
LoraLoader.
Args:
loras_json: A JSON string containing key-value pairs, where the key is the
repository to load LoRas from and the value is the strength (a float
with a minimum value of 0.0) to assign to the LoRa.
Raises:
LoraLoadingError: If an error occurs during LoRa loading.
"""
try:
lora_dict = json.loads(loras_json)
except json.JSONDecodeError:
error_message = f"Unable to parse '{loras_json}' as JSON."
logger.warning(error_message)
raise LoraLoadingError(error_message)

# Parse Lora strengths and check for invalid values.
invalid_loras = {
adapter: val
for adapter, val in lora_dict.items()
if not is_numeric(val) or float(val) < 0.0
}
if invalid_loras:
error_message = (
"All strengths must be numbers greater than or equal to 0.0."
)
logger.warning(error_message)
raise LoraLoadingError(error_message)
lora_dict = {adapter: float(val) for adapter, val in lora_dict.items()}

# Disable LoRas if none are provided.
if not lora_dict:
self.disable_loras()
return

# Limit the number of active loras to prevent pipeline slowdown.
if len(lora_dict) > LORA_LIMIT:
raise LoraLoadingError(f"Too many LoRas provided. Maximum is {LORA_LIMIT}.")

# Re-enable LoRas if they were disabled.
self.enable_loras()

# Load new LoRa adapters.
loaded_loras = self._get_loaded_loras()
try:
for adapter in lora_dict.keys():
# Load new Lora weights and evict the oldest unused Lora if necessary.
if adapter not in loaded_loras:
self.pipeline.load_lora_weights(adapter, adapter_name=adapter)
self._evict_loras_if_needed(list(lora_dict.keys()))
except Exception as e:
# Delete failed adapter and log the error.
self.pipeline.delete_adapters(adapter)
torch.cuda.empty_cache()
if "not found in the base model" in str(e):
error_message = (
"LoRa incompatible with base model: "
f"'{self.pipeline.name_or_path}'"
)
elif getattr(e, "server_message", "") == "Repository not found":
error_message = f"LoRa repository '{adapter}' not found"
else:
error_message = f"Unable to load LoRas for adapter '{adapter}'"
logger.exception(e)
raise LoraLoadingError(error_message)

# Set unused LoRas strengths to 0.0.
for lora in loaded_loras:
if lora not in lora_dict:
lora_dict[lora] = 0.0

# Set the lora adapter strengths.
self.pipeline.set_adapters(*map(list, zip(*lora_dict.items())))

def disable_loras(self) -> None:
"""Disables all LoRas in the pipeline."""
if self.loras_enabled:
self.pipeline.disable_lora()
self.loras_enabled = False

def enable_loras(self) -> None:
"""Enables all LoRas in the pipeline."""
if not self.loras_enabled:
self.pipeline.enable_lora()
self.loras_enabled = True
21 changes: 21 additions & 0 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import random
from typing import Annotated

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.utils import LoraLoadingError
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, File, Form, UploadFile, status
from fastapi.responses import JSONResponse
Expand Down Expand Up @@ -65,6 +67,16 @@ async def image_to_image(
str,
Form(description="Hugging Face model ID used for image generation."),
] = "",
loras: Annotated[
str,
Form(
description=(
"A LoRA (Low-Rank Adaptation) model and its corresponding weight for "
'image generation. Example: { "latent-consistency/lcm-lora-sdxl": '
'1.0, "nerijs/pixel-art-xl": 1.2}.'
)
),
] = "",
strength: Annotated[
float,
Form(
Expand Down Expand Up @@ -159,6 +171,7 @@ async def image_to_image(
prompt=prompt,
image=image,
strength=strength,
loras=loras,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
negative_prompt=negative_prompt,
Expand All @@ -169,7 +182,15 @@ async def image_to_image(
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except LoraLoadingError as e:
logger.error(f"ImageToImagePipeline error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
Expand Down
21 changes: 21 additions & 0 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import random
from typing import Annotated

import torch
from app.dependencies import get_pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils.utils import LoraLoadingError
from app.routes.util import HTTPError, ImageResponse, http_error, image_to_data_url
from fastapi import APIRouter, Depends, status
from fastapi.responses import JSONResponse
Expand All @@ -25,6 +27,17 @@ class TextToImageParams(BaseModel):
default="", description="Hugging Face model ID used for image generation."
),
]
loras: Annotated[
str,
Field(
default="",
description=(
"A LoRA (Low-Rank Adaptation) model and its corresponding weight for "
'image generation. Example: { "latent-consistency/lcm-lora-sdxl": '
'1.0, "nerijs/pixel-art-xl": 1.2}.'
),
),
]
prompt: Annotated[
str,
Field(
Expand Down Expand Up @@ -161,7 +174,15 @@ async def text_to_image(
imgs, nsfw_check = pipeline(**kwargs)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except LoraLoadingError as e:
logger.error(f"TextToImagePipeline error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=http_error(str(e)),
)
except Exception as e:
if isinstance(e, torch.cuda.OutOfMemoryError):
torch.cuda.empty_cache()
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
Expand Down
Loading

0 comments on commit bcd929d

Please sign in to comment.