Skip to content

Commit

Permalink
feat(worker): add num_inference_steps to I2I, I2V and Upscale pipelines
Browse files Browse the repository at this point in the history
This commit adds the `num_inference_steps` to the I2I, I2V and Upscale
pipelines in the worker codebase.
  • Loading branch information
rickstaa committed Jul 17, 2024
1 parent 97fca3a commit 5c5b198
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 30 deletions.
12 changes: 9 additions & 3 deletions runner/gen_openapi.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import argparse
import copy
import json
import os
import copy

import yaml
from app.main import app, use_route_names_as_operation_ids
from app.routes import health, image_to_image, image_to_video, text_to_image, upscale, audio_to_text

from app.routes import (
audio_to_text,
health,
image_to_image,
image_to_video,
text_to_image,
upscale,
)
from fastapi.openapi.utils import get_openapi

# Specify Endpoints for OpenAPI schema generation.
Expand Down
4 changes: 2 additions & 2 deletions runner/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -488,7 +488,7 @@
"num_inference_steps": {
"type": "integer",
"title": "Num Inference Steps",
"default": 25
"default": 100
},
"num_images_per_prompt": {
"type": "integer",
Expand Down Expand Up @@ -589,7 +589,7 @@
"num_inference_steps": {
"type": "integer",
"title": "Num Inference Steps",
"default": 50
"default": 75
}
},
"type": "object",
Expand Down
7 changes: 7 additions & 0 deletions worker/docker.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
},
ExposedPorts: nat.PortSet{
containerPort: struct{}{},
"5678/tcp": struct{}{},
},
Labels: map[string]string{
containerCreatorLabel: containerCreator,
Expand Down Expand Up @@ -205,6 +206,12 @@ func (m *DockerManager) createContainer(ctx context.Context, pipeline string, mo
HostPort: containerHostPort,
},
},
"5678/tcp": []nat.PortBinding{
{
HostIP: "0.0.0.0",
HostPort: "5678",
},
},
},
}

Expand Down
15 changes: 15 additions & 0 deletions worker/multipart.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func NewImageToImageMultipartWriter(w io.Writer, req ImageToImageMultipartReques
return nil, err
}
}
if req.NumInferenceSteps != nil {
if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil {
return nil, err
}
}

if err := mw.Close(); err != nil {
return nil, err
Expand Down Expand Up @@ -142,6 +147,11 @@ func NewImageToVideoMultipartWriter(w io.Writer, req ImageToVideoMultipartReques
return nil, err
}
}
if req.NumInferenceSteps != nil {
if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil {
return nil, err
}
}

if err := mw.Close(); err != nil {
return nil, err
Expand Down Expand Up @@ -187,6 +197,11 @@ func NewUpscaleMultipartWriter(w io.Writer, req UpscaleMultipartRequestBody) (*m
return nil, err
}
}
if req.NumInferenceSteps != nil {
if err := mw.WriteField("num_inference_steps", strconv.Itoa(*req.NumInferenceSteps)); err != nil {
return nil, err
}
}

if err := mw.Close(); err != nil {
return nil, err
Expand Down
51 changes: 26 additions & 25 deletions worker/runner.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5c5b198

Please sign in to comment.