Skip to content

Commit

Permalink
feat(I2I): add Pix2Pix model support (#94)
Browse files Browse the repository at this point in the history
* Add support for the `timbrooks/instruct-pix2pix` model

* Remove text-to-vid model

* Tweak pix2pix

* Do not randomize guidance scale

* Make image guidance scale a param

* Improve the image-to-image processing by enabling the timbrooks/instruct-pix2pix model.

* merge all stronk changes for image-to-image improvements

* minor update

* regenerated runner.gen.go

* refactor(I2I): improve pix2pix pipeline initialization

This commit removes arguments that are not used during the
initialization of the pix2pix pipeline. It also cleansup the codebase a
bit and changes the default to improve transparency.

* chore: update openapi spec

This commit updates the OpenAPI spec and removes the openapi.yaml file
since it is not directly used.

---------

Co-authored-by: Marco van Dijk <[email protected]>
Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent d71ac95 commit f549695
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 384 deletions.
35 changes: 30 additions & 5 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@
is_lightning_model,
is_turbo_model,
)
from enum import Enum

from diffusers import (
AutoPipelineForImage2Image,
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
Expand All @@ -27,7 +30,17 @@

logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""

SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"

@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))


class ImageToImagePipeline(Pipeline):
Expand All @@ -45,7 +58,7 @@ def __init__(self, model_id: str):
for _, _, files in os.walk(folder_path)
for fname in files
)
or SDXL_LIGHTNING_MODEL_ID in model_id
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id)
Expand All @@ -56,7 +69,7 @@ def __init__(self, model_id: str):
self.model_id = model_id

# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if SDXL_LIGHTNING_MODEL_ID in model_id:
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"

# ByteDance/SDXL-Lightning-2step
Expand All @@ -78,7 +91,7 @@ def __init__(self, model_id: str):
unet.load_state_dict(
load_file(
hf_hub_download(
SDXL_LIGHTNING_MODEL_ID,
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
Expand All @@ -93,6 +106,14 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.INSTRUCT_PIX2PIX.value in model_id:
self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)

self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.ldm.scheduler.config
)
else:
self.ldm = AutoPipelineForImage2Image.from_pretrained(
model_id, **kwargs
Expand Down Expand Up @@ -176,7 +197,7 @@ def __call__(

if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 2
elif SDXL_LIGHTNING_MODEL_ID in self.model_id:
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
Expand All @@ -190,6 +211,10 @@ def __call__(
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
elif ModelName.INSTRUCT_PIX2PIX.value in self.model_id:
if "num_inference_steps" not in kwargs:
# TODO: Currently set to recommended value make configurable later.
kwargs["num_inference_steps"] = 10

output = self.ldm(prompt, image=image, **kwargs)

Expand Down
2 changes: 2 additions & 0 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def image_to_image(
model_id: Annotated[str, Form()] = "",
strength: Annotated[float, Form()] = 0.8,
guidance_scale: Annotated[float, Form()] = 7.5,
image_guidance_scale: Annotated[float, Form()] = 1.5,
negative_prompt: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = True,
seed: Annotated[int, Form()] = None,
Expand Down Expand Up @@ -82,6 +83,7 @@ async def image_to_image(
image=image,
strength=strength,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
Expand Down
4 changes: 3 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ if [ "$MODE" = "alpha" ]; then

# Download text-to-image and image-to-image models.
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models

huggingface-cli download timbrooks/instruct-pix2pix --include "*fp16.safetensors" --exclude "*lora.safetensors*" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand All @@ -78,6 +79,7 @@ else
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0_Lightning --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download timbrooks/instruct-pix2pix --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models/

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
Expand Down
5 changes: 5 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@
"title": "Guidance Scale",
"default": 7.5
},
"image_guidance_scale": {
"type": "number",
"title": "Image Guidance Scale",
"default": 1.5
},
"negative_prompt": {
"type": "string",
"title": "Negative Prompt",
Expand Down
Loading

0 comments on commit f549695

Please sign in to comment.