diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 6f3fb39f17d..e03aa39d61b 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -2,7 +2,6 @@ package onlinestore import ( "context" - "crypto/sha1" "encoding/binary" "encoding/hex" "errors" @@ -13,9 +12,11 @@ import ( "github.com/gocql/gocql" "github.com/golang/protobuf/proto" "github.com/rs/zerolog/log" + "google.golang.org/protobuf/types/known/timestamppb" _ "google.golang.org/protobuf/types/known/timestamppb" _ "net" "sort" + "strings" "time" _ "time" ) @@ -49,11 +50,21 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online if !ok { cassandraHosts = "127.0.0.1" log.Warn().Msg("Host not provided: Using localhost instead") + } + var rawCassandraHosts []interface{} + if rawCassandraHosts, ok = cassandraHosts.([]interface{}); !ok { + return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field") } - var cassandraHostStr string - if cassandraHostStr, ok = cassandraHosts.(string); !ok { - return nil, fmt.Errorf("failed to convert hosts to string: %+v", cassandraHostStr) + + var cassandraHostsStr = make([]string, len(rawCassandraHosts)) + for i, rawHost := range rawCassandraHosts { + hostStr, ok := rawHost.(string) + if !ok { + return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) + } + fmt.Printf("\tsingle host: %s\n", hostStr) + cassandraHostsStr[i] = hostStr } username, ok := onlineStoreConfig["username"] @@ -79,19 +90,28 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online keyspace, ok := onlineStoreConfig["keyspace"] if !ok { - keyspace = project - log.Warn().Msgf("Keyspace not defined: Using project name %s as keyspace instead", project) + keyspace = "scylladb" + log.Warn().Msg("Keyspace not defined: Using 'scylladb' as keyspace instead") } + store.keyspace = keyspace.(string) + var keyspaceStr string if keyspaceStr, ok = keyspace.(string); !ok { return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr) } - store.clusterConfigs = gocql.NewCluster(store.hosts...) + protocolVersion, ok := onlineStoreConfig["protocol_version"] + if !ok { + protocolVersion = 4.0 + log.Warn().Msg("protocol_version not specified: Using 4 instead") + } + protocolVersionInt := int(protocolVersion.(float64)) + + store.clusterConfigs = gocql.NewCluster(cassandraHostsStr...) // TODO: Figure out if we need to offer users the ability to tune the timeouts //store.clusterConfigs.ConnectTimeout = 1 //store.clusterConfigs.Timeout = 1 - store.clusterConfigs.ProtoVersion = onlineStoreConfig["protocol"].(int) + store.clusterConfigs.ProtoVersion = protocolVersionInt //store.clusterConfigs.Consistency = gocql.Quorum store.clusterConfigs.Keyspace = keyspaceStr loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"] @@ -115,21 +135,36 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} - createdSession, err1 := gocql.NewSession(*store.clusterConfigs) - if err1 != nil { + createdSession, err := gocql.NewSession(*store.clusterConfigs) + if err != nil { return nil, fmt.Errorf("Unable to connect to the ScyllaDB database") } store.session = createdSession return &store, nil } -func (c *CassandraOnlineStore) getTableName(tableName string) string { - return fmt.Sprintf(`"%s"_"%s"_""%s"`, c.keyspace, c.project, tableName) +func (c *CassandraOnlineStore) getFqTableName(tableName string) string { + return fmt.Sprintf(`"%s"."%s_%s"`, c.keyspace, c.project, tableName) } -func (c *CassandraOnlineStore) getCQLStatement(tableName string) string { - selectStatement := "SELECT entity_key, feature_name, value, event_ts FROM %s WHERE entity_key IN ?;" - return fmt.Sprintf(selectStatement, tableName) +func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string { + // TODO: Compare with multiple single-key concurrent queries like in the Python feature server + keyPlaceholders := make([]string, nKeys) + for i := 0; i < nKeys; i++ { + keyPlaceholders[i] = "?" + } + // this prevents fetching unnecessary features + quotedFeatureNames := make([]string, len(featureNames)) + for i, featureName := range featureNames { + quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) + } + + return fmt.Sprintf( + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, + tableName, + strings.Join(keyPlaceholders, ","), + strings.Join(quotedFeatureNames, ","), + ) } func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]interface{}, map[string]int, error) { @@ -141,64 +176,103 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti return nil, nil, err } // encoding to hex - cassandraKeys[i] = hex.EncodeToString(*key) - cassandraKeyToEntityIndex[hashSerializedEntityKey(key)] = i + encodedKey := hex.EncodeToString(*key) + cassandraKeys[i] = encodedKey + cassandraKeyToEntityIndex[encodedKey] = i } return cassandraKeys, cassandraKeyToEntityIndex, nil } func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + uniqueNames := make(map[string]int32) + for _, fvName := range featureViewNames { + uniqueNames[fvName] = 0 + } + if len(uniqueNames) != 1 { + return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") + } serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + if err != nil { return nil, fmt.Errorf("Error when serializing entity keys for Cassandra") } results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) + } + featureNamesToIdx := make(map[string]int) for idx, name := range featureNames { featureNamesToIdx[name] = idx } - for _, featureViewName := range featureViewNames { - // Prepare the query - tableName := c.getTableName(featureViewName) - cqlStatement := c.getCQLStatement(tableName) - // Bundle the entity keys in one statement (gocql handles this as concurrent queries) - query := c.session.Query(cqlStatement, serializedEntityKeys...) - scan := query.Iter().Scanner() - - // Process the results - var entityKey []byte - var featureName string - var eventTs time.Time - var valueStr []byte - var deserializedValue types.Value - for scan.Next() { - err := scan.Scan(&entityKey, &featureName, &valueStr, &eventTs) - if err != nil { - return nil, errors.New("Could not read row in query for (entity key, feature name, value, event ts)") + featureViewName := featureViewNames[0] + + // Prepare the query + tableName := c.getFqTableName(featureViewName) + cqlStatement := c.getCQLStatement(tableName, featureNames, len(serializedEntityKeys)) + // Bundle the entity keys in one statement (gocql handles this as concurrent queries) + scanner := c.session.Query(cqlStatement, serializedEntityKeys...).Iter().Scanner() + + // Process the results + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + return nil, errors.New("could not read row in query for (entity key, feature name, value, event ts)") + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + return nil, errors.New("error converting parsed Cassandra Value to types.Value") + } + + var featureValues FeatureData + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, } - if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { - return nil, errors.New("error converting parsed Cassandra Value to types.Value") + } else { + // Return FeatureData with a null value + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, } + } - var featureValues FeatureData - if deserializedValue.Val != nil { - // Convert the value to a FeatureData struct - featureValues = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureName, - }, - Value: types.Value{ - Val: deserializedValue.Val, - }, - } - } else { - // Return FeatureData with a null value - featureValues = FeatureData{ + // Add the FeatureData to the results + rowIndx := serializedEntityKeyToIndex[entityKey] + results[rowIndx][featureNamesToIdx[featureName]] = featureValues + } + // Check for errors from the Scanner + if err := scanner.Err(); err != nil { + return nil, errors.New("failed to scan features: " + err.Error()) + } + + for i := 0; i < len(entityKeys); i++ { + for j := 0; j < len(featureNames); j++ { + if results[i][j].Timestamp.GetSeconds() == 0 { + results[i][j] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: featureViewName, - FeatureName: featureName, + FeatureName: featureViewNames[j], }, Value: types.Value{ Val: &types.Value_NullVal{ @@ -206,15 +280,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ }, }, } - } - // Add the FeatureData to the results - rowIndx := serializedEntityKeyToIndex[hashSerializedEntityKey(&entityKey)] - results[rowIndx][featureNamesToIdx[featureName]] = featureValues - } - // Check for errors from the Scanner - if err := scan.Err(); err != nil { - return nil, err } } @@ -312,16 +378,6 @@ func serializeCassandraValue(value interface{}, entityKeySerializationVersion in } } -func hashSerializedEntityKey(serializedEntityKey *[]byte) string { - if serializedEntityKey == nil { - return "" - } - h := sha1.New() - h.Write(*serializedEntityKey) - sha1_hash := hex.EncodeToString(h.Sum(nil)) - return sha1_hash -} - func (c *CassandraOnlineStore) Destruct() { } diff --git a/go/internal/feast/onlinestore/sqliteonlinestore.go b/go/internal/feast/onlinestore/sqliteonlinestore.go index 6c37258e740..a329dc1be6c 100644 --- a/go/internal/feast/onlinestore/sqliteonlinestore.go +++ b/go/internal/feast/onlinestore/sqliteonlinestore.go @@ -1,10 +1,9 @@ package onlinestore import ( - "crypto/sha1" "database/sql" - "encoding/hex" "errors" + "github.com/feast-dev/feast/go/internal/feast/utils" "strings" "sync" "time" @@ -76,7 +75,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. return nil, err } // TODO: fix this, string conversion is not safe - entityNameToEntityIndex[hashSerializedEntityKey(serKey)] = i + entityNameToEntityIndex[utils.HashSerializedEntityKey(serKey)] = i // for IN clause in read query in_query[i] = "?" serialized_entities[i] = *serKey @@ -109,7 +108,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. if err := proto.Unmarshal(valueString, &value); err != nil { return nil, errors.New("error converting parsed value to types.Value") } - rowIdx := entityNameToEntityIndex[hashSerializedEntityKey(&entity_key)] + rowIdx := entityNameToEntityIndex[utils.HashSerializedEntityKey(&entity_key)] if results[rowIdx] == nil { results[rowIdx] = make([]FeatureData, featureCount) } @@ -152,13 +151,3 @@ func initializeConnection(db_path string) (*sql.DB, error) { } return db, nil } - -func hashSerializedEntityKey(serializedEntityKey *[]byte) string { - if serializedEntityKey == nil { - return "" - } - h := sha1.New() - h.Write(*serializedEntityKey) - sha1_hash := hex.EncodeToString(h.Sum(nil)) - return sha1_hash -}