From 9ea95c88a2a3ec0e9470b2234c1b5bcf27ca12fd Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Thu, 4 Jul 2024 13:49:18 +0200 Subject: [PATCH] feat: add ability to mock pipelines This commit applies the mock.patch file to the main branch as explained in https://github.com/livepeer/ai-worker/tree/main/runner/dev#mocking-the-pipelines. It not meant to be merged in the main branch but simply for ease of use. --- runner/app/pipelines/image_to_image.py | 7 +++++++ runner/app/pipelines/image_to_video.py | 12 ++++++++++++ runner/app/pipelines/text_to_image.py | 7 +++++++ runner/app/pipelines/upscale.py | 7 +++++++ worker/docker.go | 6 ++++++ 5 files changed, 39 insertions(+) diff --git a/runner/app/pipelines/image_to_image.py b/runner/app/pipelines/image_to_image.py index fda2dfbf..8256b64b 100644 --- a/runner/app/pipelines/image_to_image.py +++ b/runner/app/pipelines/image_to_image.py @@ -48,6 +48,10 @@ def __init__(self, model_id: str): self.model_id = model_id kwargs = {"cache_dir": get_model_dir()} + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + logger.info("Mocking ImageToImagePipeline for %s", model_id) + return + torch_device = get_torch_device() folder_name = file_download.repo_folder_name( repo_id=model_id, repo_type="model" @@ -171,6 +175,9 @@ def __init__(self, model_id: str): def __call__( self, prompt: str, image: PIL.Image, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True] + seed = kwargs.pop("seed", None) safety_check = kwargs.pop("safety_check", True) diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index c3a528ee..81dab053 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -24,6 +24,10 @@ def __init__(self, model_id: str): self.model_id = model_id kwargs = {"cache_dir": get_model_dir()} + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + logger.info("Mocking ImageToVideoPipeline for %s", model_id) + return + torch_device = get_torch_device() folder_name = file_download.repo_folder_name( repo_id=model_id, repo_type="model" @@ -109,6 +113,14 @@ def __init__(self, model_id: str): def __call__( self, image: PIL.Image, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + return [ + [ + PIL.Image.new("RGB", (256, 256), (0, 0, 255)), + PIL.Image.new("RGB", (256, 256), (0, 0, 255)), + ] + ], [True] + seed = kwargs.pop("seed", None) safety_check = kwargs.pop("safety_check", True) diff --git a/runner/app/pipelines/text_to_image.py b/runner/app/pipelines/text_to_image.py index 278c04e5..8a0c4b6b 100644 --- a/runner/app/pipelines/text_to_image.py +++ b/runner/app/pipelines/text_to_image.py @@ -32,6 +32,10 @@ def __init__(self, model_id: str): self.model_id = model_id kwargs = {"cache_dir": get_model_dir()} + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + logger.info("Mocking TextToImagePipeline for %s", model_id) + return + torch_device = get_torch_device() folder_name = file_download.repo_folder_name( repo_id=model_id, repo_type="model" @@ -167,6 +171,9 @@ def __init__(self, model_id: str): def __call__( self, prompt: str, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True] + seed = kwargs.pop("seed", None) safety_check = kwargs.pop("safety_check", True) diff --git a/runner/app/pipelines/upscale.py b/runner/app/pipelines/upscale.py index bd2e4f2a..830d27c3 100644 --- a/runner/app/pipelines/upscale.py +++ b/runner/app/pipelines/upscale.py @@ -27,6 +27,10 @@ def __init__(self, model_id: str): self.model_id = model_id kwargs = {"cache_dir": get_model_dir()} + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + logger.info("Mocking UpscalePipeline for %s", model_id) + return + torch_device = get_torch_device() folder_name = file_download.repo_folder_name( repo_id=model_id, repo_type="model" @@ -97,6 +101,9 @@ def __init__(self, model_id: str): def __call__( self, prompt: str, image: PIL.Image, **kwargs ) -> Tuple[List[PIL.Image], List[Optional[bool]]]: + if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true": + return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True] + seed = kwargs.pop("seed", None) safety_check = kwargs.pop("safety_check", True) diff --git a/worker/docker.go b/worker/docker.go index 48131a71..c822efb8 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "log/slog" + "os" "strings" "sync" "time" @@ -168,6 +169,11 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo envVars = append(envVars, key+"="+value.String()) } + if value, ok := os.LookupEnv("MOCK_PIPELINE"); ok { + envVars = append(envVars, "MOCK_PIPELINE="+value) + slog.Info("MOCK_PIPELINE set to " + value + ", passing to runner container") + } + containerConfig := &container.Config{ Image: m.containerImageID, Env: envVars,