From fd155ec69aac6e08b41a47794719e63263bae16f Mon Sep 17 00:00:00 2001 From: ad-astra-video <99882368+ad-astra-video@users.noreply.github.com> Date: Sat, 8 Jun 2024 04:24:15 -0500 Subject: [PATCH] fix(runner): update img2img to do sequential processing of batch request (#95) * update img2img to do sequential processing of batch request * refactor(runner): improve consistency between I2I and T2I pipelines This commit enhances the consistency between the I2I and T2I pipelines, making them easier to compare. --------- Co-authored-by: Rick Staa --- runner/app/routes/image_to_image.py | 77 +++++++++++++---------------- runner/app/routes/text_to_image.py | 2 +- 2 files changed, 35 insertions(+), 44 deletions(-) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index d210f6e5..df8ac336 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -62,51 +62,42 @@ async def image_to_image( ), ) - if seed is None: - seed = random.randint(0, 2**32 - 1) - if num_images_per_prompt > 1: - seed = [ - i for i in range(seed, seed + num_images_per_prompt) - ] + seed = seed if seed is not None else random.randint(0, 2**32 - 1) + seeds = [seed + i for i in range(num_images_per_prompt)] - img = Image.open(image.file).convert("RGB") - # If a list of seeds/generators is passed, diffusers wants a list of images - # https://github.com/huggingface/diffusers/blob/17808a091e2d5615c2ed8a63d7ae6f2baea11e1e/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py#L715 - if isinstance(seed, list): - image = [img] * num_images_per_prompt - else: - image = img + image = Image.open(image.file).convert("RGB") - try: - images, has_nsfw_concept = pipeline( - prompt=prompt, - image=image, - strength=strength, - guidance_scale=guidance_scale, - image_guidance_scale=image_guidance_scale, - negative_prompt=negative_prompt, - safety_check=safety_check, - seed=seed, - num_images_per_prompt=num_images_per_prompt, - ) - except Exception as e: - logger.error(f"ImageToImagePipeline error: {e}") - logger.exception(e) - return JSONResponse( - status_code=500, content=http_error("ImageToImagePipeline error") - ) - - seeds = seed - if not isinstance(seeds, list): - seeds = [seeds] + # TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again + # once LIV-243 and LIV-379 are resolved. + images = [] + has_nsfw_concept = [] + for seed in seeds: + try: + imgs, nsfw_checks = pipeline( + prompt=prompt, + image=image, + strength=strength, + guidance_scale=guidance_scale, + image_guidance_scale=image_guidance_scale, + negative_prompt=negative_prompt, + safety_check=safety_check, + seed=seed, + num_images_per_prompt=1, + ) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_checks) + except Exception as e: + logger.error(f"ImageToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=500, content=http_error("ImageToImagePipeline error") + ) - output_images = [] - for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): - # TODO: Return None once Go codegen tool supports optional properties - # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 - is_nsfw = is_nsfw or False - output_images.append( - {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} - ) + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + output_images = [ + {"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) + ] return {"images": output_images} diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 51a28f71..3f52e36d 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -67,7 +67,7 @@ async def text_to_image( params.num_images_per_prompt = 1 for seed in seeds: try: - params.seed = [seed] + params.seed = seed imgs, nsfw_check = pipeline(**params.model_dump()) images.extend(imgs) has_nsfw_concept.extend(nsfw_check)