Skip to content

Commit

Permalink
Implement configurable batching + removed misleading log
Browse files Browse the repository at this point in the history
  • Loading branch information
acevedosharp committed Dec 20, 2024
1 parent 2c3f1de commit 535741b
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 11 deletions.
196 changes: 187 additions & 9 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ type CassandraOnlineStore struct {
session *gocql.Session

config *registry.RepoConfig

// The number of keys to include in a single CQL query for retrieval from the database
keyBatchSize int
}

type CassandraConfig struct {
Expand All @@ -44,6 +47,7 @@ type CassandraConfig struct {
loadBalancingPolicy gocql.HostSelectionPolicy
connectionTimeoutMillis int64
requestTimeoutMillis int64
keyBatchSize int
}

func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) {
Expand Down Expand Up @@ -156,6 +160,13 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig,
}
cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64))

keyBatchSize, ok := onlineStoreConfig["key_batch_size"]
if !ok {
keyBatchSize = 10.0
log.Warn().Msg("key_batch_size not specified, defaulting to batches of size 10")
}
cassandraConfig.keyBatchSize = int(keyBatchSize.(float64))

return &cassandraConfig, nil
}

Expand All @@ -176,8 +187,9 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online

store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy

if cassandraConfig.username != "" && cassandraConfig.password != "" {
log.Warn().Msg("username/password not defined, will not be using authentication")
if cassandraConfig.username == "" || cassandraConfig.password == "" {
log.Warn().Msg("username and/or password not defined, will not be using authentication")
} else {
store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{
Username: cassandraConfig.username,
Password: cassandraConfig.password,
Expand All @@ -203,14 +215,38 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online
return nil, fmt.Errorf("unable to connect to the ScyllaDB database")
}
store.session = createdSession

if cassandraConfig.keyBatchSize <= 0 {
return nil, fmt.Errorf("key_batch_size must be greater than zero")
} else if cassandraConfig.keyBatchSize == 1 {
log.Info().Msg("key batching is disabled")
} else {
log.Info().Msgf("key batching is enabled with a batch size of %d", cassandraConfig.keyBatchSize)
}
store.keyBatchSize = cassandraConfig.keyBatchSize

return &store, nil
}

func (c *CassandraOnlineStore) getFqTableName(tableName string) string {
return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName)
}

func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nkeys int) string {
func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`,
tableName,
strings.Join(quotedFeatureNames, ","),
)
}

func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string {
// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
Expand Down Expand Up @@ -244,7 +280,143 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti
}
return cassandraKeys, cassandraKeyToEntityIndex, nil
}
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {

func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
uniqueNames := make(map[string]int32)
for _, fvName := range featureViewNames {
uniqueNames[fvName] = 0
}
if len(uniqueNames) != 1 {
return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once")
}

serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)

if err != nil {
return nil, fmt.Errorf("error when serializing entity keys for Cassandra")
}
results := make([][]FeatureData, len(entityKeys))
for i := range results {
results[i] = make([]FeatureData, len(featureNames))
}

featureNamesToIdx := make(map[string]int)
for idx, name := range featureNames {
featureNamesToIdx[name] = idx
}

featureViewName := featureViewNames[0]

// Prepare the query
tableName := c.getFqTableName(featureViewName)
cqlStatement := c.getSingleKeyCQLStatement(tableName, featureNames)

var waitGroup sync.WaitGroup
waitGroup.Add(len(serializedEntityKeys))

errorsChannel := make(chan error, len(serializedEntityKeys))
for _, serializedEntityKey := range serializedEntityKeys {
go func(serEntityKey any) {
defer waitGroup.Done()

iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter()

rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)]

// fill the row with nulls if not found
if iter.NumRows() == 0 {
for _, featName := range featureNames {
results[rowIdx][featureNamesToIdx[featName]] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
return
}

scanner := iter.Scanner()
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
rowFeatures := make(map[string]FeatureData)
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)")
return
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value")
return
}

if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
rowFeatures[featureName] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
Value: types.Value{
Val: deserializedValue.Val,
},
}
}
}

if err := scanner.Err(); err != nil {
errorsChannel <- errors.New("failed to scan features: " + err.Error())
return
}

for _, featName := range featureNames {
featureData, ok := rowFeatures[featName]
if !ok {
featureData = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}
results[rowIdx][featureNamesToIdx[featName]] = featureData
}
}(serializedEntityKey)
}

// wait until all concurrent single-key queries are done
waitGroup.Wait()
close(errorsChannel)

var collectedErrors []error
for err := range errorsChannel {
if err != nil {
collectedErrors = append(collectedErrors, err)
}
}
if len(collectedErrors) > 0 {
return nil, errors.Join(collectedErrors...)
}

return results, nil
}

func (c *CassandraOnlineStore) BatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
uniqueNames := make(map[string]int32)
for _, fvName := range featureViewNames {
uniqueNames[fvName] = 0
Expand Down Expand Up @@ -273,9 +445,9 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
// Prepare the query
tableName := c.getFqTableName(featureViewName)

// do batching
// Key batching
nKeys := len(serializedEntityKeys)
batchSize := 20
batchSize := c.keyBatchSize
nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize)))

batches := make([][]any, nBatches)
Expand All @@ -293,7 +465,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
waitGroup.Add(nBatches)

errorsChannel := make(chan error, nBatches)

var prevBatchLength int
var cqlStatement string
for _, batch := range batches {
Expand All @@ -302,7 +473,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ

// this caches the previous batch query if it had the same number of keys
if len(keyBatch) != prevBatchLength {
cqlStatement = c.getCQLStatement(tableName, featureNames, len(keyBatch))
cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, len(keyBatch))
}

iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter()
Expand All @@ -327,7 +498,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
}

if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
if batchFeatures[entityKey] == nil {
batchFeatures[entityKey] = make(map[string]FeatureData)
}
Expand Down Expand Up @@ -388,6 +558,14 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ
return results, nil
}

func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
if c.keyBatchSize == 1 {
return c.UnbatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames)
} else {
return c.BatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames)
}
}

func (c *CassandraOnlineStore) Destruct() {
c.session.Close()
}
15 changes: 13 additions & 2 deletions go/internal/feast/onlinestore/cassandraonlinestore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,28 @@ func TestGetFqTableName(t *testing.T) {
assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName)
}

func TestGetCQLStatement(t *testing.T) {
func TestGetSingleKeyCQLStatement(t *testing.T) {
store := CassandraOnlineStore{}
fqTableName := `"scylladb"."dummy_project_dummy_fv"`

cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"})
cqlStatement := store.getSingleKeyCQLStatement(fqTableName, []string{"feat1", "feat2"})
assert.Equal(t,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`,
cqlStatement,
)
}

func TestGetMultiKeyCQLStatement(t *testing.T) {
store := CassandraOnlineStore{}
fqTableName := `"scylladb"."dummy_project_dummy_fv"`

cqlStatement := store.getMultiKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}, 5)
assert.Equal(t,
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" IN (?,?,?,?,?) AND "feature_name" IN ('feat1','feat2')`,
cqlStatement,
)
}

func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) {
store := CassandraOnlineStore{}
_, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"})
Expand Down

0 comments on commit 535741b

Please sign in to comment.