diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index feb11349..e2d6c692 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -5,6 +5,15 @@ import PIL import torch +from app.pipelines.base import Pipeline +from app.pipelines.utils import ( + SafetyChecker, + get_model_dir, + get_torch_device, + is_lightning_model, + is_turbo_model, + split_prompt, +) from diffusers import ( AutoPipelineForText2Image, EulerDiscreteScheduler, @@ -15,15 +24,6 @@ from huggingface_hub import file_download, hf_hub_download from safetensors.torch import load_file -from app.pipelines.base import Pipeline -from app.pipelines.utils import ( - SafetyChecker, - get_model_dir, - get_torch_device, - is_lightning_model, - is_turbo_model, -) - logger = logging.getLogger(__name__) @@ -222,7 +222,18 @@ def __call__( # Default to 8step kwargs["num_inference_steps"] = 8 - output = self.ldm(prompt, **kwargs) + # Allow users to specify multiple (negative) prompts using the '|' separator. + prompts = split_prompt(prompt, max_splits=3) + prompt = prompts.pop("prompt") + kwargs.update(prompts) + neg_prompts = split_prompt( + kwargs.pop("negative_prompt", ""), + key_prefix="negative_prompt", + max_splits=3, + ) + kwargs.update(neg_prompts) + + output = self.ldm(prompt=prompt, **kwargs) if safety_check: _, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images) diff --git a/runner/app/pipelines/utils/__init__.py b/runner/app/pipelines/utils/__init__.py index dd1b9573..844b86e9 100644 --- a/runner/app/pipelines/utils/__init__.py +++ b/runner/app/pipelines/utils/__init__.py @@ -7,5 +7,6 @@ get_torch_device, is_lightning_model, is_turbo_model, + split_prompt, validate_torch_device, ) diff --git a/runner/app/pipelines/utils/utils.py b/runner/app/pipelines/utils/utils.py index 31628357..dbc44d48 100644 --- a/runner/app/pipelines/utils/utils.py +++ b/runner/app/pipelines/utils/utils.py @@ -79,6 +79,39 @@ def is_turbo_model(model_id: str) -> bool: return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None +def split_prompt( + input_prompt: str, + separator: str = "|", + key_prefix: str = "prompt", + max_splits: int = -1, +) -> dict[str, str]: + """Splits an input prompt into prompts, including the main prompt, with customizable + key naming. + + Args: + input_prompt (str): The input prompt string to be split. + separator (str): The character used to split the input prompt. Defaults to '|'. + key_prefix (str): Prefix for keys in the returned dictionary for all prompts, + including the main prompt. Defaults to 'prompt'. + max_splits (int): Maximum number of splits to perform. Defaults to -1 (no limit). + + Returns: + Dict[str, str]: A dictionary of all prompts, including the main prompt. + """ + prompts = input_prompt.split(separator, max_splits - 1) + start_index = 1 if max_splits < 0 else max(1, len(prompts) - max_splits) + + prompt_dict = {f"{key_prefix}": prompts[0].strip()} + prompt_dict.update( + { + f"{key_prefix}_{i+1}": prompt.strip() + for i, prompt in enumerate(prompts[1:], start=start_index) + } + ) + + return prompt_dict + + class SafetyChecker: """Checks images for unsafe or inappropriate content using a pretrained model.