Skip to content

Commit

Permalink
Fix the learner to not include the full notebook in the example.
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewi committed Oct 5, 2024
1 parent 0780ece commit 04cecc6
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 14 deletions.
24 changes: 24 additions & 0 deletions app/pkg/docs/blocks.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package docs

import (
"context"

Check failure on line 4 in app/pkg/docs/blocks.go

View workflow job for this annotation

GitHub Actions / golang test & build

File is not `goimports`-ed (goimports)
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
)

// CreateQuery creates a query from a GenerateRequest
// It returns the blocks that should be used to query for similar documents
func CreateQuery(ctx context.Context, req *v1alpha1.GenerateRequest) ([]*v1alpha1.Block, error) {
// Use a simple algorithm.
// 1. Always select at least the current block
// 2. Select additional blocks if they are markup blocks.
startIndex := req.GetSelectedIndex() - 1

for ; startIndex >= 0; startIndex-- {
if req.GetDoc().GetBlocks()[startIndex].Kind != v1alpha1.BlockKind_MARKUP {
break
}
}

blocks := req.GetDoc().GetBlocks()[startIndex+1 : req.GetSelectedIndex()+1]
return blocks, nil
}
85 changes: 85 additions & 0 deletions app/pkg/docs/blocks_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package docs

import (

Check failure on line 3 in app/pkg/docs/blocks_test.go

View workflow job for this annotation

GitHub Actions / golang test & build

File is not `goimports`-ed (goimports)
"github.com/google/go-cmp/cmp"
"github.com/jlewi/foyle/app/pkg/testutil"
"github.com/jlewi/foyle/protos/go/foyle/v1alpha1"
"testing"
)

func Test_CreateQuery(t *testing.T) {
doc1 := &v1alpha1.Doc{
Blocks: []*v1alpha1.Block{
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 0",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 1",
},
{
Kind: v1alpha1.BlockKind_CODE,
Contents: "cell 2",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 3",
},
{
Kind: v1alpha1.BlockKind_MARKUP,
Contents: "cell 4",
},
},
}

type testCase struct {
name string
input *v1alpha1.GenerateRequest
expected []*v1alpha1.Block
}

cases := []testCase{
{
name: "stop-at-start",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 1,
},
expected: doc1.Blocks[0:2],
},
{
name: "start-on-codeblock",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 2,
},
expected: doc1.Blocks[0:3],
},
{
name: "stop-on-code",
input: &v1alpha1.GenerateRequest{
Doc: doc1,
SelectedIndex: 4,
},
expected: doc1.Blocks[3:5],
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
blocks, err := CreateQuery(nil, tc.input)

Check failure on line 71 in app/pkg/docs/blocks_test.go

View workflow job for this annotation

GitHub Actions / golang test & build

SA1012: do not pass a nil Context, even if a function permits it; pass context.TODO if you are unsure about which Context to use (staticcheck)
if err != nil {
t.Fatalf("CreateQuery failed: %v", err)
}
if len(blocks) != len(tc.expected) {
t.Errorf("CreateQuery returned %d blocks; want %d", len(blocks), len(tc.expected))
}

if d := cmp.Diff(tc.expected, blocks, testutil.DocComparer, testutil.BlockComparer); d != "" {
t.Errorf("CreateQuery returned unexpected blocks:\n%v", d)
}

})
}
}
8 changes: 7 additions & 1 deletion app/pkg/learn/in_memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package learn

import (
"context"
"github.com/jlewi/foyle/app/pkg/docs"

Check failure on line 5 in app/pkg/learn/in_memory.go

View workflow job for this annotation

GitHub Actions / golang test & build

File is not `goimports`-ed (goimports)
"io"
"sort"
"sync"
Expand Down Expand Up @@ -78,8 +79,13 @@ func (db *InMemoryExampleDB) GetExamples(ctx context.Context, req *v1alpha1.Gene
return []*v1alpha1.Example{}, nil
}

blocks, err := docs.CreateQuery(ctx, req)
if err != nil {
return nil, errors.Wrap(err, "Failed to create query")
}

// Compute the embedding for the query.
qVecData, err := db.vectorizer.Embed(ctx, req)
qVecData, err := db.vectorizer.Embed(ctx, blocks)
if err != nil {
return nil, errors.Wrap(err, "Failed to compute embedding for query")
}
Expand Down
25 changes: 18 additions & 7 deletions app/pkg/learn/learner.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package learn
import (
"context"
"fmt"
"github.com/jlewi/foyle/app/pkg/docs"
"io"
"strings"
"sync"
Expand Down Expand Up @@ -182,17 +183,31 @@ func (l *Learner) Reconcile(ctx context.Context, id string) error {

if len(expectedFiles) == 0 {
cellsProcessed.WithLabelValues("noExampleFiles").Inc()
log.Error(err, "No training files found", "id", b.GetId())
log.Error(err, "No training files found", "blockId", b.GetId())
return errors.Wrapf(err, "No training files found for example %s", b.GetId())
}

// TODO(jeremy): Should we take into account execution status when looking for mistakes?

// Deep copy the original message
newDoc := proto.Clone(b.Doc).(*v1alpha1.Doc)
newBlock := proto.Clone(b.ExecutedBlock).(*v1alpha1.Block)
answer := []*v1alpha1.Block{newBlock}

req := &v1alpha1.GenerateRequest{
Doc: b.Doc,
SelectedIndex: int32(len(b.Doc.Blocks) - 1),
}
queryBlocks, err := docs.CreateQuery(ctx, req)

newDoc := &v1alpha1.Doc{
Blocks: queryBlocks,
}

if err != nil {
log.Error(err, "Failed to create query", "exampleId", b.GetId())
return errors.Wrapf(err, "Failed to create query for example %s", b.GetId())
}

example := &v1alpha1.Example{
Id: b.GetId(),
Query: newDoc,
Expand Down Expand Up @@ -277,11 +292,7 @@ func (l *Learner) computeEmbeddings(ctx context.Context, example *v1alpha1.Examp
return nil
}

selectedIndex := len(example.Query.GetBlocks()) - 1
qVec, err := l.vectorizer.Embed(ctx, &v1alpha1.GenerateRequest{
Doc: example.Query,
SelectedIndex: int32(selectedIndex),
})
qVec, err := l.vectorizer.Embed(ctx, example.Query.GetBlocks())

if err != nil {
return err
Expand Down
2 changes: 2 additions & 0 deletions app/pkg/llms/query.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
package llms

Check failure on line 2 in app/pkg/llms/query.go

View workflow job for this annotation

GitHub Actions / golang test & build

File is not `gofmt`-ed with `-s` (gofmt)
4 changes: 2 additions & 2 deletions app/pkg/llms/vectorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type Vector []float32

// Vectorizer computes embedding representations of text.
type Vectorizer interface {
// Embed computes the embedding of the text
Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (Vector, error)
// Embed computes the embedding of the blocks
Embed(ctx context.Context, blocks []*v1alpha1.Block) (Vector, error)
// Length returns the length of the embeddings
Length() int
}
Expand Down
6 changes: 2 additions & 4 deletions app/pkg/oai/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,8 @@ type Vectorizer struct {
client *openai.Client
}

func (v *Vectorizer) Embed(ctx context.Context, req *v1alpha1.GenerateRequest) (llms.Vector, error) {
// Should we use more than one block? Do we have examples where using moe than 1 one block would help
qBlock := req.GetDoc().GetBlocks()[req.GetSelectedIndex()]
text := docs.BlocksToMarkdown([]*v1alpha1.Block{qBlock})
func (v *Vectorizer) Embed(ctx context.Context, blocks []*v1alpha1.Block) (llms.Vector, error) {
text := docs.BlocksToMarkdown(blocks)

// Compute the embedding for the query.
log := logs.FromContext(ctx)
Expand Down

0 comments on commit 04cecc6

Please sign in to comment.