From 6807ffd730a6309447c508f573d5d17f13b78d0e Mon Sep 17 00:00:00 2001 From: msistla96 Date: Wed, 11 Sep 2024 16:03:24 -0500 Subject: [PATCH 01/18] first set of changes for ScyllaDB Online Store --- Makefile | 2 +- go.mod | 7 + go.sum | 23 ++ go/infra/docker/feature-server/Dockerfile | 6 + .../feast/onlinestore/cassandraonlinestore.go | 339 ++++++++++++++++++ go/internal/feast/onlinestore/onlinestore.go | 5 +- go/internal/feast/utils/aws_utils.go | 117 ++++++ .../online_stores/aws_utils_online_store.py | 61 ++++ .../cassandra_online_store.py | 34 +- sdk/python/feast/repo_config.py | 1 + setup.py | 5 + 11 files changed, 593 insertions(+), 7 deletions(-) create mode 100644 go/internal/feast/onlinestore/cassandraonlinestore.go create mode 100644 go/internal/feast/utils/aws_utils.go create mode 100644 sdk/python/feast/infra/online_stores/aws_utils_online_store.py diff --git a/Makefile b/Makefile index 2ee3b60771..9d5721f5a6 100644 --- a/Makefile +++ b/Makefile @@ -486,7 +486,7 @@ build-java-docker-dev: build-go-docker-dev: docker buildx build --build-arg VERSION=dev \ -t feastdev/feature-server-go:dev \ - -f go/infra/docker/feature-server/Dockerfile --load . + -f SCYLLADB_SHARD_AWARE_ENABLED=$(SCYLLADB_SHARD_AWARE_ENABLED) go/infra/docker/feature-server/Dockerfile --load . # Documentation diff --git a/go.mod b/go.mod index 02b528cc19..39bb23a2a8 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,8 @@ require ( github.com/goccy/go-json v0.10.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect + github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect @@ -73,6 +75,11 @@ require ( golang.org/x/tools v0.22.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect + gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +require github.com/aws/aws-sdk-go v1.34.28 + +require github.com/gocql/gocql v1.6.0 diff --git a/go.sum b/go.sum index 2c3d5b28c9..542a21cc7b 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,12 @@ github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= github.com/apache/thrift v0.20.0 h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI= github.com/apache/thrift v0.20.0/go.mod h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8= +github.com/aws/aws-sdk-go v1.34.28 h1:sscPpn/Ns3i0F4HPEWAVcwdIRaZZCuL7llJ2/60yPIk= +github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -49,13 +55,17 @@ github.com/ebitengine/purego v0.5.0-alpha h1:pNZNC8WofBTN3Nm196An50C5taL/87BhFR/ github.com/ebitengine/purego v0.5.0-alpha/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= +github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= @@ -69,14 +79,23 @@ github.com/google/pprof v0.0.0-20230509042627-b1315fad0c5a h1:PEOGDI1kkyW37YqPWH github.com/google/pprof v0.0.0-20230509042627-b1315fad0c5a/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= +github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.2.1 h1:Fmg33tUaq4/8ym9TJN1x7sLJnHVwhP33CNkpYV/7rwI= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/mattn/go-sqlite3 v1.14.16 h1:yOQRA0RpS5PFz/oikGwBEqvAWhWg5ufRz4ETLjwpU1Y= @@ -157,6 +176,7 @@ golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= @@ -229,6 +249,9 @@ gopkg.in/DataDog/dd-trace-go.v1 v1.54.0/go.mod h1:1JqaWiPl1+vHNYuVNmHOG4HDyHbF84 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/infra/docker/feature-server/Dockerfile b/go/infra/docker/feature-server/Dockerfile index cf63bb4559..389925802a 100644 --- a/go/infra/docker/feature-server/Dockerfile +++ b/go/infra/docker/feature-server/Dockerfile @@ -1,5 +1,7 @@ FROM golang:1.22.5 +ARG SCYLLADB_SHARD_AWARE_ENABLED=False +ENV SCYLLADB_SHARD_AWARE_ENABLED=$SCYLLADB_SHARD_AWARE_ENABLED # Update the package list and install the ca-certificates package RUN apt-get update && apt-get install -y ca-certificates RUN apt install -y protobuf-compiler @@ -23,6 +25,10 @@ RUN find ./protos -name "*.proto" \ # Build the Go application RUN go build -o feast ./go/main.go +RUN if ["$SCYLLADB_SHARD_AWARE_ENABLED"== True]; then \ + go mod edit -replace=github.com/gocql/gocql=github.com/scylladb/gocql@latest; \ + go mod tidy; fi + # Expose ports EXPOSE 8080 diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go new file mode 100644 index 0000000000..1436182374 --- /dev/null +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -0,0 +1,339 @@ +package onlinestore + +import ( + "context" + "crypto/sha1" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "github.com/feast-dev/feast/go/internal/feast/registry" + "github.com/feast-dev/feast/go/internal/feast/utils" + "github.com/feast-dev/feast/go/protos/feast/serving" + "github.com/feast-dev/feast/go/protos/feast/types" + "github.com/gocql/gocql" + "github.com/golang/protobuf/proto" + "github.com/rs/zerolog/log" + _ "google.golang.org/protobuf/types/known/timestamppb" + _ "net" + "sort" + "time" + _ "time" +) + +type CassandraOnlineStore struct { + project string + + // Cluster configurations for Cassandra/ScyllaDB + clusterConfigs *gocql.ClusterConfig + + // Session object that holds information about the connection to the cluster + session *gocql.Session + + // keyspace of the table. Defaulted to using the project name + keyspace string + + // Host IP addresses of the cluster + hostIps []string + + config *registry.RepoConfig +} + +func NewCassandraOnlineStore(project string, config *registry.RepoConfig, onlineStoreConfig map[string]interface{}) (*CassandraOnlineStore, error) { + store := CassandraOnlineStore{ + project: project, + config: config, + } + + var username string + var password string + // Parse host_name and Ips + cassandraHostNames, ok1 := onlineStoreConfig["host_names"] + cassandraHostIps, ok2 := onlineStoreConfig["hosts"] + if !ok1 && !ok2 { + cassandraHostIps = "127.0.0.1" + + } + var cassandraHostNameStr string + if cassandraHostNameStr, ok1 = cassandraHostNames.(string); !ok1 { + return nil, fmt.Errorf("failed to convert host_names to string: %+v", cassandraHostNameStr) + } + + var cassandraHostIpsStr string + if cassandraHostIpsStr, ok2 = cassandraHostIps.(string); !ok2 { + return nil, fmt.Errorf("failed to convert hosts(ip addresses) to string: %+v", cassandraHostIpsStr) + } + + username = onlineStoreConfig["username"].(string) + password = onlineStoreConfig["password"].(string) + + if len(username) == 0 { + username = "scylla" + log.Warn().Msg("Username not defined: Using default username instead") + } + if len(username) == 0 { + password = "scylla" + log.Warn().Msg("Password not defined: Using default password instead") + } + + keyspace, ok := onlineStoreConfig["keyspace"] + if !ok { + keyspace = project + log.Warn().Msgf("Keyspace not defined: Using project name %s as keyspace instead", project) + } + var keyspaceStr string + if keyspaceStr, ok = keyspace.(string); !ok { + return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr) + } + + // If you're using host_names, it means that you will need a host resolver to get the IP of the cluster the DB is in + if len(cassandraHostNameStr) > 0 { + ec2Instance := utils.NewEC2Instance(cassandraHostNameStr, "us-west-2") + hostIps, err := ec2Instance.ResolveHostNameToIp() + if err != nil { + return nil, fmt.Errorf("Unable to resolve host name %+v to ip address: ", cassandraHostNameStr) + } + store.hostIps = hostIps + } else { + store.hostIps = []string{cassandraHostNameStr} + } + + store.clusterConfigs = gocql.NewCluster(store.hostIps...) + // TODO: Figure out if we need to offer users the ability to tune the timeouts + //store.clusterConfigs.ConnectTimeout = 1 + //store.clusterConfigs.Timeout = 1 + store.clusterConfigs.ProtoVersion = onlineStoreConfig["protocol"].(int) + //store.clusterConfigs.Consistency = gocql.Quorum + store.clusterConfigs.Keyspace = keyspaceStr + loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"] + if !ok { + return nil, nil + } + loadBalancingPolicyStr, _ := loadBalancingPolicy.(string) + if loadBalancingPolicyStr == "DCAwareRoundRobinPolicy" { + store.clusterConfigs.PoolConfig.HostSelectionPolicy = gocql.RoundRobinHostPolicy() + } else if loadBalancingPolicyStr == "TokenAwarePolicy(DCAwareRoundRobinPolicy)" { + // Configure fallback policy if unable to reach the shard + fallback := gocql.RoundRobinHostPolicy() + // If using ScyllaDB and setting this policy, this makes the driver shard aware to improve performance + store.clusterConfigs.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(fallback) + if config.OnlineStore["type"] == "scylladb" { + store.clusterConfigs.Port = 19042 + } else { + store.clusterConfigs.Port = 9042 + } + } else { + return nil, fmt.Errorf("No load balancing policy specified") + } + + store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: username, Password: password} + createdSession, err1 := gocql.NewSession(*store.clusterConfigs) + if err1 != nil { + return nil, fmt.Errorf("Unable to connect to ScyllaDB") + } + store.session = createdSession + return &store, nil +} + +func (c *CassandraOnlineStore) getTableName(tableName string) string { + return fmt.Sprintf(`"%s"_"%s"_""%s"`, c.keyspace, c.project, tableName) +} +func (c *CassandraOnlineStore) getCQLStatement(tableName string) string { + selectStatement := "SELECT entity_key, feature_name, value, event_ts FROM %s WHERE entity_key IN ?;" + return fmt.Sprintf(selectStatement, tableName) + +} + +func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]interface{}, map[string]int, error) { + cassandraKeys := make([]interface{}, len(entityKeys)) + cassandraKeyToEntityIndex := make(map[string]int) + for i := 0; i < len(entityKeys); i++ { + var key, err = serializeCassandraEntityKey(entityKeys[i], c.config.EntityKeySerializationVersion) + if err != nil { + return nil, nil, err + } + // encoding to hex + cassandraKeys[i] = hex.EncodeToString(*key) + cassandraKeyToEntityIndex[hashSerializedEntityKey(key)] = i + } + return cassandraKeys, cassandraKeyToEntityIndex, nil +} +func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + + serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + if err != nil { + return nil, fmt.Errorf("Error when serializing entity keys for Cassandra") + } + results := make([][]FeatureData, len(entityKeys)) + featureNamesToIdx := make(map[string]int) + for idx, name := range featureNames { + featureNamesToIdx[name] = idx + } + + for _, featureViewName := range featureViewNames { + // Prepare the query + tableName := c.getTableName(featureViewName) + cqlStatement := c.getCQLStatement(tableName) + // Bundle the entity keys in one statement (gocql handles this as concurrent queries) + query := c.session.Query(cqlStatement, serializedEntityKeys...) + scan := query.Iter().Scanner() + + // Process the results + var entityKey []byte + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + for scan.Next() { + err := scan.Scan(&entityKey, &featureName, &valueStr, &eventTs) + if err != nil { + return nil, errors.New("Could not read row in query for (entity key, feature name, value, event ts)") + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + return nil, errors.New("error converting parsed Cassandra Value to types.Value") + } + + var featureValues FeatureData + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Value: types.Value{ + Val: deserializedValue.Val, + }, + } + } else { + // Return FeatureData with a null value + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + + } + // Add the FeatureData to the results + rowIndx := serializedEntityKeyToIndex[hashSerializedEntityKey(&entityKey)] + results[rowIndx][featureNamesToIdx[featureName]] = featureValues + } + // Check for errors from the Scanner + if err := scan.Err(); err != nil { + return nil, err + } + } + + return results, nil + +} + +// Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. +func serializeCassandraEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { + // Ensure that we have the right amount of join keys and entity values + if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { + return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) + } + // Make sure that join keys are sorted so that we have consistent key building + m := make(map[string]*types.Value) + + for i := 0; i < len(entityKey.JoinKeys); i++ { + m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] + } + + keys := make([]string, 0, len(m)) + for k := range entityKey.JoinKeys { + keys = append(keys, entityKey.JoinKeys[k]) + } + sort.Strings(keys) + + // Build the key + length := 5 * len(keys) + bufferList := make([][]byte, length) + + for i := 0; i < len(keys); i++ { + offset := i * 2 + byteBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) + bufferList[offset] = byteBuffer + bufferList[offset+1] = []byte(keys[i]) + } + + for i := 0; i < len(keys); i++ { + offset := (2 * len(keys)) + (i * 3) + value := m[keys[i]].GetVal() + + valueBytes, valueTypeBytes, err := serializeCassandraValue(value, entityKeySerializationVersion) + if err != nil { + return valueBytes, err + } + + typeBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) + + lenBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) + + bufferList[offset+0] = typeBuffer + bufferList[offset+1] = lenBuffer + bufferList[offset+2] = *valueBytes + } + + // Convert from an array of byte arrays to a single byte array + var entityKeyBuffer []byte + for i := 0; i < len(bufferList); i++ { + entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) + } + + return &entityKeyBuffer, nil +} + +func serializeCassandraValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { + // TODO: Implement support for other types (at least the major types like ints, strings, bytes) + switch x := (value).(type) { + case *types.Value_StringVal: + valueString := []byte(x.StringVal) + return &valueString, types.ValueType_STRING, nil + case *types.Value_BytesVal: + return &x.BytesVal, types.ValueType_BYTES, nil + case *types.Value_Int32Val: + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) + return &valueBuffer, types.ValueType_INT32, nil + case *types.Value_Int64Val: + if entityKeySerializationVersion <= 1 { + // We unfortunately have to use 32 bit here for backward compatibility :( + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } else { + valueBuffer := make([]byte, 8) + binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } + case nil: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + default: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + } +} + +func hashSerializedEntityKey(serializedEntityKey *[]byte) string { + if serializedEntityKey == nil { + return "" + } + h := sha1.New() + h.Write(*serializedEntityKey) + sha1_hash := hex.EncodeToString(h.Sum(nil)) + return sha1_hash +} + +func (c *CassandraOnlineStore) Destruct() { + +} diff --git a/go/internal/feast/onlinestore/onlinestore.go b/go/internal/feast/onlinestore/onlinestore.go index 2f30e16d67..5dcc8e5370 100644 --- a/go/internal/feast/onlinestore/onlinestore.go +++ b/go/internal/feast/onlinestore/onlinestore.go @@ -61,7 +61,10 @@ func NewOnlineStore(config *registry.RepoConfig) (OnlineStore, error) { } else if onlineStoreType == "redis" { onlineStore, err := NewRedisOnlineStore(config.Project, config, config.OnlineStore) return onlineStore, err + } else if onlineStoreType == "cassandra" || onlineStoreType == "scylladb" { + onlineStore, err := NewCassandraOnlineStore(config.Project, config, config.OnlineStore) + return onlineStore, err } else { - return nil, fmt.Errorf("%s online store type is currently not supported; only redis and sqlite are supported", onlineStoreType) + return nil, fmt.Errorf("%s online store type is currently not supported; only redis, scylladb, cassandra and sqlite are supported", onlineStoreType) } } diff --git a/go/internal/feast/utils/aws_utils.go b/go/internal/feast/utils/aws_utils.go new file mode 100644 index 0000000000..66e27d635e --- /dev/null +++ b/go/internal/feast/utils/aws_utils.go @@ -0,0 +1,117 @@ +package utils + +import ( + "fmt" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ec2" + "log" + "sort" + "strings" +) + +type EC2Instance struct { + EC2Conn *ec2.EC2 + AutoScalingGroupName string + Instances []*ec2.Instance + region string +} + +func NewEC2Instance(autoScalingGroupName string, region string) *EC2Instance { + sess := session.Must(session.NewSession(&aws.Config{ + Region: aws.String(region), + })) + + svc := ec2.New(sess) + + return &EC2Instance{ + EC2Conn: svc, + AutoScalingGroupName: autoScalingGroupName, + region: region, + } +} + +func (e *EC2Instance) GetIPAddress(instanceID string) (*string, error) { + input := &ec2.DescribeInstancesInput{ + InstanceIds: []*string{ + aws.String(instanceID), + }, + } + + result, err := e.EC2Conn.DescribeInstances(input) + if err != nil { + log.Println("Error getting IP address:", err) + return nil, err + } + + for _, reservation := range result.Reservations { + for _, instance := range reservation.Instances { + return instance.PrivateIpAddress, nil + } + } + + return nil, nil +} + +func (e *EC2Instance) GetNodeInfo(nodeNum int) (*ec2.Instance, error) { + if nodeNum > len(e.Instances) { + return nil, fmt.Errorf("node number is out of range") + } + + return e.Instances[nodeNum-1], nil +} + +func (e *EC2Instance) GetAllNodes(instanceName string) ([]*ec2.Instance, error) { + input := &ec2.DescribeInstancesInput{ + Filters: []*ec2.Filter{ + { + Name: aws.String("tag:region"), + Values: []*string{aws.String(e.region)}, + }, + { + Name: aws.String("tag:aws:autoscaling:groupName"), + Values: []*string{aws.String(e.AutoScalingGroupName)}, + }, + }, + } + + result, err := e.EC2Conn.DescribeInstances(input) + if err != nil { + log.Println("Error getting all nodes:", err) + return nil, err + } + + var instances []*ec2.Instance + for _, reservation := range result.Reservations { + for _, instance := range reservation.Instances { + instances = append(instances, instance) + } + } + + sort.Slice(instances, func(i, j int) bool { + return strings.Compare(*instances[i].InstanceId, *instances[j].InstanceId) < 0 + }) + + e.Instances = instances + + return instances, nil +} + +func (e *EC2Instance) GetAllNodesIP() ([]string, error) { + instancesInfo := make([]string, 0) + for _, instance := range e.Instances { + instancesInfo = append(instancesInfo, *instance.PrivateIpAddress) + } + + return instancesInfo, nil +} + +func (e *EC2Instance) ResolveHostNameToIp() ([]string, error) { + + e.GetAllNodes(e.AutoScalingGroupName) + nodeIps, err := e.GetAllNodesIP() + if err != nil { + return nil, fmt.Errorf("Unable to get Node Ips") + } + return nodeIps, nil +} diff --git a/sdk/python/feast/infra/online_stores/aws_utils_online_store.py b/sdk/python/feast/infra/online_stores/aws_utils_online_store.py new file mode 100644 index 0000000000..a3bc7cd1c6 --- /dev/null +++ b/sdk/python/feast/infra/online_stores/aws_utils_online_store.py @@ -0,0 +1,61 @@ +import logging +from abc import ABC +import boto3 + +class Ec2_instance(ABC): + + def __init__(self, autoscalinggroup_name, region='us-east-1'): + self.ec2_conn = boto3.resource('ec2', region_name=region) + self.autoscalinggroup_name = autoscalinggroup_name + self.region = region + self.instances = [] + + def get_ipaddress(self, instanceid): + try: + logging.debug("Getting Ip Address") + instances = self.ec2_conn.instances.filter(InstanceIds=[instanceid]) + for instance in instances: + return instance.private_ip_address + except Exception as e: + logging.error("Exception raised while getting ip address - {0} ".format(e)) + return None + + def get_node_info(self, nodenum): + try: + logging.debug("Getting IP Address of node {0}".format(nodenum)) + if nodenum > len(self.instances): + raise Exception('Node number is out of range.') + return {'id': self.instances[nodenum - 1].id, 'ip': self.instances[nodenum - 1].private_ip_address} + except Exception as e: + logging.error("Exception raised while getting information on {0} node - {1}".format(nodenum, e)) + return None + + def get_all_nodes(self): + try: + logging.debug("Get all AutoScaling group nodes") + running_state_filter = {'Name': 'tag:region', 'Values': [self.region]} + asg_filter = {'Name': 'tag:aws:autoscaling:groupName', 'Values': [self.autoscalinggroup_name]} + instances = self.ec2_conn.instances.filter(Filters=[asg_filter, running_state_filter]) + sorted_instances = sorted(instances, key=lambda instance: (instance.launch_time, instance.id)) + self.instances = sorted_instances + return sorted_instances + except Exception as e: + logging.error(e) + return None + + def get_all_nodes_ip(self): + try: + logging.debug("Getting Node ID and IP Address of all nodes ") + instances_info = [] + for instance in self.instances: + instances_info.append(instance.private_ip_address) + return instances_info + except Exception as e: + logging.error(e) + + return + + def resolve_host_to_ip_address(self) -> [list]: + self.get_all_nodes() + nodeips = self.get_all_nodes_ip() + return nodeips diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index 0870bc709d..fcfaa69561 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -51,6 +51,7 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel +from feast.infra.online_stores.aws_utils_online_store import HostResolver # Error messages E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS = ( @@ -88,7 +89,7 @@ event_ts TIMESTAMP, created_ts TIMESTAMP, PRIMARY KEY ((entity_key), feature_name) - ) WITH CLUSTERING ORDER BY (feature_name ASC); + ) WITH CLUSTERING ORDER BY (feature_name ASC) {optional_ttl_clause}; """ DROP_TABLE_CQL_TEMPLATE = "DROP TABLE IF EXISTS {fqtable};" @@ -132,6 +133,9 @@ class CassandraOnlineStoreConfig(FeastConfigBaseModel): hosts: Optional[List[StrictStr]] = None """List of host addresses to reach the cluster.""" + host_names: Optional[List[StrictStr]] = None + """List of host names to reach the clusters""" + secure_bundle_path: Optional[StrictStr] = None """Path to the secure connect bundle (for Astra DB; replaces hosts).""" @@ -153,6 +157,9 @@ class CassandraOnlineStoreConfig(FeastConfigBaseModel): request_timeout: Optional[StrictFloat] = None """Request timeout in seconds.""" + ttl: Optional[StrictInt] = None + '''Time to live option''' + class CassandraLoadBalancingPolicy(FeastConfigBaseModel): """ Configuration block related to the Cluster's load-balancing policy. @@ -225,14 +232,22 @@ def _get_session(self, config: RepoConfig): return self._session if not self._session: # configuration consistency checks - hosts = online_store_config.hosts secure_bundle_path = online_store_config.secure_bundle_path - port = online_store_config.port or 9042 + port = 19042 if online_store_config.type == "scylladb" else (online_store_config.port or 9042) keyspace = online_store_config.keyspace username = online_store_config.username password = online_store_config.password protocol_version = online_store_config.protocol_version + if online_store_config.type == "scylladb": + # Using the shard aware functionality + if online_store_config.load_balancing is None: + online_store_config.load_balancing = "TokenAwarePolicy(DCAwareRoundRobinPolicy)" + if online_store_config.protocol_version is None: + protocol_version = 4 + if not online_store_config.hosts and online_store_config.host_names: + online_store_config.hosts = HostResolver.resolve_host_to_ip_address(host_names=online_store_config.host_names) + hosts = online_store_config.hosts db_directions = hosts or secure_bundle_path if not db_directions or not keyspace: raise CassandraInvalidConfig(E_CASSANDRA_NOT_CONFIGURED) @@ -562,7 +577,10 @@ def _create_table(self, config: RepoConfig, project: str, table: FeatureView): fqtable = CassandraOnlineStore._fq_table_name(keyspace, project, table) create_cql = self._get_cql_statement(config, "create", fqtable) logger.info(f"Creating table {fqtable}.") - session.execute(create_cql) + if config.online_config.ttl: + session.execute(create_cql, parameters=config.online_config.ttl) + else: + session.execute(create_cql) def _get_cql_statement( self, config: RepoConfig, op_name: str, fqtable: str, **kwargs @@ -579,9 +597,15 @@ def _get_cql_statement( """ session: Session = self._get_session(config) template, prepare = CQL_TEMPLATE_MAP[op_name] + if op_name == "create" and config.online_config.ttl: + ttl_clause = " USING TTL ?" + else: + ttl_clause = None + statement = template.format( fqtable=fqtable, - **kwargs, + optional_ttl_clause=ttl_clause, + **kwargs ) if prepare: # using the statement itself as key (no problem with that) diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index b46c980ba6..c98e35e09b 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -59,6 +59,7 @@ "postgres": "feast.infra.online_stores.contrib.postgres.PostgreSQLOnlineStore", "hbase": "feast.infra.online_stores.contrib.hbase_online_store.hbase.HbaseOnlineStore", "cassandra": "feast.infra.online_stores.contrib.cassandra_online_store.cassandra_online_store.CassandraOnlineStore", + "scylladb": "feast.infra.online_stores.contrib.cassandra_online_store.cassandra_online_store.CassandraOnlineStore", "mysql": "feast.infra.online_stores.contrib.mysql_online_store.mysql.MySQLOnlineStore", "rockset": "feast.infra.online_stores.contrib.rockset_online_store.rockset.RocksetOnlineStore", "hazelcast": "feast.infra.online_stores.contrib.hazelcast_online_store.hazelcast_online_store.HazelcastOnlineStore", diff --git a/setup.py b/setup.py index 6ea0f76d37..510cb2e9a8 100644 --- a/setup.py +++ b/setup.py @@ -112,6 +112,10 @@ "cassandra-driver>=3.24.0,<4", ] +SCYLLADB_REQUIRED = [ + "scylla-driver>=3.24.0,<4", +] + GE_REQUIRED = ["great_expectations>=0.15.41"] AZURE_REQUIRED = [ @@ -472,6 +476,7 @@ def run(self): "hbase": HBASE_REQUIRED, "docs": DOCS_REQUIRED, "cassandra": CASSANDRA_REQUIRED, + "scylladb": SCYLLADB_REQUIRED, "hazelcast": HAZELCAST_REQUIRED, "grpcio": GRPCIO_REQUIRED, "rockset": ROCKSET_REQUIRED, From 51549145f41c8aeff1f55b821ab48bc470316809 Mon Sep 17 00:00:00 2001 From: msistla96 Date: Tue, 17 Sep 2024 17:49:59 -0500 Subject: [PATCH 02/18] Modify TTL clause for CREATE TABLE --- .../contrib/cassandra_online_store/cassandra_online_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index fcfaa69561..4b10e4cada 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -598,7 +598,7 @@ def _get_cql_statement( session: Session = self._get_session(config) template, prepare = CQL_TEMPLATE_MAP[op_name] if op_name == "create" and config.online_config.ttl: - ttl_clause = " USING TTL ?" + ttl_clause = " WITH default_time_to_live = ?" else: ttl_clause = None From 5d3bf5fa8635c572036b5b932f263f1b3ded2e1a Mon Sep 17 00:00:00 2001 From: msistla96 Date: Fri, 27 Sep 2024 15:37:47 -0500 Subject: [PATCH 03/18] Remove using AWS utils for querying EC2 instances for host name; Cleanup code a bit --- .../feast/onlinestore/cassandraonlinestore.go | 66 +++---- .../onlinestore/cassandraonlinestore_test.go | 163 ++++++++++++++++++ .../cassandra_online_store.py | 6 - 3 files changed, 190 insertions(+), 45 deletions(-) create mode 100644 go/internal/feast/onlinestore/cassandraonlinestore_test.go diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 1436182374..6f3fb39f17 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "github.com/feast-dev/feast/go/internal/feast/registry" - "github.com/feast-dev/feast/go/internal/feast/utils" "github.com/feast-dev/feast/go/protos/feast/serving" "github.com/feast-dev/feast/go/protos/feast/types" "github.com/gocql/gocql" @@ -34,7 +33,7 @@ type CassandraOnlineStore struct { keyspace string // Host IP addresses of the cluster - hostIps []string + hosts []string config *registry.RepoConfig } @@ -45,37 +44,39 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online config: config, } - var username string - var password string // Parse host_name and Ips - cassandraHostNames, ok1 := onlineStoreConfig["host_names"] - cassandraHostIps, ok2 := onlineStoreConfig["hosts"] - if !ok1 && !ok2 { - cassandraHostIps = "127.0.0.1" + cassandraHosts, ok := onlineStoreConfig["hosts"] + if !ok { + cassandraHosts = "127.0.0.1" + log.Warn().Msg("Host not provided: Using localhost instead") } - var cassandraHostNameStr string - if cassandraHostNameStr, ok1 = cassandraHostNames.(string); !ok1 { - return nil, fmt.Errorf("failed to convert host_names to string: %+v", cassandraHostNameStr) - } - - var cassandraHostIpsStr string - if cassandraHostIpsStr, ok2 = cassandraHostIps.(string); !ok2 { - return nil, fmt.Errorf("failed to convert hosts(ip addresses) to string: %+v", cassandraHostIpsStr) + var cassandraHostStr string + if cassandraHostStr, ok = cassandraHosts.(string); !ok { + return nil, fmt.Errorf("failed to convert hosts to string: %+v", cassandraHostStr) } - username = onlineStoreConfig["username"].(string) - password = onlineStoreConfig["password"].(string) - - if len(username) == 0 { + username, ok := onlineStoreConfig["username"] + if !ok { username = "scylla" log.Warn().Msg("Username not defined: Using default username instead") } - if len(username) == 0 { + password, ok := onlineStoreConfig["password"] + if !ok { password = "scylla" log.Warn().Msg("Password not defined: Using default password instead") } + var usernameStr string + if usernameStr, ok = username.(string); !ok { + return nil, fmt.Errorf("failed to convert username to string: %+v", usernameStr) + } + + var passwordStr string + if passwordStr, ok = password.(string); !ok { + return nil, fmt.Errorf("failed to convert password to string: %+v", passwordStr) + } + keyspace, ok := onlineStoreConfig["keyspace"] if !ok { keyspace = project @@ -86,19 +87,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr) } - // If you're using host_names, it means that you will need a host resolver to get the IP of the cluster the DB is in - if len(cassandraHostNameStr) > 0 { - ec2Instance := utils.NewEC2Instance(cassandraHostNameStr, "us-west-2") - hostIps, err := ec2Instance.ResolveHostNameToIp() - if err != nil { - return nil, fmt.Errorf("Unable to resolve host name %+v to ip address: ", cassandraHostNameStr) - } - store.hostIps = hostIps - } else { - store.hostIps = []string{cassandraHostNameStr} - } - - store.clusterConfigs = gocql.NewCluster(store.hostIps...) + store.clusterConfigs = gocql.NewCluster(store.hosts...) // TODO: Figure out if we need to offer users the ability to tune the timeouts //store.clusterConfigs.ConnectTimeout = 1 //store.clusterConfigs.Timeout = 1 @@ -107,7 +96,8 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online store.clusterConfigs.Keyspace = keyspaceStr loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"] if !ok { - return nil, nil + loadBalancingPolicy = gocql.RoundRobinHostPolicy() + log.Warn().Msg("No load balancing policy selected; setting Round Robin Host Policy") } loadBalancingPolicyStr, _ := loadBalancingPolicy.(string) if loadBalancingPolicyStr == "DCAwareRoundRobinPolicy" { @@ -122,14 +112,12 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } else { store.clusterConfigs.Port = 9042 } - } else { - return nil, fmt.Errorf("No load balancing policy specified") } - store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: username, Password: password} + store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} createdSession, err1 := gocql.NewSession(*store.clusterConfigs) if err1 != nil { - return nil, fmt.Errorf("Unable to connect to ScyllaDB") + return nil, fmt.Errorf("Unable to connect to the ScyllaDB database") } store.session = createdSession return &store, nil diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go new file mode 100644 index 0000000000..d5461eefc9 --- /dev/null +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -0,0 +1,163 @@ +package onlinestore + +import ( + "context" + "testing" + + "github.com/feast-dev/feast/go/internal/feast/registry" + "github.com/feast-dev/feast/go/protos/feast/types" + + "github.com/stretchr/testify/assert" +) + +func TestNewCassandraOnlineStore(t *testing.T) { + var config = map[string]interface{}{ + "username": "", + "password": "", + "hosts": "", + "keyspace": "", + "load_balancing": "", + } + rc := ®istry.RepoConfig{ + OnlineStore: config, + EntityKeySerializationVersion: 2, + } + store, err := NewCassandraOnlineStore("test", rc, config) + assert.Nil(t, err) + assert.Equal(t, store.hosts, "127.0.0.1") + assert.Equal(t, store.keyspace, "test") + assert.Equal(t, store.username, "scylla") + assert.Equal(t, store.password, "scylla") + assert.Nil(t, store.session) +} + +func TestNewCassandraOnlineStoreWithPassword(t *testing.T) { + var config = map[string]interface{}{ + "connection_string": "", + } + rc := ®istry.RepoConfig{ + OnlineStore: config, + EntityKeySerializationVersion: 2, + } + store, err := NewRedisOnlineStore("test", rc, config) + assert.Nil(t, err) + var opts = store.client.Options() + assert.Equal(t, opts.Addr, "redis://localhost:6379") + assert.Equal(t, opts.Password, "secret") +} + +func TestCassandraOnlineStore_SerializeCassandraEntityKey(t *testing.T) { + store := CassandraOnlineStore{} + entityKey := &types.EntityKey{ + JoinKeys: []string{"key1", "key2"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + } + _, err := store.serializeCassandraEntityKey(entityKey, 2) + assert.Nil(t, err) +} + +func TestCassandraOnlineStore_SerializeCassandraEntityKey_InvalidEntityKey(t *testing.T) { + store := CassandraOnlineStore{} + entityKey := &types.EntityKey{ + JoinKeys: []string{"key1"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + } + _, err := store.serializeCassandraEntityKey(entityKey, 2) + assert.NotNil(t, err) +} + +func TestCassandraOnlineStore_SerializeValue(t *testing.T) { + store := CassandraOnlineStore{} + _, _, err := store.serializeValue(&types.Value_StringVal{StringVal: "value1"}, 2) + assert.Nil(t, err) +} + +func TestCassandraOnlineStore_SerializeValue_InvalidValue(t *testing.T) { + store := CassandraOnlineStore{} + _, _, err := store.serializeValue(nil, 2) + assert.NotNil(t, err) +} + +func TestCassandraOnlineStore_BuildCassandraKeys(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1", "key2"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + _, _, err := store.buildCassandraKeys(entityKeys) + assert.Nil(t, err) +} + +func TestCassandraOnlineStore_BuildCassandraKeys_InvalidEntityKeys(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + _, _, err := store.buildCassandraKeys(entityKeys) + assert.NotNil(t, err) +} + +func TestCassandraOnlineStore_OnlineRead_HappyPath(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1", "key2"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + featureViewNames := []string{"featureView1"} + featureNames := []string{"feature1"} + + _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) + assert.Nil(t, err) +} + +func TestCassandraOnlineStore_OnlineRead_InvalidEntityKey(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + featureViewNames := []string{"featureView1"} + featureNames := []string{"feature1"} + + _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) + assert.NotNil(t, err) +} + +func TestCassandraOnlineStore_OnlineRead_NoFeatureViewNames(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1", "key2"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + featureViewNames := []string{} + featureNames := []string{"feature1"} + + _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) + assert.NotNil(t, err) +} + +func TestCassandraOnlineStore_OnlineRead_NoFeatureNames(t *testing.T) { + store := CassandraOnlineStore{} + entityKeys := []*types.EntityKey{ + { + JoinKeys: []string{"key1", "key2"}, + EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, + }, + } + featureViewNames := []string{"featureView1"} + featureNames := []string{} + + _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) + assert.NotNil(t, err) +} diff --git a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py index 4b10e4cada..e8952ca280 100644 --- a/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py +++ b/sdk/python/feast/infra/online_stores/contrib/cassandra_online_store/cassandra_online_store.py @@ -51,7 +51,6 @@ from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel -from feast.infra.online_stores.aws_utils_online_store import HostResolver # Error messages E_CASSANDRA_UNEXPECTED_CONFIGURATION_CLASS = ( @@ -133,9 +132,6 @@ class CassandraOnlineStoreConfig(FeastConfigBaseModel): hosts: Optional[List[StrictStr]] = None """List of host addresses to reach the cluster.""" - host_names: Optional[List[StrictStr]] = None - """List of host names to reach the clusters""" - secure_bundle_path: Optional[StrictStr] = None """Path to the secure connect bundle (for Astra DB; replaces hosts).""" @@ -244,8 +240,6 @@ def _get_session(self, config: RepoConfig): online_store_config.load_balancing = "TokenAwarePolicy(DCAwareRoundRobinPolicy)" if online_store_config.protocol_version is None: protocol_version = 4 - if not online_store_config.hosts and online_store_config.host_names: - online_store_config.hosts = HostResolver.resolve_host_to_ip_address(host_names=online_store_config.host_names) hosts = online_store_config.hosts db_directions = hosts or secure_bundle_path From 95e0544855db7f8c49f8dc247727c73fbde40cec Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Fri, 15 Nov 2024 18:02:36 -0500 Subject: [PATCH 04/18] Fix config parsing, and multi key retrieval --- .../feast/onlinestore/cassandraonlinestore.go | 198 +++++++++++------- .../feast/onlinestore/sqliteonlinestore.go | 17 +- 2 files changed, 130 insertions(+), 85 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 6f3fb39f17..e03aa39d61 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -2,7 +2,6 @@ package onlinestore import ( "context" - "crypto/sha1" "encoding/binary" "encoding/hex" "errors" @@ -13,9 +12,11 @@ import ( "github.com/gocql/gocql" "github.com/golang/protobuf/proto" "github.com/rs/zerolog/log" + "google.golang.org/protobuf/types/known/timestamppb" _ "google.golang.org/protobuf/types/known/timestamppb" _ "net" "sort" + "strings" "time" _ "time" ) @@ -49,11 +50,21 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online if !ok { cassandraHosts = "127.0.0.1" log.Warn().Msg("Host not provided: Using localhost instead") + } + var rawCassandraHosts []interface{} + if rawCassandraHosts, ok = cassandraHosts.([]interface{}); !ok { + return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field") } - var cassandraHostStr string - if cassandraHostStr, ok = cassandraHosts.(string); !ok { - return nil, fmt.Errorf("failed to convert hosts to string: %+v", cassandraHostStr) + + var cassandraHostsStr = make([]string, len(rawCassandraHosts)) + for i, rawHost := range rawCassandraHosts { + hostStr, ok := rawHost.(string) + if !ok { + return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) + } + fmt.Printf("\tsingle host: %s\n", hostStr) + cassandraHostsStr[i] = hostStr } username, ok := onlineStoreConfig["username"] @@ -79,19 +90,28 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online keyspace, ok := onlineStoreConfig["keyspace"] if !ok { - keyspace = project - log.Warn().Msgf("Keyspace not defined: Using project name %s as keyspace instead", project) + keyspace = "scylladb" + log.Warn().Msg("Keyspace not defined: Using 'scylladb' as keyspace instead") } + store.keyspace = keyspace.(string) + var keyspaceStr string if keyspaceStr, ok = keyspace.(string); !ok { return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr) } - store.clusterConfigs = gocql.NewCluster(store.hosts...) + protocolVersion, ok := onlineStoreConfig["protocol_version"] + if !ok { + protocolVersion = 4.0 + log.Warn().Msg("protocol_version not specified: Using 4 instead") + } + protocolVersionInt := int(protocolVersion.(float64)) + + store.clusterConfigs = gocql.NewCluster(cassandraHostsStr...) // TODO: Figure out if we need to offer users the ability to tune the timeouts //store.clusterConfigs.ConnectTimeout = 1 //store.clusterConfigs.Timeout = 1 - store.clusterConfigs.ProtoVersion = onlineStoreConfig["protocol"].(int) + store.clusterConfigs.ProtoVersion = protocolVersionInt //store.clusterConfigs.Consistency = gocql.Quorum store.clusterConfigs.Keyspace = keyspaceStr loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"] @@ -115,21 +135,36 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} - createdSession, err1 := gocql.NewSession(*store.clusterConfigs) - if err1 != nil { + createdSession, err := gocql.NewSession(*store.clusterConfigs) + if err != nil { return nil, fmt.Errorf("Unable to connect to the ScyllaDB database") } store.session = createdSession return &store, nil } -func (c *CassandraOnlineStore) getTableName(tableName string) string { - return fmt.Sprintf(`"%s"_"%s"_""%s"`, c.keyspace, c.project, tableName) +func (c *CassandraOnlineStore) getFqTableName(tableName string) string { + return fmt.Sprintf(`"%s"."%s_%s"`, c.keyspace, c.project, tableName) } -func (c *CassandraOnlineStore) getCQLStatement(tableName string) string { - selectStatement := "SELECT entity_key, feature_name, value, event_ts FROM %s WHERE entity_key IN ?;" - return fmt.Sprintf(selectStatement, tableName) +func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string { + // TODO: Compare with multiple single-key concurrent queries like in the Python feature server + keyPlaceholders := make([]string, nKeys) + for i := 0; i < nKeys; i++ { + keyPlaceholders[i] = "?" + } + // this prevents fetching unnecessary features + quotedFeatureNames := make([]string, len(featureNames)) + for i, featureName := range featureNames { + quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) + } + + return fmt.Sprintf( + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, + tableName, + strings.Join(keyPlaceholders, ","), + strings.Join(quotedFeatureNames, ","), + ) } func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]interface{}, map[string]int, error) { @@ -141,64 +176,103 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti return nil, nil, err } // encoding to hex - cassandraKeys[i] = hex.EncodeToString(*key) - cassandraKeyToEntityIndex[hashSerializedEntityKey(key)] = i + encodedKey := hex.EncodeToString(*key) + cassandraKeys[i] = encodedKey + cassandraKeyToEntityIndex[encodedKey] = i } return cassandraKeys, cassandraKeyToEntityIndex, nil } func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + uniqueNames := make(map[string]int32) + for _, fvName := range featureViewNames { + uniqueNames[fvName] = 0 + } + if len(uniqueNames) != 1 { + return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") + } serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + if err != nil { return nil, fmt.Errorf("Error when serializing entity keys for Cassandra") } results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) + } + featureNamesToIdx := make(map[string]int) for idx, name := range featureNames { featureNamesToIdx[name] = idx } - for _, featureViewName := range featureViewNames { - // Prepare the query - tableName := c.getTableName(featureViewName) - cqlStatement := c.getCQLStatement(tableName) - // Bundle the entity keys in one statement (gocql handles this as concurrent queries) - query := c.session.Query(cqlStatement, serializedEntityKeys...) - scan := query.Iter().Scanner() - - // Process the results - var entityKey []byte - var featureName string - var eventTs time.Time - var valueStr []byte - var deserializedValue types.Value - for scan.Next() { - err := scan.Scan(&entityKey, &featureName, &valueStr, &eventTs) - if err != nil { - return nil, errors.New("Could not read row in query for (entity key, feature name, value, event ts)") + featureViewName := featureViewNames[0] + + // Prepare the query + tableName := c.getFqTableName(featureViewName) + cqlStatement := c.getCQLStatement(tableName, featureNames, len(serializedEntityKeys)) + // Bundle the entity keys in one statement (gocql handles this as concurrent queries) + scanner := c.session.Query(cqlStatement, serializedEntityKeys...).Iter().Scanner() + + // Process the results + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + return nil, errors.New("could not read row in query for (entity key, feature name, value, event ts)") + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + return nil, errors.New("error converting parsed Cassandra Value to types.Value") + } + + var featureValues FeatureData + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, } - if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { - return nil, errors.New("error converting parsed Cassandra Value to types.Value") + } else { + // Return FeatureData with a null value + featureValues = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, } + } - var featureValues FeatureData - if deserializedValue.Val != nil { - // Convert the value to a FeatureData struct - featureValues = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureName, - }, - Value: types.Value{ - Val: deserializedValue.Val, - }, - } - } else { - // Return FeatureData with a null value - featureValues = FeatureData{ + // Add the FeatureData to the results + rowIndx := serializedEntityKeyToIndex[entityKey] + results[rowIndx][featureNamesToIdx[featureName]] = featureValues + } + // Check for errors from the Scanner + if err := scanner.Err(); err != nil { + return nil, errors.New("failed to scan features: " + err.Error()) + } + + for i := 0; i < len(entityKeys); i++ { + for j := 0; j < len(featureNames); j++ { + if results[i][j].Timestamp.GetSeconds() == 0 { + results[i][j] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: featureViewName, - FeatureName: featureName, + FeatureName: featureViewNames[j], }, Value: types.Value{ Val: &types.Value_NullVal{ @@ -206,15 +280,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ }, }, } - } - // Add the FeatureData to the results - rowIndx := serializedEntityKeyToIndex[hashSerializedEntityKey(&entityKey)] - results[rowIndx][featureNamesToIdx[featureName]] = featureValues - } - // Check for errors from the Scanner - if err := scan.Err(); err != nil { - return nil, err } } @@ -312,16 +378,6 @@ func serializeCassandraValue(value interface{}, entityKeySerializationVersion in } } -func hashSerializedEntityKey(serializedEntityKey *[]byte) string { - if serializedEntityKey == nil { - return "" - } - h := sha1.New() - h.Write(*serializedEntityKey) - sha1_hash := hex.EncodeToString(h.Sum(nil)) - return sha1_hash -} - func (c *CassandraOnlineStore) Destruct() { } diff --git a/go/internal/feast/onlinestore/sqliteonlinestore.go b/go/internal/feast/onlinestore/sqliteonlinestore.go index 6c37258e74..a329dc1be6 100644 --- a/go/internal/feast/onlinestore/sqliteonlinestore.go +++ b/go/internal/feast/onlinestore/sqliteonlinestore.go @@ -1,10 +1,9 @@ package onlinestore import ( - "crypto/sha1" "database/sql" - "encoding/hex" "errors" + "github.com/feast-dev/feast/go/internal/feast/utils" "strings" "sync" "time" @@ -76,7 +75,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. return nil, err } // TODO: fix this, string conversion is not safe - entityNameToEntityIndex[hashSerializedEntityKey(serKey)] = i + entityNameToEntityIndex[utils.HashSerializedEntityKey(serKey)] = i // for IN clause in read query in_query[i] = "?" serialized_entities[i] = *serKey @@ -109,7 +108,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. if err := proto.Unmarshal(valueString, &value); err != nil { return nil, errors.New("error converting parsed value to types.Value") } - rowIdx := entityNameToEntityIndex[hashSerializedEntityKey(&entity_key)] + rowIdx := entityNameToEntityIndex[utils.HashSerializedEntityKey(&entity_key)] if results[rowIdx] == nil { results[rowIdx] = make([]FeatureData, featureCount) } @@ -152,13 +151,3 @@ func initializeConnection(db_path string) (*sql.DB, error) { } return db, nil } - -func hashSerializedEntityKey(serializedEntityKey *[]byte) string { - if serializedEntityKey == nil { - return "" - } - h := sha1.New() - h.Write(*serializedEntityKey) - sha1_hash := hex.EncodeToString(h.Sum(nil)) - return sha1_hash -} From 254074c4b20e6543995b2a87fc7db86b23000f1f Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Fri, 15 Nov 2024 18:07:17 -0500 Subject: [PATCH 05/18] Add small comment --- go/internal/feast/onlinestore/cassandraonlinestore.go | 1 + 1 file changed, 1 insertion(+) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index e03aa39d61..d6a36ffec2 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -266,6 +266,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ return nil, errors.New("failed to scan features: " + err.Error()) } + // Will fill feature slots that were left empty with null values for i := 0; i < len(entityKeys); i++ { for j := 0; j < len(featureNames); j++ { if results[i][j].Timestamp.GetSeconds() == 0 { From aa1b41bb343fcdc580d7c4d1aad49dfffd5b893b Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 18 Nov 2024 10:54:29 -0500 Subject: [PATCH 06/18] Fix build --- go/internal/feast/utils/key_utils.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 go/internal/feast/utils/key_utils.go diff --git a/go/internal/feast/utils/key_utils.go b/go/internal/feast/utils/key_utils.go new file mode 100644 index 0000000000..a3e20580f7 --- /dev/null +++ b/go/internal/feast/utils/key_utils.go @@ -0,0 +1,15 @@ +package utils + +import ( + "crypto/sha1" + "encoding/hex" +) + +func HashSerializedEntityKey(serializedEntityKey *[]byte) string { + if serializedEntityKey == nil { + return "" + } + h := sha1.New() + h.Write(*serializedEntityKey) + return hex.EncodeToString(h.Sum(nil)) +} From 69d46fbeba3a724bc924481d2266276662690f69 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 18 Nov 2024 14:02:43 -0500 Subject: [PATCH 07/18] Added DD tracing to Go Cassandra store --- .../feast/onlinestore/cassandraonlinestore.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index d6a36ffec2..aec88af008 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -14,7 +14,9 @@ import ( "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" _ "google.golang.org/protobuf/types/known/timestamppb" + gocqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gocql/gocql" _ "net" + "os" "sort" "strings" "time" @@ -25,10 +27,10 @@ type CassandraOnlineStore struct { project string // Cluster configurations for Cassandra/ScyllaDB - clusterConfigs *gocql.ClusterConfig + clusterConfigs *gocqltrace.ClusterConfig // Session object that holds information about the connection to the cluster - session *gocql.Session + session *gocqltrace.Session // keyspace of the table. Defaulted to using the project name keyspace string @@ -107,7 +109,11 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } protocolVersionInt := int(protocolVersion.(float64)) - store.clusterConfigs = gocql.NewCluster(cassandraHostsStr...) + redisTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" + if redisTraceServiceName == "" { + redisTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set + } + store.clusterConfigs = gocqltrace.NewCluster(cassandraHostsStr, gocqltrace.WithServiceName(redisTraceServiceName)) // TODO: Figure out if we need to offer users the ability to tune the timeouts //store.clusterConfigs.ConnectTimeout = 1 //store.clusterConfigs.Timeout = 1 @@ -135,7 +141,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} - createdSession, err := gocql.NewSession(*store.clusterConfigs) + createdSession, err := store.clusterConfigs.CreateSession() if err != nil { return nil, fmt.Errorf("Unable to connect to the ScyllaDB database") } From 4b560b72bd69f616658005bd6cccedab6923f4b4 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 18 Nov 2024 15:41:50 -0500 Subject: [PATCH 08/18] Refactored common code across go servers --- .../feast/onlinestore/cassandraonlinestore.go | 125 +++--------------- .../feast/onlinestore/redisonlinestore.go | 96 +------------- .../feast/onlinestore/sqliteonlinestore.go | 2 +- go/internal/feast/utils/key_utils.go | 95 +++++++++++++ 4 files changed, 115 insertions(+), 203 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index aec88af008..db77a14c22 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -2,25 +2,23 @@ package onlinestore import ( "context" - "encoding/binary" "encoding/hex" "errors" "fmt" + "os" + "strings" + "time" + "github.com/feast-dev/feast/go/internal/feast/registry" + "github.com/feast-dev/feast/go/internal/feast/utils" "github.com/feast-dev/feast/go/protos/feast/serving" "github.com/feast-dev/feast/go/protos/feast/types" "github.com/gocql/gocql" "github.com/golang/protobuf/proto" "github.com/rs/zerolog/log" + "google.golang.org/protobuf/types/known/timestamppb" - _ "google.golang.org/protobuf/types/known/timestamppb" gocqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gocql/gocql" - _ "net" - "os" - "sort" - "strings" - "time" - _ "time" ) type CassandraOnlineStore struct { @@ -50,8 +48,8 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online // Parse host_name and Ips cassandraHosts, ok := onlineStoreConfig["hosts"] if !ok { - cassandraHosts = "127.0.0.1" - log.Warn().Msg("Host not provided: Using localhost instead") + cassandraHosts = []interface{}{"127.0.0.1"} + log.Warn().Msg("host not provided: Using 127.0.0.1 instead") } var rawCassandraHosts []interface{} @@ -65,19 +63,19 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online if !ok { return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) } - fmt.Printf("\tsingle host: %s\n", hostStr) cassandraHostsStr[i] = hostStr } username, ok := onlineStoreConfig["username"] if !ok { - username = "scylla" - log.Warn().Msg("Username not defined: Using default username instead") + username = "cassandra" + log.Warn().Msg("username not defined: Using default username instead") } + password, ok := onlineStoreConfig["password"] if !ok { - password = "scylla" - log.Warn().Msg("Password not defined: Using default password instead") + password = "cassandra" + log.Warn().Msg("password not defined: Using default password instead") } var usernameStr string @@ -143,7 +141,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} createdSession, err := store.clusterConfigs.CreateSession() if err != nil { - return nil, fmt.Errorf("Unable to connect to the ScyllaDB database") + return nil, fmt.Errorf("unable to connect to the ScyllaDB database") } store.session = createdSession return &store, nil @@ -152,6 +150,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online func (c *CassandraOnlineStore) getFqTableName(tableName string) string { return fmt.Sprintf(`"%s"."%s_%s"`, c.keyspace, c.project, tableName) } + func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string { // TODO: Compare with multiple single-key concurrent queries like in the Python feature server keyPlaceholders := make([]string, nKeys) @@ -177,7 +176,7 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti cassandraKeys := make([]interface{}, len(entityKeys)) cassandraKeyToEntityIndex := make(map[string]int) for i := 0; i < len(entityKeys); i++ { - var key, err = serializeCassandraEntityKey(entityKeys[i], c.config.EntityKeySerializationVersion) + var key, err = utils.SerializeEntityKey(entityKeys[i], c.config.EntityKeySerializationVersion) if err != nil { return nil, nil, err } @@ -200,7 +199,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) if err != nil { - return nil, fmt.Errorf("Error when serializing entity keys for Cassandra") + return nil, fmt.Errorf("error when serializing entity keys for Cassandra") } results := make([][]FeatureData, len(entityKeys)) for i := range results { @@ -295,96 +294,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ } -// Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. -func serializeCassandraEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { - // Ensure that we have the right amount of join keys and entity values - if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { - return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) - } - // Make sure that join keys are sorted so that we have consistent key building - m := make(map[string]*types.Value) - - for i := 0; i < len(entityKey.JoinKeys); i++ { - m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] - } - - keys := make([]string, 0, len(m)) - for k := range entityKey.JoinKeys { - keys = append(keys, entityKey.JoinKeys[k]) - } - sort.Strings(keys) - - // Build the key - length := 5 * len(keys) - bufferList := make([][]byte, length) - - for i := 0; i < len(keys); i++ { - offset := i * 2 - byteBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) - bufferList[offset] = byteBuffer - bufferList[offset+1] = []byte(keys[i]) - } - - for i := 0; i < len(keys); i++ { - offset := (2 * len(keys)) + (i * 3) - value := m[keys[i]].GetVal() - - valueBytes, valueTypeBytes, err := serializeCassandraValue(value, entityKeySerializationVersion) - if err != nil { - return valueBytes, err - } - - typeBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) - - lenBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) - - bufferList[offset+0] = typeBuffer - bufferList[offset+1] = lenBuffer - bufferList[offset+2] = *valueBytes - } - - // Convert from an array of byte arrays to a single byte array - var entityKeyBuffer []byte - for i := 0; i < len(bufferList); i++ { - entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) - } - - return &entityKeyBuffer, nil -} - -func serializeCassandraValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { - // TODO: Implement support for other types (at least the major types like ints, strings, bytes) - switch x := (value).(type) { - case *types.Value_StringVal: - valueString := []byte(x.StringVal) - return &valueString, types.ValueType_STRING, nil - case *types.Value_BytesVal: - return &x.BytesVal, types.ValueType_BYTES, nil - case *types.Value_Int32Val: - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) - return &valueBuffer, types.ValueType_INT32, nil - case *types.Value_Int64Val: - if entityKeySerializationVersion <= 1 { - // We unfortunately have to use 32 bit here for backward compatibility :( - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } else { - valueBuffer := make([]byte, 8) - binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } - case nil: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - default: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - } -} - func (c *CassandraOnlineStore) Destruct() { } diff --git a/go/internal/feast/onlinestore/redisonlinestore.go b/go/internal/feast/onlinestore/redisonlinestore.go index f38ee92029..747c37a180 100644 --- a/go/internal/feast/onlinestore/redisonlinestore.go +++ b/go/internal/feast/onlinestore/redisonlinestore.go @@ -6,8 +6,8 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/feast-dev/feast/go/internal/feast/utils" "os" - "sort" "strconv" "strings" @@ -329,102 +329,10 @@ func (r *RedisOnlineStore) Destruct() { } func buildRedisKey(project string, entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { - serKey, err := serializeEntityKey(entityKey, entityKeySerializationVersion) + serKey, err := utils.SerializeEntityKey(entityKey, entityKeySerializationVersion) if err != nil { return nil, err } fullKey := append(*serKey, []byte(project)...) return &fullKey, nil } - -func serializeEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { - // Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. - - // Ensure that we have the right amount of join keys and entity values - if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { - return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) - } - - // Make sure that join keys are sorted so that we have consistent key building - m := make(map[string]*types.Value) - - for i := 0; i < len(entityKey.JoinKeys); i++ { - m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] - } - - keys := make([]string, 0, len(m)) - for k := range entityKey.JoinKeys { - keys = append(keys, entityKey.JoinKeys[k]) - } - sort.Strings(keys) - - // Build the key - length := 5 * len(keys) - bufferList := make([][]byte, length) - - for i := 0; i < len(keys); i++ { - offset := i * 2 - byteBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) - bufferList[offset] = byteBuffer - bufferList[offset+1] = []byte(keys[i]) - } - - for i := 0; i < len(keys); i++ { - offset := (2 * len(keys)) + (i * 3) - value := m[keys[i]].GetVal() - - valueBytes, valueTypeBytes, err := serializeValue(value, entityKeySerializationVersion) - if err != nil { - return valueBytes, err - } - - typeBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) - - lenBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) - - bufferList[offset+0] = typeBuffer - bufferList[offset+1] = lenBuffer - bufferList[offset+2] = *valueBytes - } - - // Convert from an array of byte arrays to a single byte array - var entityKeyBuffer []byte - for i := 0; i < len(bufferList); i++ { - entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) - } - - return &entityKeyBuffer, nil -} - -func serializeValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { - // TODO: Implement support for other types (at least the major types like ints, strings, bytes) - switch x := (value).(type) { - case *types.Value_StringVal: - valueString := []byte(x.StringVal) - return &valueString, types.ValueType_STRING, nil - case *types.Value_BytesVal: - return &x.BytesVal, types.ValueType_BYTES, nil - case *types.Value_Int32Val: - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) - return &valueBuffer, types.ValueType_INT32, nil - case *types.Value_Int64Val: - if entityKeySerializationVersion <= 1 { - // We unfortunately have to use 32 bit here for backward compatibility :( - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } else { - valueBuffer := make([]byte, 8) - binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } - case nil: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - default: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - } -} diff --git a/go/internal/feast/onlinestore/sqliteonlinestore.go b/go/internal/feast/onlinestore/sqliteonlinestore.go index a329dc1be6..95f95610e5 100644 --- a/go/internal/feast/onlinestore/sqliteonlinestore.go +++ b/go/internal/feast/onlinestore/sqliteonlinestore.go @@ -70,7 +70,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. in_query := make([]string, len(entityKeys)) serialized_entities := make([]interface{}, len(entityKeys)) for i := 0; i < len(entityKeys); i++ { - serKey, err := serializeEntityKey(entityKeys[i], s.repoConfig.EntityKeySerializationVersion) + serKey, err := utils.SerializeEntityKey(entityKeys[i], s.repoConfig.EntityKeySerializationVersion) if err != nil { return nil, err } diff --git a/go/internal/feast/utils/key_utils.go b/go/internal/feast/utils/key_utils.go index a3e20580f7..7cf9459455 100644 --- a/go/internal/feast/utils/key_utils.go +++ b/go/internal/feast/utils/key_utils.go @@ -2,7 +2,11 @@ package utils import ( "crypto/sha1" + "encoding/binary" "encoding/hex" + "fmt" + "github.com/feast-dev/feast/go/protos/feast/types" + "sort" ) func HashSerializedEntityKey(serializedEntityKey *[]byte) string { @@ -13,3 +17,94 @@ func HashSerializedEntityKey(serializedEntityKey *[]byte) string { h.Write(*serializedEntityKey) return hex.EncodeToString(h.Sum(nil)) } + +// SerializeEntityKey Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. +func SerializeEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { + // Ensure that we have the right amount of join keys and entity values + if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { + return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) + } + + // Make sure that join keys are sorted so that we have consistent key building + m := make(map[string]*types.Value) + + for i := 0; i < len(entityKey.JoinKeys); i++ { + m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] + } + + keys := make([]string, 0, len(m)) + for k := range entityKey.JoinKeys { + keys = append(keys, entityKey.JoinKeys[k]) + } + sort.Strings(keys) + + // Build the key + length := 5 * len(keys) + bufferList := make([][]byte, length) + + for i := 0; i < len(keys); i++ { + offset := i * 2 + byteBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) + bufferList[offset] = byteBuffer + bufferList[offset+1] = []byte(keys[i]) + } + + for i := 0; i < len(keys); i++ { + offset := (2 * len(keys)) + (i * 3) + value := m[keys[i]].GetVal() + + valueBytes, valueTypeBytes, err := serializeValue(value, entityKeySerializationVersion) + if err != nil { + return valueBytes, err + } + + typeBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) + + lenBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) + + bufferList[offset+0] = typeBuffer + bufferList[offset+1] = lenBuffer + bufferList[offset+2] = *valueBytes + } + + // Convert from an array of byte arrays to a single byte array + var entityKeyBuffer []byte + for i := 0; i < len(bufferList); i++ { + entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) + } + + return &entityKeyBuffer, nil +} + +func serializeValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { + // TODO: Implement support for other types (at least the major types like ints, strings, bytes) + switch x := (value).(type) { + case *types.Value_StringVal: + valueString := []byte(x.StringVal) + return &valueString, types.ValueType_STRING, nil + case *types.Value_BytesVal: + return &x.BytesVal, types.ValueType_BYTES, nil + case *types.Value_Int32Val: + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) + return &valueBuffer, types.ValueType_INT32, nil + case *types.Value_Int64Val: + if entityKeySerializationVersion <= 1 { + // We unfortunately have to use 32 bit here for backward compatibility :( + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } else { + valueBuffer := make([]byte, 8) + binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } + case nil: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + default: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + } +} From 3ded707df1e9ed7147ef6fd6b9582bd3c923e830 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Thu, 21 Nov 2024 15:03:49 -0500 Subject: [PATCH 09/18] Address most of the PR comments --- Makefile | 2 +- go.mod | 6 +- go.sum | 17 +- go/infra/docker/feature-server/Dockerfile | 7 - .../feast/onlinestore/cassandraonlinestore.go | 219 +++++++++----- .../onlinestore/cassandraonlinestore_test.go | 270 ++++++++---------- go/internal/feast/utils/aws_utils.go | 117 -------- 7 files changed, 275 insertions(+), 363 deletions(-) delete mode 100644 go/internal/feast/utils/aws_utils.go diff --git a/Makefile b/Makefile index 9d5721f5a6..2ee3b60771 100644 --- a/Makefile +++ b/Makefile @@ -486,7 +486,7 @@ build-java-docker-dev: build-go-docker-dev: docker buildx build --build-arg VERSION=dev \ -t feastdev/feature-server-go:dev \ - -f SCYLLADB_SHARD_AWARE_ENABLED=$(SCYLLADB_SHARD_AWARE_ENABLED) go/infra/docker/feature-server/Dockerfile --load . + -f go/infra/docker/feature-server/Dockerfile --load . # Documentation diff --git a/go.mod b/go.mod index 39bb23a2a8..b24f011568 100644 --- a/go.mod +++ b/go.mod @@ -41,6 +41,7 @@ require ( require ( github.com/apache/arrow/go/v17 v17.0.0 + github.com/gocql/gocql v0.0.0-20220224095938-0eacd3183625 github.com/rs/zerolog v1.21.0 ) @@ -55,7 +56,6 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/klauspost/asmfmt v1.3.2 // indirect github.com/klauspost/compress v1.17.9 // indirect github.com/klauspost/cpuid/v2 v2.2.8 // indirect @@ -80,6 +80,4 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -require github.com/aws/aws-sdk-go v1.34.28 - -require github.com/gocql/gocql v1.6.0 +replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.4 diff --git a/go.sum b/go.sum index 542a21cc7b..5906274b37 100644 --- a/go.sum +++ b/go.sum @@ -26,8 +26,6 @@ github.com/apache/arrow/go/v17 v17.0.0 h1:RRR2bdqKcdbss9Gxy2NS/hK8i4LDMh23L6BbkN github.com/apache/arrow/go/v17 v17.0.0/go.mod h1:jR7QHkODl15PfYyjM2nU+yTLScZ/qfj7OSUZmJ8putc= github.com/apache/thrift v0.20.0 h1:631+KvYbsBZxmuJjYwhezVsrfc/TbqtZV4QcxOX1fOI= github.com/apache/thrift v0.20.0/go.mod h1:hOk1BQqcp2OLzGsyVXdfMk7YFlMxK3aoEVhjD06QhB8= -github.com/aws/aws-sdk-go v1.34.28 h1:sscPpn/Ns3i0F4HPEWAVcwdIRaZZCuL7llJ2/60yPIk= -github.com/aws/aws-sdk-go v1.34.28/go.mod h1:H7NKnBqNVzoTJpGfLrQkkD+ytBA93eiDYi/+8rV9s48= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= @@ -55,11 +53,8 @@ github.com/ebitengine/purego v0.5.0-alpha h1:pNZNC8WofBTN3Nm196An50C5taL/87BhFR/ github.com/ebitengine/purego v0.5.0-alpha/go.mod h1:ah1In8AOtksoNK6yk5z1HTJeUkC1Ez4Wk2idgGslMwQ= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= -github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= -github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU= -github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8= github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= @@ -70,6 +65,7 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= @@ -81,10 +77,6 @@ github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= -github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= -github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/klauspost/asmfmt v1.3.2 h1:4Ri7ox3EwapiOjCki+hw14RyKk201CN4rzyCJRFLpK4= github.com/klauspost/asmfmt v1.3.2/go.mod h1:AG8TuvYojzulgDAMCnYn50l/5QV3Bs/tp6j0HLHbNSE= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= @@ -123,6 +115,8 @@ github.com/richardartoul/molecule v1.0.1-0.20221107223329-32cfee06a052/go.mod h1 github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/zerolog v1.21.0 h1:Q3vdXlfLNT+OftyBHsU0Y445MD+8m8axjKgf2si0QcM= github.com/rs/zerolog v1.21.0/go.mod h1:ZPhntP/xmq1nnND05hhpAh2QMhSsA4UN3MGZ6O2J3hM= +github.com/scylladb/gocql v1.14.4 h1:MhevwCfyAraQ6RvZYFO3pF4Lt0YhvQlfg8Eo2HEqVQA= +github.com/scylladb/gocql v1.14.4/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= github.com/secure-systems-lab/go-securesystemslib v0.7.0 h1:OwvJ5jQf9LnIAS83waAjPbcMsODrTQUpJ02eNLUoxBg= github.com/secure-systems-lab/go-securesystemslib v0.7.0/go.mod h1:/2gYnlnHVQ6xeGtfIqFy7Do03K4cdCY0A/GlJLDKLHI= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= @@ -176,10 +170,10 @@ golang.org/x/mod v0.18.0 h1:5+9lSbEzPSdWkH32vYPBwEpX8KwDbM52Ud9xBUvNlb0= golang.org/x/mod v0.18.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200202094626-16171245cfb2/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/net v0.26.0 h1:soB7SVo0PWrY4vPW/+ay0jKDNScG2X9wFeYlXIvJsOQ= @@ -251,7 +245,6 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntN gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -259,3 +252,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= inet.af/netaddr v0.0.0-20220811202034-502d2d690317 h1:U2fwK6P2EqmopP/hFLTOAjWTki0qgd4GMJn5X8wOleU= inet.af/netaddr v0.0.0-20220811202034-502d2d690317/go.mod h1:OIezDfdzOgFhuw4HuWapWq2e9l0H9tK4F1j+ETRtF3k= +sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/go/infra/docker/feature-server/Dockerfile b/go/infra/docker/feature-server/Dockerfile index 389925802a..5ac71b93ef 100644 --- a/go/infra/docker/feature-server/Dockerfile +++ b/go/infra/docker/feature-server/Dockerfile @@ -1,7 +1,5 @@ FROM golang:1.22.5 -ARG SCYLLADB_SHARD_AWARE_ENABLED=False -ENV SCYLLADB_SHARD_AWARE_ENABLED=$SCYLLADB_SHARD_AWARE_ENABLED # Update the package list and install the ca-certificates package RUN apt-get update && apt-get install -y ca-certificates RUN apt install -y protobuf-compiler @@ -24,11 +22,6 @@ RUN find ./protos -name "*.proto" \ # Build the Go application RUN go build -o feast ./go/main.go - -RUN if ["$SCYLLADB_SHARD_AWARE_ENABLED"== True]; then \ - go mod edit -replace=github.com/gocql/gocql=github.com/scylladb/gocql@latest; \ - go mod tidy; fi - # Expose ports EXPOSE 8080 diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index db77a14c22..428afe827b 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -30,115 +30,179 @@ type CassandraOnlineStore struct { // Session object that holds information about the connection to the cluster session *gocqltrace.Session - // keyspace of the table. Defaulted to using the project name - keyspace string - - // Host IP addresses of the cluster - hosts []string - config *registry.RepoConfig } -func NewCassandraOnlineStore(project string, config *registry.RepoConfig, onlineStoreConfig map[string]interface{}) (*CassandraOnlineStore, error) { - store := CassandraOnlineStore{ - project: project, - config: config, - } +type CassandraConfig struct { + hosts []string + username string + password string + keyspace string + protocolVersion int + loadBalancingPolicy gocql.HostSelectionPolicy + connectionTimeoutMillis int64 + requestTimeoutMillis int64 + numConnections int +} + +func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, error) { + cassandraConfig := CassandraConfig{} - // Parse host_name and Ips + // parse hosts cassandraHosts, ok := onlineStoreConfig["hosts"] if !ok { - cassandraHosts = []interface{}{"127.0.0.1"} + cassandraConfig.hosts = []string{"127.0.0.1"} log.Warn().Msg("host not provided: Using 127.0.0.1 instead") + } else { + var rawCassandraHosts []any + if rawCassandraHosts, ok = cassandraHosts.([]any); !ok { + return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field") + } + var cassandraHostsStr = make([]string, len(rawCassandraHosts)) + for i, rawHost := range rawCassandraHosts { + hostStr, ok := rawHost.(string) + if !ok { + return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) + } + cassandraHostsStr[i] = hostStr + } + cassandraConfig.hosts = cassandraHostsStr } - var rawCassandraHosts []interface{} - if rawCassandraHosts, ok = cassandraHosts.([]interface{}); !ok { - return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field") - } - - var cassandraHostsStr = make([]string, len(rawCassandraHosts)) - for i, rawHost := range rawCassandraHosts { - hostStr, ok := rawHost.(string) + // parse username + rawUsername, ok := onlineStoreConfig["username"] + if !ok { + cassandraConfig.username = "cassandra" + log.Warn().Msg("username not defined: Using default username instead") + } else { + cassandraConfig.username, ok = rawUsername.(string) if !ok { - return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) + return nil, fmt.Errorf("failed to convert username to string: %v", rawUsername) } - cassandraHostsStr[i] = hostStr } - username, ok := onlineStoreConfig["username"] + // parse password + rawPassword, ok := onlineStoreConfig["password"] if !ok { - username = "cassandra" - log.Warn().Msg("username not defined: Using default username instead") + cassandraConfig.password = "cassandra" + log.Warn().Msg("password not defined: Using default password instead") + } else { + cassandraConfig.password, ok = rawPassword.(string) + if !ok { + return nil, fmt.Errorf("failed to convert password to string: %v", rawPassword) + } } - password, ok := onlineStoreConfig["password"] + // parse keyspace + rawKeyspace, ok := onlineStoreConfig["keyspace"] if !ok { - password = "cassandra" - log.Warn().Msg("password not defined: Using default password instead") + cassandraConfig.keyspace = "feast_keyspace" + log.Warn().Msg("keyspace not defined: Using 'feast_keyspace' as keyspace instead") + } else { + cassandraConfig.keyspace, ok = rawKeyspace.(string) + if !ok { + return nil, fmt.Errorf("failed to convert keyspace to string: %v", rawKeyspace) + } } - var usernameStr string - if usernameStr, ok = username.(string); !ok { - return nil, fmt.Errorf("failed to convert username to string: %+v", usernameStr) + // parse protocolVersion + protocolVersion, ok := onlineStoreConfig["protocol_version"] + if !ok { + protocolVersion = 4.0 + log.Warn().Msg("protocol_version not specified: Using 4 instead") } + cassandraConfig.protocolVersion = int(protocolVersion.(float64)) - var passwordStr string - if passwordStr, ok = password.(string); !ok { - return nil, fmt.Errorf("failed to convert password to string: %+v", passwordStr) + // parse loadBalancing + loadBalancingDict, ok := onlineStoreConfig["load_balancing"] + if !ok { + loadBalancingDict = gocql.RoundRobinHostPolicy() + log.Warn().Msg("no load balancing policy selected, defaulted to RoundRobinHostPolicy") + } else { + loadBalancingProps := loadBalancingDict.(map[string]any) + policy := loadBalancingProps["load_balancing_policy"].(string) + switch policy { + case "TokenAwarePolicy(DCAwareRoundRobinPolicy)": + rawLocalDC, ok := loadBalancingProps["local_dc"] + if !ok { + return nil, fmt.Errorf("a local_dc is needed for policy DCAwareRoundRobinPolicy") + } + localDc := rawLocalDC.(string) + cassandraConfig.loadBalancingPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy(localDc)) + case "DCAwareRoundRobinPolicy": + rawLocalDC, ok := loadBalancingProps["local_dc"] + if !ok { + return nil, fmt.Errorf("a local_dc is needed for policy DCAwareRoundRobinPolicy") + } + localDc := rawLocalDC.(string) + cassandraConfig.loadBalancingPolicy = gocql.DCAwareRoundRobinPolicy(localDc) + default: + log.Warn().Msg("defaulted to using RoundRobinHostPolicy") + cassandraConfig.loadBalancingPolicy = gocql.RoundRobinHostPolicy() + } } - keyspace, ok := onlineStoreConfig["keyspace"] + // parse connectionTimeoutMillis + connectionTimeoutMillis, ok := onlineStoreConfig["connection_timeout_millis"] if !ok { - keyspace = "scylladb" - log.Warn().Msg("Keyspace not defined: Using 'scylladb' as keyspace instead") + connectionTimeoutMillis = 8000.0 + log.Warn().Msg("connection_timeout_millis not specified: Defaulted to 8000ms") } - store.keyspace = keyspace.(string) + cassandraConfig.connectionTimeoutMillis = int64(connectionTimeoutMillis.(float64)) - var keyspaceStr string - if keyspaceStr, ok = keyspace.(string); !ok { - return nil, fmt.Errorf("failed to convert keyspace to string: %+v", keyspaceStr) + // parse requestTimeoutMillis + requestTimeoutMillis, ok := onlineStoreConfig["request_timeout_millis"] + if !ok { + requestTimeoutMillis = 1000.0 + log.Warn().Msg("request_timeout_millis not specified: Defaulted to 1000ms") } + cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) - protocolVersion, ok := onlineStoreConfig["protocol_version"] + // parse numConnections + numConnections, ok := onlineStoreConfig["num_connections"] if !ok { - protocolVersion = 4.0 - log.Warn().Msg("protocol_version not specified: Using 4 instead") + numConnections = 2.0 + log.Warn().Msg("num_connections not specified: Defaulted to 2") } - protocolVersionInt := int(protocolVersion.(float64)) + cassandraConfig.numConnections = int(numConnections.(float64)) - redisTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" - if redisTraceServiceName == "" { - redisTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set + return &cassandraConfig, nil +} + +func NewCassandraOnlineStore(project string, config *registry.RepoConfig, onlineStoreConfig map[string]any) (*CassandraOnlineStore, error) { + store := CassandraOnlineStore{ + project: project, + config: config, } - store.clusterConfigs = gocqltrace.NewCluster(cassandraHostsStr, gocqltrace.WithServiceName(redisTraceServiceName)) - // TODO: Figure out if we need to offer users the ability to tune the timeouts - //store.clusterConfigs.ConnectTimeout = 1 - //store.clusterConfigs.Timeout = 1 - store.clusterConfigs.ProtoVersion = protocolVersionInt - //store.clusterConfigs.Consistency = gocql.Quorum - store.clusterConfigs.Keyspace = keyspaceStr - loadBalancingPolicy, ok := onlineStoreConfig["load_balancing"] - if !ok { - loadBalancingPolicy = gocql.RoundRobinHostPolicy() - log.Warn().Msg("No load balancing policy selected; setting Round Robin Host Policy") + + cassandraConfig, configError := extractCassandraConfig(onlineStoreConfig) + if configError != nil { + return nil, configError } - loadBalancingPolicyStr, _ := loadBalancingPolicy.(string) - if loadBalancingPolicyStr == "DCAwareRoundRobinPolicy" { - store.clusterConfigs.PoolConfig.HostSelectionPolicy = gocql.RoundRobinHostPolicy() - } else if loadBalancingPolicyStr == "TokenAwarePolicy(DCAwareRoundRobinPolicy)" { - // Configure fallback policy if unable to reach the shard - fallback := gocql.RoundRobinHostPolicy() - // If using ScyllaDB and setting this policy, this makes the driver shard aware to improve performance - store.clusterConfigs.PoolConfig.HostSelectionPolicy = gocql.TokenAwareHostPolicy(fallback) - if config.OnlineStore["type"] == "scylladb" { - store.clusterConfigs.Port = 19042 - } else { - store.clusterConfigs.Port = 9042 - } + + cassandraTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" + if cassandraTraceServiceName == "" { + cassandraTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set } + store.clusterConfigs = gocqltrace.NewCluster(cassandraConfig.hosts, gocqltrace.WithServiceName(cassandraTraceServiceName)) + store.clusterConfigs.ProtoVersion = cassandraConfig.protocolVersion + store.clusterConfigs.Keyspace = cassandraConfig.keyspace - store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{Username: usernameStr, Password: passwordStr} + store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy + + store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ + Username: cassandraConfig.username, + Password: cassandraConfig.password, + } + store.clusterConfigs.ConnectTimeout = time.Millisecond * time.Duration(cassandraConfig.connectionTimeoutMillis) + store.clusterConfigs.Timeout = time.Millisecond * time.Duration(cassandraConfig.requestTimeoutMillis) + store.clusterConfigs.NumConns = cassandraConfig.numConnections + store.clusterConfigs.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} + store.clusterConfigs.Consistency = gocql.LocalOne + + //store.clusterConfigs.SslOpts = &gocql.SslOptions{ + // EnableHostVerification: true, + //} createdSession, err := store.clusterConfigs.CreateSession() if err != nil { return nil, fmt.Errorf("unable to connect to the ScyllaDB database") @@ -148,7 +212,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online } func (c *CassandraOnlineStore) getFqTableName(tableName string) string { - return fmt.Sprintf(`"%s"."%s_%s"`, c.keyspace, c.project, tableName) + return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) } func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string { @@ -172,15 +236,14 @@ func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames [] ) } -func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]interface{}, map[string]int, error) { - cassandraKeys := make([]interface{}, len(entityKeys)) +func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]any, map[string]int, error) { + cassandraKeys := make([]any, len(entityKeys)) cassandraKeyToEntityIndex := make(map[string]int) for i := 0; i < len(entityKeys); i++ { var key, err = utils.SerializeEntityKey(entityKeys[i], c.config.EntityKeySerializationVersion) if err != nil { return nil, nil, err } - // encoding to hex encodedKey := hex.EncodeToString(*key) cassandraKeys[i] = encodedKey cassandraKeyToEntityIndex[encodedKey] = i diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go index d5461eefc9..04300ad279 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore_test.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -1,163 +1,143 @@ package onlinestore import ( - "context" + "github.com/gocql/gocql" "testing" "github.com/feast-dev/feast/go/internal/feast/registry" - "github.com/feast-dev/feast/go/protos/feast/types" - "github.com/stretchr/testify/assert" ) -func TestNewCassandraOnlineStore(t *testing.T) { - var config = map[string]interface{}{ - "username": "", - "password": "", - "hosts": "", - "keyspace": "", - "load_balancing": "", - } +func TestNewCassandraOnlineStoreDefaults(t *testing.T) { + var config = map[string]interface{}{} rc := ®istry.RepoConfig{ OnlineStore: config, - EntityKeySerializationVersion: 2, + EntityKeySerializationVersion: 4, } store, err := NewCassandraOnlineStore("test", rc, config) assert.Nil(t, err) assert.Equal(t, store.hosts, "127.0.0.1") - assert.Equal(t, store.keyspace, "test") - assert.Equal(t, store.username, "scylla") - assert.Equal(t, store.password, "scylla") + assert.Equal(t, store.keyspace, "scylladb") + assert.Equal(t, store.clusterConfigs.Authenticator, gocql.PasswordAuthenticator{ + Username: "cassandra", + Password: "cassandra", + }) + assert.Equal(t, store.clusterConfigs.ProtoVersion, 4) assert.Nil(t, store.session) } -func TestNewCassandraOnlineStoreWithPassword(t *testing.T) { - var config = map[string]interface{}{ - "connection_string": "", - } - rc := ®istry.RepoConfig{ - OnlineStore: config, - EntityKeySerializationVersion: 2, - } - store, err := NewRedisOnlineStore("test", rc, config) - assert.Nil(t, err) - var opts = store.client.Options() - assert.Equal(t, opts.Addr, "redis://localhost:6379") - assert.Equal(t, opts.Password, "secret") -} - -func TestCassandraOnlineStore_SerializeCassandraEntityKey(t *testing.T) { - store := CassandraOnlineStore{} - entityKey := &types.EntityKey{ - JoinKeys: []string{"key1", "key2"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - } - _, err := store.serializeCassandraEntityKey(entityKey, 2) - assert.Nil(t, err) -} - -func TestCassandraOnlineStore_SerializeCassandraEntityKey_InvalidEntityKey(t *testing.T) { - store := CassandraOnlineStore{} - entityKey := &types.EntityKey{ - JoinKeys: []string{"key1"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - } - _, err := store.serializeCassandraEntityKey(entityKey, 2) - assert.NotNil(t, err) -} - -func TestCassandraOnlineStore_SerializeValue(t *testing.T) { - store := CassandraOnlineStore{} - _, _, err := store.serializeValue(&types.Value_StringVal{StringVal: "value1"}, 2) - assert.Nil(t, err) -} - -func TestCassandraOnlineStore_SerializeValue_InvalidValue(t *testing.T) { - store := CassandraOnlineStore{} - _, _, err := store.serializeValue(nil, 2) - assert.NotNil(t, err) -} - -func TestCassandraOnlineStore_BuildCassandraKeys(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1", "key2"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - _, _, err := store.buildCassandraKeys(entityKeys) - assert.Nil(t, err) -} - -func TestCassandraOnlineStore_BuildCassandraKeys_InvalidEntityKeys(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - _, _, err := store.buildCassandraKeys(entityKeys) - assert.NotNil(t, err) -} - -func TestCassandraOnlineStore_OnlineRead_HappyPath(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1", "key2"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - featureViewNames := []string{"featureView1"} - featureNames := []string{"feature1"} - - _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) - assert.Nil(t, err) -} - -func TestCassandraOnlineStore_OnlineRead_InvalidEntityKey(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - featureViewNames := []string{"featureView1"} - featureNames := []string{"feature1"} - - _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) - assert.NotNil(t, err) -} - -func TestCassandraOnlineStore_OnlineRead_NoFeatureViewNames(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1", "key2"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - featureViewNames := []string{} - featureNames := []string{"feature1"} - - _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) - assert.NotNil(t, err) -} - -func TestCassandraOnlineStore_OnlineRead_NoFeatureNames(t *testing.T) { - store := CassandraOnlineStore{} - entityKeys := []*types.EntityKey{ - { - JoinKeys: []string{"key1", "key2"}, - EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, - }, - } - featureViewNames := []string{"featureView1"} - featureNames := []string{} - - _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) - assert.NotNil(t, err) -} +//func TestCassandraOnlineStore_SerializeCassandraEntityKey(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKey := &types.EntityKey{ +// JoinKeys: []string{"key1", "key2"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// } +// _, err := store.serializeCassandraEntityKey(entityKey, 2) +// assert.Nil(t, err) +//} +// +//func TestCassandraOnlineStore_SerializeCassandraEntityKey_InvalidEntityKey(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKey := &types.EntityKey{ +// JoinKeys: []string{"key1"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// } +// _, err := store.serializeCassandraEntityKey(entityKey, 2) +// assert.NotNil(t, err) +//} +// +//func TestCassandraOnlineStore_SerializeValue(t *testing.T) { +// store := CassandraOnlineStore{} +// _, _, err := store.serializeValue(&types.Value_StringVal{StringVal: "value1"}, 2) +// assert.Nil(t, err) +//} +// +//func TestCassandraOnlineStore_SerializeValue_InvalidValue(t *testing.T) { +// store := CassandraOnlineStore{} +// _, _, err := store.serializeValue(nil, 2) +// assert.NotNil(t, err) +//} +// +//func TestCassandraOnlineStore_BuildCassandraKeys(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1", "key2"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// _, _, err := store.buildCassandraKeys(entityKeys) +// assert.Nil(t, err) +//} +// +//func TestCassandraOnlineStore_BuildCassandraKeys_InvalidEntityKeys(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// _, _, err := store.buildCassandraKeys(entityKeys) +// assert.NotNil(t, err) +//} +// +//func TestCassandraOnlineStore_OnlineRead_HappyPath(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1", "key2"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// featureViewNames := []string{"featureView1"} +// featureNames := []string{"feature1"} +// +// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) +// assert.Nil(t, err) +//} +// +//func TestCassandraOnlineStore_OnlineRead_InvalidEntityKey(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// featureViewNames := []string{"featureView1"} +// featureNames := []string{"feature1"} +// +// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) +// assert.NotNil(t, err) +//} +// +//func TestCassandraOnlineStore_OnlineRead_NoFeatureViewNames(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1", "key2"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// featureViewNames := []string{} +// featureNames := []string{"feature1"} +// +// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) +// assert.NotNil(t, err) +//} +// +//func TestCassandraOnlineStore_OnlineRead_NoFeatureNames(t *testing.T) { +// store := CassandraOnlineStore{} +// entityKeys := []*types.EntityKey{ +// { +// JoinKeys: []string{"key1", "key2"}, +// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, +// }, +// } +// featureViewNames := []string{"featureView1"} +// featureNames := []string{} +// +// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) +// assert.NotNil(t, err) +//} diff --git a/go/internal/feast/utils/aws_utils.go b/go/internal/feast/utils/aws_utils.go deleted file mode 100644 index 66e27d635e..0000000000 --- a/go/internal/feast/utils/aws_utils.go +++ /dev/null @@ -1,117 +0,0 @@ -package utils - -import ( - "fmt" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "log" - "sort" - "strings" -) - -type EC2Instance struct { - EC2Conn *ec2.EC2 - AutoScalingGroupName string - Instances []*ec2.Instance - region string -} - -func NewEC2Instance(autoScalingGroupName string, region string) *EC2Instance { - sess := session.Must(session.NewSession(&aws.Config{ - Region: aws.String(region), - })) - - svc := ec2.New(sess) - - return &EC2Instance{ - EC2Conn: svc, - AutoScalingGroupName: autoScalingGroupName, - region: region, - } -} - -func (e *EC2Instance) GetIPAddress(instanceID string) (*string, error) { - input := &ec2.DescribeInstancesInput{ - InstanceIds: []*string{ - aws.String(instanceID), - }, - } - - result, err := e.EC2Conn.DescribeInstances(input) - if err != nil { - log.Println("Error getting IP address:", err) - return nil, err - } - - for _, reservation := range result.Reservations { - for _, instance := range reservation.Instances { - return instance.PrivateIpAddress, nil - } - } - - return nil, nil -} - -func (e *EC2Instance) GetNodeInfo(nodeNum int) (*ec2.Instance, error) { - if nodeNum > len(e.Instances) { - return nil, fmt.Errorf("node number is out of range") - } - - return e.Instances[nodeNum-1], nil -} - -func (e *EC2Instance) GetAllNodes(instanceName string) ([]*ec2.Instance, error) { - input := &ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ - { - Name: aws.String("tag:region"), - Values: []*string{aws.String(e.region)}, - }, - { - Name: aws.String("tag:aws:autoscaling:groupName"), - Values: []*string{aws.String(e.AutoScalingGroupName)}, - }, - }, - } - - result, err := e.EC2Conn.DescribeInstances(input) - if err != nil { - log.Println("Error getting all nodes:", err) - return nil, err - } - - var instances []*ec2.Instance - for _, reservation := range result.Reservations { - for _, instance := range reservation.Instances { - instances = append(instances, instance) - } - } - - sort.Slice(instances, func(i, j int) bool { - return strings.Compare(*instances[i].InstanceId, *instances[j].InstanceId) < 0 - }) - - e.Instances = instances - - return instances, nil -} - -func (e *EC2Instance) GetAllNodesIP() ([]string, error) { - instancesInfo := make([]string, 0) - for _, instance := range e.Instances { - instancesInfo = append(instancesInfo, *instance.PrivateIpAddress) - } - - return instancesInfo, nil -} - -func (e *EC2Instance) ResolveHostNameToIp() ([]string, error) { - - e.GetAllNodes(e.AutoScalingGroupName) - nodeIps, err := e.GetAllNodesIP() - if err != nil { - return nil, fmt.Errorf("Unable to get Node Ips") - } - return nodeIps, nil -} From c6eccd40d920c85bd9e7374ccf087f583fc003c6 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Thu, 21 Nov 2024 15:08:56 -0500 Subject: [PATCH 10/18] Run go mod tidy --- go.mod | 3 +++ go.sum | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/go.mod b/go.mod index aa8d6fc9bb..c8f11d3b99 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.22.5 require ( github.com/apache/arrow/go/v17 v17.0.0 github.com/ghodss/yaml v1.0.0 + github.com/gocql/gocql v1.6.0 github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 @@ -56,6 +57,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8 // indirect github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.6 // indirect @@ -86,6 +88,7 @@ require ( golang.org/x/tools v0.25.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 9cb0906c98..eafff11eea 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,10 @@ github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -60,10 +64,12 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -76,6 +82,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 h1:qnpS github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1/go.mod h1:lXGCsh6c22WGtjr+qGHj1otzZpV/1kwTMAqkwZsnWRU= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 h1:pRhl55Yx1eC7BZ1N+BBWwnKaMyD8uC+34TLdndZMAKk= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0/go.mod h1:XKMd7iuf/RGPSMJ/U4HP0zS2Z9Fh8Ps9a+6X26m/tmI= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8 h1:iBt4Ew4XEGLfh6/bPk4rSYmuZJGizr6/x/AEizP0CQc= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8/go.mod h1:aiJI+PIApBRQG7FZTEBx5GiiX+HbOHilUdNxUZi4eV0= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= @@ -91,8 +99,11 @@ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2 github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -153,6 +164,8 @@ github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/scylladb/gocql v1.14.4 h1:MhevwCfyAraQ6RvZYFO3pF4Lt0YhvQlfg8Eo2HEqVQA= +github.com/scylladb/gocql v1.14.4/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= github.com/secure-systems-lab/go-securesystemslib v0.8.0 h1:mr5An6X45Kb2nddcFlbmfHkLguCE9laoZCUzEEpIZXA= github.com/secure-systems-lab/go-securesystemslib v0.8.0/go.mod h1:UH2VZVuJfCYR8WgMlCU1uFsOUU+KeyrTWcSS73NBOzU= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= @@ -197,6 +210,7 @@ golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -210,6 +224,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -218,8 +234,10 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= @@ -231,6 +249,7 @@ golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= @@ -247,6 +266,8 @@ gopkg.in/DataDog/dd-trace-go.v1 v1.68.0/go.mod h1:mkZpWVLO/ERW5NqlW+w5d8waQKNvMS gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -266,3 +287,5 @@ modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= From f89a56c94dedf38642075dda0b201c07aef5cc72 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Thu, 21 Nov 2024 17:42:55 -0500 Subject: [PATCH 11/18] Implement multi-key queries using parallel single-key queries --- .../feast/onlinestore/cassandraonlinestore.go | 130 ++++++++++-------- 1 file changed, 70 insertions(+), 60 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 428afe827b..03df91ee8e 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -7,6 +7,7 @@ import ( "fmt" "os" "strings" + "sync" "time" "github.com/feast-dev/feast/go/internal/feast/registry" @@ -215,13 +216,7 @@ func (c *CassandraOnlineStore) getFqTableName(tableName string) string { return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) } -func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nKeys int) string { - // TODO: Compare with multiple single-key concurrent queries like in the Python feature server - keyPlaceholders := make([]string, nKeys) - for i := 0; i < nKeys; i++ { - keyPlaceholders[i] = "?" - } - +func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string { // this prevents fetching unnecessary features quotedFeatureNames := make([]string, len(featureNames)) for i, featureName := range featureNames { @@ -229,9 +224,8 @@ func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames [] } return fmt.Sprintf( - `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" IN (%s) AND "feature_name" IN (%s)`, + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`, tableName, - strings.Join(keyPlaceholders, ","), strings.Join(quotedFeatureNames, ","), ) } @@ -278,60 +272,77 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ // Prepare the query tableName := c.getFqTableName(featureViewName) - cqlStatement := c.getCQLStatement(tableName, featureNames, len(serializedEntityKeys)) - // Bundle the entity keys in one statement (gocql handles this as concurrent queries) - scanner := c.session.Query(cqlStatement, serializedEntityKeys...).Iter().Scanner() - - // Process the results - var entityKey string - var featureName string - var eventTs time.Time - var valueStr []byte - var deserializedValue types.Value - for scanner.Next() { - err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) - if err != nil { - return nil, errors.New("could not read row in query for (entity key, feature name, value, event ts)") - } - if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { - return nil, errors.New("error converting parsed Cassandra Value to types.Value") - } + cqlStatement := c.getCQLStatement(tableName, featureNames) + + var waitGroup sync.WaitGroup + waitGroup.Add(len(serializedEntityKeys)) + + errorsChannel := make(chan error, len(serializedEntityKeys)) + for _, serializedEntityKey := range serializedEntityKeys { + go func(serEntityKey any) { + defer waitGroup.Done() + + scanner := c.session.Query(cqlStatement, serEntityKey).Iter().Scanner() + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)") + return + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value") + return + } - var featureValues FeatureData - if deserializedValue.Val != nil { - // Convert the value to a FeatureData struct - featureValues = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureName, - }, - Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, - Value: types.Value{ - Val: deserializedValue.Val, - }, - } - } else { - // Return FeatureData with a null value - featureValues = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureName, - }, - Value: types.Value{ - Val: &types.Value_NullVal{ - NullVal: types.Null_NULL, - }, - }, + rowIdx := serializedEntityKeyToIndex[entityKey] + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + results[rowIdx][featureNamesToIdx[featureName]] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, + } + } else { + // Return FeatureData with a null value + results[rowIdx][featureNamesToIdx[featureName]] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } } - } - // Add the FeatureData to the results - rowIndx := serializedEntityKeyToIndex[entityKey] - results[rowIndx][featureNamesToIdx[featureName]] = featureValues + if err := scanner.Err(); err != nil { + errorsChannel <- errors.New("failed to scan features: " + err.Error()) + return + } + }(serializedEntityKey) } - // Check for errors from the Scanner - if err := scanner.Err(); err != nil { - return nil, errors.New("failed to scan features: " + err.Error()) + + // wait until all concurrent single-key queries are done + waitGroup.Wait() + close(errorsChannel) + + for err := range errorsChannel { + if err != nil { + return nil, err + } } // Will fill feature slots that were left empty with null values @@ -354,7 +365,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ } return results, nil - } func (c *CassandraOnlineStore) Destruct() { From 41e451db6af8bda55ae44fca7d76fdb8baadb9ef Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Thu, 21 Nov 2024 18:28:24 -0500 Subject: [PATCH 12/18] Small details --- .../online_stores/aws_utils_online_store.py | 61 ------------------- setup.py | 4 -- 2 files changed, 65 deletions(-) delete mode 100644 sdk/python/feast/infra/online_stores/aws_utils_online_store.py diff --git a/sdk/python/feast/infra/online_stores/aws_utils_online_store.py b/sdk/python/feast/infra/online_stores/aws_utils_online_store.py deleted file mode 100644 index a3bc7cd1c6..0000000000 --- a/sdk/python/feast/infra/online_stores/aws_utils_online_store.py +++ /dev/null @@ -1,61 +0,0 @@ -import logging -from abc import ABC -import boto3 - -class Ec2_instance(ABC): - - def __init__(self, autoscalinggroup_name, region='us-east-1'): - self.ec2_conn = boto3.resource('ec2', region_name=region) - self.autoscalinggroup_name = autoscalinggroup_name - self.region = region - self.instances = [] - - def get_ipaddress(self, instanceid): - try: - logging.debug("Getting Ip Address") - instances = self.ec2_conn.instances.filter(InstanceIds=[instanceid]) - for instance in instances: - return instance.private_ip_address - except Exception as e: - logging.error("Exception raised while getting ip address - {0} ".format(e)) - return None - - def get_node_info(self, nodenum): - try: - logging.debug("Getting IP Address of node {0}".format(nodenum)) - if nodenum > len(self.instances): - raise Exception('Node number is out of range.') - return {'id': self.instances[nodenum - 1].id, 'ip': self.instances[nodenum - 1].private_ip_address} - except Exception as e: - logging.error("Exception raised while getting information on {0} node - {1}".format(nodenum, e)) - return None - - def get_all_nodes(self): - try: - logging.debug("Get all AutoScaling group nodes") - running_state_filter = {'Name': 'tag:region', 'Values': [self.region]} - asg_filter = {'Name': 'tag:aws:autoscaling:groupName', 'Values': [self.autoscalinggroup_name]} - instances = self.ec2_conn.instances.filter(Filters=[asg_filter, running_state_filter]) - sorted_instances = sorted(instances, key=lambda instance: (instance.launch_time, instance.id)) - self.instances = sorted_instances - return sorted_instances - except Exception as e: - logging.error(e) - return None - - def get_all_nodes_ip(self): - try: - logging.debug("Getting Node ID and IP Address of all nodes ") - instances_info = [] - for instance in self.instances: - instances_info.append(instance.private_ip_address) - return instances_info - except Exception as e: - logging.error(e) - - return - - def resolve_host_to_ip_address(self) -> [list]: - self.get_all_nodes() - nodeips = self.get_all_nodes_ip() - return nodeips diff --git a/setup.py b/setup.py index ea354e84ca..d298f74c97 100644 --- a/setup.py +++ b/setup.py @@ -117,10 +117,6 @@ "cassandra-driver>=3.24.0,<4", ] -SCYLLADB_REQUIRED = [ - "scylla-driver>=3.24.0,<4", -] - GE_REQUIRED = ["great_expectations>=0.15.41"] SCYLLADB_REQUIRED = [ From 2a07d1eac00c0ff532a5efb4f682946f2bd1894c Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Fri, 22 Nov 2024 12:21:40 -0500 Subject: [PATCH 13/18] Config changes, optimize data writing and use observer dd trace --- .../feast/onlinestore/cassandraonlinestore.go | 132 +++++++++++------- 1 file changed, 82 insertions(+), 50 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 03df91ee8e..f874b23cbd 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -26,10 +26,10 @@ type CassandraOnlineStore struct { project string // Cluster configurations for Cassandra/ScyllaDB - clusterConfigs *gocqltrace.ClusterConfig + clusterConfigs *gocql.ClusterConfig // Session object that holds information about the connection to the cluster - session *gocqltrace.Session + session *gocql.Session config *registry.RepoConfig } @@ -73,8 +73,8 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, // parse username rawUsername, ok := onlineStoreConfig["username"] if !ok { - cassandraConfig.username = "cassandra" - log.Warn().Msg("username not defined: Using default username instead") + cassandraConfig.username = "" + log.Warn().Msg("username not defined, will not be using authentication") } else { cassandraConfig.username, ok = rawUsername.(string) if !ok { @@ -85,8 +85,8 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, // parse password rawPassword, ok := onlineStoreConfig["password"] if !ok { - cassandraConfig.password = "cassandra" - log.Warn().Msg("password not defined: Using default password instead") + cassandraConfig.password = "" + log.Warn().Msg("password not defined, will not be using authentication") } else { cassandraConfig.password, ok = rawPassword.(string) if !ok { @@ -146,24 +146,24 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, // parse connectionTimeoutMillis connectionTimeoutMillis, ok := onlineStoreConfig["connection_timeout_millis"] if !ok { - connectionTimeoutMillis = 8000.0 - log.Warn().Msg("connection_timeout_millis not specified: Defaulted to 8000ms") + connectionTimeoutMillis = 0.0 + log.Warn().Msg("connection_timeout_millis not specified, using gocql default") } cassandraConfig.connectionTimeoutMillis = int64(connectionTimeoutMillis.(float64)) // parse requestTimeoutMillis requestTimeoutMillis, ok := onlineStoreConfig["request_timeout_millis"] if !ok { - requestTimeoutMillis = 1000.0 - log.Warn().Msg("request_timeout_millis not specified: Defaulted to 1000ms") + requestTimeoutMillis = 0.0 + log.Warn().Msg("request_timeout_millis not specified, using gocql default") } cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) // parse numConnections numConnections, ok := onlineStoreConfig["num_connections"] if !ok { - numConnections = 2.0 - log.Warn().Msg("num_connections not specified: Defaulted to 2") + numConnections = 0.0 + log.Warn().Msg("num_connections not specified, using gocql default") } cassandraConfig.numConnections = int(numConnections.(float64)) @@ -181,30 +181,37 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online return nil, configError } - cassandraTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" - if cassandraTraceServiceName == "" { - cassandraTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set - } - store.clusterConfigs = gocqltrace.NewCluster(cassandraConfig.hosts, gocqltrace.WithServiceName(cassandraTraceServiceName)) + store.clusterConfigs = gocql.NewCluster(cassandraConfig.hosts...) store.clusterConfigs.ProtoVersion = cassandraConfig.protocolVersion store.clusterConfigs.Keyspace = cassandraConfig.keyspace store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy - store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ - Username: cassandraConfig.username, - Password: cassandraConfig.password, + if cassandraConfig.username != "" && cassandraConfig.password != "" { + store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ + Username: cassandraConfig.username, + Password: cassandraConfig.password, + } + } + + if cassandraConfig.connectionTimeoutMillis != 0 { + store.clusterConfigs.ConnectTimeout = time.Millisecond * time.Duration(cassandraConfig.connectionTimeoutMillis) + } + if cassandraConfig.requestTimeoutMillis != 0 { + store.clusterConfigs.Timeout = time.Millisecond * time.Duration(cassandraConfig.requestTimeoutMillis) + } + if cassandraConfig.numConnections != 0 { + store.clusterConfigs.NumConns = cassandraConfig.numConnections } - store.clusterConfigs.ConnectTimeout = time.Millisecond * time.Duration(cassandraConfig.connectionTimeoutMillis) - store.clusterConfigs.Timeout = time.Millisecond * time.Duration(cassandraConfig.requestTimeoutMillis) - store.clusterConfigs.NumConns = cassandraConfig.numConnections + store.clusterConfigs.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} store.clusterConfigs.Consistency = gocql.LocalOne - //store.clusterConfigs.SslOpts = &gocql.SslOptions{ - // EnableHostVerification: true, - //} - createdSession, err := store.clusterConfigs.CreateSession() + cassandraTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" + if cassandraTraceServiceName == "" { + cassandraTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set + } + createdSession, err := gocqltrace.CreateTracedSession(store.clusterConfigs, gocqltrace.WithServiceName(cassandraTraceServiceName)) if err != nil { return nil, fmt.Errorf("unable to connect to the ScyllaDB database") } @@ -282,12 +289,35 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ go func(serEntityKey any) { defer waitGroup.Done() - scanner := c.session.Query(cqlStatement, serEntityKey).Iter().Scanner() + iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter() + + rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)] + + // fill the row with nulls if not found + if iter.NumRows() == 0 { + for _, featName := range featureNames { + results[rowIdx][featureNamesToIdx[featName]] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + return + } + + scanner := iter.Scanner() var entityKey string var featureName string var eventTs time.Time var valueStr []byte var deserializedValue types.Value + rowFeatures := make(map[string]FeatureData) for scanner.Next() { err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) if err != nil { @@ -299,10 +329,9 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ return } - rowIdx := serializedEntityKeyToIndex[entityKey] if deserializedValue.Val != nil { // Convert the value to a FeatureData struct - results[rowIdx][featureNamesToIdx[featureName]] = FeatureData{ + rowFeatures[featureName] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: featureViewName, FeatureName: featureName, @@ -314,7 +343,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ } } else { // Return FeatureData with a null value - results[rowIdx][featureNamesToIdx[featureName]] = FeatureData{ + rowFeatures[featureName] = FeatureData{ Reference: serving.FeatureReferenceV2{ FeatureViewName: featureViewName, FeatureName: featureName, @@ -332,6 +361,24 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ errorsChannel <- errors.New("failed to scan features: " + err.Error()) return } + + for _, featName := range featureNames { + featureData, ok := rowFeatures[featName] + if !ok { + featureData = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + results[rowIdx][featureNamesToIdx[featName]] = featureData + } }(serializedEntityKey) } @@ -339,29 +386,14 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ waitGroup.Wait() close(errorsChannel) + var collectedErrors []error for err := range errorsChannel { if err != nil { - return nil, err + collectedErrors = append(collectedErrors, err) } } - - // Will fill feature slots that were left empty with null values - for i := 0; i < len(entityKeys); i++ { - for j := 0; j < len(featureNames); j++ { - if results[i][j].Timestamp.GetSeconds() == 0 { - results[i][j] = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureViewNames[j], - }, - Value: types.Value{ - Val: &types.Value_NullVal{ - NullVal: types.Null_NULL, - }, - }, - } - } - } + if len(collectedErrors) > 0 { + return nil, errors.Join(collectedErrors...) } return results, nil From 13b6bc8bd30857c4e49f20c4b1209d8c270220b7 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Tue, 26 Nov 2024 15:25:04 -0500 Subject: [PATCH 14/18] Added tests --- .../feast/onlinestore/cassandraonlinestore.go | 2 +- .../onlinestore/cassandraonlinestore_test.go | 202 ++++++------------ 2 files changed, 71 insertions(+), 133 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index f874b23cbd..d8ff636b86 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -117,7 +117,7 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, // parse loadBalancing loadBalancingDict, ok := onlineStoreConfig["load_balancing"] if !ok { - loadBalancingDict = gocql.RoundRobinHostPolicy() + cassandraConfig.loadBalancingPolicy = gocql.RoundRobinHostPolicy() log.Warn().Msg("no load balancing policy selected, defaulted to RoundRobinHostPolicy") } else { loadBalancingProps := loadBalancingDict.(map[string]any) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go index 04300ad279..0df37d12ad 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore_test.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -1,143 +1,81 @@ package onlinestore import ( + "context" "github.com/gocql/gocql" - "testing" - - "github.com/feast-dev/feast/go/internal/feast/registry" "github.com/stretchr/testify/assert" + "reflect" + "testing" ) -func TestNewCassandraOnlineStoreDefaults(t *testing.T) { +func TestExtractCassandraConfig_CorrectDefaults(t *testing.T) { var config = map[string]interface{}{} - rc := ®istry.RepoConfig{ - OnlineStore: config, - EntityKeySerializationVersion: 4, + cassandraConfig, _ := extractCassandraConfig(config) + + assert.Equal(t, []string{"127.0.0.1"}, cassandraConfig.hosts) + assert.Equal(t, "", cassandraConfig.username) + assert.Equal(t, "", cassandraConfig.password) + assert.Equal(t, "feast_keyspace", cassandraConfig.keyspace) + assert.Equal(t, 4, cassandraConfig.protocolVersion) + assert.True(t, reflect.TypeOf(gocql.RoundRobinHostPolicy()) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) + assert.Equal(t, int64(0), cassandraConfig.connectionTimeoutMillis) + assert.Equal(t, int64(0), cassandraConfig.requestTimeoutMillis) + assert.Equal(t, 0, cassandraConfig.numConnections) +} + +func TestExtractCassandraConfig_CorrectSettings(t *testing.T) { + var config = map[string]any{ + "hosts": []any{"0.0.0.0", "255.255.255.255"}, + "username": "scylladb", + "password": "scylladb", + "keyspace": "scylladb", + "protocol_version": 271.0, + "load_balancing": map[string]any{ + "load_balancing_policy": "DCAwareRoundRobinPolicy", + "local_dc": "aws-us-west-2", + }, + "connection_timeout_millis": 271.0, + "request_timeout_millis": 271.0, + "num_connections": 2.0, } - store, err := NewCassandraOnlineStore("test", rc, config) - assert.Nil(t, err) - assert.Equal(t, store.hosts, "127.0.0.1") - assert.Equal(t, store.keyspace, "scylladb") - assert.Equal(t, store.clusterConfigs.Authenticator, gocql.PasswordAuthenticator{ - Username: "cassandra", - Password: "cassandra", - }) - assert.Equal(t, store.clusterConfigs.ProtoVersion, 4) - assert.Nil(t, store.session) + cassandraConfig, _ := extractCassandraConfig(config) + + assert.Equal(t, []string{"0.0.0.0", "255.255.255.255"}, cassandraConfig.hosts) + assert.Equal(t, "scylladb", cassandraConfig.username) + assert.Equal(t, "scylladb", cassandraConfig.password) + assert.Equal(t, "scylladb", cassandraConfig.keyspace) + assert.Equal(t, 271, cassandraConfig.protocolVersion) + assert.True(t, reflect.TypeOf(gocql.DCAwareRoundRobinPolicy("aws-us-west-2")) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) + assert.Equal(t, int64(271), cassandraConfig.connectionTimeoutMillis) + assert.Equal(t, int64(271), cassandraConfig.requestTimeoutMillis) + assert.Equal(t, 2, cassandraConfig.numConnections) } -//func TestCassandraOnlineStore_SerializeCassandraEntityKey(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKey := &types.EntityKey{ -// JoinKeys: []string{"key1", "key2"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// } -// _, err := store.serializeCassandraEntityKey(entityKey, 2) -// assert.Nil(t, err) -//} -// -//func TestCassandraOnlineStore_SerializeCassandraEntityKey_InvalidEntityKey(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKey := &types.EntityKey{ -// JoinKeys: []string{"key1"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// } -// _, err := store.serializeCassandraEntityKey(entityKey, 2) -// assert.NotNil(t, err) -//} -// -//func TestCassandraOnlineStore_SerializeValue(t *testing.T) { -// store := CassandraOnlineStore{} -// _, _, err := store.serializeValue(&types.Value_StringVal{StringVal: "value1"}, 2) -// assert.Nil(t, err) -//} -// -//func TestCassandraOnlineStore_SerializeValue_InvalidValue(t *testing.T) { -// store := CassandraOnlineStore{} -// _, _, err := store.serializeValue(nil, 2) -// assert.NotNil(t, err) -//} -// -//func TestCassandraOnlineStore_BuildCassandraKeys(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1", "key2"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// _, _, err := store.buildCassandraKeys(entityKeys) -// assert.Nil(t, err) -//} -// -//func TestCassandraOnlineStore_BuildCassandraKeys_InvalidEntityKeys(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// _, _, err := store.buildCassandraKeys(entityKeys) -// assert.NotNil(t, err) -//} -// -//func TestCassandraOnlineStore_OnlineRead_HappyPath(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1", "key2"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// featureViewNames := []string{"featureView1"} -// featureNames := []string{"feature1"} -// -// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) -// assert.Nil(t, err) -//} -// -//func TestCassandraOnlineStore_OnlineRead_InvalidEntityKey(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// featureViewNames := []string{"featureView1"} -// featureNames := []string{"feature1"} -// -// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) -// assert.NotNil(t, err) -//} -// -//func TestCassandraOnlineStore_OnlineRead_NoFeatureViewNames(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1", "key2"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// featureViewNames := []string{} -// featureNames := []string{"feature1"} -// -// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) -// assert.NotNil(t, err) -//} -// -//func TestCassandraOnlineStore_OnlineRead_NoFeatureNames(t *testing.T) { -// store := CassandraOnlineStore{} -// entityKeys := []*types.EntityKey{ -// { -// JoinKeys: []string{"key1", "key2"}, -// EntityValues: []*types.Value{{Val: &types.Value_StringVal{StringVal: "value1"}}, {Val: &types.Value_StringVal{StringVal: "value2"}}}, -// }, -// } -// featureViewNames := []string{"featureView1"} -// featureNames := []string{} -// -// _, err := store.OnlineRead(context.Background(), entityKeys, featureViewNames, featureNames) -// assert.NotNil(t, err) -//} +func TestGetFqTableName(t *testing.T) { + store := CassandraOnlineStore{ + project: "dummy_project", + clusterConfigs: &gocql.ClusterConfig{ + Keyspace: "scylladb", + }, + } + + fqTableName := store.getFqTableName("dummy_fv") + assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName) +} + +func TestGetCQLStatement(t *testing.T) { + store := CassandraOnlineStore{} + fqTableName := `"scylladb"."dummy_project_dummy_fv"` + + cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"}) + assert.Equal(t, + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`, + cqlStatement, + ) +} + +func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) { + store := CassandraOnlineStore{} + _, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"}) + assert.Error(t, err) +} From 1f7cb20482b33f8ac0dea87cba9515d7d03e98c1 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Tue, 26 Nov 2024 15:45:22 -0500 Subject: [PATCH 15/18] Removed num_connections as it is ignored --- .../feast/onlinestore/cassandraonlinestore.go | 12 ------------ .../feast/onlinestore/cassandraonlinestore_test.go | 3 --- 2 files changed, 15 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index d8ff636b86..5b6ffb8c55 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -43,7 +43,6 @@ type CassandraConfig struct { loadBalancingPolicy gocql.HostSelectionPolicy connectionTimeoutMillis int64 requestTimeoutMillis int64 - numConnections int } func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, error) { @@ -159,14 +158,6 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, } cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) - // parse numConnections - numConnections, ok := onlineStoreConfig["num_connections"] - if !ok { - numConnections = 0.0 - log.Warn().Msg("num_connections not specified, using gocql default") - } - cassandraConfig.numConnections = int(numConnections.(float64)) - return &cassandraConfig, nil } @@ -200,9 +191,6 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online if cassandraConfig.requestTimeoutMillis != 0 { store.clusterConfigs.Timeout = time.Millisecond * time.Duration(cassandraConfig.requestTimeoutMillis) } - if cassandraConfig.numConnections != 0 { - store.clusterConfigs.NumConns = cassandraConfig.numConnections - } store.clusterConfigs.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} store.clusterConfigs.Consistency = gocql.LocalOne diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go index 0df37d12ad..67a9eea548 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore_test.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -20,7 +20,6 @@ func TestExtractCassandraConfig_CorrectDefaults(t *testing.T) { assert.True(t, reflect.TypeOf(gocql.RoundRobinHostPolicy()) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) assert.Equal(t, int64(0), cassandraConfig.connectionTimeoutMillis) assert.Equal(t, int64(0), cassandraConfig.requestTimeoutMillis) - assert.Equal(t, 0, cassandraConfig.numConnections) } func TestExtractCassandraConfig_CorrectSettings(t *testing.T) { @@ -36,7 +35,6 @@ func TestExtractCassandraConfig_CorrectSettings(t *testing.T) { }, "connection_timeout_millis": 271.0, "request_timeout_millis": 271.0, - "num_connections": 2.0, } cassandraConfig, _ := extractCassandraConfig(config) @@ -48,7 +46,6 @@ func TestExtractCassandraConfig_CorrectSettings(t *testing.T) { assert.True(t, reflect.TypeOf(gocql.DCAwareRoundRobinPolicy("aws-us-west-2")) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) assert.Equal(t, int64(271), cassandraConfig.connectionTimeoutMillis) assert.Equal(t, int64(271), cassandraConfig.requestTimeoutMillis) - assert.Equal(t, 2, cassandraConfig.numConnections) } func TestGetFqTableName(t *testing.T) { From 3e65b569427eec3505c3a58792e2c5189f014fa1 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 2 Dec 2024 12:01:43 -0500 Subject: [PATCH 16/18] Simplify null feature --- .../feast/onlinestore/cassandraonlinestore.go | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 5b6ffb8c55..3b7697c715 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -329,19 +329,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ Val: deserializedValue.Val, }, } - } else { - // Return FeatureData with a null value - rowFeatures[featureName] = FeatureData{ - Reference: serving.FeatureReferenceV2{ - FeatureViewName: featureViewName, - FeatureName: featureName, - }, - Value: types.Value{ - Val: &types.Value_NullVal{ - NullVal: types.Null_NULL, - }, - }, - } } } From 03b1c261a8415a6f0db028059eb2adb483c01f3c Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 2 Dec 2024 12:13:40 -0500 Subject: [PATCH 17/18] Addressing more comments --- go/internal/feast/onlinestore/cassandraonlinestore.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 3b7697c715..eb7926fa2e 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -375,5 +375,5 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ } func (c *CassandraOnlineStore) Destruct() { - + c.session.Close() } From a2f54070eca3305280f09217cd56d9132a9f13c3 Mon Sep 17 00:00:00 2001 From: Jose Acevedo Date: Mon, 2 Dec 2024 12:54:42 -0500 Subject: [PATCH 18/18] Added common string parse func --- .../feast/onlinestore/cassandraonlinestore.go | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index eb7926fa2e..897b2e13a6 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -45,6 +45,18 @@ type CassandraConfig struct { requestTimeoutMillis int64 } +func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) { + rawValue, ok := config[fieldName] + if !ok { + return defaultValue, nil + } + stringValue, ok := rawValue.(string) + if !ok { + return "", fmt.Errorf("failed to convert %s to string: %v", fieldName, rawValue) + } + return stringValue, nil +} + func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, error) { cassandraConfig := CassandraConfig{} @@ -70,40 +82,25 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, } // parse username - rawUsername, ok := onlineStoreConfig["username"] - if !ok { - cassandraConfig.username = "" - log.Warn().Msg("username not defined, will not be using authentication") - } else { - cassandraConfig.username, ok = rawUsername.(string) - if !ok { - return nil, fmt.Errorf("failed to convert username to string: %v", rawUsername) - } + username, err := parseStringField(onlineStoreConfig, "username", "") + if err != nil { + return nil, err } + cassandraConfig.username = username // parse password - rawPassword, ok := onlineStoreConfig["password"] - if !ok { - cassandraConfig.password = "" - log.Warn().Msg("password not defined, will not be using authentication") - } else { - cassandraConfig.password, ok = rawPassword.(string) - if !ok { - return nil, fmt.Errorf("failed to convert password to string: %v", rawPassword) - } + password, err := parseStringField(onlineStoreConfig, "password", "") + if err != nil { + return nil, err } + cassandraConfig.password = password // parse keyspace - rawKeyspace, ok := onlineStoreConfig["keyspace"] - if !ok { - cassandraConfig.keyspace = "feast_keyspace" - log.Warn().Msg("keyspace not defined: Using 'feast_keyspace' as keyspace instead") - } else { - cassandraConfig.keyspace, ok = rawKeyspace.(string) - if !ok { - return nil, fmt.Errorf("failed to convert keyspace to string: %v", rawKeyspace) - } + keyspace, err := parseStringField(onlineStoreConfig, "keyspace", "feast_keyspace") + if err != nil { + return nil, err } + cassandraConfig.keyspace = keyspace // parse protocolVersion protocolVersion, ok := onlineStoreConfig["protocol_version"] @@ -179,6 +176,7 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy if cassandraConfig.username != "" && cassandraConfig.password != "" { + log.Warn().Msg("username/password not defined, will not be using authentication") store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ Username: cassandraConfig.username, Password: cassandraConfig.password,