Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve image to image generation - Add Pix2Pix model support #94

Merged
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion runner/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ RUN pyenv install $PYTHON_VERSION && \

# Upgrade pip and install your desired packages
ARG PIP_VERSION=23.3.2

# Pin the setuptools package to ensure INSTRUCT_PIX2PIX model runs properly.
# These specific dependencies are critical to each model.
RUN pip install --no-cache-dir --upgrade pip==${PIP_VERSION} setuptools==69.5.1 wheel==0.43.0 && \
pip install --no-cache-dir torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1
pip install --no-cache-dir torch==2.1.1 torchvision==0.16.1 torchaudio==2.1.1

WORKDIR /app
COPY ./requirements.txt /app
Expand Down
16 changes: 16 additions & 0 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
Expand All @@ -14,6 +16,7 @@
from typing import List, Tuple, Optional
import logging
import os
import random

from PIL import ImageFile

Expand All @@ -23,6 +26,8 @@

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"

# https://huggingface.co/timbrooks/instruct-pix2pix
INSTRUCT_PIX2PIX_MODEL_ID = "timbrooks/instruct-pix2pix"

class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
Expand Down Expand Up @@ -87,6 +92,17 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif INSTRUCT_PIX2PIX_MODEL_ID in model_id:
if "image_guidance_scale" not in kwargs:
kwargs["image_guidance_scale"] = round(random.uniform(1.2, 1.8), ndigits=2)
if "num_inference_steps" not in kwargs:
kwargs["num_inference_steps"] = 10
# Initialize the pipeline for the InstructPix2Pix model
self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
# Assign the scheduler for the InstructPix2Pix model
self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(self.ldm.scheduler.config)
else:
self.ldm = AutoPipelineForImage2Image.from_pretrained(
model_id, **kwargs
Expand Down
2 changes: 2 additions & 0 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def image_to_image(
model_id: Annotated[str, Form()] = "",
strength: Annotated[float, Form()] = 0.8,
guidance_scale: Annotated[float, Form()] = 7.5,
image_guidance_scale: Annotated[float, Form()] = 0,
rickstaa marked this conversation as resolved.
Show resolved Hide resolved
negative_prompt: Annotated[str, Form()] = "",
safety_check: Annotated[bool, Form()] = True,
seed: Annotated[int, Form()] = None,
Expand Down Expand Up @@ -82,6 +83,7 @@ async def image_to_image(
image=image,
strength=strength,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
negative_prompt=negative_prompt,
safety_check=safety_check,
seed=seed,
Expand Down
4 changes: 3 additions & 1 deletion runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ if [ "$MODE" = "alpha" ]; then

# Download text-to-image and image-to-image models.
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models

huggingface-cli download timbrooks/instruct-pix2pix --include "*fp16.safetensors" --exclude "*lora.safetensors*" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand All @@ -78,6 +79,7 @@ else
huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0_Lightning --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models
huggingface-cli download timbrooks/instruct-pix2pix --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models/

# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models
Expand Down
5 changes: 5 additions & 0 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,11 @@
"title": "Guidance Scale",
"default": 7.5
},
"image_guidance_scale": {
"type": "number",
"title": "Image Guidance Scale",
"default": 1
},
"negative_prompt": {
"type": "string",
"title": "Negative Prompt",
Expand Down
4 changes: 4 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ components:
type: number
title: Guidance Scale
default: 7.5
image_guidance_scale:
type: number
title: Image Guidance Scale
default: 1
negative_prompt:
type: string
title: Negative Prompt
Expand Down
1 change: 1 addition & 0 deletions runner/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ xformers==0.0.23
triton>=2.1.0
peft==0.11.1
deepcache==0.1.1
safetensors==0.4.3
43 changes: 22 additions & 21 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.