Skip to content

Commit

Permalink
Usage Features (#863)
Browse files Browse the repository at this point in the history
  • Loading branch information
dave-gray101 authored Aug 18, 2023
1 parent 2bacd01 commit 8cb1061
Show file tree
Hide file tree
Showing 40 changed files with 1,221 additions and 316 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ LocalAI
local-ai
# prevent above rules from omitting the helm chart
!charts/*
# prevent above rules from omitting the api/localai folder
!api/localai

# Ignore models
models/*
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ GOVET=$(GOCMD) vet
BINARY_NAME=local-ai

# llama.cpp versions
GOLLAMA_VERSION?=50cee7712066d9e38306eccadcfbb44ea87df4b7
GOLLAMA_VERSION?=f03869d188b72c8a617bea3a36cf8eb43f73445c

# gpt4all version
GPT4ALL_REPO?=https://github.com/nomic-ai/gpt4all
Expand Down
149 changes: 82 additions & 67 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api

import (
"errors"
"fmt"
"strings"

config "github.com/go-skynet/LocalAI/api/config"
Expand All @@ -19,14 +20,73 @@ import (
"github.com/rs/zerolog/log"
)

func App(opts ...options.AppOption) (*fiber.App, error) {
func Startup(opts ...options.AppOption) (*options.Option, *config.ConfigLoader, error) {
options := options.NewOptions(opts...)

zerolog.SetGlobalLevel(zerolog.InfoLevel)
if options.Debug {
zerolog.SetGlobalLevel(zerolog.DebugLevel)
}

log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())

cl := config.NewConfigLoader()
if err := cl.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}

if options.ConfigFile != "" {
if err := cl.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}

if options.Debug {
for _, v := range cl.ListConfigs() {
cfg, _ := cl.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}

if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}

if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cl, options.Galleries); err != nil {
return nil, nil, err
}
}

if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cl, options.Galleries); err != nil {
return nil, nil, err
}
}

// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopAllGRPC()
}()

return options, cl, nil
}

func App(opts ...options.AppOption) (*fiber.App, error) {

options, cl, err := Startup(opts...)
if err != nil {
return nil, fmt.Errorf("failed basic startup tasks with error %s", err.Error())
}

// Return errors as JSON responses
app := fiber.New(fiber.Config{
BodyLimit: options.UploadLimitMB * 1024 * 1024, // this is the default limit of 4MB
Expand Down Expand Up @@ -57,36 +117,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
}))
}

log.Info().Msgf("Starting LocalAI using %d threads, with models path: %s", options.Threads, options.Loader.ModelPath)
log.Info().Msgf("LocalAI version: %s", internal.PrintableVersion())

cm := config.NewConfigLoader()
if err := cm.LoadConfigs(options.Loader.ModelPath); err != nil {
log.Error().Msgf("error loading config files: %s", err.Error())
}

if options.ConfigFile != "" {
if err := cm.LoadConfigFile(options.ConfigFile); err != nil {
log.Error().Msgf("error loading config file: %s", err.Error())
}
}

if options.Debug {
for _, v := range cm.ListConfigs() {
cfg, _ := cm.GetConfig(v)
log.Debug().Msgf("Model: %s (config: %+v)", v, cfg)
}
}

if options.AssetsDestination != "" {
// Extract files from the embedded FS
err := assets.ExtractFiles(options.BackendAssets, options.AssetsDestination)
log.Debug().Msgf("Extracting backend assets files to %s", options.AssetsDestination)
if err != nil {
log.Warn().Msgf("Failed extracting backend assets files: %s (might be required for some backends to work properly, like gpt4all)", err)
}
}

// Default middleware config
app.Use(recover.New())

Expand Down Expand Up @@ -116,18 +146,6 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
return c.Next()
}

if options.PreloadJSONModels != "" {
if err := localai.ApplyGalleryFromString(options.Loader.ModelPath, options.PreloadJSONModels, cm, options.Galleries); err != nil {
return nil, err
}
}

if options.PreloadModelsFromPath != "" {
if err := localai.ApplyGalleryFromFile(options.Loader.ModelPath, options.PreloadModelsFromPath, cm, options.Galleries); err != nil {
return nil, err
}
}

if options.CORS {
var c func(ctx *fiber.Ctx) error
if options.CORSAllowOrigins == "" {
Expand All @@ -141,44 +159,44 @@ func App(opts ...options.AppOption) (*fiber.App, error) {

// LocalAI API endpoints
galleryService := localai.NewGalleryService(options.Loader.ModelPath)
galleryService.Start(options.Context, cm)
galleryService.Start(options.Context, cl)

app.Get("/version", auth, func(c *fiber.Ctx) error {
return c.JSON(struct {
Version string `json:"version"`
}{Version: internal.PrintableVersion()})
})

app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cm, galleryService.C, options.Galleries))
app.Post("/models/apply", auth, localai.ApplyModelGalleryEndpoint(options.Loader.ModelPath, cl, galleryService.C, options.Galleries))
app.Get("/models/available", auth, localai.ListModelFromGalleryEndpoint(options.Galleries, options.Loader.ModelPath))
app.Get("/models/jobs/:uuid", auth, localai.GetOpStatusEndpoint(galleryService))

// openAI compatible API endpoint

// chat
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cm, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cm, options))
app.Post("/v1/chat/completions", auth, openai.ChatEndpoint(cl, options))
app.Post("/chat/completions", auth, openai.ChatEndpoint(cl, options))

// edit
app.Post("/v1/edits", auth, openai.EditEndpoint(cm, options))
app.Post("/edits", auth, openai.EditEndpoint(cm, options))
app.Post("/v1/edits", auth, openai.EditEndpoint(cl, options))
app.Post("/edits", auth, openai.EditEndpoint(cl, options))

// completion
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cm, options))
app.Post("/v1/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/completions", auth, openai.CompletionEndpoint(cl, options))
app.Post("/v1/engines/:model/completions", auth, openai.CompletionEndpoint(cl, options))

// embeddings
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cm, options))
app.Post("/v1/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))
app.Post("/v1/engines/:model/embeddings", auth, openai.EmbeddingsEndpoint(cl, options))

// audio
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cm, options))
app.Post("/tts", auth, localai.TTSEndpoint(cm, options))
app.Post("/v1/audio/transcriptions", auth, openai.TranscriptEndpoint(cl, options))
app.Post("/tts", auth, localai.TTSEndpoint(cl, options))

// images
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cm, options))
app.Post("/v1/images/generations", auth, openai.ImageEndpoint(cl, options))

if options.ImageDir != "" {
app.Static("/generated-images", options.ImageDir)
Expand All @@ -196,16 +214,13 @@ func App(opts ...options.AppOption) (*fiber.App, error) {
app.Get("/healthz", ok)
app.Get("/readyz", ok)

// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cm))
// Experimental Backend Statistics Module
backendMonitor := localai.NewBackendMonitor(cl, options) // Split out for now
app.Get("/backend/monitor", localai.BackendMonitorEndpoint(backendMonitor))

// turn off any process that was started by GRPC if the context is canceled
go func() {
<-options.Context.Done()
log.Debug().Msgf("Context canceled, shutting down")
options.Loader.StopGRPC()
}()
// models
app.Get("/v1/models", auth, openai.ListModelsEndpoint(options.Loader, cl))
app.Get("/models", auth, openai.ListModelsEndpoint(options.Loader, cl))

return app, nil
}
13 changes: 0 additions & 13 deletions api/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package backend

import (
"fmt"
"sync"

config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
Expand Down Expand Up @@ -88,18 +87,6 @@ func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, c config.
}

return func() ([]float32, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()

embeds, err := fn()
if err != nil {
return embeds, err
Expand Down
18 changes: 1 addition & 17 deletions api/backend/image.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
package backend

import (
"sync"

config "github.com/go-skynet/LocalAI/api/config"
"github.com/go-skynet/LocalAI/api/options"
"github.com/go-skynet/LocalAI/pkg/grpc/proto"
Expand Down Expand Up @@ -67,19 +65,5 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
return err
}

return func() error {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[c.Backend]
if !ok {
m := &sync.Mutex{}
mutexes[c.Backend] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()

return fn()
}, nil
return fn, nil
}
68 changes: 47 additions & 21 deletions api/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,17 @@ import (
"github.com/go-skynet/LocalAI/pkg/utils"
)

func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string) bool) (func() (string, error), error) {
type LLMResponse struct {
Response string // should this be []byte?
Usage TokenUsage
}

type TokenUsage struct {
Prompt int
Completion int
}

func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c config.Config, o *options.Option, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
modelFile := c.Model

grpcOpts := gRPCModelOpts(c)
Expand Down Expand Up @@ -70,40 +80,56 @@ func ModelInference(ctx context.Context, s string, loader *model.ModelLoader, c
}

// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
fn := func() (string, error) {
fn := func() (LLMResponse, error) {
opts := gRPCPredictOpts(c, loader.ModelPath)
opts.Prompt = s

tokenUsage := TokenUsage{}

// check the per-model feature flag for usage, since tokenCallback may have a cost, but default to on.
if !c.FeatureFlag["usage"] {
userTokenCallback := tokenCallback
if userTokenCallback == nil {
userTokenCallback = func(token string, usage TokenUsage) bool {
return true
}
}

promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts)
if pErr == nil && promptInfo.Length > 0 {
tokenUsage.Prompt = int(promptInfo.Length)
}

tokenCallback = func(token string, usage TokenUsage) bool {
tokenUsage.Completion++
return userTokenCallback(token, tokenUsage)
}
}

if tokenCallback != nil {
ss := ""
err := inferenceModel.PredictStream(ctx, opts, func(s []byte) {
tokenCallback(string(s))
tokenCallback(string(s), tokenUsage)
ss += string(s)
})
return ss, err
return LLMResponse{
Response: ss,
Usage: tokenUsage,
}, err
} else {
// TODO: Is the chicken bit the only way to get here? is that acceptable?
reply, err := inferenceModel.Predict(ctx, opts)
if err != nil {
return "", err
return LLMResponse{}, err
}
return string(reply.Message), err
return LLMResponse{
Response: string(reply.Message),
Usage: tokenUsage,
}, err
}
}

return func() (string, error) {
// This is still needed, see: https://github.com/ggerganov/llama.cpp/discussions/784
mutexMap.Lock()
l, ok := mutexes[modelFile]
if !ok {
m := &sync.Mutex{}
mutexes[modelFile] = m
l = m
}
mutexMap.Unlock()
l.Lock()
defer l.Unlock()

return fn()
}, nil
return fn, nil
}

var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp)
Expand Down
Loading

0 comments on commit 8cb1061

Please sign in to comment.