Skip to content

Commit

Permalink
centralized request middleware
Browse files Browse the repository at this point in the history
Signed-off-by: Dave Lee <[email protected]>
  • Loading branch information
dave-gray101 committed Oct 24, 2024
1 parent 835932e commit 85ed9ae
Show file tree
Hide file tree
Showing 44 changed files with 834 additions and 714 deletions.
6 changes: 3 additions & 3 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ type TokenUsage struct {
Completion int
}

func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model

var inferenceModel grpc.Backend
var err error

opts := ModelOptions(c, o, []model.Option{})
opts := ModelOptions(*c, o, []model.Option{})

if c.Backend != "" {
opts = append(opts, model.WithBackendString(c.Backend))
Expand Down Expand Up @@ -96,7 +96,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im

// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts := gRPCPredictOpts(*c, loader.ModelPath)
opts.Prompt = s
opts.Messages = protoMessages
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate
Expand Down
4 changes: 2 additions & 2 deletions core/backend/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
model "github.com/mudler/LocalAI/pkg/model"
)

func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
opts := ModelOptions(backendConfig, appConfig, []model.Option{})
rerankModel, err := loader.BackendLoader(opts...)
if err != nil {
return nil, err
Expand Down
5 changes: 2 additions & 3 deletions core/backend/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
)

func SoundGeneration(
modelFile string,
text string,
duration *float32,
temperature *float32,
Expand All @@ -25,7 +24,7 @@ func SoundGeneration(
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {

opts := ModelOptions(backendConfig, appConfig, []model.Option{model.WithModel(modelFile)})
opts := ModelOptions(backendConfig, appConfig, []model.Option{})

soundGenModel, err := loader.BackendLoader(opts...)
if err != nil {
Expand All @@ -45,7 +44,7 @@ func SoundGeneration(

res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
Text: text,
Model: modelFile,
Model: backendConfig.Model,
Dst: filePath,
Sample: doSample,
Duration: duration,
Expand Down
6 changes: 1 addition & 5 deletions core/backend/tokenize.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@ import (

func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {

modelFile := backendConfig.Model

var inferenceModel grpc.Backend
var err error

opts := ModelOptions(backendConfig, appConfig, []model.Option{
model.WithModel(modelFile),
})
opts := ModelOptions(backendConfig, appConfig, []model.Option{})

if backendConfig.Backend == "" {
inferenceModel, err = loader.GreedyLoader(opts...)
Expand Down
41 changes: 16 additions & 25 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,23 @@ import (
)

func ModelTTS(
backend,
text,
modelFile,
voice,
language string,
loader *model.ModelLoader,
appConfig *config.ApplicationConfig,
backendConfig config.BackendConfig,
) (string, *proto.Result, error) {
bb := backend
if bb == "" {
bb = model.PiperBackend
}

opts := ModelOptions(config.BackendConfig{}, appConfig, []model.Option{
model.WithBackendString(bb),
model.WithModel(modelFile),
opts := ModelOptions(*&backendConfig, appConfig, []model.Option{
model.WithDefaultBackendString(model.PiperBackend),
})
ttsModel, err := loader.BackendLoader(opts...)
if err != nil {
return "", nil, err
}

if ttsModel == nil {
return "", nil, fmt.Errorf("could not load piper model")
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
}

if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
Expand All @@ -48,22 +40,21 @@ func ModelTTS(
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
filePath := filepath.Join(appConfig.AudioDir, fileName)

// If the model file is not empty, we pass it joined with the model path
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
// This should be addressed in a follow up PR soon.
// Copying it over nearly verbatim, as TTS backends are not functional without this.
modelPath := ""
if modelFile != "" {
// If the model file is not empty, we pass it joined with the model path
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, modelFile)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = modelFile
// Checking first that it exists and is not outside ModelPath
// TODO: we should actually first check if the modelFile is looking like
// a FS path
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
if _, err := os.Stat(mp); err == nil {
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
return "", nil, err
}
modelPath = mp
} else {
modelPath = backendConfig.Model // skip this step if it fails?????
}

res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{
Expand Down
3 changes: 2 additions & 1 deletion core/cli/soundgeneration.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model

var inputFile *string
if t.InputFile != "" {
inputFile = &t.InputFile
}

filePath, _, err := backend.SoundGeneration(t.Model, text,
filePath, _, err := backend.SoundGeneration(text,
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)

Expand Down
4 changes: 3 additions & 1 deletion core/cli/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {

options := config.BackendConfig{}
options.SetDefaults()
options.Backend = t.Backend
options.Model = t.Model

filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
if err != nil {
return err
}
Expand Down
31 changes: 20 additions & 11 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,19 +432,20 @@ func (c *BackendConfig) HasTemplate() bool {
type BackendConfigUsecases int

const (
FLAG_ANY BackendConfigUsecases = 0b000000000
FLAG_CHAT BackendConfigUsecases = 0b000000001
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
FLAG_EDIT BackendConfigUsecases = 0b000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
FLAG_RERANK BackendConfigUsecases = 0b000010000
FLAG_IMAGE BackendConfigUsecases = 0b000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
FLAG_TTS BackendConfigUsecases = 0b010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
FLAG_ANY BackendConfigUsecases = 0b0000000000
FLAG_CHAT BackendConfigUsecases = 0b0000000001
FLAG_COMPLETION BackendConfigUsecases = 0b0000000010
FLAG_EDIT BackendConfigUsecases = 0b0000000100
FLAG_EMBEDDINGS BackendConfigUsecases = 0b0000001000
FLAG_RERANK BackendConfigUsecases = 0b0000010000
FLAG_IMAGE BackendConfigUsecases = 0b0000100000
FLAG_TRANSCRIPT BackendConfigUsecases = 0b0001000000
FLAG_TTS BackendConfigUsecases = 0b0010000000
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b0100000000
FLAG_TOKENIZE BackendConfigUsecases = 0b1000000000

// Common Subsets
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
)

func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
Expand All @@ -459,6 +460,7 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
"FLAG_TTS": FLAG_TTS,
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
"FLAG_TOKENIZE": FLAG_TOKENIZE,
"FLAG_LLM": FLAG_LLM,
}
}
Expand Down Expand Up @@ -544,5 +546,12 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
}
}

if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
return false
}
}

return true
}
30 changes: 21 additions & 9 deletions core/config/backend_config_loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ func readMultipleBackendConfigsFromFile(file string, opts ...ConfigLoaderOption)
c := &[]*BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
return nil, fmt.Errorf("readMultipleBackendConfigsFromFile cannot unmarshal config file %q: %w", file, err)
}

for _, cc := range *c {
Expand All @@ -101,10 +101,10 @@ func readBackendConfigFromFile(file string, opts ...ConfigLoaderOption) (*Backen
c := &BackendConfig{}
f, err := os.ReadFile(file)
if err != nil {
return nil, fmt.Errorf("cannot read config file: %w", err)
return nil, fmt.Errorf("readBackendConfigFromFile cannot read config file %q: %w", file, err)
}
if err := yaml.Unmarshal(f, c); err != nil {
return nil, fmt.Errorf("cannot unmarshal config file: %w", err)
return nil, fmt.Errorf("readBackendConfigFromFile cannot unmarshal config file %q: %w", file, err)
}

c.SetDefaults(opts...)
Expand All @@ -117,7 +117,9 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
// Load a config file if present after the model name
cfg := &BackendConfig{
PredictionOptions: schema.PredictionOptions{
Model: modelName,
BasicModelRequest: schema.BasicModelRequest{
Model: modelName,
},
},
}

Expand Down Expand Up @@ -145,6 +147,15 @@ func (bcl *BackendConfigLoader) LoadBackendConfigFileByName(modelName, modelPath
return cfg, nil
}

func (bcl *BackendConfigLoader) LoadBackendConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*BackendConfig, error) {
return bcl.LoadBackendConfigFileByName(modelName, appConfig.ModelPath,
LoadOptionDebug(appConfig.Debug),
LoadOptionThreads(appConfig.Threads),
LoadOptionContextSize(appConfig.ContextSize),
LoadOptionF16(appConfig.F16),
ModelPath(appConfig.ModelPath))
}

// This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile
func (bcl *BackendConfigLoader) LoadMultipleBackendConfigsSingleFile(file string, opts ...ConfigLoaderOption) error {
bcl.Lock()
Expand All @@ -167,7 +178,7 @@ func (bcl *BackendConfigLoader) LoadBackendConfig(file string, opts ...ConfigLoa
defer bcl.Unlock()
c, err := readBackendConfigFromFile(file, opts...)
if err != nil {
return fmt.Errorf("cannot read config file: %w", err)
return fmt.Errorf("LoadBackendConfig cannot read config file %q: %w", file, err)
}

if c.Validate() {
Expand Down Expand Up @@ -324,9 +335,10 @@ func (bcl *BackendConfigLoader) Preload(modelPath string) error {
func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...ConfigLoaderOption) error {
bcl.Lock()
defer bcl.Unlock()

entries, err := os.ReadDir(path)
if err != nil {
return fmt.Errorf("cannot read directory '%s': %w", path, err)
return fmt.Errorf("LoadBackendConfigsFromPath cannot read directory '%s': %w", path, err)
}
files := make([]fs.FileInfo, 0, len(entries))
for _, entry := range entries {
Expand All @@ -344,13 +356,13 @@ func (bcl *BackendConfigLoader) LoadBackendConfigsFromPath(path string, opts ...
}
c, err := readBackendConfigFromFile(filepath.Join(path, file.Name()), opts...)
if err != nil {
log.Error().Err(err).Msgf("cannot read config file: %s", file.Name())
log.Error().Err(err).Str("File Name", file.Name()).Msgf("LoadBackendConfigsFromPath cannot read config file")
continue
}
if c.Validate() {
bcl.configs[c.Name] = *c
} else {
log.Error().Err(err).Msgf("config is not valid")
log.Error().Err(err).Str("Name", c.Name).Msgf("config is not valid")
}
}

Expand Down
9 changes: 5 additions & 4 deletions core/config/guesser.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ const (
type settingsConfig struct {
StopWords []string
TemplateConfig TemplateConfig
RepeatPenalty float64
RepeatPenalty float64
}

// default settings to adopt with a given model family
var defaultsSettings map[familyType]settingsConfig = map[familyType]settingsConfig{
Gemma: {
RepeatPenalty: 1.0,
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
StopWords: []string{"<|im_end|>", "<end_of_turn>", "<start_of_turn>"},
TemplateConfig: TemplateConfig{
Chat: "{{.Input }}\n<start_of_turn>model\n",
ChatMessage: "<start_of_turn>{{if eq .RoleName \"assistant\" }}model{{else}}{{ .RoleName }}{{end}}\n{{ if .Content -}}\n{{.Content -}}\n{{ end -}}<end_of_turn>",
Expand Down Expand Up @@ -161,10 +161,11 @@ func guessDefaultsFromFile(cfg *BackendConfig, modelPath string) {
}

// We try to guess only if we don't have a template defined already
f, err := gguf.ParseGGUFFile(filepath.Join(modelPath, cfg.ModelFileName()))
guessPath := filepath.Join(modelPath, cfg.ModelFileName())
f, err := gguf.ParseGGUFFile(guessPath)
if err != nil {
// Only valid for gguf files
log.Debug().Msgf("guessDefaultsFromFile: %s", "not a GGUF file")
log.Debug().Str("filePath", guessPath).Msg("guessDefaultsFromFile: not a GGUF file")
return
}

Expand Down
11 changes: 6 additions & 5 deletions core/http/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
return metricsService.Shutdown()
})
}

}
// Health Checks should always be exempt from auth, so register these first
routes.HealthRoutes(app)
Expand Down Expand Up @@ -158,13 +157,15 @@ func App(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *confi
galleryService := services.NewGalleryService(appConfig)
galleryService.Start(appConfig.Context, cl)

routes.RegisterElevenLabsRoutes(app, cl, ml, appConfig)
routes.RegisterLocalAIRoutes(app, cl, ml, appConfig, galleryService)
routes.RegisterOpenAIRoutes(app, cl, ml, appConfig)
requestExtractor := middleware.NewRequestExtractor(cl, ml, appConfig)

routes.RegisterElevenLabsRoutes(app, requestExtractor, cl, ml, appConfig)
routes.RegisterLocalAIRoutes(app, requestExtractor, cl, ml, appConfig, galleryService)
routes.RegisterOpenAIRoutes(app, requestExtractor, cl, ml, appConfig)
if !appConfig.DisableWebUI {
routes.RegisterUIRoutes(app, cl, ml, appConfig, galleryService)
}
routes.RegisterJINARoutes(app, cl, ml, appConfig)
routes.RegisterJINARoutes(app, requestExtractor, cl, ml, appConfig)

httpFS := http.FS(embedDirStatic)

Expand Down
Loading

0 comments on commit 85ed9ae

Please sign in to comment.