-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: add pipeline mock development guide
This commit adds all the components needed to be able to mock the pipelines.
- Loading branch information
Showing
3 changed files
with
57 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
commit bd77a8bc1ec9afbc86682beff5b71ef13703753c | ||
Author: Rick Staa <[email protected]> | ||
Date: Wed Jun 26 12:25:37 2024 +0100 | ||
|
||
refactor: improve code formating | ||
|
||
This commit improves the I2I code formatting. | ||
|
||
diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py | ||
index ac2334a..278c04e 100644 | ||
--- a/runner/app/pipelines/text_to_image.py | ||
+++ b/runner/app/pipelines/text_to_image.py | ||
@@ -14,7 +14,13 @@ from huggingface_hub import file_download, hf_hub_download | ||
from safetensors.torch import load_file | ||
|
||
from app.pipelines.base import Pipeline | ||
-from app.pipelines.util import get_model_dir, get_torch_device, SafetyChecker, is_lightning_model, is_turbo_model | ||
+from app.pipelines.util import ( | ||
+ get_model_dir, | ||
+ get_torch_device, | ||
+ SafetyChecker, | ||
+ is_lightning_model, | ||
+ is_turbo_model, | ||
+) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
@@ -52,7 +58,6 @@ class TextToImagePipeline(Pipeline): | ||
logger.info("TextToImagePipeline using bfloat16 precision for %s", model_id) | ||
kwargs["torch_dtype"] = torch.bfloat16 | ||
|
||
- | ||
# Special case SDXL-Lightning because the unet for SDXL needs to be swapped | ||
if SDXL_LIGHTNING_MODEL_ID in model_id: | ||
base = "stabilityai/stable-diffusion-xl-base-1.0" |