Skip to content

Commit

Permalink
feat(worker): add optimization flags (#61)
Browse files Browse the repository at this point in the history
* feat(worker): add optimization flags

This commit enables users to pass optimization flags to the AI inference
pipelines. Due to startup delays we currently only pass these flags to
containers that are started warm.

* fix: decrease I2V SFAST warmup 'decode_chunk_size'

This commit decreases the `decode_chunk_size` used in the warmup passes
when SFAST is enabled. This was done to prevent `CUDA out of memory` errors
from occuring.

* fix: fix SFAST container timeout

This commit ensures that a longer timeout is used in the Runner
Container context so that pipelines with the SFAST optimization enabled
have enough time to startup. It also renames the `StringBool` type to
the more descriptive `EnvValue`.
  • Loading branch information
rickstaa authored Apr 16, 2024
1 parent 97469ea commit db751cd
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 18 deletions.
2 changes: 1 addition & 1 deletion runner/app/pipelines/image_to_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, model_id: str):
"fps": 6,
"motion_bucket_id": 127,
"noise_aug_strength": 0.02,
"decode_chunk_size": 25,
"decode_chunk_size": 4,
}

logger.info("Warming up ImageToVideoPipeline pipeline...")
Expand Down
9 changes: 5 additions & 4 deletions worker/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ type RunnerContainerConfig struct {
Endpoint RunnerEndpoint

// For managed containers only
ID string
GPU string
KeepWarm bool
ID string
GPU string
KeepWarm bool
containerTimeout time.Duration
}

func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*RunnerContainer, error) {
Expand All @@ -54,7 +55,7 @@ func NewRunnerContainer(ctx context.Context, cfg RunnerContainerConfig) (*Runner
return nil, err
}

cctx, cancel := context.WithTimeout(ctx, containerTimeout)
cctx, cancel := context.WithTimeout(ctx, cfg.containerTimeout)
if err := runnerWaitUntilReady(cctx, client, pollingInterval); err != nil {
cancel()
return nil, err
Expand Down
38 changes: 27 additions & 11 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ const containerModelDir = "/models"
const containerPort = "8000/tcp"
const pollingInterval = 500 * time.Millisecond
const containerTimeout = 2 * time.Minute
const optFlagsContainerTimeout = 5 * time.Minute
const containerRemoveTimeout = 30 * time.Second
const containerCreatorLabel = "creator"
const containerCreator = "ai-worker"
Expand Down Expand Up @@ -70,11 +71,11 @@ func NewDockerManager(containerImageID string, gpus []string, modelDir string) (
}, nil
}

func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string) error {
func (m *DockerManager) Warm(ctx context.Context, pipeline string, modelID string, optimizationFlags OptimizationFlags) error {
m.mu.Lock()
defer m.mu.Unlock()

_, err := m.createContainer(ctx, pipeline, modelID, true)
_, err := m.createContainer(ctx, pipeline, modelID, true, optimizationFlags)
return err
}

Expand Down Expand Up @@ -110,7 +111,8 @@ func (m *DockerManager) Borrow(ctx context.Context, pipeline, modelID string) (*
if !ok {
// The container does not exist so try to create it
var err error
rc, err = m.createContainer(ctx, pipeline, modelID, false)
// TODO: Optimization flags for dynamically loaded (borrowed) containers are not currently supported due to startup delays.
rc, err = m.createContainer(ctx, pipeline, modelID, false, map[string]EnvValue{})
if err != nil {
return nil, err
}
Expand All @@ -127,7 +129,7 @@ func (m *DockerManager) Return(rc *RunnerContainer) {
m.containers[dockerContainerName(rc.Pipeline, rc.ModelID)] = rc
}

func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool) (*RunnerContainer, error) {
func (m *DockerManager) createContainer(ctx context.Context, pipeline string, modelID string, keepWarm bool, optimizationFlags OptimizationFlags) (*RunnerContainer, error) {
containerName := dockerContainerName(pipeline, modelID)

gpu, err := m.allocGPU(ctx)
Expand All @@ -137,12 +139,18 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo

slog.Info("Starting managed container", slog.String("gpu", gpu), slog.String("name", containerName), slog.String("modelID", modelID))

// Add optimization flags as environment variables.
envVars := []string{
"PIPELINE=" + pipeline,
"MODEL_ID=" + modelID,
}
for key, value := range optimizationFlags {
envVars = append(envVars, key+"="+value.String())
}

containerConfig := &container.Config{
Image: m.containerImageID,
Env: []string{
"PIPELINE=" + pipeline,
"MODEL_ID=" + modelID,
},
Env: envVars,
Volumes: map[string]struct{}{
containerModelDir: {},
},
Expand Down Expand Up @@ -200,16 +208,24 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
}
cancel()

// Extend runner container timeout when optimization flags are used, as these
// pipelines may require more startup time.
runnerContainerTimeout := containerTimeout
if len(optimizationFlags) > 0 {
runnerContainerTimeout = optFlagsContainerTimeout
}

cfg := RunnerContainerConfig{
Type: Managed,
Pipeline: pipeline,
ModelID: modelID,
Endpoint: RunnerEndpoint{
URL: "http://localhost:" + containerHostPort,
},
ID: resp.ID,
GPU: gpu,
KeepWarm: keepWarm,
ID: resp.ID,
GPU: gpu,
KeepWarm: keepWarm,
containerTimeout: runnerContainerTimeout,
}

rc, err := NewRunnerContainer(ctx, cfg)
Expand Down
34 changes: 32 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,39 @@ import (
"encoding/json"
"errors"
"log/slog"
"strconv"
"sync"
)

// EnvValue unmarshals JSON booleans as strings for compatibility with env variables.
type EnvValue string

// UnmarshalJSON converts JSON booleans to strings for EnvValue.
func (sb *EnvValue) UnmarshalJSON(b []byte) error {
var boolVal bool
err := json.Unmarshal(b, &boolVal)
if err == nil {
*sb = EnvValue(strconv.FormatBool(boolVal))
return nil
}

var strVal string
err = json.Unmarshal(b, &strVal)
if err == nil {
*sb = EnvValue(strVal)
}

return err
}

// String returns the string representation of the EnvValue.
func (sb EnvValue) String() string {
return string(sb)
}

// OptimizationFlags is a map of optimization flags to be passed to the pipeline.
type OptimizationFlags map[string]EnvValue

type Worker struct {
manager *DockerManager

Expand Down Expand Up @@ -167,9 +197,9 @@ func (w *Worker) ImageToVideo(ctx context.Context, req ImageToVideoMultipartRequ
return resp.JSON200, nil
}

func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint) error {
func (w *Worker) Warm(ctx context.Context, pipeline string, modelID string, endpoint RunnerEndpoint, optimizationFlags OptimizationFlags) error {
if endpoint.URL == "" {
return w.manager.Warm(ctx, pipeline, modelID)
return w.manager.Warm(ctx, pipeline, modelID, optimizationFlags)
}

w.mu.Lock()
Expand Down

0 comments on commit db751cd

Please sign in to comment.