Skip to content

Commit

Permalink
change: unless KNOW_PREFER_NEW_EMBEDDING_MODEL is set, always use the…
Browse files Browse the repository at this point in the history
… model that's already set on the dataset (#231)
  • Loading branch information
iwilltry42 authored Nov 24, 2024
1 parent a6327b0 commit 1e8bdfc
Show file tree
Hide file tree
Showing 22 changed files with 161 additions and 49 deletions.
2 changes: 1 addition & 1 deletion knowledge/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ require (
github.com/knadh/koanf/v2 v2.1.1
github.com/ledongthuc/pdf v0.0.0-20240201131950-da5b75280b06
github.com/lu4p/cat v0.1.5
github.com/mitchellh/copystructure v1.2.0
github.com/mitchellh/mapstructure v1.5.0
github.com/ncruces/go-sqlite3 v0.19.0
github.com/pgvector/pgvector-go v0.2.2
Expand Down Expand Up @@ -125,7 +126,6 @@ require (
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-runewidth v0.0.16 // indirect
github.com/microcosm-cc/bluemonday v1.0.26 // indirect
github.com/mitchellh/copystructure v1.2.0 // indirect
github.com/mitchellh/reflectwalk v1.0.2 // indirect
github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect
github.com/ncruces/julianday v1.0.0 // indirect
Expand Down
1 change: 0 additions & 1 deletion knowledge/pkg/cmd/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ func (s *ClientLoad) run(ctx context.Context, input, output string) error {
return fmt.Errorf("failed to write output to %q: %w", output, err)
}
return nil

}

func dropCommon(target, common map[string]any) map[string]any {
Expand Down
1 change: 0 additions & 1 deletion knowledge/pkg/datastore/document.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
)

func (s *Datastore) DeleteDocument(ctx context.Context, documentID, datasetID string) error {

// Remove from Index
if err := s.Index.DeleteDocument(ctx, documentID, datasetID); err != nil {
return fmt.Errorf("failed to remove document from Index: %w", err)
Expand Down
2 changes: 0 additions & 2 deletions knowledge/pkg/datastore/documentloader/converter/soffice.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,10 @@ func (c *SofficeConverter) Name() string {
}

func NewSofficeConverter() (*SofficeConverter, error) {

return &SofficeConverter{}, nil
}

func (c *SofficeConverter) Convert(ctx context.Context, reader io.Reader, sourceExt, outputFormat string) (io.Reader, error) {

// Convert the file using soffice
outputFormat = strings.ToLower(outputFormat)
sourceExt = strings.ToLower(sourceExt)
Expand Down
18 changes: 0 additions & 18 deletions knowledge/pkg/datastore/embeddings/adapter/adapter.go

This file was deleted.

11 changes: 10 additions & 1 deletion knowledge/pkg/datastore/embeddings/cohere/cohere.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package cohere

import (
"dario.cat/mergo"
"fmt"

"dario.cat/mergo"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/load"
cg "github.com/philippgille/chromem-go"
)
Expand All @@ -14,6 +15,14 @@ type EmbeddingModelProviderCohere struct {
Model string `env:"COHERE_MODEL" koanf:"model" export:"required"`
}

func (p *EmbeddingModelProviderCohere) UseEmbeddingModel(model string) {
p.Model = model
}

func (p *EmbeddingModelProviderCohere) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingModelProviderCohere) Name() string {
return EmbeddingModelProviderCohereName
}
Expand Down
2 changes: 1 addition & 1 deletion knowledge/pkg/datastore/embeddings/embeddings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func TestLoadConfVertex(t *testing.T) {

conf := p.Config().(*vertex.EmbeddingProviderVertex)

require.Equal(t, "foo-embedding-001", conf.Model)
require.Equal(t, "foo-embedding-001", conf.EmbeddingModelName)
require.Equal(t, "foo-project", conf.Project)
}

Expand Down
13 changes: 11 additions & 2 deletions knowledge/pkg/datastore/embeddings/jina/jina.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
package jina

import (
"dario.cat/mergo"
"fmt"
"strings"

"dario.cat/mergo"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/load"
cg "github.com/philippgille/chromem-go"
"strings"
)

type EmbeddingProviderJina struct {
Expand All @@ -15,6 +16,14 @@ type EmbeddingProviderJina struct {

const EmbeddingProviderJinaName = "jina"

func (p *EmbeddingProviderJina) UseEmbeddingModel(model string) {
p.Model = model
}

func (p *EmbeddingProviderJina) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderJina) Name() string {
return EmbeddingProviderJinaName
}
Expand Down
13 changes: 11 additions & 2 deletions knowledge/pkg/datastore/embeddings/localai/localai.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
package localai

import (
"dario.cat/mergo"
"fmt"
"strings"

"dario.cat/mergo"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/load"
cg "github.com/philippgille/chromem-go"
"strings"
)

type EmbeddingProviderLocalAI struct {
Model string `koanf:"model" env:"LOCALAI_MODEL" export:"required"`
}

func (p *EmbeddingProviderLocalAI) UseEmbeddingModel(model string) {
p.Model = model
}

const EmbeddingProviderLocalAIName = "localai"

func (p *EmbeddingProviderLocalAI) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderLocalAI) Name() string {
return EmbeddingProviderLocalAIName
}
Expand Down
13 changes: 11 additions & 2 deletions knowledge/pkg/datastore/embeddings/mistral/mistral.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
package mistral

import (
"dario.cat/mergo"
"fmt"
"strings"

"dario.cat/mergo"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/load"
cg "github.com/philippgille/chromem-go"
"strings"
)

type EmbeddingProviderMistral struct {
APIKey string `koanf:"apiKey" env:"MISTRAL_API_KEY" export:"false"`
Model string `koanf:"model" env:"MISTRAL_MODEL" export:"required"`
}

func (p *EmbeddingProviderMistral) UseEmbeddingModel(model string) {
p.Model = model
}

const EmbeddingProviderMistralName = "mistral"

func (p *EmbeddingProviderMistral) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderMistral) Name() string {
return EmbeddingProviderMistralName
}
Expand Down
13 changes: 11 additions & 2 deletions knowledge/pkg/datastore/embeddings/mixedbread/mixedbread.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
package mixedbread

import (
"dario.cat/mergo"
"fmt"
"strings"

"dario.cat/mergo"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/load"
cg "github.com/philippgille/chromem-go"
"strings"
)

type EmbeddingProviderMixedbread struct {
APIKey string `koanf:"apiKey" env:"MIXEDBREAD_API_KEY" export:"false"`
Model string `koanf:"model" env:"MIXEDBREAD_MODEL" export:"required"`
}

func (p *EmbeddingProviderMixedbread) UseEmbeddingModel(model string) {
p.Model = model
}

const EmbeddingProviderMixedbreadName = "mixedbread"

func (p *EmbeddingProviderMixedbread) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderMixedbread) Name() string {
return EmbeddingProviderMixedbreadName
}
Expand Down
8 changes: 8 additions & 0 deletions knowledge/pkg/datastore/embeddings/ollama/ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ type EmbeddingProviderOllama struct {
Model string `koanf:"model" env:"OLLAMA_MODEL" export:"required"`
}

func (p *EmbeddingProviderOllama) UseEmbeddingModel(model string) {
p.Model = model
}

const EmbeddingProviderOllamaName = "ollama"

func (p *EmbeddingProviderOllama) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderOllama) Name() string {
return EmbeddingProviderOllamaName
}
Expand Down
12 changes: 12 additions & 0 deletions knowledge/pkg/datastore/embeddings/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,22 @@ func (o OpenAIConfig) Name() string {
return EmbeddingModelProviderOpenAIName
}

func (o OpenAIConfig) EmbeddingModelName() string {
return o.EmbeddingModel
}

type AzureOpenAIConfig struct {
Deployment string `usage:"Azure OpenAI deployment name (overrides openai-embedding-model, if set)" default:"" env:"OPENAI_AZURE_DEPLOYMENT" koanf:"deployment"`
}

func (p *EmbeddingModelProviderOpenAI) UseEmbeddingModel(model string) {
p.EmbeddingModel = model
}

func (p *EmbeddingModelProviderOpenAI) EmbeddingModelName() string {
return p.EmbeddingModel
}

func (p *EmbeddingModelProviderOpenAI) Name() string {
return EmbeddingModelProviderOpenAIName
}
Expand Down
2 changes: 2 additions & 0 deletions knowledge/pkg/datastore/embeddings/types/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ type EmbeddingModelProvider interface {
EmbeddingFunc() (cg.EmbeddingFunc, error)
Configure() error
Config() any
EmbeddingModelName() string
UseEmbeddingModel(model string)
}
8 changes: 8 additions & 0 deletions knowledge/pkg/datastore/embeddings/vertex/vertex.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ type EmbeddingProviderVertex struct {
Model string `koanf:"model" env:"VERTEX_MODEL" export:"required"`
}

func (p *EmbeddingProviderVertex) UseEmbeddingModel(model string) {
p.Model = model
}

const EmbeddingProviderVertexName = "vertex"

func (p *EmbeddingProviderVertex) EmbeddingModelName() string {
return p.Model
}

func (p *EmbeddingProviderVertex) Name() string {
return EmbeddingProviderVertexName
}
Expand Down
21 changes: 16 additions & 5 deletions knowledge/pkg/datastore/ingest.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import (
"context"
"fmt"
"log/slog"
"os"
"time"

"github.com/gptscript-ai/knowledge/pkg/datastore/documentloader"
"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings"
"github.com/gptscript-ai/knowledge/pkg/index/types"
"github.com/gptscript-ai/knowledge/pkg/log"
"github.com/gptscript-ai/knowledge/pkg/output"
vs "github.com/gptscript-ai/knowledge/pkg/vectorstore/types"

"github.com/google/uuid"
Expand Down Expand Up @@ -73,11 +75,20 @@ func (s *Datastore) Ingest(ctx context.Context, datasetID string, filename strin
return nil, fmt.Errorf("failed to get embeddings model provider: %w", err)
}

// TODO: Use the dataset-provided config (merge with override)
err = embeddings.CompareRequiredFields(s.EmbeddingModelProvider.Config(), dsEmbeddingProvider.Config())
if err != nil {
slog.Info("Dataset has attached embeddings provider config", "config", ds.EmbeddingsProviderConfig)
return nil, fmt.Errorf("mismatching embedding provider configs: %w", err)
if s.EmbeddingModelProvider.EmbeddingModelName() != dsEmbeddingProvider.EmbeddingModelName() {
slog.Warn("Embeddings model mismatch", "dataset", datasetID, "attached", dsEmbeddingProvider.EmbeddingModelName(), "configured", s.EmbeddingModelProvider.EmbeddingModelName())
if os.Getenv("KNOW_PREFER_NEW_EMBEDDING_MODEL") == "" {
slog.Info("Using dataset's embeddings model", "model", dsEmbeddingProvider.EmbeddingModelName())
s.EmbeddingModelProvider.UseEmbeddingModel(dsEmbeddingProvider.EmbeddingModelName())
}
}

if os.Getenv("KNOW_STRICT_EMBEDDING_CONFIG_CHECK") != "" {
err = embeddings.CompareRequiredFields(s.EmbeddingModelProvider.Config(), dsEmbeddingProvider.Config())
if err != nil {
slog.Info("Dataset has attached embeddings provider config", "config", output.RedactSensitive(ds.EmbeddingsProviderConfig))
return nil, fmt.Errorf("mismatching embedding provider configs: %w", err)
}
}
}

Expand Down
36 changes: 35 additions & 1 deletion knowledge/pkg/datastore/retrieve.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@ package datastore
import (
"context"
"log/slog"
"os"

"github.com/gptscript-ai/knowledge/pkg/datastore/embeddings"
etypes "github.com/gptscript-ai/knowledge/pkg/datastore/embeddings/types"
"github.com/gptscript-ai/knowledge/pkg/datastore/types"
"github.com/gptscript-ai/knowledge/pkg/output"
types2 "github.com/gptscript-ai/knowledge/pkg/vectorstore/types"
"github.com/mitchellh/copystructure"
"github.com/philippgille/chromem-go"

"github.com/gptscript-ai/knowledge/pkg/datastore/defaults"
"github.com/gptscript-ai/knowledge/pkg/flows"

cg "github.com/philippgille/chromem-go"
)

type RetrieveOpts struct {
Expand Down Expand Up @@ -66,5 +73,32 @@ func (s *Datastore) Retrieve(ctx context.Context, datasetIDs []string, query str
}

func (s *Datastore) SimilaritySearch(ctx context.Context, query string, numDocuments int, datasetID string, where map[string]string, whereDocument []chromem.WhereDocument) ([]types2.Document, error) {
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, where, whereDocument)
ds, err := s.GetDataset(ctx, datasetID)
if err != nil {
return nil, err
}
var ef cg.EmbeddingFunc
if ds.EmbeddingsProviderConfig != nil {
dsEmbeddingProvider, err := embeddings.ProviderFromConfig(*ds.EmbeddingsProviderConfig)
if err != nil {
return nil, err
}
if s.EmbeddingModelProvider.EmbeddingModelName() != dsEmbeddingProvider.EmbeddingModelName() {
slog.Warn("Embeddings model mismatch", "dataset", datasetID, "attached", dsEmbeddingProvider.EmbeddingModelName(), "configured", s.EmbeddingModelProvider.EmbeddingModelName())
if os.Getenv("KNOW_PREFER_NEW_EMBEDDING_MODEL") == "" {
slog.Info("Using dataset's embeddings model", "model", dsEmbeddingProvider.EmbeddingModelName())
copied, err := copystructure.Copy(s.EmbeddingModelProvider)
if err != nil {
return nil, err
}
copied.(etypes.EmbeddingModelProvider).UseEmbeddingModel(dsEmbeddingProvider.EmbeddingModelName())
ef, err = copied.(etypes.EmbeddingModelProvider).EmbeddingFunc()
if err != nil {
return nil, err
}
slog.Debug("Using dataset specific embedding function", "dataset", datasetID, "model", dsEmbeddingProvider.Name(), "newProviderConfig", output.RedactSensitive(copied.(etypes.EmbeddingModelProvider)))
}
}
}
return s.Vectorstore.SimilaritySearch(ctx, query, numDocuments, datasetID, where, whereDocument, ef)
}
2 changes: 0 additions & 2 deletions knowledge/pkg/index/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ func (db *DB) ListDatasets() ([]Dataset, error) {
}

func (db *DB) DeleteFile(ctx context.Context, datasetID, fileID string) error {

// Find file in database with associated documents
var file File
tx := db.WithContext(ctx).Preload("Documents").Where("id = ? AND dataset = ?", fileID, datasetID).First(&file)
Expand All @@ -115,7 +114,6 @@ func (db *DB) DeleteFile(ctx context.Context, datasetID, fileID string) error {

// Remove owned documents from VectorStore and Database
for _, doc := range file.Documents {

tx = db.WithContext(ctx).Delete(&doc)
if tx.Error != nil {
return fmt.Errorf("failed to delete document from DB: %w", tx.Error)
Expand Down
Loading

0 comments on commit 1e8bdfc

Please sign in to comment.