Skip to content

Commit

Permalink
Experiment with combined base+refiner for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
stronk-dev committed Apr 12, 2024
1 parent 3c297ca commit 8ab28a4
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
25 changes: 25 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
EulerDiscreteScheduler,
StableDiffusionXLImg2ImgPipeline
)
from safetensors.torch import load_file
from huggingface_hub import file_download, hf_hub_download
Expand All @@ -18,6 +19,7 @@
logger = logging.getLogger(__name__)

SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning"
SDXL_BASE_MODEL_ID = "stabilityai/stable-diffusion-xl-base-1.0"


class TextToImagePipeline(Pipeline):
Expand Down Expand Up @@ -90,6 +92,20 @@ def __init__(self, model_id: str):
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif SDXL_BASE_MODEL_ID in self.model_id:
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
kwargs["use_safetensors"] = True
self.ldm = StableDiffusionXLPipeline.from_pretrained(model_id, **kwargs).to("cuda")
self.refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.ldm.text_encoder_2,
vae=self.ldm.vae,
torch_dtype=kwargs["torch_dtype"],
use_safetensors=True,
variant=kwargs["variant"],
).to("cuda")

else:
self.ldm = AutoPipelineForText2Image.from_pretrained(model_id, **kwargs).to(
torch_device
Expand Down Expand Up @@ -156,6 +172,15 @@ def __call__(self, prompt: str, **kwargs) -> List[PIL.Image]:
else:
# Default to 2step
kwargs["num_inference_steps"] = 2
elif SDXL_BASE_MODEL_ID in self.model_id:
kwargs["num_inference_steps"] = 40
image = self.ldm(prompt=prompt,
num_inference_steps=40,
denoising_end=0.8,
output_type="latent",).images
kwargs["image"] = image
kwargs["denoising_start"] = 0.8
return self.refiner(prompt, **kwargs).images

return self.ldm(prompt, **kwargs).images

Expand Down
3 changes: 3 additions & 0 deletions runner/dl_checkpoints.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ else
# Download image-to-video models.
huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download some refiner models
huggingface-cli download stabilityai/stable-diffusion-xl-refiner-1.0 --include "*.fp16.safetensors" "*.json" --cache-dir models

# Download image-to-video models (token-gated).
printf "\nDownloading token-gated models...\n"
check_hf_auth
Expand Down

0 comments on commit 8ab28a4

Please sign in to comment.