Skip to content

Commit

Permalink
feat: add ability to mock pipelines
Browse files Browse the repository at this point in the history
This commit applies the mock.patch file to the main branch as explained
in
https://github.com/livepeer/ai-worker/tree/main/runner/dev#mocking-the-pipelines.
It not meant to be merged in the main branch but simply for ease of use.
  • Loading branch information
rickstaa committed Jul 4, 2024
1 parent ea24a70 commit 9ea95c8
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 0 deletions.
7 changes: 7 additions & 0 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
logger.info("Mocking ImageToImagePipeline for %s", model_id)
return

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
Expand Down Expand Up @@ -171,6 +175,9 @@ def __init__(self, model_id: str):
def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True]

seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

Expand Down
12 changes: 12 additions & 0 deletions runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
logger.info("Mocking ImageToVideoPipeline for %s", model_id)
return

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
Expand Down Expand Up @@ -109,6 +113,14 @@ def __init__(self, model_id: str):
def __call__(
self, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
return [
[
PIL.Image.new("RGB", (256, 256), (0, 0, 255)),
PIL.Image.new("RGB", (256, 256), (0, 0, 255)),
]
], [True]

seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

Expand Down
7 changes: 7 additions & 0 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
logger.info("Mocking TextToImagePipeline for %s", model_id)
return

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
Expand Down Expand Up @@ -167,6 +171,9 @@ def __init__(self, model_id: str):
def __call__(
self, prompt: str, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True]

seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

Expand Down
7 changes: 7 additions & 0 deletions runner/app/pipelines/upscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {"cache_dir": get_model_dir()}

if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
logger.info("Mocking UpscalePipeline for %s", model_id)
return

torch_device = get_torch_device()
folder_name = file_download.repo_folder_name(
repo_id=model_id, repo_type="model"
Expand Down Expand Up @@ -97,6 +101,9 @@ def __init__(self, model_id: str):
def __call__(
self, prompt: str, image: PIL.Image, **kwargs
) -> Tuple[List[PIL.Image], List[Optional[bool]]]:
if os.getenv("MOCK_PIPELINE", "").strip().lower() == "true":
return [PIL.Image.new("RGB", (256, 256), (0, 0, 255))], [True]

seed = kwargs.pop("seed", None)
safety_check = kwargs.pop("safety_check", True)

Expand Down
6 changes: 6 additions & 0 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"log/slog"
"os"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -168,6 +169,11 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
envVars = append(envVars, key+"="+value.String())
}

if value, ok := os.LookupEnv("MOCK_PIPELINE"); ok {
envVars = append(envVars, "MOCK_PIPELINE="+value)
slog.Info("MOCK_PIPELINE set to " + value + ", passing to runner container")
}

containerConfig := &container.Config{
Image: m.containerImageID,
Env: envVars,
Expand Down

0 comments on commit 9ea95c8

Please sign in to comment.