forked from livepeer/ai-worker
-
Notifications
You must be signed in to change notification settings - Fork 0
/
image_to_image.py
256 lines (221 loc) · 9.51 KB
/
image_to_image.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import logging
import os
from enum import Enum
import time
from typing import List, Optional, Tuple
import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile
from safetensors.torch import load_file
ImageFile.LOAD_TRUNCATED_IMAGES = True
logger = logging.getLogger(__name__)
SFAST_WARMUP_ITERATIONS = 2 # Model warm-up iterations when SFAST is enabled.
class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs."""
SDXL_LIGHTNING = "ByteDance/SDXL-Lightning"
INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix"
@classmethod
def list(cls):
"""Return a list of all model IDs."""
return list(map(lambda c: c.value, cls))
class ImageToImagePipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}
torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
)
folder_path = os.path.join(get_model_dir(), folder_name)
# Load the fp16 variant if fp16 'safetensors' files are present in the cache.
# NOTE: Exception for SDXL-Lightning model: despite having fp16 'safetensors'
# files, they are not named according to the standard convention.
has_fp16_variant = (
any(
".fp16.safetensors" in fname
for _, _, files in os.walk(folder_path)
for fname in files
)
or ModelName.SDXL_LIGHTNING.value in model_id
)
if torch_device != "cpu" and has_fp16_variant:
logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id)
kwargs["torch_dtype"] = torch.float16
kwargs["variant"] = "fp16"
# Special case SDXL-Lightning because the unet for SDXL needs to be swapped
if ModelName.SDXL_LIGHTNING.value in model_id:
base = "stabilityai/stable-diffusion-xl-base-1.0"
# ByteDance/SDXL-Lightning-2step
if "2step" in model_id:
unet_id = "sdxl_lightning_2step_unet"
# ByteDance/SDXL-Lightning-4step
elif "4step" in model_id:
unet_id = "sdxl_lightning_4step_unet"
# ByteDance/SDXL-Lightning-8step
elif "8step" in model_id:
unet_id = "sdxl_lightning_8step_unet"
else:
# Default to 8step
unet_id = "sdxl_lightning_8step_unet"
unet_config = UNet2DConditionModel.load_config(
pretrained_model_name_or_path=base,
subfolder="unet",
cache_dir=kwargs["cache_dir"],
)
unet = UNet2DConditionModel.from_config(unet_config).to(
torch_device, kwargs["torch_dtype"]
)
unet.load_state_dict(
load_file(
hf_hub_download(
ModelName.SDXL_LIGHTNING.value,
f"{unet_id}.safetensors",
cache_dir=kwargs["cache_dir"],
),
device=str(torch_device),
)
)
self.ldm = StableDiffusionXLPipeline.from_pretrained(
base, unet=unet, **kwargs
).to(torch_device)
self.ldm.scheduler = EulerDiscreteScheduler.from_config(
self.ldm.scheduler.config, timestep_spacing="trailing"
)
elif ModelName.INSTRUCT_PIX2PIX.value in model_id:
self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained(
model_id, **kwargs
).to(torch_device)
self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config(
self.ldm.scheduler.config
)
else:
self.ldm = AutoPipelineForImage2Image.from_pretrained(
model_id, **kwargs
).to(torch_device)
sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
logger.warning(
"Both 'SFAST' and 'DEEPCACHE' are enabled. This is not recommended "
"as it may lead to suboptimal performance. Please disable one of them."
)
if sfast_enabled:
logger.info(
"ImageToImagePipeline will be dynamically compiled with stable-fast "
"for %s",
model_id,
)
from app.pipelines.optim.sfast import compile_model
self.ldm = compile_model(self.ldm)
# Warm-up the pipeline.
# TODO: Not yet supported for ImageToImagePipeline.
if os.getenv("SFAST_WARMUP", "true").lower() == "true":
warmup_kwargs = {
"prompt":"A warmed up pipeline is a happy pipeline a short poem by ricksta",
"image": PIL.Image.new("RGB", (576, 1024)),
"strength": 0.8,
"negative_prompt": "No blurry or weird artifacts",
"num_images_per_prompt":4,
}
logger.info("Warming up ImageToImagePipeline pipeline...")
total_time = 0
for ii in range(SFAST_WARMUP_ITERATIONS):
t = time.time()
try:
self.ldm(**warmup_kwargs).images
except Exception as e:
logger.error(f"ImageToImagePipeline warmup error: {e}")
raise e
iteration_time = time.time() - t
total_time += iteration_time
logger.info(
"Warmup iteration %s took %s seconds", ii + 1, iteration_time
)
logger.info("Total warmup time: %s seconds", total_time)
if deepcache_enabled and not (
is_lightning_model(model_id) or is_turbo_model(model_id)
):
logger.info(
"ImageToImagePipeline will be optimized with DeepCache for %s",
model_id,
)
from app.pipelines.optim.deepcache import enable_deepcache
self.ldm = enable_deepcache(self.ldm)
elif deepcache_enabled:
logger.warning(
"DeepCache is not supported for Lightning or Turbo models. "
"ImageToImagePipeline will NOT be optimized with DeepCache for %s",
model_id,
)
safety_checker_device = os.getenv("SAFETY_CHECKER_DEVICE", "cuda").lower()
self._safety_checker = SafetyChecker(device=safety_checker_device)
def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)
if seed is not None:
if isinstance(seed, int):
kwargs["generator"] = torch.Generator(get_torch_device()).manual_seed(
seed
)
elif isinstance(seed, list):
kwargs["generator"] = [
torch.Generator(get_torch_device()).manual_seed(s) for s in seed
]
if "num_inference_steps" in kwargs and (
kwargs["num_inference_steps"] is None or kwargs["num_inference_steps"] < 1
):
del kwargs["num_inference_steps"]
if (
self.model_id == "stabilityai/sdxl-turbo"
or self.model_id == "stabilityai/sd-turbo"
):
# SD turbo models were trained without guidance_scale so
# it should be set to 0
kwargs["guidance_scale"] = 0.0
# Ensure num_inference_steps * strength >= 1 for minimum pipeline
# execution steps.
if "num_inference_steps" in kwargs:
kwargs["strength"] = max(
1.0 / kwargs.get("num_inference_steps", 1),
kwargs.get("strength", 0.5),
)
elif ModelName.SDXL_LIGHTNING.value in self.model_id:
# SDXL-Lightning models should have guidance_scale = 0 and use
# the correct number of inference steps for the unet checkpoint loaded
kwargs["guidance_scale"] = 0.0
if "2step" in self.model_id:
kwargs["num_inference_steps"] = 2
elif "4step" in self.model_id:
kwargs["num_inference_steps"] = 4
elif "8step" in self.model_id:
kwargs["num_inference_steps"] = 8
else:
# Default to 8step
kwargs["num_inference_steps"] = 8
output = self.ldm(prompt, image=image, **kwargs)
if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
else:
has_nsfw_concept = [None] * len(output.images)
return output.images, has_nsfw_concept
def __str__(self) -> str:
return f"ImageToImagePipeline model_id={self.model_id}"