From 8b5cd1ea0df3860a8ea9371e3b31e3a3f1c26e70 Mon Sep 17 00:00:00 2001 From: ad-astra-video <99882368+ad-astra-video@users.noreply.github.com> Date: Mon, 20 May 2024 19:02:26 -0500 Subject: [PATCH] fix: process T2I batch sequentially to prevent CUDA out of memory errors (#66) This commit ensures that batches in the T2I pipeline are processed sequentially. This change is necessary because we currently lack the ability to estimate a GPU's VRAM capacity and manage requests accordingly. * process text-to-image requested image count sequentially * refactor: cleanup sequential images code This commit cleans up the sequential images code a bit. --------- Co-authored-by: Rick Staa --- runner/app/routes/text_to_image.py | 53 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 27 deletions(-) diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3a747ae1..51a28f71 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -7,7 +7,7 @@ from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error import logging import random -import os +import os, json router = APIRouter() @@ -57,33 +57,32 @@ async def text_to_image( ), ) - if params.seed is None: - params.seed = random.randint(0, 2**32 - 1) - if params.num_images_per_prompt > 1: - params.seed = [ - i for i in range(params.seed, params.seed + params.num_images_per_prompt) - ] + seed = params.seed if params.seed is not None else random.randint(0, 2**32 - 1) + seeds = [seed + i for i in range(params.num_images_per_prompt)] - try: - images, has_nsfw_concept = pipeline(**params.model_dump()) - except Exception as e: - logger.error(f"TextToImagePipeline error: {e}") - logger.exception(e) - return JSONResponse( - status_code=500, content=http_error("TextToImagePipeline error") - ) - - seeds = params.seed - if not isinstance(seeds, list): - seeds = [seeds] + # TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again + # once LIV-243 and LIV-379 are resolved. + images = [] + has_nsfw_concept = [] + params.num_images_per_prompt = 1 + for seed in seeds: + try: + params.seed = [seed] + imgs, nsfw_check = pipeline(**params.model_dump()) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_check) + except Exception as e: + logger.error(f"TextToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=500, content=http_error("TextToImagePipeline error") + ) - output_images = [] - for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): - # TODO: Return None once Go codegen tool supports optional properties - # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 - is_nsfw = is_nsfw or False - output_images.append( - {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} - ) + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + output_images = [ + {"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) + ] return {"images": output_images}