Skip to content

Commit

Permalink
feat(runner): add basic diffusers server
Browse files Browse the repository at this point in the history
  • Loading branch information
philwinder committed Nov 23, 2024
1 parent 00996bd commit c0f6434
Show file tree
Hide file tree
Showing 14 changed files with 2,483 additions and 8 deletions.
11 changes: 11 additions & 0 deletions Dockerfile.runner
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

ARG TAG=2024-11-21a-empty

FROM ghcr.io/astral-sh/uv:0.5.4 as uv

### BUILD

FROM golang:1.22 AS go-build-env
Expand Down Expand Up @@ -45,6 +47,15 @@ WORKDIR /workspace/helix
# Copy runner directory from the repo
COPY runner ./runner

# We need to set this environment variable so that uv knows where
# the virtual environment is to install packages
ENV UV_PROJECT_ENVIRONMENT=/workspace/helix/runner/helix-diffusers/venv

# Install the packages with uv using --mount=type=cache to cache the downloaded packages
RUN --mount=type=cache,target=/root/.cache/uv \
--mount=from=uv,source=/uv,target=/usr/bin/uv \
cd /workspace/helix/runner/helix-diffusers && uv sync --no-dev

# Copy the cog wrapper, cog and cog-sdxl is installed in the base image, this is just the cog server
COPY cog/helix_cog_wrapper.py /workspace/cog-sdxl/helix_cog_wrapper.py

Expand Down
68 changes: 68 additions & 0 deletions api/pkg/model/diffusers_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package model

import (
"context"
"fmt"
"os/exec"

"github.com/helixml/helix/api/pkg/types"
)

var _ Model = &DiffusersGenericImage{}

type DiffusersGenericImage struct {
Id string // e.g. "stabilityai/stable-diffusion-3.5-medium"
Name string // e.g. "Stable Diffusion 3.5 Medium"
Memory uint64
Description string
Hide bool
}

func (i *DiffusersGenericImage) GetMemoryRequirements(mode types.SessionMode) uint64 {
return i.Memory
}

func (i *DiffusersGenericImage) GetType() types.SessionType {
return types.SessionTypeImage
}

func (i *DiffusersGenericImage) GetID() string {
return i.Id
}

func (i *DiffusersGenericImage) ModelName() ModelName {
return NewModel(i.Id)
}

func (i *DiffusersGenericImage) GetTask(session *types.Session, fileManager ModelSessionFileManager) (*types.RunnerTask, error) {
task, err := getGenericTask(session)
if err != nil {
return nil, err
}

return task, nil
}

func (i *DiffusersGenericImage) GetCommand(ctx context.Context, sessionFilter types.SessionFilter, config types.RunnerProcessConfig) (*exec.Cmd, error) {
return nil, fmt.Errorf("not implemented 1")
}

func (i *DiffusersGenericImage) GetTextStreams(mode types.SessionMode, eventHandler WorkerEventHandler) (*TextStream, *TextStream, error) {
return nil, nil, fmt.Errorf("not implemented 2")
}

func (i *DiffusersGenericImage) PrepareFiles(session *types.Session, isInitialSession bool, fileManager ModelSessionFileManager) (*types.Session, error) {
return nil, fmt.Errorf("not implemented 3")
}

func (i *DiffusersGenericImage) GetDescription() string {
return i.Description
}

func (i *DiffusersGenericImage) GetHumanReadableName() string {
return i.Name
}

func (i *DiffusersGenericImage) GetHidden() bool {
return i.Hide
}
32 changes: 31 additions & 1 deletion api/pkg/model/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ func (m ModelName) InferenceRuntime() types.InferenceRuntime {
if m.String() == Model_Cog_SDXL {
return types.InferenceRuntimeCog
}
diffusersModels, err := GetDefaultDiffusersModels()
if err != nil {
return types.InferenceRuntimeAxolotl
}
for _, model := range diffusersModels {
if m.String() == model.Id {
return types.InferenceRuntimeDiffusers
}
}

// misnamed: axolotl runtime handles axolotl and cog/sd-scripts
return types.InferenceRuntimeAxolotl
}
Expand Down Expand Up @@ -112,7 +122,7 @@ func ProcessModelName(
}
}
case types.SessionTypeImage:
return Model_Cog_SDXL, nil
return Model_Diffusers_SD35, nil
}

// shouldn't get here
Expand All @@ -133,12 +143,20 @@ func GetModels() (map[string]Model, error) {
for _, model := range ollamaModels {
models[model.Id] = model
}
diffusersModels, err := GetDefaultDiffusersModels()
if err != nil {
return nil, err
}
for _, model := range diffusersModels {
models[model.Id] = model
}
return models, nil
}

const (
Model_Axolotl_Mistral7b string = "mistralai/Mistral-7B-Instruct-v0.1"
Model_Cog_SDXL string = "stabilityai/stable-diffusion-xl-base-1.0"
Model_Diffusers_SD35 string = "stabilityai/stable-diffusion-3.5-medium"

// We only need constants for _some_ ollama models that are hardcoded in
// various places (backward compat). Other ones can be added dynamically now.
Expand All @@ -149,6 +167,18 @@ const (
Model_Ollama_Phi3 string = "phi3:instruct"
)

func GetDefaultDiffusersModels() ([]*DiffusersGenericImage, error) {
return []*DiffusersGenericImage{
{
Id: Model_Diffusers_SD35,
Name: "Stable Diffusion 3.5 Medium",
Memory: GB * 21,
Description: "Medium model, from Stability AI",
Hide: false,
},
}, nil
}

// See also types/models.go for model name constants
func GetDefaultOllamaModels() ([]*OllamaGenericText, error) {
models := []*OllamaGenericText{
Expand Down
Loading

0 comments on commit c0f6434

Please sign in to comment.