diff --git a/runner/app/pipelines/image_to_video.py b/runner/app/pipelines/image_to_video.py index d2e8a907..4a936841 100644 --- a/runner/app/pipelines/image_to_video.py +++ b/runner/app/pipelines/image_to_video.py @@ -66,7 +66,7 @@ def __init__(self, model_id: str): "fps": 6, "motion_bucket_id": 127, "noise_aug_strength": 0.02, - "decode_chunk_size": 25, + "decode_chunk_size": 4, } logger.info("Warming up ImageToVideoPipeline pipeline...") diff --git a/worker/container.go b/worker/container.go index 014030f2..16f93249 100644 --- a/worker/container.go +++ b/worker/container.go @@ -33,9 +33,10 @@ type RunnerContainerConfig struct { Endpoint RunnerEndpoint // For managed containers only - ID string - GPU string - KeepWarm bool + ID string + GPU string + KeepWarm bool + containerTimeout time.Duration } func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*RunnerContainer, error) { @@ -54,7 +55,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*Runner return nil, err } - cctx, cancel := context.WithTimeout(ctx, containerTimeout) + cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout) if err := runnerWaitUntilReady(cctx, client, pollingInterval); err != nil { cancel() return nil, err diff --git a/worker/docker.go b/worker/docker.go index e7dcca14..c82cf280 100644 --- a/worker/docker.go +++ b/worker/docker.go @@ -21,6 +21,7 @@ const containerModelDir = "/models" const containerPort = "8000/tcp" const pollingInterval = 500 * time.Millisecond const containerTimeout = 2 * time.Minute +const optFlagsContainerTimeout = 5 * time.Minute const containerRemoveTimeout = 30 * time.Second const containerCreatorLabel = "creator" const containerCreator = "ai-worker" @@ -70,11 +71,11 @@ func NewDockerManager(containerImageID string, gpus []string, modelDir string) ( }, nil } -func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string) error { +func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string, optimizationFlags OptimizationFlags) error { m.mu.Lock() defer m.mu.Unlock() - _, err := m.createContainer(ctx, pipeline, modelID, true) + _, err := m.createContainer(ctx, pipeline, modelID, true, optimizationFlags) return err } @@ -110,7 +111,8 @@ func (m *DockerManager) Borrow(ctx context.Context, pipeline, modelID string) (* if !ok { // The container does not exist so try to create it var err error - rc, err = m.createContainer(ctx, pipeline, modelID, false) + // TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays. + rc, err = m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{}) if err != nil { return nil, err } @@ -127,7 +129,7 @@ func (m *DockerManager) Return(rc *RunnerContainer) { m.containers[dockerContainerName(rc.Pipeline, rc.ModelID)] = rc } -func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool) (*RunnerContainer, error) { +func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) { containerName := dockerContainerName(pipeline, modelID) gpu, err := m.allocGPU(ctx) @@ -137,12 +139,18 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID)) + // Add optimization flags as environment variables. + envVars := []string{ + "PIPELINE=" + pipeline, + "MODEL_ID=" + modelID, + } + for key, value := range optimizationFlags { + envVars = append(envVars, key+"="+value.String()) + } + containerConfig := &container.Config{ Image: m.containerImageID, - Env: []string{ - "PIPELINE=" + pipeline, - "MODEL_ID=" + modelID, - }, + Env: envVars, Volumes: map[string]struct{}{ containerModelDir: {}, }, @@ -200,6 +208,13 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo } cancel() + // Extend runner container timeout when optimization flags are used, as these + // pipelines may require more startup time. + runnerContainerTimeout := containerTimeout + if len(optimizationFlags) > 0 { + runnerContainerTimeout = optFlagsContainerTimeout + } + cfg := RunnerContainerConfig{ Type: Managed, Pipeline: pipeline, @@ -207,9 +222,10 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo Endpoint: RunnerEndpoint{ URL: "http://localhost:" + containerHostPort, }, - ID: resp.ID, - GPU: gpu, - KeepWarm: keepWarm, + ID: resp.ID, + GPU: gpu, + KeepWarm: keepWarm, + containerTimeout: runnerContainerTimeout, } rc, err := NewRunnerContainer(ctx, cfg) diff --git a/worker/worker.go b/worker/worker.go index ad15910d..5a9b9f38 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -6,9 +6,39 @@ import ( "encoding/json" "errors" "log/slog" + "strconv" "sync" ) +// EnvValue unmarshals JSON booleans as strings for compatibility with env variables. +type EnvValue string + +// UnmarshalJSON converts JSON booleans to strings for EnvValue. +func (sb *EnvValue) UnmarshalJSON(b []byte) error { + var boolVal bool + err := json.Unmarshal(b, &boolVal) + if err == nil { + *sb = EnvValue(strconv.FormatBool(boolVal)) + return nil + } + + var strVal string + err = json.Unmarshal(b, &strVal) + if err == nil { + *sb = EnvValue(strVal) + } + + return err +} + +// String returns the string representation of the EnvValue. +func (sb EnvValue) String() string { + return string(sb) +} + +// OptimizationFlags is a map of optimization flags to be passed to the pipeline. +type OptimizationFlags map[string]EnvValue + type Worker struct { manager *DockerManager @@ -167,9 +197,9 @@ func (w *Worker) ImageToVideo(ctx context.Context, req ImageToVideoMultipartRequ return resp.JSON200, nil } -func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint) error { +func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error { if endpoint.URL == "" { - return w.manager.Warm(ctx, pipeline, modelID) + return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags) } w.mu.Lock()