From 5ab6f10217542e243395e7bacdfe02dce246214c Mon Sep 17 00:00:00 2001 From: Brad P Date: Sun, 21 Apr 2024 11:10:35 -0500 Subject: [PATCH] process text-to-image requested image count sequentially --- runner/app/routes/text_to_image.py | 42 ++++++++++++++++++------------ 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/runner/app/routes/text_to_image.py b/runner/app/routes/text_to_image.py index 3a747ae1..af197251 100644 --- a/runner/app/routes/text_to_image.py +++ b/runner/app/routes/text_to_image.py @@ -7,7 +7,7 @@ from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error import logging import random -import os +import os, json router = APIRouter() @@ -59,24 +59,32 @@ async def text_to_image( if params.seed is None: params.seed = random.randint(0, 2**32 - 1) - if params.num_images_per_prompt > 1: - params.seed = [ - i for i in range(params.seed, params.seed + params.num_images_per_prompt) - ] - - try: - images, has_nsfw_concept = pipeline(**params.model_dump()) - except Exception as e: - logger.error(f"TextToImagePipeline error: {e}") - logger.exception(e) - return JSONResponse( - status_code=500, content=http_error("TextToImagePipeline error") - ) + if params.num_images_per_prompt > 1: + params.seed = [ + i for i in range(params.seed, params.seed + params.num_images_per_prompt) + ] + if not isinstance(params.seed, list): + params.seed = [params.seed] + + num_images_per_prompt = params.num_images_per_prompt + params.num_images_per_prompt = 1 + images = [] + has_nsfw_concept = [] seeds = params.seed - if not isinstance(seeds, list): - seeds = [seeds] - + for i in range(num_images_per_prompt): + try: + params.seed = [seeds[i]] #pass one seed at a time + imgs, nsfw_check = pipeline(**params.model_dump()) + images.extend(imgs) + has_nsfw_concept.extend(nsfw_check) + except Exception as e: + logger.error(f"TextToImagePipeline error: {e}") + logger.exception(e) + return JSONResponse( + status_code=500, content=http_error("TextToImagePipeline error") + ) + output_images = [] for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept): # TODO: Return None once Go codegen tool supports optional properties