Skip to content

Commit

Permalink
fix: process T2I batch sequentially to prevent CUDA out of memory err…
Browse files Browse the repository at this point in the history
…ors (#66)

This commit ensures that batches in the T2I pipeline are processed sequentially. This change is necessary because we currently lack the ability to estimate a GPU's VRAM capacity and manage requests accordingly.

* process text-to-image requested image count sequentially

* refactor: cleanup sequential images code

This commit cleans up the sequential images code a bit.

---------

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
ad-astra-video and rickstaa authored May 21, 2024
1 parent c05cccf commit 8b5cd1e
Showing 1 changed file with 26 additions and 27 deletions.
53 changes: 26 additions & 27 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from app.routes.util import image_to_data_url, ImageResponse, HTTPError, http_error
import logging
import random
import os
import os, json

router = APIRouter()

Expand Down Expand Up @@ -57,33 +57,32 @@ async def text_to_image(
),
)

if params.seed is None:
params.seed = random.randint(0, 2**32 - 1)
if params.num_images_per_prompt > 1:
params.seed = [
i for i in range(params.seed, params.seed + params.num_images_per_prompt)
]
seed = params.seed if params.seed is not None else random.randint(0, 2**32 - 1)
seeds = [seed + i for i in range(params.num_images_per_prompt)]

try:
images, has_nsfw_concept = pipeline(**params.model_dump())
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("TextToImagePipeline error")
)

seeds = params.seed
if not isinstance(seeds, list):
seeds = [seeds]
# TODO: Process one image at a time to avoid CUDA OEM errors. Can be removed again
# once LIV-243 and LIV-379 are resolved.
images = []
has_nsfw_concept = []
params.num_images_per_prompt = 1
for seed in seeds:
try:
params.seed = [seed]
imgs, nsfw_check = pipeline(**params.model_dump())
images.extend(imgs)
has_nsfw_concept.extend(nsfw_check)
except Exception as e:
logger.error(f"TextToImagePipeline error: {e}")
logger.exception(e)
return JSONResponse(
status_code=500, content=http_error("TextToImagePipeline error")
)

output_images = []
for img, sd, is_nsfw in zip(images, seeds, has_nsfw_concept):
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
is_nsfw = is_nsfw or False
output_images.append(
{"url": image_to_data_url(img), "seed": sd, "nsfw": is_nsfw}
)
# TODO: Return None once Go codegen tool supports optional properties
# OAPI 3.1 https://github.com/deepmap/oapi-codegen/issues/373
output_images = [
{"url": image_to_data_url(img), "seed": sd, "nsfw": nsfw or False}
for img, sd, nsfw in zip(images, seeds, has_nsfw_concept)
]

return {"images": output_images}

0 comments on commit 8b5cd1e

Please sign in to comment.