diff --git a/api/pkg/model/models.go b/api/pkg/model/models.go index 05c85c3c7..d941a5a9f 100644 --- a/api/pkg/model/models.go +++ b/api/pkg/model/models.go @@ -122,7 +122,13 @@ func ProcessModelName( } } case types.SessionTypeImage: - return Model_Diffusers_SDTurbo, nil + if modelName == "" { + // default image model for image inference + return Model_Diffusers_SDTurbo, nil + } + // allow user-provided model name (e.g. assume API users + // know what they're doing). + return modelName, nil } // shouldn't get here @@ -158,6 +164,7 @@ const ( Model_Cog_SDXL string = "stabilityai/stable-diffusion-xl-base-1.0" Model_Diffusers_SD35 string = "stabilityai/stable-diffusion-3.5-medium" Model_Diffusers_SDTurbo string = "stabilityai/sd-turbo" + Model_Diffusers_FluxDev string = "black-forest-labs/FLUX.1-dev" // We only need constants for _some_ ollama models that are hardcoded in // various places (backward compat). Other ones can be added dynamically now. @@ -177,6 +184,20 @@ func GetDefaultDiffusersModels() ([]*DiffusersGenericImage, error) { Description: "Turbo model, from Stability AI", Hide: false, }, + { + Id: Model_Diffusers_SD35, + Name: "Stable Diffusion 3.5 Medium", + Memory: GB * 24, + Description: "Medium model, from Stability AI", + Hide: false, + }, + { + Id: Model_Diffusers_FluxDev, + Name: "Flux 1 Dev", + Memory: GB * 24, + Description: "Dev model, from Black Forest Labs", + Hide: false, + }, }, nil } diff --git a/api/pkg/model/types.go b/api/pkg/model/types.go index 17619f6b3..3516c26d1 100644 --- a/api/pkg/model/types.go +++ b/api/pkg/model/types.go @@ -44,6 +44,7 @@ type OpenAIModel struct { Name string `json:"name,omitempty"` Description string `json:"description,omitempty"` Hide bool `json:"hide,omitempty"` + Type string `json:"type,omitempty"` } // ModelsList is a list of models, including those that belong to the user or organization. diff --git a/api/pkg/openai/helix_openai_client.go b/api/pkg/openai/helix_openai_client.go index b0b9f8cea..2e3c696c9 100644 --- a/api/pkg/openai/helix_openai_client.go +++ b/api/pkg/openai/helix_openai_client.go @@ -41,6 +41,23 @@ func ListModels(ctx context.Context) ([]model.OpenAIModel, error) { Name: m.GetHumanReadableName(), Description: m.GetDescription(), Hide: m.GetHidden(), + Type: "text", + }) + } + + diffusersModels, err := model.GetDefaultDiffusersModels() + if err != nil { + return nil, fmt.Errorf("failed to get Diffusers models: %w", err) + } + for _, m := range diffusersModels { + HelixModels = append(HelixModels, model.OpenAIModel{ + ID: m.ModelName().String(), + Object: "model", + OwnedBy: "helix", + Name: m.GetHumanReadableName(), + Description: m.GetDescription(), + Hide: m.GetHidden(), + Type: "image", }) } diff --git a/api/pkg/runner/diffusers_model_instance.go b/api/pkg/runner/diffusers_model_instance.go index cb3c373ce..da729b703 100644 --- a/api/pkg/runner/diffusers_model_instance.go +++ b/api/pkg/runner/diffusers_model_instance.go @@ -322,6 +322,7 @@ func (i *DiffusersModelInstance) Start(ctx context.Context) error { cmd.Dir = "/workspace/helix/runner/helix-diffusers" cmd.Env = append(cmd.Env, + fmt.Sprintf("MODEL_ID=%s", i.initialSession.ModelName), // Add the HF_TOKEN environment variable which is required by the diffusers library fmt.Sprintf("HF_TOKEN=hf_ISxQhTIkdWkfZgUFPNUwVtHrCpMiwOYPIEKEN=%s", os.Getenv("HF_TOKEN")), // Set python to be unbuffered so we get logs in real time diff --git a/frontend/src/components/app/AppSettings.tsx b/frontend/src/components/app/AppSettings.tsx index 460e16872..000401c97 100644 --- a/frontend/src/components/app/AppSettings.tsx +++ b/frontend/src/components/app/AppSettings.tsx @@ -82,6 +82,7 @@ const AppSettings: React.FC = ({ Model diff --git a/frontend/src/components/create/ModelPicker.tsx b/frontend/src/components/create/ModelPicker.tsx index f313eb32b..36188ef76 100644 --- a/frontend/src/components/create/ModelPicker.tsx +++ b/frontend/src/components/create/ModelPicker.tsx @@ -1,16 +1,18 @@ -import React, { FC, useState, useEffect, useContext } from 'react' +import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown' import Box from '@mui/material/Box' -import Typography from '@mui/material/Typography' import Menu from '@mui/material/Menu' import MenuItem from '@mui/material/MenuItem' -import KeyboardArrowDownIcon from '@mui/icons-material/KeyboardArrowDown' -import useLightTheme from '../../hooks/useLightTheme' +import Typography from '@mui/material/Typography' +import React, { FC, useContext, useEffect, useState } from 'react' import { AccountContext } from '../../contexts/account' +import useLightTheme from '../../hooks/useLightTheme' const ModelPicker: FC<{ + type: string, model: string, onSetModel: (model: string) => void, }> = ({ + type, model, onSetModel }) => { @@ -18,13 +20,6 @@ const ModelPicker: FC<{ const [modelMenuAnchorEl, setModelMenuAnchorEl] = useState() const { models } = useContext(AccountContext) - useEffect(() => { - // Set the first model as default if current model is not set or not in the list - if (models.length > 0 && (!model || model === '' || !models.some(m => m.id === model))) { - onSetModel(models[0].id); - } - }, [models, model, onSetModel]) - const handleOpenMenu = (event: React.MouseEvent) => { setModelMenuAnchorEl(event.currentTarget) } @@ -34,6 +29,15 @@ const ModelPicker: FC<{ } const modelData = models.find(m => m.id === model) || models[0]; + + const filteredModels = models.filter(m => m.type === type) + + useEffect(() => { + // Set the first model as default if current model is not set or not in the list + if (filteredModels.length > 0 && (!model || model === '' || !filteredModels.some(m => m.id === model))) { + onSetModel(filteredModels[0].id); + } + }, [filteredModels, model, onSetModel]) return ( <> @@ -73,7 +77,7 @@ const ModelPicker: FC<{ }} > { - models.map(model => ( + filteredModels.map(model => ( { - !(app || appRequested) && mode === SESSION_MODE_INFERENCE && type === SESSION_TYPE_TEXT && ( + !(app || appRequested) && mode === SESSION_MODE_INFERENCE && ( diff --git a/frontend/src/contexts/account.tsx b/frontend/src/contexts/account.tsx index 4b0218a0e..e56b58033 100644 --- a/frontend/src/contexts/account.tsx +++ b/frontend/src/contexts/account.tsx @@ -1,19 +1,18 @@ -import React, { FC, useEffect, createContext, useMemo, useState, useCallback } from 'react' import bluebird from 'bluebird' import Keycloak from 'keycloak-js' +import { createContext, FC, useCallback, useEffect, useMemo, useState } from 'react' import useApi from '../hooks/useApi' -import useSnackbar from '../hooks/useSnackbar' +import { extractErrorMessage } from '../hooks/useErrorCallback' import useLoading from '../hooks/useLoading' import useRouter from '../hooks/useRouter' -import { extractErrorMessage } from '../hooks/useErrorCallback' +import useSnackbar from '../hooks/useSnackbar' import { - IKeycloakUser, - ISession, IApiKey, - IServerConfig, - IUserConfig, IHelixModel, + IKeycloakUser, + IServerConfig, + IUserConfig } from '../types' const REALM = 'helix' @@ -225,7 +224,8 @@ export const useAccountContext = (): IAccountContext => { id: m.id, name: m.name || m.id, description: m.description || '', - hide: m.hide || false + hide: m.hide || false, + type: m.type || 'text', })); // Filter out hidden models diff --git a/frontend/src/types.ts b/frontend/src/types.ts index ae7193175..726783436 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -89,6 +89,7 @@ export interface IUserConfig { export interface IHelixModel { id: string; + type: string; name: string; description: string; hide?: boolean;