Skip to content

Commit

Permalink
feat(ui): add image model selection to UI
Browse files Browse the repository at this point in the history
  • Loading branch information
philwinder committed Nov 27, 2024
1 parent babf62e commit 8b2c3c4
Show file tree
Hide file tree
Showing 9 changed files with 77 additions and 31 deletions.
23 changes: 22 additions & 1 deletion api/pkg/model/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,13 @@ func ProcessModelName(
}
}
case types.SessionTypeImage:
return Model_Diffusers_SDTurbo, nil
if modelName == "" {
// default image model for image inference
return Model_Diffusers_SDTurbo, nil
}
// allow user-provided model name (e.g. assume API users
// know what they're doing).
return modelName, nil
}

// shouldn't get here
Expand Down Expand Up @@ -158,6 +164,7 @@ const (
Model_Cog_SDXL string = "stabilityai/stable-diffusion-xl-base-1.0"
Model_Diffusers_SD35 string = "stabilityai/stable-diffusion-3.5-medium"
Model_Diffusers_SDTurbo string = "stabilityai/sd-turbo"
Model_Diffusers_FluxDev string = "black-forest-labs/FLUX.1-dev"

// We only need constants for _some_ ollama models that are hardcoded in
// various places (backward compat). Other ones can be added dynamically now.
Expand All @@ -177,6 +184,20 @@ func GetDefaultDiffusersModels() ([]*DiffusersGenericImage, error) {
Description: "Turbo model, from Stability AI",
Hide: false,
},
{
Id: Model_Diffusers_SD35,
Name: "Stable Diffusion 3.5 Medium",
Memory: GB * 24,
Description: "Medium model, from Stability AI",
Hide: false,
},
{
Id: Model_Diffusers_FluxDev,
Name: "Flux 1 Dev",
Memory: GB * 24,
Description: "Dev model, from Black Forest Labs",
Hide: false,
},
}, nil
}

Expand Down
1 change: 1 addition & 0 deletions api/pkg/model/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type OpenAIModel struct {
Name string `json:"name,omitempty"`
Description string `json:"description,omitempty"`
Hide bool `json:"hide,omitempty"`
Type string `json:"type,omitempty"`
}

// ModelsList is a list of models, including those that belong to the user or organization.
Expand Down
17 changes: 17 additions & 0 deletions api/pkg/openai/helix_openai_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ func ListModels(ctx context.Context) ([]model.OpenAIModel, error) {
Name: m.GetHumanReadableName(),
Description: m.GetDescription(),
Hide: m.GetHidden(),
Type: "text",
})
}

diffusersModels, err := model.GetDefaultDiffusersModels()
if err != nil {
return nil, fmt.Errorf("failed to get Diffusers models: %w", err)
}
for _, m := range diffusersModels {
HelixModels = append(HelixModels, model.OpenAIModel{
ID: m.ModelName().String(),
Object: "model",
OwnedBy: "helix",
Name: m.GetHumanReadableName(),
Description: m.GetDescription(),
Hide: m.GetHidden(),
Type: "image",
})
}

Expand Down
1 change: 1 addition & 0 deletions api/pkg/runner/diffusers_model_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ func (i *DiffusersModelInstance) Start(ctx context.Context) error {
cmd.Dir = "/workspace/helix/runner/helix-diffusers"

cmd.Env = append(cmd.Env,
fmt.Sprintf("MODEL_ID=%s", i.initialSession.ModelName),
// Add the HF_TOKEN environment variable which is required by the diffusers library
fmt.Sprintf("HF_TOKEN=hf_ISxQhTIkdWkfZgUFPNUwVtHrCpMiwOYPIEKEN=%s", os.Getenv("HF_TOKEN")),
// Set python to be unbuffered so we get logs in real time
Expand Down
1 change: 1 addition & 0 deletions frontend/src/components/app/AppSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ const AppSettings: React.FC<AppSettingsProps> = ({
<Box sx={{ mb: 3 }}>
<Typography variant="subtitle1" sx={{ mb: 1 }}>Model</Typography>
<ModelPicker
type="text"
model={model}
onSetModel={setModel}
/>
Expand Down
28 changes: 16 additions & 12 deletions frontend/src/components/create/ModelPicker.tsx
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
import React, { FC, useState, useEffect, useContext } from 'react'
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'
import Box from '@mui/material/Box'
import Typography from '@mui/material/Typography'
import Menu from '@mui/material/Menu'
import MenuItem from '@mui/material/MenuItem'
import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown'
import useLightTheme from '../../hooks/useLightTheme'
import Typography from '@mui/material/Typography'
import React, { FC, useContext, useEffect, useState } from 'react'
import { AccountContext } from '../../contexts/account'
import useLightTheme from '../../hooks/useLightTheme'

const ModelPicker: FC<{
type: string,
model: string,
onSetModel: (model: string) => void,
}> = ({
type,
model,
onSetModel
}) => {
const lightTheme = useLightTheme()
const [modelMenuAnchorEl, setModelMenuAnchorEl] = useState<HTMLElement>()
const { models } = useContext(AccountContext)

useEffect(() => {
// Set the first model as default if current model is not set or not in the list
if (models.length > 0 && (!model || model === '' || !models.some(m => m.id === model))) {
onSetModel(models[0].id);
}
}, [models, model, onSetModel])

const handleOpenMenu = (event: React.MouseEvent<HTMLElement>) => {
setModelMenuAnchorEl(event.currentTarget)
}
Expand All @@ -34,6 +29,15 @@ const ModelPicker: FC<{
}

const modelData = models.find(m => m.id === model) || models[0];

const filteredModels = models.filter(m => m.type === type)

useEffect(() => {
// Set the first model as default if current model is not set or not in the list
if (filteredModels.length > 0 && (!model || model === '' || !filteredModels.some(m => m.id === model))) {
onSetModel(filteredModels[0].id);
}
}, [filteredModels, model, onSetModel])

return (
<>
Expand Down Expand Up @@ -73,7 +77,7 @@ const ModelPicker: FC<{
}}
>
{
models.map(model => (
filteredModels.map(model => (
<MenuItem
key={ model.id }
sx={{fontSize: "large"}}
Expand Down
20 changes: 10 additions & 10 deletions frontend/src/components/create/Toolbar.tsx
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
import React, { FC } from 'react'
import Button from '@mui/material/Button'
import IconButton from '@mui/material/IconButton'
import ConstructionIcon from '@mui/icons-material/Construction'
import LoginIcon from '@mui/icons-material/Login'
import Button from '@mui/material/Button'
import IconButton from '@mui/material/IconButton'
import { FC } from 'react'

import Cell from '../widgets/Cell'
import Row from '../widgets/Row'
import SessionModeSwitch from './SessionModeSwitch'
import SessionModeDropdown from './SessionModeDropdown'
import ModelPicker from './ModelPicker'
import SessionModeDropdown from './SessionModeDropdown'
import SessionModeSwitch from './SessionModeSwitch'

import useIsBigScreen from '../../hooks/useIsBigScreen'
import useAccount from '../../hooks/useAccount'
import useIsBigScreen from '../../hooks/useIsBigScreen'

import {
IApp,
ISessionMode,
ISessionType,
IApp,
SESSION_MODE_INFERENCE,
SESSION_TYPE_TEXT,
SESSION_MODE_INFERENCE
} from '../../types'

const CreateToolbar: FC<{
Expand All @@ -45,8 +44,9 @@ const CreateToolbar: FC<{
<Row>
<Cell>
{
!(app || appRequested) && mode === SESSION_MODE_INFERENCE && type === SESSION_TYPE_TEXT && (
!(app || appRequested) && mode === SESSION_MODE_INFERENCE && (
<ModelPicker
type={type}
model={model || ''}
onSetModel={onSetModel}
/>
Expand Down
16 changes: 8 additions & 8 deletions frontend/src/contexts/account.tsx
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import React, { FC, useEffect, createContext, useMemo, useState, useCallback } from 'react'
import bluebird from 'bluebird'
import Keycloak from 'keycloak-js'
import { createContext, FC, useCallback, useEffect, useMemo, useState } from 'react'
import useApi from '../hooks/useApi'
import useSnackbar from '../hooks/useSnackbar'
import { extractErrorMessage } from '../hooks/useErrorCallback'
import useLoading from '../hooks/useLoading'
import useRouter from '../hooks/useRouter'
import { extractErrorMessage } from '../hooks/useErrorCallback'
import useSnackbar from '../hooks/useSnackbar'

import {
IKeycloakUser,
ISession,
IApiKey,
IServerConfig,
IUserConfig,
IHelixModel,
IKeycloakUser,
IServerConfig,
IUserConfig
} from '../types'

const REALM = 'helix'
Expand Down Expand Up @@ -225,7 +224,8 @@ export const useAccountContext = (): IAccountContext => {
id: m.id,
name: m.name || m.id,
description: m.description || '',
hide: m.hide || false
hide: m.hide || false,
type: m.type || 'text',
}));

// Filter out hidden models
Expand Down
1 change: 1 addition & 0 deletions frontend/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ export interface IUserConfig {

export interface IHelixModel {
id: string;
type: string;
name: string;
description: string;
hide?: boolean;
Expand Down

0 comments on commit 8b2c3c4

Please sign in to comment.