From 5ab6f10217542e243395e7bacdfe02dce246214c Mon Sep 17 00:00:00 2001 From: Brad P Date: Sun, 21 Apr 2024 11:10:35 -0500 Subject: [PATCH 1/2] process text-to-image requested image count sequentially --- runner/app/routes/text_to_image.py | 42 ++++++++++++++++++------------ 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3a747ae1..af197251 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -7,7 +7,7 @@ from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error import logging import random -import os +import os, json router = APIRouter() @@ -59,24 +59,32 @@ async def text_to_image( if params.seed is None: params.seed = random.randint(0, 2**32 - 1) - if params.num_images_per_prompt > 1: - params.seed = [ - i for i in range(params.seed, params.seed + params.num_images_per_prompt) - ] - - try: - images, has_nsfw_concept = pipeline(**params.model_dump()) - except Exception as e: - logger.error(f"TextToImagePipeline error: {e}") - logger.exception(e) - return JSONResponse( - status_code=500, content=http_error("TextToImagePipeline error") - ) + if params.num_images_per_prompt > 1: + params.seed = [ + i for i in range(params.seed, params.seed + params.num_images_per_prompt) + ] + if not isinstance(params.seed, list): + params.seed = [params.seed] + + num_images_per_prompt = params.num_images_per_prompt + params.num_images_per_prompt = 1 + images = [] + has_nsfw_concept = [] seeds = params.seed - if not isinstance(seeds, list): - seeds = [seeds] - + for i in range(num_images_per_prompt): + try: + params.seed = [seeds[i]] #pass one seed at a time + imgs, nsfw_check = pipeline(**params.model_dump()) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_check) + except Exception as e: + logger.error(f"TextToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=500, content=http_error("TextToImagePipeline error") + ) + output_images = [] for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): # TODO: Return None once Go codegen tool supports optional properties From f4c43907a1d5a902e5910bbcb70dd8e4eff28f2d Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 21 May 2024 01:41:42 +0200 Subject: [PATCH 2/2] refactor: cleanup sequential images code This commit cleans up the sequential images code a bit. --- runner/app/routes/text_to_image.py | 37 +++++++++++------------------- 1 file changed, 14 insertions(+), 23 deletions(-) diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index af197251..51a28f71 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -57,24 +57,17 @@ async def text_to_image( ), ) - if params.seed is None: - params.seed = random.randint(0, 2**32 - 1) - if params.num_images_per_prompt > 1: - params.seed = [ - i for i in range(params.seed, params.seed + params.num_images_per_prompt) - ] + seed = params.seed if params.seed is not None else random.randint(0, 2**32 - 1) + seeds = [seed + i for i in range(params.num_images_per_prompt)] - if not isinstance(params.seed, list): - params.seed = [params.seed] - - num_images_per_prompt = params.num_images_per_prompt - params.num_images_per_prompt = 1 + # TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again + # once LIV-243 and LIV-379 are resolved. images = [] has_nsfw_concept = [] - seeds = params.seed - for i in range(num_images_per_prompt): + params.num_images_per_prompt = 1 + for seed in seeds: try: - params.seed = [seeds[i]] #pass one seed at a time + params.seed = [seed] imgs, nsfw_check = pipeline(**params.model_dump()) images.extend(imgs) has_nsfw_concept.extend(nsfw_check) @@ -84,14 +77,12 @@ async def text_to_image( return JSONResponse( status_code=500, content=http_error("TextToImagePipeline error") ) - - output_images = [] - for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): - # TODO: Return None once Go codegen tool supports optional properties - # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 - is_nsfw = is_nsfw or False - output_images.append( - {"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw} - ) + + # TODO: Return None once Go codegen tool supports optional properties + # OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373 + output_images = [ + {"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False} + for img, sd, nsfw in zip(images, seeds, has_nsfw_concept) + ] return {"images": output_images}