Skip to content

Commit

Permalink
Fix config parsing, and multi key retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
Jose Acevedo committed Nov 15, 2024
1 parent 5d3bf5f commit ed81825
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 85 deletions.
198 changes: 127 additions & 71 deletions go/internal/feast/onlinestore/cassandraonlinestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package onlinestore

import (
"context"
"crypto/sha1"
"encoding/binary"
"encoding/hex"
"errors"
Expand All @@ -13,9 +12,11 @@ import (
"github.com/gocql/gocql"
"github.com/golang/protobuf/proto"
"github.com/rs/zerolog/log"
"google.golang.org/protobuf/types/known/timestamppb"
_ "google.golang.org/protobuf/types/known/timestamppb"
_ "net"
"sort"
"strings"
"time"
_ "time"
)
Expand Down Expand Up @@ -49,11 +50,21 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online
if !ok {
cassandraHosts = "127.0.0.1"
log.Warn().Msg("Host not provided: Using localhost instead")
}

var rawCassandraHosts []interface{}
if rawCassandraHosts, ok = cassandraHosts.([]interface{}); !ok {
return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field")
}
var cassandraHostStr string
if cassandraHostStr, ok = cassandraHosts.(string); !ok {
return nil, fmt.Errorf("failed to convert hosts to string: %+v", cassandraHostStr)

var cassandraHostsStr = make([]string, len(rawCassandraHosts))
for i, rawHost := range rawCassandraHosts {
hostStr, ok := rawHost.(string)
if !ok {
return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost)
}
fmt.Printf("\tsingle host: %s\n", hostStr)
cassandraHostsStr[i] = hostStr
}

username, ok := onlineStoreConfig["username"]
Expand All @@ -79,19 +90,28 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online

keyspace, ok := onlineStoreConfig["keyspace"]
if !ok {
keyspace = project
log.Warn().Msgf("Keyspace not defined: Using project name %s as keyspace instead", project)
keyspace = "scylladb"
log.Warn().Msg("Keyspace not defined: Using 'scylladb' as keyspace instead")
}
store.keyspace = keyspace.(string)

var keyspaceStr string
if keyspaceStr, ok = keyspace.(string); !ok {
return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr)
}

store.clusterConfigs = gocql.NewCluster(store.hosts...)
protocolVersion, ok := onlineStoreConfig["protocol_version"]
if !ok {
protocolVersion = 4.0
log.Warn().Msg("protocol_version not specified: Using 4 instead")
}
protocolVersionInt := int(protocolVersion.(float64))

store.clusterConfigs = gocql.NewCluster(cassandraHostsStr...)
// TODO: Figure out if we need to offer users the ability to tune the timeouts
//store.clusterConfigs.ConnectTimeout = 1
//store.clusterConfigs.Timeout = 1
store.clusterConfigs.ProtoVersion = onlineStoreConfig["protocol"].(int)
store.clusterConfigs.ProtoVersion = protocolVersionInt
//store.clusterConfigs.Consistency = gocql.Quorum
store.clusterConfigs.Keyspace = keyspaceStr
loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"]
Expand All @@ -115,21 +135,36 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online
}

store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr}
createdSession, err1 := gocql.NewSession(*store.clusterConfigs)
if err1 != nil {
createdSession, err := gocql.NewSession(*store.clusterConfigs)
if err != nil {
return nil, fmt.Errorf("Unable to connect to the ScyllaDB database")
}
store.session = createdSession
return &store, nil
}

func (c *CassandraOnlineStore) getTableName(tableName string) string {
return fmt.Sprintf(`"%s"_"%s"_""%s"`, c.keyspace, c.project, tableName)
func (c *CassandraOnlineStore) getFqTableName(tableName string) string {
return fmt.Sprintf(`"%s"."%s_%s"`, c.keyspace, c.project, tableName)
}
func (c *CassandraOnlineStore) getCQLStatement(tableName string) string {
selectStatement := "SELECT entity_key, feature_name, value, event_ts FROM %s WHERE entity_key IN ?;"
return fmt.Sprintf(selectStatement, tableName)
func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string {
// TODO: Compare with multiple single-key concurrent queries like in the Python feature server
keyPlaceholders := make([]string, nKeys)
for i := 0; i < nKeys; i++ {
keyPlaceholders[i] = "?"
}

// this prevents fetching unnecessary features
quotedFeatureNames := make([]string, len(featureNames))
for i, featureName := range featureNames {
quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName)
}

return fmt.Sprintf(
`SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`,
tableName,
strings.Join(keyPlaceholders, ","),
strings.Join(quotedFeatureNames, ","),
)
}

func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]interface{}, map[string]int, error) {
Expand All @@ -141,80 +176,111 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti
return nil, nil, err
}
// encoding to hex
cassandraKeys[i] = hex.EncodeToString(*key)
cassandraKeyToEntityIndex[hashSerializedEntityKey(key)] = i
encodedKey := hex.EncodeToString(*key)
cassandraKeys[i] = encodedKey
cassandraKeyToEntityIndex[encodedKey] = i
}
return cassandraKeys, cassandraKeyToEntityIndex, nil
}
func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) {
uniqueNames := make(map[string]int32)
for _, fvName := range featureViewNames {
uniqueNames[fvName] = 0
}
if len(uniqueNames) != 1 {
return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once")
}

serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys)

if err != nil {
return nil, fmt.Errorf("Error when serializing entity keys for Cassandra")
}
results := make([][]FeatureData, len(entityKeys))
for i := range results {
results[i] = make([]FeatureData, len(featureNames))
}

featureNamesToIdx := make(map[string]int)
for idx, name := range featureNames {
featureNamesToIdx[name] = idx
}

for _, featureViewName := range featureViewNames {
// Prepare the query
tableName := c.getTableName(featureViewName)
cqlStatement := c.getCQLStatement(tableName)
// Bundle the entity keys in one statement (gocql handles this as concurrent queries)
query := c.session.Query(cqlStatement, serializedEntityKeys...)
scan := query.Iter().Scanner()

// Process the results
var entityKey []byte
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
for scan.Next() {
err := scan.Scan(&entityKey, &featureName, &valueStr, &eventTs)
if err != nil {
return nil, errors.New("Could not read row in query for (entity key, feature name, value, event ts)")
featureViewName := featureViewNames[0]

// Prepare the query
tableName := c.getFqTableName(featureViewName)
cqlStatement := c.getCQLStatement(tableName, featureNames, len(serializedEntityKeys))
// Bundle the entity keys in one statement (gocql handles this as concurrent queries)
scanner := c.session.Query(cqlStatement, serializedEntityKeys...).Iter().Scanner()

// Process the results
var entityKey string
var featureName string
var eventTs time.Time
var valueStr []byte
var deserializedValue types.Value
for scanner.Next() {
err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr)
if err != nil {
return nil, errors.New("could not read row in query for (entity key, feature name, value, event ts)")
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
return nil, errors.New("error converting parsed Cassandra Value to types.Value")
}

var featureValues FeatureData
if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
featureValues = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())},
Value: types.Value{
Val: deserializedValue.Val,
},
}
if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil {
return nil, errors.New("error converting parsed Cassandra Value to types.Value")
} else {
// Return FeatureData with a null value
featureValues = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}
}

var featureValues FeatureData
if deserializedValue.Val != nil {
// Convert the value to a FeatureData struct
featureValues = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
},
Value: types.Value{
Val: deserializedValue.Val,
},
}
} else {
// Return FeatureData with a null value
featureValues = FeatureData{
// Add the FeatureData to the results
rowIndx := serializedEntityKeyToIndex[entityKey]
results[rowIndx][featureNamesToIdx[featureName]] = featureValues
}
// Check for errors from the Scanner
if err := scanner.Err(); err != nil {
return nil, errors.New("failed to scan features: " + err.Error())
}

for i := 0; i < len(entityKeys); i++ {
for j := 0; j < len(featureNames); j++ {
if results[i][j].Timestamp.GetSeconds() == 0 {
results[i][j] = FeatureData{
Reference: serving.FeatureReferenceV2{
FeatureViewName: featureViewName,
FeatureName: featureName,
FeatureName: featureViewNames[j],
},
Value: types.Value{
Val: &types.Value_NullVal{
NullVal: types.Null_NULL,
},
},
}

}
// Add the FeatureData to the results
rowIndx := serializedEntityKeyToIndex[hashSerializedEntityKey(&entityKey)]
results[rowIndx][featureNamesToIdx[featureName]] = featureValues
}
// Check for errors from the Scanner
if err := scan.Err(); err != nil {
return nil, err
}
}

Expand Down Expand Up @@ -312,16 +378,6 @@ func serializeCassandraValue(value interface{}, entityKeySerializationVersion in
}
}

func hashSerializedEntityKey(serializedEntityKey *[]byte) string {
if serializedEntityKey == nil {
return ""
}
h := sha1.New()
h.Write(*serializedEntityKey)
sha1_hash := hex.EncodeToString(h.Sum(nil))
return sha1_hash
}

func (c *CassandraOnlineStore) Destruct() {

}
17 changes: 3 additions & 14 deletions go/internal/feast/onlinestore/sqliteonlinestore.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
package onlinestore

import (
"crypto/sha1"
"database/sql"
"encoding/hex"
"errors"
"github.com/feast-dev/feast/go/internal/feast/utils"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -76,7 +75,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.
return nil, err
}
// TODO: fix this, string conversion is not safe
entityNameToEntityIndex[hashSerializedEntityKey(serKey)] = i
entityNameToEntityIndex[utils.HashSerializedEntityKey(serKey)] = i

Check failure on line 78 in go/internal/feast/onlinestore/sqliteonlinestore.go

View workflow job for this annotation

GitHub Actions / lint-go

undefined: utils.HashSerializedEntityKey
// for IN clause in read query
in_query[i] = "?"
serialized_entities[i] = *serKey
Expand Down Expand Up @@ -109,7 +108,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.
if err := proto.Unmarshal(valueString, &value); err != nil {
return nil, errors.New("error converting parsed value to types.Value")
}
rowIdx := entityNameToEntityIndex[hashSerializedEntityKey(&entity_key)]
rowIdx := entityNameToEntityIndex[utils.HashSerializedEntityKey(&entity_key)]

Check failure on line 111 in go/internal/feast/onlinestore/sqliteonlinestore.go

View workflow job for this annotation

GitHub Actions / lint-go

undefined: utils.HashSerializedEntityKey
if results[rowIdx] == nil {
results[rowIdx] = make([]FeatureData, featureCount)
}
Expand Down Expand Up @@ -152,13 +151,3 @@ func initializeConnection(db_path string) (*sql.DB, error) {
}
return db, nil
}

func hashSerializedEntityKey(serializedEntityKey *[]byte) string {
if serializedEntityKey == nil {
return ""
}
h := sha1.New()
h.Write(*serializedEntityKey)
sha1_hash := hex.EncodeToString(h.Sum(nil))
return sha1_hash
}

0 comments on commit ed81825

Please sign in to comment.