From 04cecc609d7d2c978e85ea433f70a5234b69c797 Mon Sep 17 00:00:00 2001 From: Jeremy Lewi Date: Sat, 5 Oct 2024 13:25:51 -0700 Subject: [PATCH] Fix the learner to not include the full notebook in the example. --- app/pkg/docs/blocks.go | 24 +++++++++++ app/pkg/docs/blocks_test.go | 85 +++++++++++++++++++++++++++++++++++++ app/pkg/learn/in_memory.go | 8 +++- app/pkg/learn/learner.go | 25 ++++++++--- app/pkg/llms/query.go | 2 + app/pkg/llms/vectorizer.go | 4 +- app/pkg/oai/embeddings.go | 6 +-- 7 files changed, 140 insertions(+), 14 deletions(-) create mode 100644 app/pkg/docs/blocks.go create mode 100644 app/pkg/docs/blocks_test.go create mode 100644 app/pkg/llms/query.go diff --git a/app/pkg/docs/blocks.go b/app/pkg/docs/blocks.go new file mode 100644 index 00000000..541dab86 --- /dev/null +++ b/app/pkg/docs/blocks.go @@ -0,0 +1,24 @@ +package docs + +import ( + "context" + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" +) + +// CreateQuery creates a query from a GenerateRequest +// It returns the blocks that should be used to query for similar documents +func CreateQuery(ctx context.Context, req *v1alpha1.GenerateRequest) ([]*v1alpha1.Block, error) { + // Use a simple algorithm. + // 1. Always select at least the current block + // 2. Select additional blocks if they are markup blocks. + startIndex := req.GetSelectedIndex() - 1 + + for ; startIndex >= 0; startIndex-- { + if req.GetDoc().GetBlocks()[startIndex].Kind != v1alpha1.BlockKind_MARKUP { + break + } + } + + blocks := req.GetDoc().GetBlocks()[startIndex+1 : req.GetSelectedIndex()+1] + return blocks, nil +} diff --git a/app/pkg/docs/blocks_test.go b/app/pkg/docs/blocks_test.go new file mode 100644 index 00000000..5c955673 --- /dev/null +++ b/app/pkg/docs/blocks_test.go @@ -0,0 +1,85 @@ +package docs + +import ( + "github.com/google/go-cmp/cmp" + "github.com/jlewi/foyle/app/pkg/testutil" + "github.com/jlewi/foyle/protos/go/foyle/v1alpha1" + "testing" +) + +func Test_CreateQuery(t *testing.T) { + doc1 := &v1alpha1.Doc{ + Blocks: []*v1alpha1.Block{ + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 0", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 1", + }, + { + Kind: v1alpha1.BlockKind_CODE, + Contents: "cell 2", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 3", + }, + { + Kind: v1alpha1.BlockKind_MARKUP, + Contents: "cell 4", + }, + }, + } + + type testCase struct { + name string + input *v1alpha1.GenerateRequest + expected []*v1alpha1.Block + } + + cases := []testCase{ + { + name: "stop-at-start", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 1, + }, + expected: doc1.Blocks[0:2], + }, + { + name: "start-on-codeblock", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 2, + }, + expected: doc1.Blocks[0:3], + }, + { + name: "stop-on-code", + input: &v1alpha1.GenerateRequest{ + Doc: doc1, + SelectedIndex: 4, + }, + expected: doc1.Blocks[3:5], + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + blocks, err := CreateQuery(nil, tc.input) + if err != nil { + t.Fatalf("CreateQuery failed: %v", err) + } + if len(blocks) != len(tc.expected) { + t.Errorf("CreateQuery returned %d blocks; want %d", len(blocks), len(tc.expected)) + } + + if d := cmp.Diff(tc.expected, blocks, testutil.DocComparer, testutil.BlockComparer); d != "" { + t.Errorf("CreateQuery returned unexpected blocks:\n%v", d) + } + + }) + } +} diff --git a/app/pkg/learn/in_memory.go b/app/pkg/learn/in_memory.go index 81382afa..723dbf11 100644 --- a/app/pkg/learn/in_memory.go +++ b/app/pkg/learn/in_memory.go @@ -2,6 +2,7 @@ package learn import ( "context" + "github.com/jlewi/foyle/app/pkg/docs" "io" "sort" "sync" @@ -78,8 +79,13 @@ func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.Gene return []*v1alpha1.Example{}, nil } + blocks, err := docs.CreateQuery(ctx, req) + if err != nil { + return nil, errors.Wrap(err, "Failed to create query") + } + // Compute the embedding for the query. - qVecData, err := db.vectorizer.Embed(ctx, req) + qVecData, err := db.vectorizer.Embed(ctx, blocks) if err != nil { return nil, errors.Wrap(err, "Failed to compute embedding for query") } diff --git a/app/pkg/learn/learner.go b/app/pkg/learn/learner.go index eac2a6f0..02e11a3c 100644 --- a/app/pkg/learn/learner.go +++ b/app/pkg/learn/learner.go @@ -3,6 +3,7 @@ package learn import ( "context" "fmt" + "github.com/jlewi/foyle/app/pkg/docs" "io" "strings" "sync" @@ -182,17 +183,31 @@ func (l *Learner) Reconcile(ctx context.Context, id string) error { if len(expectedFiles) == 0 { cellsProcessed.WithLabelValues("noExampleFiles").Inc() - log.Error(err, "No training files found", "id", b.GetId()) + log.Error(err, "No training files found", "blockId", b.GetId()) return errors.Wrapf(err, "No training files found for example %s", b.GetId()) } // TODO(jeremy): Should we take into account execution status when looking for mistakes? // Deep copy the original message - newDoc := proto.Clone(b.Doc).(*v1alpha1.Doc) newBlock := proto.Clone(b.ExecutedBlock).(*v1alpha1.Block) answer := []*v1alpha1.Block{newBlock} + req := &v1alpha1.GenerateRequest{ + Doc: b.Doc, + SelectedIndex: int32(len(b.Doc.Blocks) - 1), + } + queryBlocks, err := docs.CreateQuery(ctx, req) + + newDoc := &v1alpha1.Doc{ + Blocks: queryBlocks, + } + + if err != nil { + log.Error(err, "Failed to create query", "exampleId", b.GetId()) + return errors.Wrapf(err, "Failed to create query for example %s", b.GetId()) + } + example := &v1alpha1.Example{ Id: b.GetId(), Query: newDoc, @@ -277,11 +292,7 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp return nil } - selectedIndex := len(example.Query.GetBlocks()) - 1 - qVec, err := l.vectorizer.Embed(ctx, &v1alpha1.GenerateRequest{ - Doc: example.Query, - SelectedIndex: int32(selectedIndex), - }) + qVec, err := l.vectorizer.Embed(ctx, example.Query.GetBlocks()) if err != nil { return err diff --git a/app/pkg/llms/query.go b/app/pkg/llms/query.go new file mode 100644 index 00000000..7e544ee6 --- /dev/null +++ b/app/pkg/llms/query.go @@ -0,0 +1,2 @@ +package llms + diff --git a/app/pkg/llms/vectorizer.go b/app/pkg/llms/vectorizer.go index a981efcf..169505d5 100644 --- a/app/pkg/llms/vectorizer.go +++ b/app/pkg/llms/vectorizer.go @@ -12,8 +12,8 @@ type Vector []float32 // Vectorizer computes embedding representations of text. type Vectorizer interface { - // Embed computes the embedding of the text - Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (Vector, error) + // Embed computes the embedding of the blocks + Embed(ctx context.Context, blocks []*v1alpha1.Block) (Vector, error) // Length returns the length of the embeddings Length() int } diff --git a/app/pkg/oai/embeddings.go b/app/pkg/oai/embeddings.go index fe88409d..37c7d0d7 100644 --- a/app/pkg/oai/embeddings.go +++ b/app/pkg/oai/embeddings.go @@ -22,10 +22,8 @@ type Vectorizer struct { client *openai.Client } -func (v *Vectorizer) Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (llms.Vector, error) { - // Should we use more than one block? Do we have examples where using moe than 1 one block would help - qBlock := req.GetDoc().GetBlocks()[req.GetSelectedIndex()] - text := docs.BlocksToMarkdown([]*v1alpha1.Block{qBlock}) +func (v *Vectorizer) Embed(ctx context.Context, blocks []*v1alpha1.Block) (llms.Vector, error) { + text := docs.BlocksToMarkdown(blocks) // Compute the embedding for the query. log := logs.FromContext(ctx)