diff --git a/README.md b/README.md index cb6881a..3ce683f 100644 --- a/README.md +++ b/README.md @@ -11,9 +11,7 @@ Because `chromem-go` is embeddable it enables you to add retrieval augmented gen It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own. -The focus is not scale or number of features, but simplicity. - -Performance has not been a priority yet. Without optimizations (except some parallelization with goroutines) querying 5,000 documents takes ~500ms on a mid-range laptop CPU (11th Gen Intel i5-1135G7, like in the first generation Framework Laptop 13). +The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.5 ms and 100,000 documents in 56 ms, both with just 44 memory allocations. See [Benchmarks](#benchmarks) for details. > ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md). @@ -23,8 +21,9 @@ Performance has not been a priority yet. Without optimizations (except some para 2. [Interface](#interface) 3. [Features](#features) 4. [Usage](#usage) -5. [Motivation](#motivation) -6. [Related projects](#related-projects) +5. [Benchmarks](#benchmarks) +6. [Motivation](#motivation) +7. [Related projects](#related-projects) ## Use cases @@ -156,25 +155,25 @@ See the Godoc for details: res[j].Similarity + return docSim[i].similarity > docSim[j].similarity }) + // Return the top nResults or len(docSim), whichever is smaller + if len(docSim) < nResults { + nResults = len(docSim) + } + res := make([]Result, 0, nResults) + for i := 0; i < nResults; i++ { + res = append(res, Result{ + ID: docSim[i].docID, + Metadata: c.documents[docSim[i].docID].Metadata, + Embedding: c.documents[docSim[i].docID].Embedding, + Content: c.documents[docSim[i].docID].Content, + Similarity: docSim[i].similarity, + }) + } + // Return the top nResults - return res[:nResults], nil + return res, nil } diff --git a/collection_test.go b/collection_test.go index 883fdab..274b271 100644 --- a/collection_test.go +++ b/collection_test.go @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) { ctx := context.Background() name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) { db := NewDB() name := "test" metadata := map[string]string{"foo": "bar"} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { - return []float32{-0.1, 0.1, 0.2}, nil + return vectors, nil } c, err := db.CreateCollection(name, metadata, embeddingFunc) if err != nil { diff --git a/db.go b/db.go index 42ab698..7c78b39 100644 --- a/db.go +++ b/db.go @@ -13,6 +13,9 @@ import ( // EmbeddingFunc is a function that creates embeddings for a given text. // chromem-go will use OpenAI`s "text-embedding-3-small" model by default, // but you can provide your own function, using any model you like. +// The function must return a *normalized* vector, i.e. the length of the vector +// must be 1. OpenAI's and Mistral's embedding models do this by default. Some +// others like Nomic's "nomic-embed-text-v1.5" don't. type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error) // DB is the chromem-go database. It holds collections, which hold documents. diff --git a/db_test.go b/db_test.go index ad04bbe..fb25373 100644 --- a/db_test.go +++ b/db_test.go @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) { // Values in the collection name := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil } diff --git a/document_test.go b/document_test.go index 668d6b0..4290c5a 100644 --- a/document_test.go +++ b/document_test.go @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) { ctx := context.Background() id := "test" metadata := map[string]string{"foo": "bar"} - vectors := []float32{-0.1, 0.1, 0.2} + vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` content := "hello world" embeddingFunc := func(_ context.Context, _ string) ([]float32, error) { return vectors, nil diff --git a/embed_compat.go b/embed_compat.go index a3d18c9..c3cf4ed 100644 --- a/embed_compat.go +++ b/embed_compat.go @@ -9,9 +9,13 @@ const ( // NewEmbeddingFuncMistral returns a function that creates embeddings for a text // using the Mistral API. func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc { + // Mistral embeddings are normalized, see section "Distance Measures" on + // https://docs.mistral.ai/guides/embeddings/. + normalized := true + // The Mistral API docs don't mention the `encoding_format` as optional, // but it seems to be, just like OpenAI. So we reuse the OpenAI function. - return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral) + return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized) } const baseURLJina = "https://api.jina.ai/v1" @@ -28,7 +32,7 @@ const ( // NewEmbeddingFuncJina returns a function that creates embeddings for a text // using the Jina API. func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil) } const baseURLMixedbread = "https://api.mixedbread.ai" @@ -49,7 +53,7 @@ const ( // NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text // using the mixedbread.ai API. func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model)) + return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil) } const baseURLLocalAI = "http://localhost:8080/v1" @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1" // But other embedding models are supported as well. See the LocalAI documentation // for details. func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model) + return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil) } diff --git a/embed_ollama.go b/embed_ollama.go index 7ed0573..d2231f7 100644 --- a/embed_ollama.go +++ b/embed_ollama.go @@ -8,6 +8,7 @@ import ( "fmt" "io" "net/http" + "sync" ) // TODO: Turn into const and use as default, but allow user to pass custom URL @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Embedding, nil + v := embeddingResponse.Embedding + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_ollama_test.go b/embed_ollama_test.go index 5ce6121..a3af70a 100644 --- a/embed_ollama_test.go +++ b/embed_ollama_test.go @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/embed_openai.go b/embed_openai.go index 2d301c7..1d09e01 100644 --- a/embed_openai.go +++ b/embed_openai.go @@ -9,6 +9,7 @@ import ( "io" "net/http" "os" + "sync" ) const BaseURLOpenAI = "https://api.openai.com/v1" @@ -39,7 +40,9 @@ func NewEmbeddingFuncDefault() EmbeddingFunc { // NewEmbeddingFuncOpenAI returns a function that creates embeddings for a text // using the OpenAI API. func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) EmbeddingFunc { - return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model)) + // OpenAI embeddings are normalized + normalized := true + return NewEmbeddingFuncOpenAICompat(BaseURLOpenAI, apiKey, string(model), &normalized) } // NewEmbeddingFuncOpenAICompat returns a function that creates embeddings for a text @@ -48,12 +51,20 @@ func NewEmbeddingFuncOpenAI(apiKey string, model EmbeddingModelOpenAI) Embedding // - LitLLM: https://github.com/BerriAI/litellm // - Ollama: https://github.com/ollama/ollama/blob/main/docs/openai.md // - etc. -func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { +// +// The `normalized` parameter indicates whether the vectors returned by the embedding +// model are already normalized, as is the case for OpenAI's and Mistral's models. +// The flag is optional. If it's nil, it will be autodetected on the first request +// (which bears a small risk that the vector just happens to have a length of 1). +func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string, normalized *bool) EmbeddingFunc { // We don't set a default timeout here, although it's usually a good idea. // In our case though, the library user can set the timeout on the context, // and it might have to be a long timeout, depending on the text length. client := &http.Client{} + var checkedNormalized bool + checkNormalized := sync.Once{} + return func(ctx context.Context, text string) ([]float32, error) { // Prepare the request body. reqBody, err := json.Marshal(map[string]string{ @@ -101,6 +112,24 @@ func NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model string) EmbeddingFunc { return nil, errors.New("no embeddings found in the response") } - return embeddingResponse.Data[0].Embedding, nil + v := embeddingResponse.Data[0].Embedding + if normalized != nil { + if *normalized { + return v, nil + } + return normalizeVector(v), nil + } + checkNormalized.Do(func() { + if isNormalized(v) { + checkedNormalized = true + } else { + checkedNormalized = false + } + }) + if !checkedNormalized { + v = normalizeVector(v) + } + + return v, nil } } diff --git a/embed_openai_test.go b/embed_openai_test.go index 03049b0..5243b81 100644 --- a/embed_openai_test.go +++ b/embed_openai_test.go @@ -33,7 +33,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { if err != nil { t.Fatal("unexpected error:", err) } - wantRes := []float32{-0.1, 0.1, 0.2} + wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}` // Mock server ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -75,7 +75,7 @@ func TestNewEmbeddingFuncOpenAICompat(t *testing.T) { defer ts.Close() baseURL := ts.URL + baseURLSuffix - f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model) + f := chromem.NewEmbeddingFuncOpenAICompat(baseURL, apiKey, model, nil) res, err := f(context.Background(), input) if err != nil { t.Fatal("expected nil, got", err) diff --git a/examples/rag-wikipedia-ollama/README.md b/examples/rag-wikipedia-ollama/README.md index 539e90c..60612c8 100644 --- a/examples/rag-wikipedia-ollama/README.md +++ b/examples/rag-wikipedia-ollama/README.md @@ -29,7 +29,7 @@ The output can differ slightly on each run, but it's along the lines of: 2024/03/02 20:02:34 Reading JSON lines... 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 2024/03/02 20:03:11 Querying chromem-go... -2024/03/02 20:03:11 Search took 231.672667ms +2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 2024/03/02 20:03:11 Document 1 (similarity: 0.723627): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 2024/03/02 20:03:11 Document 2 (similarity: 0.550584): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 2024/03/02 20:03:11 Asking LLM with augmented question... diff --git a/examples/rag-wikipedia-ollama/main.go b/examples/rag-wikipedia-ollama/main.go index 4d451ae..ce1cc47 100644 --- a/examples/rag-wikipedia-ollama/main.go +++ b/examples/rag-wikipedia-ollama/main.go @@ -104,7 +104,7 @@ func main() { if err != nil { panic(err) } - log.Println("Search took", time.Since(start)) + log.Println("Search (incl query embedding) took", time.Since(start)) // Here you could filter out any documents whose similarity is below a certain threshold. // if docRes[...].Similarity < 0.5 { ... @@ -129,7 +129,7 @@ func main() { 2024/03/02 20:02:34 Reading JSON lines... 2024/03/02 20:02:34 Adding documents to chromem-go, including creating their embeddings via Ollama API... 2024/03/02 20:03:11 Querying chromem-go... - 2024/03/02 20:03:11 Search took 231.672667ms + 2024/03/02 20:03:11 Search (incl query embedding) took 231.672667ms 2024/03/02 20:03:11 Document 1 (similarity: 0.723627): "Malleable Iron Range Company was a company that existed from 1896 to 1985 and primarily produced kitchen ranges made of malleable iron but also produced a variety of other related products. The company's primary trademark was 'Monarch' and was colloquially often referred to as the Monarch Company or just Monarch." 2024/03/02 20:03:11 Document 2 (similarity: 0.550584): "The American Motor Car Company was a short-lived company in the automotive industry founded in 1906 lasting until 1913. It was based in Indianapolis Indiana United States. The American Motor Car Company pioneered the underslung design." 2024/03/02 20:03:11 Asking LLM with augmented question... diff --git a/examples/semantic-search-arxiv-openai/README.md b/examples/semantic-search-arxiv-openai/README.md index 1292fcf..829024f 100644 --- a/examples/semantic-search-arxiv-openai/README.md +++ b/examples/semantic-search-arxiv-openai/README.md @@ -12,8 +12,8 @@ This is not a retrieval augmented generation (RAG) app, because after *retrievin 1. Ensure you have [ripgrep](https://github.com/BurntSushi/ripgrep) installed, or adapt the following commands to use grep 2. Run `rg '"categories":"cs.CL"' ~/Downloads/arxiv-metadata-oai-snapshot.json | rg '"update_date":"2023' > /tmp/arxiv_cs-cl_2023.jsonl` (adapt input file path if necessary) 3. Check the data - 1. `wc -l arxiv_cs-cl_2023.jsonl` should show ~5,000 lines - 2. `du -h arxiv_cs-cl_2023.jsonl` should show ~8.8 MB + 1. `wc -l /tmp/arxiv_cs-cl_2023.jsonl` should show ~5,000 lines + 2. `du -h /tmp/arxiv_cs-cl_2023.jsonl` should show ~8.8 MB 2. Set the OpenAI API key in your env as `OPENAI_API_KEY` 3. Run the example: `go run .` @@ -27,7 +27,7 @@ The output can differ slightly on each run, but it's along the lines of: 2024/03/10 18:23:55 Read and parsed 5006 documents. 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 2024/03/10 18:28:12 Querying chromem-go... - 2024/03/10 18:28:12 Search took 529.451163ms + 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 2024/03/10 18:28:12 Search results: 1) Similarity 0.488895: URL: https://arxiv.org/abs/2209.15469 diff --git a/examples/semantic-search-arxiv-openai/main.go b/examples/semantic-search-arxiv-openai/main.go index e0d341b..c52c366 100644 --- a/examples/semantic-search-arxiv-openai/main.go +++ b/examples/semantic-search-arxiv-openai/main.go @@ -93,7 +93,7 @@ func main() { if err != nil { panic(err) } - log.Println("Search took", time.Since(start)) + log.Println("Search (incl query embedding) took", time.Since(start)) // Here you could filter out any documents whose similarity is below a certain threshold. // if docRes[...].Similarity < 0.5 { ... @@ -117,7 +117,7 @@ func main() { 2024/03/10 18:23:55 Read and parsed 5006 documents. 2024/03/10 18:23:55 Adding documents to chromem-go, including creating their embeddings via OpenAI API... 2024/03/10 18:28:12 Querying chromem-go... - 2024/03/10 18:28:12 Search took 529.451163ms + 2024/03/10 18:28:12 Search (incl query embedding) took 529.451163ms 2024/03/10 18:28:12 Search results: 1) Similarity 0.488895: URL: https://arxiv.org/abs/2209.15469 diff --git a/persistence_test.go b/persistence_test.go index 469126e..c5799af 100644 --- a/persistence_test.go +++ b/persistence_test.go @@ -23,7 +23,7 @@ func TestPersistence(t *testing.T) { } obj := s{ Foo: "test", - Bar: []float32{-0.1, 0.1, 0.2}, + Bar: []float32{-0.40824828, 0.40824828, 0.81649655}, // normalized version of `{-0.1, 0.1, 0.2}` } persist(tempDir, obj) diff --git a/query.go b/query.go index e7e5cd3..9da1be2 100644 --- a/query.go +++ b/query.go @@ -10,17 +10,9 @@ import ( var supportedFilters = []string{"$contains", "$not_contains"} -// Result represents a single result from a query. -type Result struct { - ID string - Metadata map[string]string - Embedding []float32 - Content string - - // The cosine similarity between the query and the document. - // The higher the value, the more similar the document is to the query. - // The value is in the range [-1, 1]. - Similarity float32 +type docSim struct { + docID string + similarity float32 } // filterDocs filters a map of documents by metadata and content. @@ -103,9 +95,9 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] return true } -func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]Result, error) { - res := make([]Result, 0, len(docs)) - resLock := sync.Mutex{} +func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { + similarities := make([]docSim, 0, len(docs)) + similaritiesLock := sync.Mutex{} // Determine concurrency. Use number of docs or CPUs, whichever is smaller. numCPUs := runtime.NumCPU() @@ -131,58 +123,48 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu } wg := sync.WaitGroup{} - docChan := make(chan *Document, concurrency*2) + // Instead of using a channel to pass documents into the goroutines, we just + // split the slice into sub-slices and pass those to the goroutines. + // This turned out to be faster in the query benchmarks. + subSliceSize := len(docs) / concurrency // Can leave remainder, e.g. 10/3 = 3; leaves 1 + rem := len(docs) % concurrency for i := 0; i < concurrency; i++ { + start := i * subSliceSize + end := start + subSliceSize + // Add remainder to last goroutine + if i == concurrency-1 { + end += rem + } + wg.Add(1) - go func() { + go func(subSlice []*Document) { defer wg.Done() - for doc := range docChan { + for _, doc := range subSlice { // Stop work if another goroutine encountered an error. if ctx.Err() != nil { return } - sim, err := cosineSimilarity(queryVectors, doc.Embedding) + // As the vectors are normalized, the dot product is the cosine similarity. + sim, err := dotProduct(queryVectors, doc.Embedding) if err != nil { setSharedErr(fmt.Errorf("couldn't calculate similarity for document '%s': %w", doc.ID, err)) return } - resLock.Lock() + similaritiesLock.Lock() // We don't defer the unlock because we want to unlock much earlier. - res = append(res, Result{ - ID: doc.ID, - Metadata: doc.Metadata, - Embedding: doc.Embedding, - Content: doc.Content, - - Similarity: sim, - }) - resLock.Unlock() + similarities = append(similarities, docSim{docID: doc.ID, similarity: sim}) + similaritiesLock.Unlock() } - }() + }(docs[start:end]) } -OuterLoop: - for _, doc := range docs { - // The doc channel has limited capacity, so writing to the channel blocks - // when a goroutine runs into an error and then all goroutines stop processing - // the channel and it gets full. - // To avoid a deadlock we check for ctx.Done() here, which is closed by - // the goroutine that encountered the error. - select { - case docChan <- doc: - case <-ctx.Done(): - break OuterLoop - } - } - close(docChan) - wg.Wait() if sharedErr != nil { return nil, sharedErr } - return res, nil + return similarities, nil } diff --git a/vector.go b/vector.go index 674b584..972b6b2 100644 --- a/vector.go +++ b/vector.go @@ -2,9 +2,12 @@ package chromem import ( "errors" + "fmt" "math" ) +const isNormalizedPrecisionTolerance = 1e-6 + // cosineSimilarity calculates the cosine similarity between two vectors. // Vectors are normalized first. // The resulting value represents the similarity, so a higher value means the @@ -15,16 +18,37 @@ func cosineSimilarity(a, b []float32) (float32, error) { return 0, errors.New("vectors must have the same length") } - x, y := normalizeVector(a), normalizeVector(b) - var dotProduct float32 - for i := range x { - dotProduct += x[i] * y[i] + if !isNormalized(a) || !isNormalized(b) { + a, b = normalizeVector(a), normalizeVector(b) + } + dotProduct, err := dotProduct(a, b) + if err != nil { + return 0, fmt.Errorf("couldn't calculate dot product: %w", err) } + // Vectors are already normalized, so no need to divide by magnitudes return dotProduct, nil } +// dotProduct calculates the dot product between two vectors. +// It's the same as cosine similarity for normalized vectors. +// The resulting value represents the similarity, so a higher value means the +// vectors are more similar. +func dotProduct(a, b []float32) (float32, error) { + // The vectors must have the same length + if len(a) != len(b) { + return 0, errors.New("vectors must have the same length") + } + + var dotProduct float32 + for i := range a { + dotProduct += a[i] * b[i] + } + + return dotProduct, nil +} + func normalizeVector(v []float32) []float32 { var norm float32 for _, val := range v { @@ -39,3 +63,13 @@ func normalizeVector(v []float32) []float32 { return res } + +// isNormalized checks if the vector is normalized. +func isNormalized(v []float32) bool { + var sqSum float64 + for _, val := range v { + sqSum += float64(val) * float64(val) + } + magnitude := math.Sqrt(sqSum) + return math.Abs(magnitude-1) < isNormalizedPrecisionTolerance +}