diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index a2ec720f..e2a53b0a 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -6,12 +6,15 @@ is_lightning_model, is_turbo_model, ) +from enum import Enum from diffusers import ( AutoPipelineForImage2Image, StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + StableDiffusionInstructPix2PixPipeline, ) from safetensors.torch import load_file from huggingface_hub import file_download, hf_hub_download @@ -27,7 +30,17 @@ logger = logging.getLogger(__name__) -SDXL_LIGHTNING_MODEL_ID = "ByteDance/SDXL-Lightning" + +class ModelName(Enum): + """Enumeration mapping model names to their corresponding IDs.""" + + SDXL_LIGHTNING = "ByteDance/SDXL-Lightning" + INSTRUCT_PIX2PIX = "timbrooks/instruct-pix2pix" + + @classmethod + def list(cls): + """Return a list of all model IDs.""" + return list(map(lambda c: c.value, cls)) class ImageToImagePipeline(Pipeline): @@ -45,7 +58,7 @@ def __init__(self, model_id: str): for _, _, files in os.walk(folder_path) for fname in files ) - or SDXL_LIGHTNING_MODEL_ID in model_id + or ModelName.SDXL_LIGHTNING.value in model_id ) if torch_device != "cpu" and has_fp16_variant: logger.info("ImageToImagePipeline loading fp16 variant for %s", model_id) @@ -56,7 +69,7 @@ def __init__(self, model_id: str): self.model_id = model_id # Special case SDXL-Lightning because the unet for SDXL needs to be swapped - if SDXL_LIGHTNING_MODEL_ID in model_id: + if ModelName.SDXL_LIGHTNING.value in model_id: base = "stabilityai/stable-diffusion-xl-base-1.0" # ByteDance/SDXL-Lightning-2step @@ -78,7 +91,7 @@ def __init__(self, model_id: str): unet.load_state_dict( load_file( hf_hub_download( - SDXL_LIGHTNING_MODEL_ID, + ModelName.SDXL_LIGHTNING.value, f"{unet_id}.safetensors", cache_dir=kwargs["cache_dir"], ), @@ -93,6 +106,14 @@ def __init__(self, model_id: str): self.ldm.scheduler = EulerDiscreteScheduler.from_config( self.ldm.scheduler.config, timestep_spacing="trailing" ) + elif ModelName.INSTRUCT_PIX2PIX.value in model_id: + self.ldm = StableDiffusionInstructPix2PixPipeline.from_pretrained( + model_id, **kwargs + ).to(torch_device) + + self.ldm.scheduler = EulerAncestralDiscreteScheduler.from_config( + self.ldm.scheduler.config + ) else: self.ldm = AutoPipelineForImage2Image.from_pretrained( model_id, **kwargs @@ -176,7 +197,7 @@ def __call__( if "num_inference_steps" not in kwargs: kwargs["num_inference_steps"] = 2 - elif SDXL_LIGHTNING_MODEL_ID in self.model_id: + elif ModelName.SDXL_LIGHTNING.value in self.model_id: # SDXL-Lightning models should have guidance_scale = 0 and use # the correct number of inference steps for the unet checkpoint loaded kwargs["guidance_scale"] = 0.0 @@ -190,6 +211,10 @@ def __call__( else: # Default to 2step kwargs["num_inference_steps"] = 2 + elif ModelName.INSTRUCT_PIX2PIX.value in self.model_id: + if "num_inference_steps" not in kwargs: + # TODO: Currently set to recommended value make configurable later. + kwargs["num_inference_steps"] = 10 output = self.ldm(prompt, image=image, **kwargs) diff --git a/runner/app/routes/image_to_image.py b/runner/app/routes/image_to_image.py index 6bd212f6..d210f6e5 100644 --- a/runner/app/routes/image_to_image.py +++ b/runner/app/routes/image_to_image.py @@ -36,6 +36,7 @@ async def image_to_image( model_id: Annotated[str, Form()] = "", strength: Annotated[float, Form()] = 0.8, guidance_scale: Annotated[float, Form()] = 7.5, + image_guidance_scale: Annotated[float, Form()] = 1.5, negative_prompt: Annotated[str, Form()] = "", safety_check: Annotated[bool, Form()] = True, seed: Annotated[int, Form()] = None, @@ -82,6 +83,7 @@ async def image_to_image( image=image, strength=strength, guidance_scale=guidance_scale, + image_guidance_scale=image_guidance_scale, negative_prompt=negative_prompt, safety_check=safety_check, seed=seed, diff --git a/runner/dl_checkpoints.sh b/runner/dl_checkpoints.sh index 146b4d3e..3cc76158 100755 --- a/runner/dl_checkpoints.sh +++ b/runner/dl_checkpoints.sh @@ -60,7 +60,8 @@ if [ "$MODE" = "alpha" ]; then # Download text-to-image and image-to-image models. huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models - + huggingface-cli download timbrooks/instruct-pix2pix --include "*fp16.safetensors" --exclude "*lora.safetensors*" --cache-dir models + # Download image-to-video models (token-gated). printf "\nDownloading token-gated models...\n" check_hf_auth @@ -78,6 +79,7 @@ else huggingface-cli download ByteDance/SDXL-Lightning --include "*unet.safetensors" --exclude "*lora.safetensors*" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0_Lightning --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models huggingface-cli download SG161222/RealVisXL_V4.0 --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models + huggingface-cli download timbrooks/instruct-pix2pix --include "*.fp16.safetensors" "*.json" "*.txt" --exclude ".onnx" ".onnx_data" --cache-dir models/ # Download image-to-video models. huggingface-cli download stabilityai/stable-video-diffusion-img2vid-xt --include "*.fp16.safetensors" "*.json" --cache-dir models diff --git a/runner/openapi.json b/runner/openapi.json index dbafb47f..5b8af588 100644 --- a/runner/openapi.json +++ b/runner/openapi.json @@ -261,6 +261,11 @@ "title": "Guidance Scale", "default": 7.5 }, + "image_guidance_scale": { + "type": "number", + "title": "Image Guidance Scale", + "default": 1.5 + }, "negative_prompt": { "type": "string", "title": "Negative Prompt", diff --git a/runner/openapi.yaml b/runner/openapi.yaml deleted file mode 100644 index 2f9b27ea..00000000 --- a/runner/openapi.yaml +++ /dev/null @@ -1,357 +0,0 @@ -openapi: 3.1.0 -info: - title: Livepeer AI Runner - description: An application to run AI pipelines - version: 0.1.0 -servers: -- url: https://dream-gateway.livepeer.cloud - description: Livepeer Cloud Community Gateway -paths: - /health: - get: - summary: Health - operationId: health - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: '#/components/schemas/HealthCheck' - /text-to-image: - post: - summary: Text To Image - operationId: text_to_image - requestBody: - content: - application/json: - schema: - $ref: '#/components/schemas/TextToImageParams' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: '#/components/schemas/ImageResponse' - '400': - description: Bad Request - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '500': - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - security: - - HTTPBearer: [] - /image-to-image: - post: - summary: Image To Image - operationId: image_to_image - requestBody: - content: - multipart/form-data: - schema: - $ref: '#/components/schemas/Body_image_to_image_image_to_image_post' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: '#/components/schemas/ImageResponse' - '400': - description: Bad Request - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '500': - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - security: - - HTTPBearer: [] - /image-to-video: - post: - summary: Image To Video - operationId: image_to_video - requestBody: - content: - multipart/form-data: - schema: - $ref: '#/components/schemas/Body_image_to_video_image_to_video_post' - required: true - responses: - '200': - description: Successful Response - content: - application/json: - schema: - $ref: '#/components/schemas/VideoResponse' - '400': - description: Bad Request - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '500': - description: Internal Server Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPError' - '422': - description: Validation Error - content: - application/json: - schema: - $ref: '#/components/schemas/HTTPValidationError' - security: - - HTTPBearer: [] -components: - schemas: - APIError: - properties: - msg: - type: string - title: Msg - type: object - required: - - msg - title: APIError - Body_image_to_image_image_to_image_post: - properties: - prompt: - type: string - title: Prompt - image: - type: string - format: binary - title: Image - model_id: - type: string - title: Model Id - default: '' - strength: - type: number - title: Strength - default: 0.8 - guidance_scale: - type: number - title: Guidance Scale - default: 7.5 - negative_prompt: - type: string - title: Negative Prompt - default: '' - safety_check: - type: boolean - title: Safety Check - default: true - seed: - type: integer - title: Seed - num_images_per_prompt: - type: integer - title: Num Images Per Prompt - default: 1 - type: object - required: - - prompt - - image - title: Body_image_to_image_image_to_image_post - Body_image_to_video_image_to_video_post: - properties: - image: - type: string - format: binary - title: Image - model_id: - type: string - title: Model Id - default: '' - height: - type: integer - title: Height - default: 576 - width: - type: integer - title: Width - default: 1024 - fps: - type: integer - title: Fps - default: 6 - motion_bucket_id: - type: integer - title: Motion Bucket Id - default: 127 - noise_aug_strength: - type: number - title: Noise Aug Strength - default: 0.02 - seed: - type: integer - title: Seed - safety_check: - type: boolean - title: Safety Check - default: true - type: object - required: - - image - title: Body_image_to_video_image_to_video_post - HTTPError: - properties: - detail: - $ref: '#/components/schemas/APIError' - type: object - required: - - detail - title: HTTPError - HTTPValidationError: - properties: - detail: - items: - $ref: '#/components/schemas/ValidationError' - type: array - title: Detail - type: object - title: HTTPValidationError - HealthCheck: - properties: - status: - type: string - title: Status - default: OK - type: object - title: HealthCheck - ImageResponse: - properties: - images: - items: - $ref: '#/components/schemas/Media' - type: array - title: Images - type: object - required: - - images - title: ImageResponse - Media: - properties: - url: - type: string - title: Url - seed: - type: integer - title: Seed - nsfw: - type: boolean - title: Nsfw - type: object - required: - - url - - seed - - nsfw - title: Media - TextToImageParams: - properties: - model_id: - type: string - title: Model Id - default: '' - prompt: - type: string - title: Prompt - height: - type: integer - title: Height - width: - type: integer - title: Width - guidance_scale: - type: number - title: Guidance Scale - default: 7.5 - negative_prompt: - type: string - title: Negative Prompt - default: '' - safety_check: - type: boolean - title: Safety Check - default: true - seed: - type: integer - title: Seed - num_inference_steps: - type: integer - title: Num Inference Steps - default: 50 - num_images_per_prompt: - type: integer - title: Num Images Per Prompt - default: 1 - type: object - required: - - prompt - title: TextToImageParams - ValidationError: - properties: - loc: - items: - anyOf: - - type: string - - type: integer - type: array - title: Location - msg: - type: string - title: Message - type: - type: string - title: Error Type - type: object - required: - - loc - - msg - - type - title: ValidationError - VideoResponse: - properties: - frames: - items: - items: - $ref: '#/components/schemas/Media' - type: array - type: array - title: Frames - type: object - required: - - frames - title: VideoResponse - securitySchemes: - HTTPBearer: - type: http - scheme: bearer diff --git a/runner/requirements.txt b/runner/requirements.txt index aaafda6d..ec9cbfe9 100644 --- a/runner/requirements.txt +++ b/runner/requirements.txt @@ -11,3 +11,4 @@ xformers==0.0.23 triton>=2.1.0 peft==0.11.1 deepcache==0.1.1 +safetensors==0.4.3 \ No newline at end of file diff --git a/worker/runner.gen.go b/worker/runner.gen.go index 8328447b..b1c51a6d 100644 --- a/worker/runner.gen.go +++ b/worker/runner.gen.go @@ -36,6 +36,7 @@ type APIError struct { type BodyImageToImageImageToImagePost struct { GuidanceScale *float32 `json:"guidance_scale,omitempty"` Image openapi_types.File `json:"image"` + ImageGuidanceScale *float32 `json:"image_guidance_scale,omitempty"` ModelId *string `json:"model_id,omitempty"` NegativePrompt *string `json:"negative_prompt,omitempty"` NumImagesPerPrompt *int `json:"num_images_per_prompt,omitempty"` @@ -1083,27 +1084,27 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl // Base64 encoded, gzipped, json marshaled Swagger object var swaggerSpec = []string{ - "H4sIAAAAAAAC/+xYW2/bNhT+KwS3R8d2vGYZ/JZkW2tsaYPY6x6CwGCkY5mtRHK8JDUC//eBh7ZEXTy7", - "W5piRZ5iSefyncPvXJhHmshCSQHCGjp+pCZZQsHw59nV5Betpfa/lZYKtOWAXwqT+T+W2xzomF6ajPao", - "XSn/YKzmIqPrdY9q+MtxDSkd36DKba9UKW2XevLuAySWrnv0XKarOS9YBnMrNz8aj0oa24aVOZ4ykcDc", - "JMx7eaQpLJjLLR2f9k8q5683cmSKciUE4Yo70B4CevEGFlIXzNIxveOC6RWtjExQpBV2jxYyhXzO05p/", - "GmleegEySbuUBWTM8nuYKy0LZXfaeLuRI1dBrsuUK0K2zFyB7jJ4HNlzBcGIDLkC3bLKhYUspKays9Xd", - "DcGwBdjVPFlC8rHm2WoHlfMpipELFCvN3EmZAxNoByCNPU79cxc4YzWIzC5rzob9nyJfW4nWqTcIq7ZR", - "BS5E3D2Un3upfc9TkM3HbmovlKnF9GMF51dlOnOxBJ4t6wd+chrpvQnfu1S/Gv0LabkU8zuXfATbNHI8", - "Oo2teElyjpI1a1EcQnIDc+ay+Q5iDEdRCXhhcuYyspsjX4HSDzxtwD4ejl5Vnv7E723NBp33sHg3FTtY", - "/GY2u9oxGVKwjOf+1/caFnRMvxtU82WwGS6Dsvs3UW7UI5iVrx1A3rOcp8yTYS8kbqEw+7A17a0rLD8H", - "SyUQpjVbYQwx2qaBLtzAcru82HKojtdYZl292um732jcwlCga+JWtV056PCPxXsNRklhoI0gTI2DM3YJ", - "KWdxnsIg6cpTi5EmPus6rA7cwVMLrzCLh7iW3vrn/1R0Tuex3B8637vgOJQxwSIiiiILwDsimsEnO5MY", - "+BXTLCT7S2011UQ4YAZ842sMmhUL0ICptdAYsCfDhtWtLJmi7P9uNSrnyGcOjk1QEZnbnO0g9t62nMuk", - "1mGYWL1b0PHNYytXjy2It1Gz+V0m6KbVbnqtawoYs2NpCS8qUcRMZv7tvrr3cQRXG8koUweMgvd+0u5u", - "xQvNikYr/sye3MhJuTUGw3t69MZ9HFINbysgZGTiNLerqYcSsPuxeA5Mgy6vmEjj8Ko0srRW0bW3wcVC", - "hqowieYKz3dMzwRhSuU8HDixkmgnyNmEKK4g5yLEs+UFvwcFoP33aycEOroHbYKtYf+4P/QJkQoEU5yO", - "6Q/4qkcVs0uEPVjiGMUmDFjX/mjQ+SQtpyz1KQv5QK3RcOj/JFJYEKgVgR58MN799p697xjjOY6JqSdk", - "6pIEjFm4nJRHgkfgisKv6yVE/3KAXfTIyqNyvd/eNephYWVvCpwGPoCxfl9sxFW43HLFtB34e8JRyiw7", - "PLRDb1HrOid9e1x/wYzXd5BDc96jr57y1Mudt8P/OUvJdTgS9DsaPanf1vrbRlCJkHJFPnmu8CfCghYs", - "J1PQ96BJdY/Y9h2cIXHHubld38Y1gUdMZjJsCo3awJvP3trALvhctbH7bvbMtVHv/S+18S3XRmA41oaF", - "T/aAsRGthf9YGf8++Pbi+TIcXgrgaQvAcyyeDajrjRlUrfsrd8yLXLqUXMiicILbFXnNLDywFd38JwE3", - "WzMeDFINrDjKwtd+vlHvJ17dX2v+DgAA//+bzF7djhkAAA==", + "H4sIAAAAAAAC/+xY32/bthP/Vwh+v4+O7XjNOvgtybbW2NIGsdc9BIHBSGeZrURyJJXUCPy/DzzKEvXD", + "lbukGVbkKZZ0Pz53vM/dMQ80kpmSAoQ1dPpATbSGjOHP08vZL1pL7X4rLRVoywG/ZCZxfyy3KdApvTAJ", + "HVC7Ue7BWM1FQrfbAdXwV841xHR6jSo3g1KltF3qyduPEFm6HdAzGW+WPGMJLK0sfjQelTS2DSvJecxE", + "BEsTMeflgcawYnlq6fT18KRy/qaQI3OUKyGIPLsF7SCgF2dgJXXGLJ3SWy6Y3tDKyAxFWmEXussvYDkO", + "saAZ0o8okzGkSx7XLNEAz4UTILO4C5KAhFl+B0ulZabsXhvvCjly6eW6TOWZPwOzVKC7DB4H9vKMYICG", + "XIJuWeXCQuLDq+zsdPdDMGwFdrOM1hB9qnm2OofK+RzFyDmKlWZupUyBCbQDEIce5+65C5yxGkRi1zVn", + "4+FPga+dROvkGjRQu6h8hQWMOLTqewlzx2OQzcduwqyUqcX0YwXnV2U6c7EGnqzrB37yOtB76793qT6G", + "VI8q/0xaLsXyNo8+gW0aOZ68Dq04SXKGkjVrQRxCcgNLlifLPYUxngQUcMLkNE/I/hr5F0r6nscN2Mfj", + "yavK05/4va3ZKOeeKt5fih1V/HaxuNwzb2KwjKfu1/81rOiU/m9UTa1RMbJG5UxpoizUA5iVrz1APrCU", + "x8wVQy8kbiEzfdia9rYVlp+9pRII05ptMIYQbdNAF25gqV2f72qojtdYZvM62+n732jYwlCga45X3K4c", + "dPhH8l6BUVIYaCPwU+PgjF1AzFmYJz9IuvLUqkgTnnUdVgdu76mFV5jVfcild+75UaTLdRrK/aHT3rUp", + "RxnjLSKiIDIPvCOiBXy2C4mBXzLNfLK/1a5UTYQDZsB3vsagWbECDZhaC40BezJuWN3JkjnK/udWo3KO", + "fOXgKIIKirldsx2F3duWUxnVOgwTm/crOr1+aOXqoQXxJmg2v8sI3bTazaB1+QFj9iwt/kUlipjJwr3t", + "472Lw7sqJINMHTAKPrhJu78VrzTLGq34K3tyIyfl1ugN9/Town0YUg1vKyCsyCjX3G7mDorH7sbiGTAN", + "ury4Yhn7V6WRtbWKbp0NLlbSs8JEmis83yk9FYQplXJ/4MRKonNBTmdEcQUpFz6eXV3wO1AA2n2/yoVA", + "R3egjbc1Hh4Pxy4hUoFgitMp/QFfDahido2wR2sco9iEAXntjgadz+JyylKXMp8P1JqMx+5PJIUFgVoB", + "6NFH49zvbu99xxjOcUxMPSHzPIrAmFWekvJI8AjyLHPregnRvRxhFz2y8qhc73d3jXpYyOyC4NTXAxjr", + "9sVGXFmeWq6YtiN3TziKmWWHh3boLWpbr0nXHrffMOP1HeTQnA/oq6c89XLn7fB/xmJy5Y8E/U4mT+q3", + "tf62EVQipFyRT54r/JmwoAVLyRz0HWhS3SN2fQdnSNhxrm+2NyEn/D9yFtJvCg1u4M2nlxvYBZ+LG/vv", + "Zs/MjXrvf+HG98wNX+HIDQuf7QFjI1gLv8iMfx58e/F8GQ4vBHhaArgaC2cD6jpjBlXr/sod8zyVeUzO", + "ZZblgtsNecMs3LMNLf6TgJutmY5GsQaWHSX+6zAt1IeRU3fXmr8DAAD//zD/cojkGQAA", } // GetSwagger returns the content of the embedded swagger specification file