Skip to content

Commit

Permalink
refactor(runner): improve consistency between I2I and T2I pipelines
Browse files Browse the repository at this point in the history
This commit enhances the consistency between the I2I and T2I pipelines,
making them easier to compare.
  • Loading branch information
rickstaa committed Jun 5, 2024
1 parent 6f15030 commit 21b8d4f
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 39 deletions.
64 changes: 26 additions & 38 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,20 @@ async def image_to_image(
),
)

seeds = []
if seed is None:
seeds = [random.randint(0, 2**32 - 1)]
if num_images_per_prompt > 1:
seeds = [
i for i in range(seeds[0], seeds[0] + num_images_per_prompt)
]
seed = seed if seed is not None else random.randint(0, 2**32 - 1)
seeds = [seed + i for i in range(num_images_per_prompt)]

img = Image.open(image.file).convert("RGB")

try:
images = []
has_nsfw_concept = []

for seed in seeds:
image_out, nsfw = pipeline(
image = Image.open(image.file).convert("RGB")

# TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again
# once LIV-243 and LIV-379 are resolved.
images = []
has_nsfw_concept = []
for seed in seeds:
try:
imgs, nsfw_checks = pipeline(
prompt=prompt,
image=img,
image=image,
strength=strength,
guidance_scale=guidance_scale,
image_guidance_scale=image_guidance_scale,
Expand All @@ -88,28 +84,20 @@ async def image_to_image(
seed=seed,
num_images_per_prompt=1,
)

images.extend(image_out)
has_nsfw_concept.extend(nsfw)

except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
)

seeds = seed
if not isinstance(seeds, list):
seeds = [seeds]
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
except Exception as e:
logger.error(f"ImageToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("ImageToImagePipeline error")
)

output_images = []
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
output_images = [
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False}
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept)
]

return {"images": output_images}
2 changes: 1 addition & 1 deletion runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ async def text_to_image(
params.num_images_per_prompt = 1
for seed in seeds:
try:
params.seed = [seed]
params.seed = seed
imgs, nsfw_check = pipeline(**params.model_dump())
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
Expand Down

0 comments on commit 21b8d4f

Please sign in to comment.