Skip to content

Commit

Permalink
Merge pull request #47 from philippgille/query-perf
Browse files Browse the repository at this point in the history
Improve query performance
  • Loading branch information
philippgille authored Mar 16, 2024
2 parents f59e9dc + 1410612 commit acb1e3f
Show file tree
Hide file tree
Showing 18 changed files with 227 additions and 105 deletions.
59 changes: 40 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ Because `chromem-go` is embeddable it enables you to add retrieval augmented gen

It's *not* a library to connect to Chroma and also not a reimplementation of it in Go. It's a database on its own.

The focus is not scale or number of features, but simplicity.

Performance has not been a priority yet. Without optimizations (except some parallelization with goroutines) querying 5,000 documents takes ~500ms on a mid-range laptop CPU (11th Gen Intel i5-1135G7, like in the first generation Framework Laptop 13).
The focus is not scale (millions of documents) or number of features, but simplicity and performance for the most common use cases. On a mid-range 2020 Intel laptop CPU you can query 1,000 documents in 0.5 ms and 100,000 documents in 56 ms, both with just 44 memory allocations. See [Benchmarks](#benchmarks) for details.

> ⚠️ The project is in beta, under heavy construction, and may introduce breaking changes in releases before `v1.0.0`. All changes are documented in the [`CHANGELOG`](./CHANGELOG.md).
Expand All @@ -23,8 +21,9 @@ Performance has not been a priority yet. Without optimizations (except some para
2. [Interface](#interface)
3. [Features](#features)
4. [Usage](#usage)
5. [Motivation](#motivation)
6. [Related projects](#related-projects)
5. [Benchmarks](#benchmarks)
6. [Motivation](#motivation)
7. [Related projects](#related-projects)

## Use cases

Expand Down Expand Up @@ -156,32 +155,54 @@ See the Godoc for details: <https://pkg.go.dev/github.com/philippgille/chromem-g
### Roadmap

- Performance:
- [ ] Add Go benchmark code
- [ ] Improve code based on CPU and memory profiles
- Add SIMD / Assembler to speed up dot product calculation
- Add [roaring bitmaps](https://github.com/RoaringBitmap/roaring) to speed up full text filtering
- Embedding creators:
- [ ] Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile)
- Add an `EmbeddingFunc` that downloads and shells out to [llamafile](https://github.com/Mozilla-Ocho/llamafile)
- Similarity search:
- [ ] Approximate nearest neighbor search with index (ANN)
- [ ] Hierarchical Navigable Small World (HNSW)
- [ ] Inverted file flat (IVFFlat)
- Approximate nearest neighbor search with index (ANN)
- Hierarchical Navigable Small World (HNSW)
- Inverted file flat (IVFFlat)
- Filters:
- [ ] Operators (`$and`, `$or` etc.)
- Operators (`$and`, `$or` etc.)
- Storage:
- [ ] JSON as second encoding format
- [ ] Write-ahead log (WAL) as second file format
- [ ] Compression
- [ ] Encryption (at rest)
- [ ] Optional remote storage (S3, PostgreSQL, ...)
- JSON as second encoding format
- Write-ahead log (WAL) as second file format
- Compression
- Encryption (at rest)
- Optional remote storage (S3, PostgreSQL, ...)
- Data types:
- [ ] Images
- [ ] Videos
- Images
- Videos

## Usage

See the Godoc for a reference: <https://pkg.go.dev/github.com/philippgille/chromem-go>

For full, working examples, using the vector database for retrieval augmented generation (RAG) and semantic search and using either OpenAI or locally running the embeddings model and LLM (in Ollama), see the [example code](examples).

## Benchmarks

```console
$ go test -benchmem -run=^$ -bench .
goos: linux
goarch: amd64
pkg: github.com/philippgille/chromem-go
cpu: 11th Gen Intel(R) Core(TM) i5-1135G7 @ 2.40GHz
BenchmarkCollection_Query_NoContent_100-8 10000 110126 ns/op 6492 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_1000-8 2020 537416 ns/op 35669 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_5000-8 351 4264192 ns/op 166728 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_25000-8 75 16411744 ns/op 813928 B/op 44 allocs/op
BenchmarkCollection_Query_NoContent_100000-8 18 64670962 ns/op 3205962 B/op 44 allocs/op
BenchmarkCollection_Query_100-8 10923 109936 ns/op 6480 B/op 44 allocs/op
BenchmarkCollection_Query_1000-8 2184 562778 ns/op 35667 B/op 44 allocs/op
BenchmarkCollection_Query_5000-8 400 2986732 ns/op 166750 B/op 44 allocs/op
BenchmarkCollection_Query_25000-8 88 15433911 ns/op 813896 B/op 44 allocs/op
BenchmarkCollection_Query_100000-8 19 63696478 ns/op 3205982 B/op 44 allocs/op
PASS
ok github.com/philippgille/chromem-go 31.373s
```

## Motivation

In December 2023, when I wanted to play around with retrieval augmented generation (RAG) in a Go program, I looked for a vector database that could be embedded in the Go program, just like you would embed SQLite in order to not require any separate DB setup and maintenance. I was surprised when I didn't find any, given the abundance of embedded key-value stores in the Go ecosystem.
Expand Down
42 changes: 37 additions & 5 deletions collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ func (c *Collection) AddDocument(ctx context.Context, doc Document) error {
m[k] = v
}

// Create embedding if they don't exist
// Create embedding if they don't exist, otherwise normalize if necessary
if len(doc.Embedding) == 0 {
embedding, err := c.embed(ctx, doc.Content)
if err != nil {
return fmt.Errorf("couldn't create embedding of document: %w", err)
}
doc.Embedding = embedding
} else {
if !isNormalized(doc.Embedding) {
doc.Embedding = normalizeVector(doc.Embedding)
}
}

c.documentsLock.Lock()
Expand Down Expand Up @@ -247,6 +251,19 @@ func (c *Collection) Count() int {
return len(c.documents)
}

// Result represents a single result from a query.
type Result struct {
ID string
Metadata map[string]string
Embedding []float32
Content string

// The cosine similarity between the query and the document.
// The higher the value, the more similar the document is to the query.
// The value is in the range [-1, 1].
Similarity float32
}

// Performs an exhaustive nearest neighbor search on the collection.
//
// - queryText: The text to search for.
Expand Down Expand Up @@ -288,17 +305,32 @@ func (c *Collection) Query(ctx context.Context, queryText string, nResults int,
}

// For the remaining documents, calculate cosine similarity.
res, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
docSim, err := calcDocSimilarity(ctx, queryVectors, filteredDocs)
if err != nil {
return nil, fmt.Errorf("couldn't calculate cosine similarity: %w", err)
}

// Sort by similarity
sort.Slice(res, func(i, j int) bool {
sort.Slice(docSim, func(i, j int) bool {
// The `less` function would usually use `<`, but we want to sort descending.
return res[i].Similarity > res[j].Similarity
return docSim[i].similarity > docSim[j].similarity
})

// Return the top nResults or len(docSim), whichever is smaller
if len(docSim) < nResults {
nResults = len(docSim)
}
res := make([]Result, 0, nResults)
for i := 0; i < nResults; i++ {
res = append(res, Result{
ID: docSim[i].docID,
Metadata: c.documents[docSim[i].docID].Metadata,
Embedding: c.documents[docSim[i].docID].Embedding,
Content: c.documents[docSim[i].docID].Content,
Similarity: docSim[i].similarity,
})
}

// Return the top nResults
return res[:nResults], nil
return res, nil
}
11 changes: 6 additions & 5 deletions collection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func TestCollection_Add(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestCollection_Add_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func TestCollection_AddConcurrently(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func TestCollection_AddConcurrently_Error(t *testing.T) {
ctx := context.Background()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -313,8 +313,9 @@ func TestCollection_Count(t *testing.T) {
db := NewDB()
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return []float32{-0.1, 0.1, 0.2}, nil
return vectors, nil
}
c, err := db.CreateCollection(name, metadata, embeddingFunc)
if err != nil {
Expand Down
3 changes: 3 additions & 0 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (
// EmbeddingFunc is a function that creates embeddings for a given text.
// chromem-go will use OpenAI`s "text-embedding-3-small" model by default,
// but you can provide your own function, using any model you like.
// The function must return a *normalized* vector, i.e. the length of the vector
// must be 1. OpenAI's and Mistral's embedding models do this by default. Some
// others like Nomic's "nomic-embed-text-v1.5" don't.
type EmbeddingFunc func(ctx context.Context, text string) ([]float32, error)

// DB is the chromem-go database. It holds collections, which hold documents.
Expand Down
12 changes: 6 additions & 6 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDB_CreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestDB_ListCollections(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func TestDB_GetCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -196,7 +196,7 @@ func TestDB_GetOrCreateCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -299,7 +299,7 @@ func TestDB_DeleteCollection(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down Expand Up @@ -331,7 +331,7 @@ func TestDB_Reset(t *testing.T) {
// Values in the collection
name := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
}
Expand Down
2 changes: 1 addition & 1 deletion document_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ func TestDocument_New(t *testing.T) {
ctx := context.Background()
id := "test"
metadata := map[string]string{"foo": "bar"}
vectors := []float32{-0.1, 0.1, 0.2}
vectors := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`
content := "hello world"
embeddingFunc := func(_ context.Context, _ string) ([]float32, error) {
return vectors, nil
Expand Down
12 changes: 8 additions & 4 deletions embed_compat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@ const (
// NewEmbeddingFuncMistral returns a function that creates embeddings for a text
// using the Mistral API.
func NewEmbeddingFuncMistral(apiKey string) EmbeddingFunc {
// Mistral embeddings are normalized, see section "Distance Measures" on
// https://docs.mistral.ai/guides/embeddings/.
normalized := true

// The Mistral API docs don't mention the `encoding_format` as optional,
// but it seems to be, just like OpenAI. So we reuse the OpenAI function.
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral)
return NewEmbeddingFuncOpenAICompat(baseURLMistral, apiKey, embeddingModelMistral, &normalized)
}

const baseURLJina = "https://api.jina.ai/v1"
Expand All @@ -28,7 +32,7 @@ const (
// NewEmbeddingFuncJina returns a function that creates embeddings for a text
// using the Jina API.
func NewEmbeddingFuncJina(apiKey string, model EmbeddingModelJina) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLJina, apiKey, string(model), nil)
}

const baseURLMixedbread = "https://api.mixedbread.ai"
Expand All @@ -49,7 +53,7 @@ const (
// NewEmbeddingFuncMixedbread returns a function that creates embeddings for a text
// using the mixedbread.ai API.
func NewEmbeddingFuncMixedbread(apiKey string, model EmbeddingModelMixedbread) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model))
return NewEmbeddingFuncOpenAICompat(baseURLMixedbread, apiKey, string(model), nil)
}

const baseURLLocalAI = "http://localhost:8080/v1"
Expand All @@ -64,5 +68,5 @@ const baseURLLocalAI = "http://localhost:8080/v1"
// But other embedding models are supported as well. See the LocalAI documentation
// for details.
func NewEmbeddingFuncLocalAI(model string) EmbeddingFunc {
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model)
return NewEmbeddingFuncOpenAICompat(baseURLLocalAI, "", model, nil)
}
18 changes: 17 additions & 1 deletion embed_ollama.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"io"
"net/http"
"sync"
)

// TODO: Turn into const and use as default, but allow user to pass custom URL
Expand All @@ -28,6 +29,9 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
// and it might have to be a long timeout, depending on the text length.
client := &http.Client{}

var checkedNormalized bool
checkNormalized := sync.Once{}

return func(ctx context.Context, text string) ([]float32, error) {
// Prepare the request body.
reqBody, err := json.Marshal(map[string]string{
Expand Down Expand Up @@ -74,6 +78,18 @@ func NewEmbeddingFuncOllama(model string) EmbeddingFunc {
return nil, errors.New("no embeddings found in the response")
}

return embeddingResponse.Embedding, nil
v := embeddingResponse.Embedding
checkNormalized.Do(func() {
if isNormalized(v) {
checkedNormalized = true
} else {
checkedNormalized = false
}
})
if !checkedNormalized {
v = normalizeVector(v)
}

return v, nil
}
}
2 changes: 1 addition & 1 deletion embed_ollama_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ func TestNewEmbeddingFuncOllama(t *testing.T) {
if err != nil {
t.Fatal("unexpected error:", err)
}
wantRes := []float32{-0.1, 0.1, 0.2}
wantRes := []float32{-0.40824828, 0.40824828, 0.81649655} // normalized version of `{-0.1, 0.1, 0.2}`

// Mock server
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down
Loading

0 comments on commit acb1e3f

Please sign in to comment.