Skip to content

Commit

Permalink
chore: extract realtime models into two categories
Browse files Browse the repository at this point in the history
One is anyToAny models that requires a VAD model, and one is
wrappedModel that requires as well VAD models along others in the
pipeline.

Signed-off-by: Ettore Di Giacinto <[email protected]>
  • Loading branch information
mudler committed Nov 13, 2024
1 parent 62cce6a commit 9234e24
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 91 deletions.
5 changes: 5 additions & 0 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ type Pipeline struct {
TTS string `yaml:"tts"`
LLM string `yaml:"llm"`
Transcription string `yaml:"transcription"`
VAD string `yaml:"vad"`
}

func (p Pipeline) IsNotConfigured() bool {
return p.LLM == "" || p.TTS == "" || p.Transcription == ""
}

type File struct {
Expand Down
95 changes: 4 additions & 91 deletions core/http/endpoints/openai/realtime.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package openai

import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"sync"

"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
grpc "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"google.golang.org/grpc"

"github.com/rs/zerolog/log"
)
Expand Down Expand Up @@ -114,95 +115,7 @@ var sessionLock sync.Mutex

// TODO: implement interface as we start to define usages
type Model interface {
}

type wrappedModel struct {
TTSConfig *config.BackendConfig
TranscriptionConfig *config.BackendConfig
LLMConfig *config.BackendConfig
TTSClient grpc.Backend
TranscriptionClient grpc.Backend
LLMClient grpc.Backend
}

// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {

cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

if cfg.Pipeline.LLM == "" || cfg.Pipeline.TTS == "" || cfg.Pipeline.Transcription == "" {
// If we don't have Wrapped model definitions, just return a standard model
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend),
model.WithModel(cfg.Model))
return ml.Load(opts...)
}

log.Debug().Msg("Loading a wrapped model")

// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

opts := backend.ModelOptions(*cfgTTS, appConfig)
ttsClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}

opts = backend.ModelOptions(*cfgSST, appConfig)
transcriptionClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load SST model: %w", err)
}

opts = backend.ModelOptions(*cfgLLM, appConfig)
llmClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}

return &wrappedModel{
TTSConfig: cfgTTS,
TranscriptionConfig: cfgSST,
LLMConfig: cfgLLM,
TTSClient: ttsClient,
TranscriptionClient: transcriptionClient,
LLMClient: llmClient,
}, nil
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
}

func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
Expand Down
169 changes: 169 additions & 0 deletions core/http/endpoints/openai/realtime_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
package openai

import (
"context"
"fmt"

"github.com/mudler/LocalAI/core/backend"
"github.com/mudler/LocalAI/core/config"
grpcClient "github.com/mudler/LocalAI/pkg/grpc"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
"github.com/rs/zerolog/log"
"google.golang.org/grpc"
)

// wrappedModel represent a model which does not support Any-to-Any operations
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
type wrappedModel struct {
TTSConfig *config.BackendConfig
TranscriptionConfig *config.BackendConfig
LLMConfig *config.BackendConfig
TTSClient grpcClient.Backend
TranscriptionClient grpcClient.Backend
LLMClient grpcClient.Backend

VADConfig *config.BackendConfig
VADClient grpcClient.Backend
}

// anyToAnyModel represent a model which supports Any-to-Any operations
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
// In the future there could be models that accept continous audio input only so this design will be useful for that
type anyToAnyModel struct {
LLMConfig *config.BackendConfig
LLMClient grpcClient.Backend

VADConfig *config.BackendConfig
VADClient grpcClient.Backend
}

func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
return m.VADClient.VAD(ctx, in)
}

func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
return m.VADClient.VAD(ctx, in)
}

// returns and loads either a wrapped model or a model that support audio-to-audio
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {

cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
if err != nil {
return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

// Prepare VAD model
cfgVAD, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.VAD, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfgVAD.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

opts := backend.ModelOptions(*cfgVAD, appConfig)
VADClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}

// If we don't have Wrapped model definitions, just return a standard model
if cfg.Pipeline.IsNotConfigured() {

// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(cfg.Model, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfgAnyToAny.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

opts := backend.ModelOptions(*cfgAnyToAny, appConfig)
anyToAnyClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}

return &anyToAnyModel{
LLMConfig: cfgAnyToAny,
LLMClient: anyToAnyClient,
VADConfig: cfgVAD,
VADClient: VADClient,
}, nil
}

log.Debug().Msg("Loading a wrapped model")

// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
if err != nil {

return nil, fmt.Errorf("failed to load backend config: %w", err)
}

if !cfg.Validate() {
return nil, fmt.Errorf("failed to validate config: %w", err)
}

opts = backend.ModelOptions(*cfgTTS, appConfig)
ttsClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load tts model: %w", err)
}

opts = backend.ModelOptions(*cfgSST, appConfig)
transcriptionClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load SST model: %w", err)
}

opts = backend.ModelOptions(*cfgLLM, appConfig)
llmClient, err := ml.Load(opts...)
if err != nil {
return nil, fmt.Errorf("failed to load LLM model: %w", err)
}

return &wrappedModel{
TTSConfig: cfgTTS,
TranscriptionConfig: cfgSST,
LLMConfig: cfgLLM,
TTSClient: ttsClient,
TranscriptionClient: transcriptionClient,
LLMClient: llmClient,

VADConfig: cfgVAD,
VADClient: VADClient,
}, nil
}

0 comments on commit 9234e24

Please sign in to comment.