Skip to content

Commit

Permalink
Implement key batching for cassandra online store in go
Browse files Browse the repository at this point in the history
  • Loading branch information
acevedosharp committed Dec 6, 2024
1 parent 1c1d1f0 commit 118f947
Showing 1 changed file with 55 additions and 45 deletions.
100 changes: 55 additions & 45 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/hex"
"errors"
"fmt"
"math"
"os"
"strings"
"sync"
Expand Down Expand Up @@ -209,16 +210,22 @@ func (c *CassandraOnlineStore) getFqTableName(tableName string) string {
return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName)
}

func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string {
func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nkeys int) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
}

keyPlaceholders := make([]string, nkeys)
for i := 0; i < nkeys; i++ {
keyPlaceholders[i] = "?"
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`,
tableName,
strings.Join(keyPlaceholders, ","),
strings.Join(quotedFeatureNames, ","),
)
}
Expand Down Expand Up @@ -265,45 +272,43 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

// Prepare the query
tableName := c.getFqTableName(featureViewName)
cqlStatement := c.getCQLStatement(tableName, featureNames)

var waitGroup sync.WaitGroup
waitGroup.Add(len(serializedEntityKeys))
// do batching
nKeys := len(serializedEntityKeys)
batchSize := 2
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize)))

batches := make([][]any, nBatches)
nAssigned := 0
for i := 0; i < nBatches; i++ {
thisBatchSize := int(math.Min(float64(batchSize), float64(nKeys-nAssigned)))
nAssigned += thisBatchSize
batches[i] = make([]any, thisBatchSize)
for j := 0; j < thisBatchSize; j++ {
batches[i][j] = serializedEntityKeys[i*batchSize+j]
}
}

errorsChannel := make(chan error, len(serializedEntityKeys))
for _, serializedEntityKey := range serializedEntityKeys {
go func(serEntityKey any) {
defer waitGroup.Done()
var waitGroup sync.WaitGroup
waitGroup.Add(nBatches)

iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter()
errorsChannel := make(chan error, nBatches)

rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)]
for _, batch := range batches {
go func(keyBatch []any) {
defer waitGroup.Done()

// fill the row with nulls if not found
if iter.NumRows() == 0 {
for _, featName := range featureNames {
results[rowIdx][featureNamesToIdx[featName]] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
return
}
cqlStatement := c.getCQLStatement(tableName, featureNames, len(keyBatch))
iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter()

scanner := iter.Scanner()
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
rowFeatures := make(map[string]FeatureData)
// key 1: entityKey - key 2: featureName
batchFeatures := make(map[string]map[string]FeatureData)
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
Expand All @@ -317,7 +322,10 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
rowFeatures[featureName] = FeatureData{
if batchFeatures[entityKey] == nil {
batchFeatures[entityKey] = make(map[string]FeatureData)
}
batchFeatures[entityKey][featureName] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
Expand All @@ -335,26 +343,28 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
return
}

for _, featName := range featureNames {
featureData, ok := rowFeatures[featName]
if !ok {
featureData = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
for _, serializedEntityKey := range keyBatch {
for _, featName := range featureNames {
keyString := serializedEntityKey.(string)
featureData, ok := batchFeatures[keyString][featName]
if !ok {
featureData = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
results[serializedEntityKeyToIndex[keyString]][featureNamesToIdx[featName]] = featureData
}
results[rowIdx][featureNamesToIdx[featName]] = featureData
}
}(serializedEntityKey)
}(batch)
}

// wait until all concurrent single-key queries are done
waitGroup.Wait()
close(errorsChannel)
Expand Down

0 comments on commit 118f947

Please sign in to comment.