Skip to content

Commit

Permalink
Fix failures to compute embeddings too much context.
Browse files Browse the repository at this point in the history
* Prior to this document we were computing the embeddings using the entire notebook.
  This would lead to context exceeded errors on longer documents.

* This had two negative impacts

  1. We stop learning from long documents because we no longer compute embeddings for the document
  1. When making suggestions we don't embed up retrieving any documents from RAG because we can't compute the embeddings
     for the current document

* This PR also refactors the code to share code for computing embeddings between the learner and the Agent
  to minimize risk of training and serving skew.

* Fix #260
  • Loading branch information
jlewi committed Oct 5, 2024
1 parent e769e0d commit e5ab380
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 49 deletions.
2 changes: 1 addition & 1 deletion app/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (a *Agent) Generate(ctx context.Context, req *v1alpha1.GenerateRequest) (*v
var examples []*v1alpha1.Example
if a.config.UseRAG() {
var err error
examples, err = a.db.GetExamples(ctx, req.Doc, a.config.RagMaxResults())
examples, err = a.db.GetExamples(ctx, req, a.config.RagMaxResults())
if err != nil {
// Fail gracefully; keep going without examples
log.Error(err, "Failed to get examples")
Expand Down
3 changes: 2 additions & 1 deletion app/pkg/eval/evaluator.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@ package eval

import (
"context"
"github.com/go-logr/logr"
"os"
"path/filepath"
"sort"
"time"

"github.com/go-logr/logr"

"connectrpc.com/connect"
"github.com/jlewi/foyle/app/pkg/agent"
"github.com/jlewi/foyle/app/pkg/oai"
Expand Down
11 changes: 5 additions & 6 deletions app/pkg/learn/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"k8s.io/client-go/util/workqueue"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -71,21 +70,21 @@ func NewInMemoryExampleDB(cfg config.Config, vectorizer llms.Vectorizer) (*InMem
return db, nil
}

func (db *InMemoryExampleDB) GetExamples(ctx context.Context, doc *v1alpha1.Doc, maxResults int) ([]*v1alpha1.Example, error) {
func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.GenerateRequest, maxResults int) ([]*v1alpha1.Example, error) {
log := logs.FromContext(ctx)
query := docs.DocToMarkdown(doc)

if len(db.examples) == 0 {
// TODO(jeremy): What should we do in this case?
return nil, errors.New("No examples available")
// Since there are no examples just return an empty list
return []*v1alpha1.Example{}, nil
}

// Compute the embedding for the query.
qVec, err := db.vectorizer.Embed(ctx, query)
qVecData, err := db.vectorizer.Embed(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "Failed to compute embedding for query")
}

qVec := llms.VectorToVecDense(qVecData)
// Acquire a lock on the data so we can safely read it.
db.lock.RLock()
defer db.lock.RUnlock()
Expand Down
6 changes: 5 additions & 1 deletion app/pkg/learn/in_memory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,11 @@ func Test_InMemoryDB(t *testing.T) {
},
},
}
examples, err := db.GetExamples(context.Background(), doc, 1)
req := &v1alpha1.GenerateRequest{
Doc: doc,
SelectedIndex: 0,
}
examples, err := db.GetExamples(context.Background(), req, 1)
if err != nil {
t.Fatalf("Error getting examples; %v", err)
}
Expand Down
51 changes: 22 additions & 29 deletions app/pkg/learn/learner.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"io"
"strings"
"sync"
"time"

"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
Expand All @@ -20,7 +19,6 @@ import (
logspb "github.com/jlewi/foyle/protos/go/foyle/logs"

"github.com/jlewi/foyle/app/pkg/config"
"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/logs"
"github.com/jlewi/foyle/app/pkg/oai"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
Expand Down Expand Up @@ -59,18 +57,22 @@ type Learner struct {
postFunc PostLearnEvent
eventLoopIsDone sync.WaitGroup
factory *files.Factory
vectorizer *oai.Vectorizer
}

func NewLearner(cfg config.Config, client *openai.Client, blocksDB *dbutil.LockingDB[*logspb.BlockLog]) (*Learner, error) {
if client == nil {
return nil, errors.New("OpenAI client is required")
}

vectorizer := oai.NewVectorizer(client)
return &Learner{
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
Config: cfg,
client: client,
blocksDB: blocksDB,
queue: workqueue.NewDelayingQueue(),
factory: &files.Factory{},
vectorizer: vectorizer,
}, nil
}

Expand Down Expand Up @@ -113,10 +115,10 @@ func (l *Learner) eventLoop(ctx context.Context) {
}

if err := l.Reconcile(ctx, exampleId); err != nil {
// N.B. Right now we treat learning errors as permanent and don't retry.
// The most likely source of retryable errors the vectorizer endpoint should already be handled
// by using a retryable HTTP client.
log.Error(err, "Error learning from example", "example", exampleId)
// Requeue the item so we will try again.
// TODO(jeremy): should we use a rate limiting queue so we eventually give up?
l.queue.AddAfter(exampleId, 30*time.Second)
return
}
}()
Expand Down Expand Up @@ -275,30 +277,21 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp
return nil
}

query := docs.DocToMarkdown(example.Query)
selectedIndex := len(example.Query.GetBlocks()) - 1
qVec, err := l.vectorizer.Embed(ctx, &v1alpha1.GenerateRequest{
Doc: example.Query,
SelectedIndex: int32(selectedIndex),
})

request := openai.EmbeddingRequestStrings{
Input: []string{query},
Model: openai.SmallEmbedding3,
User: "",
EncodingFormat: "float",
}
resp, err := l.client.CreateEmbeddings(ctx, request)
if err != nil {
log.Error(err, "Failed to create embeddings", "id", example.Id, "query", query)
return errors.Wrapf(err, "Failed to create embeddings")
}

if len(resp.Data) != 1 {
log.Error(err, "Expected exactly 1 embedding", "id", example.Id, "query", query, "got", len(resp.Data))
return errors.Errorf("Expected exactly 1 embedding but got %d", len(resp.Data))
return err
}

if len(resp.Data[0].Embedding) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", query, "got", len(resp.Data[0].Embedding), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), oai.SmallEmbeddingsDims)
if len(qVec) != oai.SmallEmbeddingsDims {
log.Error(err, "Embeddings have wrong dimension", "id", example.Id, "query", example.Query, "got", len(qVec), "want", oai.SmallEmbeddingsDims)
return errors.Wrapf(err, "Embeddings have wrong dimension; got %v, want %v", len(qVec), oai.SmallEmbeddingsDims)
}

example.Embedding = resp.Data[0].Embedding
example.Embedding = qVec
return nil
}
2 changes: 1 addition & 1 deletion app/pkg/learn/learner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func Test_Learner(t *testing.T) {

blocksDB, err := pebble.Open(cfg.GetBlocksDBDir(), &pebble.Options{})
if err != nil {
t.Fatalf("could not open blocks database %s", cfg.GetBlocksDBDir())
t.Fatalf("could not open blocks database %+v", cfg.GetBlocksDBDir())
}
defer blocksDB.Close()

Expand Down
16 changes: 15 additions & 1 deletion app/pkg/llms/vectorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,27 @@ package llms
import (
"context"

"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"

"gonum.org/v1/gonum/mat"
)

type Vector []float32

// Vectorizer computes embedding representations of text.
type Vectorizer interface {
// Embed computes the embedding of the text
Embed(ctx context.Context, text string) (*mat.VecDense, error)
Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (Vector, error)
// Length returns the length of the embeddings
Length() int
}

// VectorToVecDense converts a Vector to a *mat.VecDense
func VectorToVecDense(v Vector) *mat.VecDense {
// We need to cast from float32 to float64
qVec := mat.NewVecDense(len(v), nil)
for i := 0; i < len(v); i++ {
qVec.SetVec(i, float64(v[i]))
}
return qVec
}
20 changes: 12 additions & 8 deletions app/pkg/oai/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package oai
import (
"context"

"github.com/jlewi/foyle/app/pkg/docs"
"github.com/jlewi/foyle/app/pkg/llms"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"

"github.com/jlewi/foyle/app/pkg/logs"
"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
"gonum.org/v1/gonum/mat"
)

func NewVectorizer(client *openai.Client) *Vectorizer {
Expand All @@ -19,7 +22,12 @@ type Vectorizer struct {
client *openai.Client
}

func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, error) {
func (v *Vectorizer) Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (llms.Vector, error) {
// Should we use more than one block? Do we have examples where using moe than 1 one block would help
qBlock := req.GetDoc().GetBlocks()[req.GetSelectedIndex()]
text := docs.BlocksToMarkdown([]*v1alpha1.Block{qBlock})

// Compute the embedding for the query.
log := logs.FromContext(ctx)
log.Info("RAG Query", "query", text)
request := openai.EmbeddingRequestStrings{
Expand All @@ -29,6 +37,7 @@ func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, err
EncodingFormat: "float",
}

// N.B. regarding retries. We should already be doing retries in the HTTP client.
resp, err := v.client.CreateEmbeddings(ctx, request)
if err != nil {
return nil, errors.Errorf("Failed to create embeddings")
Expand All @@ -42,12 +51,7 @@ func (v *Vectorizer) Embed(ctx context.Context, text string) (*mat.VecDense, err
return nil, errors.Errorf("Embeddings have wrong dimension; got %v, want %v", len(resp.Data[0].Embedding), SmallEmbeddingsDims)
}

// Compute the cosine similarity between the query and each example.
qVec := mat.NewVecDense(SmallEmbeddingsDims, nil)
for i := 0; i < SmallEmbeddingsDims; i++ {
qVec.SetVec(i, float64(resp.Data[0].Embedding[i]))
}
return qVec, nil
return resp.Data[0].Embedding, nil
}

func (v *Vectorizer) Length() int {
Expand Down
15 changes: 14 additions & 1 deletion app/pkg/oai/errors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package oai

import "github.com/sashabaranov/go-openai"
import (
"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
)

const (
// ContextLengthExceededCode the error code returned by OpenAI to indicate the context length was exceeded
Expand All @@ -21,3 +24,13 @@ func ErrorIs(err error, oaiCode string) bool {

return val == oaiCode
}

// HTTPStatusCode returns the HTTP status code from the error if it is an OpenAI error.
// Returns -1 if its not of type APIError.
func HTTPStatusCode(err error) int {
target := &openai.APIError{}
if !errors.As(err, &target) {
return -1
}
return target.HTTPStatusCode
}
47 changes: 47 additions & 0 deletions app/pkg/oai/errors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package oai

import (
"testing"

"github.com/pkg/errors"
"github.com/sashabaranov/go-openai"
)

func Test_HTTPStatusCode(t *testing.T) {
type testCase struct {
name string
err error
expected int
}

cases := []testCase{
{
name: "basic",
err: &openai.APIError{
HTTPStatusCode: 404,
},
expected: 404,
},
{
name: "wrapped",
err: errors.Wrapf(&openai.APIError{
HTTPStatusCode: 509,
}, "wrapped"),
expected: 509,
},
{
name: "not api error",
err: errors.New("not an api error"),
expected: -1,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
actual := HTTPStatusCode(tc.err)
if actual != tc.expected {
t.Errorf("expected %v, got %v", tc.expected, actual)
}
})
}
}

0 comments on commit e5ab380

Please sign in to comment.