From 1c1d1f043c24c66665e2063b084d9a733000b897 Mon Sep 17 00:00:00 2001 From: Meenakshi Sistla <85261163+msistla96@users.noreply.github.com> Date: Mon, 2 Dec 2024 12:07:45 -0600 Subject: [PATCH] feat: Implement Cassandra/ScyllaDB Go Online Store (#138) Implement Cassandra/ScyllaDB Go online store --------- Co-authored-by: Jose Acevedo --- go.mod | 5 + go.sum | 23 ++ go/infra/docker/feature-server/Dockerfile | 1 - .../feast/onlinestore/cassandraonlinestore.go | 377 ++++++++++++++++++ .../onlinestore/cassandraonlinestore_test.go | 78 ++++ go/internal/feast/onlinestore/onlinestore.go | 5 +- .../feast/onlinestore/redisonlinestore.go | 96 +---- .../feast/onlinestore/sqliteonlinestore.go | 19 +- go/internal/feast/utils/key_utils.go | 110 +++++ 9 files changed, 603 insertions(+), 111 deletions(-) create mode 100644 go/internal/feast/onlinestore/cassandraonlinestore.go create mode 100644 go/internal/feast/onlinestore/cassandraonlinestore_test.go create mode 100644 go/internal/feast/utils/key_utils.go diff --git a/go.mod b/go.mod index 26265273e2..c8f11d3b99 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ toolchain go1.22.5 require ( github.com/apache/arrow/go/v17 v17.0.0 github.com/ghodss/yaml v1.0.0 + github.com/gocql/gocql v1.6.0 github.com/golang/protobuf v1.5.4 github.com/google/uuid v1.6.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 @@ -56,6 +57,7 @@ require ( github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v24.3.25+incompatible // indirect github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 // indirect + github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8 // indirect github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 // indirect github.com/hashicorp/go-sockaddr v1.0.6 // indirect @@ -86,6 +88,9 @@ require ( golang.org/x/tools v0.25.0 // indirect golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240903143218-8af14fe29dc1 // indirect + gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/gocql/gocql => github.com/scylladb/gocql v1.14.4 diff --git a/go.sum b/go.sum index 9cb0906c98..eafff11eea 100644 --- a/go.sum +++ b/go.sum @@ -29,6 +29,10 @@ github.com/apache/thrift v0.21.0 h1:tdPmh/ptjE1IJnhbhrcl2++TauVjy242rkV/UzJChnE= github.com/apache/thrift v0.21.0/go.mod h1:W1H8aR/QRtYNvrPeFXBtobyRkd0/YVhTc6i07XIAgDw= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= +github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= +github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= @@ -60,10 +64,12 @@ github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/flatbuffers v24.3.25+incompatible h1:CX395cjN9Kke9mmalRoL3d81AtFUxJM+yDthflgJGkI= github.com/google/flatbuffers v24.3.25+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= @@ -76,6 +82,8 @@ github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1 h1:qnpS github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.0.1/go.mod h1:lXGCsh6c22WGtjr+qGHj1otzZpV/1kwTMAqkwZsnWRU= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 h1:pRhl55Yx1eC7BZ1N+BBWwnKaMyD8uC+34TLdndZMAKk= github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0/go.mod h1:XKMd7iuf/RGPSMJ/U4HP0zS2Z9Fh8Ps9a+6X26m/tmI= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= +github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8 h1:iBt4Ew4XEGLfh6/bPk4rSYmuZJGizr6/x/AEizP0CQc= github.com/hashicorp/go-secure-stdlib/parseutil v0.1.8/go.mod h1:aiJI+PIApBRQG7FZTEBx5GiiX+HbOHilUdNxUZi4eV0= github.com/hashicorp/go-secure-stdlib/strutil v0.1.2 h1:kes8mmyCpxJsI7FTwtzRqEy9CdjCtrXrXGuOpxEA7Ts= @@ -91,8 +99,11 @@ github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2 github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= github.com/klauspost/cpuid/v2 v2.2.8 h1:+StwCXwm9PdpiEkPyzBXIy+M9KUb4ODm0Zarf1kS5BM= github.com/klauspost/cpuid/v2 v2.2.8/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= @@ -153,6 +164,8 @@ github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8= github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/ryanuber/go-glob v1.0.0 h1:iQh3xXAumdQ+4Ufa5b25cRpC5TYKlno6hsv6Cb3pkBk= github.com/ryanuber/go-glob v1.0.0/go.mod h1:807d1WSdnB0XRJzKNil9Om6lcp/3a0v4qIHxIXzX/Yc= +github.com/scylladb/gocql v1.14.4 h1:MhevwCfyAraQ6RvZYFO3pF4Lt0YhvQlfg8Eo2HEqVQA= +github.com/scylladb/gocql v1.14.4/go.mod h1:ZLEJ0EVE5JhmtxIW2stgHq/v1P4fWap0qyyXSKyV8K0= github.com/secure-systems-lab/go-securesystemslib v0.8.0 h1:mr5An6X45Kb2nddcFlbmfHkLguCE9laoZCUzEEpIZXA= github.com/secure-systems-lab/go-securesystemslib v0.8.0/go.mod h1:UH2VZVuJfCYR8WgMlCU1uFsOUU+KeyrTWcSS73NBOzU= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= @@ -197,6 +210,7 @@ golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20220526153639-5463443f8c37/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.29.0 h1:5ORfpBpCs4HzDYoodCDBbwHzdR5UrLBZ3sOnUJmFoHo= golang.org/x/net v0.29.0/go.mod h1:gLkgy8jTGERgjzMic6DS9+SP0ajcu6Xu3Orq/SpETg0= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -210,6 +224,8 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220627191245-f75cf1eec38b/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -218,8 +234,10 @@ golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.18.0 h1:XvMDiNzPAl0jr17s6W9lcaIhGUfUORdGCNsuLmPG224= golang.org/x/text v0.18.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= golang.org/x/time v0.6.0 h1:eTDhh4ZXt5Qf0augr54TN6suAUudPcawVZeIAPU7D4U= @@ -231,6 +249,7 @@ golang.org/x/tools v0.25.0 h1:oFU9pkj/iJgs+0DT+VMHrx+oBKs/LJMV+Uvg78sl+fE= golang.org/x/tools v0.25.0/go.mod h1:/vtpO8WL1N9cQC3FN5zPqb//fRXskFHbLKk4OW1Q7rg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da h1:noIWHXmPHxILtqtCOPIhSt0ABwskkZKjD3bXGnZGpNY= golang.org/x/xerrors v0.0.0-20240903120638-7835f813f4da/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= @@ -247,6 +266,8 @@ gopkg.in/DataDog/dd-trace-go.v1 v1.68.0/go.mod h1:mkZpWVLO/ERW5NqlW+w5d8waQKNvMS gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= @@ -266,3 +287,5 @@ modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/go/infra/docker/feature-server/Dockerfile b/go/infra/docker/feature-server/Dockerfile index cf63bb4559..5ac71b93ef 100644 --- a/go/infra/docker/feature-server/Dockerfile +++ b/go/infra/docker/feature-server/Dockerfile @@ -22,7 +22,6 @@ RUN find ./protos -name "*.proto" \ # Build the Go application RUN go build -o feast ./go/main.go - # Expose ports EXPOSE 8080 diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go new file mode 100644 index 0000000000..897b2e13a6 --- /dev/null +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -0,0 +1,377 @@ +package onlinestore + +import ( + "context" + "encoding/hex" + "errors" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/feast-dev/feast/go/internal/feast/registry" + "github.com/feast-dev/feast/go/internal/feast/utils" + "github.com/feast-dev/feast/go/protos/feast/serving" + "github.com/feast-dev/feast/go/protos/feast/types" + "github.com/gocql/gocql" + "github.com/golang/protobuf/proto" + "github.com/rs/zerolog/log" + + "google.golang.org/protobuf/types/known/timestamppb" + gocqltrace "gopkg.in/DataDog/dd-trace-go.v1/contrib/gocql/gocql" +) + +type CassandraOnlineStore struct { + project string + + // Cluster configurations for Cassandra/ScyllaDB + clusterConfigs *gocql.ClusterConfig + + // Session object that holds information about the connection to the cluster + session *gocql.Session + + config *registry.RepoConfig +} + +type CassandraConfig struct { + hosts []string + username string + password string + keyspace string + protocolVersion int + loadBalancingPolicy gocql.HostSelectionPolicy + connectionTimeoutMillis int64 + requestTimeoutMillis int64 +} + +func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) { + rawValue, ok := config[fieldName] + if !ok { + return defaultValue, nil + } + stringValue, ok := rawValue.(string) + if !ok { + return "", fmt.Errorf("failed to convert %s to string: %v", fieldName, rawValue) + } + return stringValue, nil +} + +func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, error) { + cassandraConfig := CassandraConfig{} + + // parse hosts + cassandraHosts, ok := onlineStoreConfig["hosts"] + if !ok { + cassandraConfig.hosts = []string{"127.0.0.1"} + log.Warn().Msg("host not provided: Using 127.0.0.1 instead") + } else { + var rawCassandraHosts []any + if rawCassandraHosts, ok = cassandraHosts.([]any); !ok { + return nil, fmt.Errorf("didn't pass a list of hosts in the 'hosts' field") + } + var cassandraHostsStr = make([]string, len(rawCassandraHosts)) + for i, rawHost := range rawCassandraHosts { + hostStr, ok := rawHost.(string) + if !ok { + return nil, fmt.Errorf("failed to convert a host to a string: %+v", rawHost) + } + cassandraHostsStr[i] = hostStr + } + cassandraConfig.hosts = cassandraHostsStr + } + + // parse username + username, err := parseStringField(onlineStoreConfig, "username", "") + if err != nil { + return nil, err + } + cassandraConfig.username = username + + // parse password + password, err := parseStringField(onlineStoreConfig, "password", "") + if err != nil { + return nil, err + } + cassandraConfig.password = password + + // parse keyspace + keyspace, err := parseStringField(onlineStoreConfig, "keyspace", "feast_keyspace") + if err != nil { + return nil, err + } + cassandraConfig.keyspace = keyspace + + // parse protocolVersion + protocolVersion, ok := onlineStoreConfig["protocol_version"] + if !ok { + protocolVersion = 4.0 + log.Warn().Msg("protocol_version not specified: Using 4 instead") + } + cassandraConfig.protocolVersion = int(protocolVersion.(float64)) + + // parse loadBalancing + loadBalancingDict, ok := onlineStoreConfig["load_balancing"] + if !ok { + cassandraConfig.loadBalancingPolicy = gocql.RoundRobinHostPolicy() + log.Warn().Msg("no load balancing policy selected, defaulted to RoundRobinHostPolicy") + } else { + loadBalancingProps := loadBalancingDict.(map[string]any) + policy := loadBalancingProps["load_balancing_policy"].(string) + switch policy { + case "TokenAwarePolicy(DCAwareRoundRobinPolicy)": + rawLocalDC, ok := loadBalancingProps["local_dc"] + if !ok { + return nil, fmt.Errorf("a local_dc is needed for policy DCAwareRoundRobinPolicy") + } + localDc := rawLocalDC.(string) + cassandraConfig.loadBalancingPolicy = gocql.TokenAwareHostPolicy(gocql.DCAwareRoundRobinPolicy(localDc)) + case "DCAwareRoundRobinPolicy": + rawLocalDC, ok := loadBalancingProps["local_dc"] + if !ok { + return nil, fmt.Errorf("a local_dc is needed for policy DCAwareRoundRobinPolicy") + } + localDc := rawLocalDC.(string) + cassandraConfig.loadBalancingPolicy = gocql.DCAwareRoundRobinPolicy(localDc) + default: + log.Warn().Msg("defaulted to using RoundRobinHostPolicy") + cassandraConfig.loadBalancingPolicy = gocql.RoundRobinHostPolicy() + } + } + + // parse connectionTimeoutMillis + connectionTimeoutMillis, ok := onlineStoreConfig["connection_timeout_millis"] + if !ok { + connectionTimeoutMillis = 0.0 + log.Warn().Msg("connection_timeout_millis not specified, using gocql default") + } + cassandraConfig.connectionTimeoutMillis = int64(connectionTimeoutMillis.(float64)) + + // parse requestTimeoutMillis + requestTimeoutMillis, ok := onlineStoreConfig["request_timeout_millis"] + if !ok { + requestTimeoutMillis = 0.0 + log.Warn().Msg("request_timeout_millis not specified, using gocql default") + } + cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) + + return &cassandraConfig, nil +} + +func NewCassandraOnlineStore(project string, config *registry.RepoConfig, onlineStoreConfig map[string]any) (*CassandraOnlineStore, error) { + store := CassandraOnlineStore{ + project: project, + config: config, + } + + cassandraConfig, configError := extractCassandraConfig(onlineStoreConfig) + if configError != nil { + return nil, configError + } + + store.clusterConfigs = gocql.NewCluster(cassandraConfig.hosts...) + store.clusterConfigs.ProtoVersion = cassandraConfig.protocolVersion + store.clusterConfigs.Keyspace = cassandraConfig.keyspace + + store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy + + if cassandraConfig.username != "" && cassandraConfig.password != "" { + log.Warn().Msg("username/password not defined, will not be using authentication") + store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ + Username: cassandraConfig.username, + Password: cassandraConfig.password, + } + } + + if cassandraConfig.connectionTimeoutMillis != 0 { + store.clusterConfigs.ConnectTimeout = time.Millisecond * time.Duration(cassandraConfig.connectionTimeoutMillis) + } + if cassandraConfig.requestTimeoutMillis != 0 { + store.clusterConfigs.Timeout = time.Millisecond * time.Duration(cassandraConfig.requestTimeoutMillis) + } + + store.clusterConfigs.RetryPolicy = &gocql.SimpleRetryPolicy{NumRetries: 3} + store.clusterConfigs.Consistency = gocql.LocalOne + + cassandraTraceServiceName := os.Getenv("DD_SERVICE") + "-cassandra" + if cassandraTraceServiceName == "" { + cassandraTraceServiceName = "cassandra.client" // default service name if DD_SERVICE is not set + } + createdSession, err := gocqltrace.CreateTracedSession(store.clusterConfigs, gocqltrace.WithServiceName(cassandraTraceServiceName)) + if err != nil { + return nil, fmt.Errorf("unable to connect to the ScyllaDB database") + } + store.session = createdSession + return &store, nil +} + +func (c *CassandraOnlineStore) getFqTableName(tableName string) string { + return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) +} + +func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string) string { + // this prevents fetching unnecessary features + quotedFeatureNames := make([]string, len(featureNames)) + for i, featureName := range featureNames { + quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) + } + + return fmt.Sprintf( + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`, + tableName, + strings.Join(quotedFeatureNames, ","), + ) +} + +func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.EntityKey) ([]any, map[string]int, error) { + cassandraKeys := make([]any, len(entityKeys)) + cassandraKeyToEntityIndex := make(map[string]int) + for i := 0; i < len(entityKeys); i++ { + var key, err = utils.SerializeEntityKey(entityKeys[i], c.config.EntityKeySerializationVersion) + if err != nil { + return nil, nil, err + } + encodedKey := hex.EncodeToString(*key) + cassandraKeys[i] = encodedKey + cassandraKeyToEntityIndex[encodedKey] = i + } + return cassandraKeys, cassandraKeyToEntityIndex, nil +} +func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + uniqueNames := make(map[string]int32) + for _, fvName := range featureViewNames { + uniqueNames[fvName] = 0 + } + if len(uniqueNames) != 1 { + return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") + } + + serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + + if err != nil { + return nil, fmt.Errorf("error when serializing entity keys for Cassandra") + } + results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) + } + + featureNamesToIdx := make(map[string]int) + for idx, name := range featureNames { + featureNamesToIdx[name] = idx + } + + featureViewName := featureViewNames[0] + + // Prepare the query + tableName := c.getFqTableName(featureViewName) + cqlStatement := c.getCQLStatement(tableName, featureNames) + + var waitGroup sync.WaitGroup + waitGroup.Add(len(serializedEntityKeys)) + + errorsChannel := make(chan error, len(serializedEntityKeys)) + for _, serializedEntityKey := range serializedEntityKeys { + go func(serEntityKey any) { + defer waitGroup.Done() + + iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter() + + rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)] + + // fill the row with nulls if not found + if iter.NumRows() == 0 { + for _, featName := range featureNames { + results[rowIdx][featureNamesToIdx[featName]] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + return + } + + scanner := iter.Scanner() + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + rowFeatures := make(map[string]FeatureData) + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)") + return + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value") + return + } + + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + rowFeatures[featureName] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, + } + } + } + + if err := scanner.Err(); err != nil { + errorsChannel <- errors.New("failed to scan features: " + err.Error()) + return + } + + for _, featName := range featureNames { + featureData, ok := rowFeatures[featName] + if !ok { + featureData = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + results[rowIdx][featureNamesToIdx[featName]] = featureData + } + }(serializedEntityKey) + } + + // wait until all concurrent single-key queries are done + waitGroup.Wait() + close(errorsChannel) + + var collectedErrors []error + for err := range errorsChannel { + if err != nil { + collectedErrors = append(collectedErrors, err) + } + } + if len(collectedErrors) > 0 { + return nil, errors.Join(collectedErrors...) + } + + return results, nil +} + +func (c *CassandraOnlineStore) Destruct() { + c.session.Close() +} diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go new file mode 100644 index 0000000000..67a9eea548 --- /dev/null +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -0,0 +1,78 @@ +package onlinestore + +import ( + "context" + "github.com/gocql/gocql" + "github.com/stretchr/testify/assert" + "reflect" + "testing" +) + +func TestExtractCassandraConfig_CorrectDefaults(t *testing.T) { + var config = map[string]interface{}{} + cassandraConfig, _ := extractCassandraConfig(config) + + assert.Equal(t, []string{"127.0.0.1"}, cassandraConfig.hosts) + assert.Equal(t, "", cassandraConfig.username) + assert.Equal(t, "", cassandraConfig.password) + assert.Equal(t, "feast_keyspace", cassandraConfig.keyspace) + assert.Equal(t, 4, cassandraConfig.protocolVersion) + assert.True(t, reflect.TypeOf(gocql.RoundRobinHostPolicy()) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) + assert.Equal(t, int64(0), cassandraConfig.connectionTimeoutMillis) + assert.Equal(t, int64(0), cassandraConfig.requestTimeoutMillis) +} + +func TestExtractCassandraConfig_CorrectSettings(t *testing.T) { + var config = map[string]any{ + "hosts": []any{"0.0.0.0", "255.255.255.255"}, + "username": "scylladb", + "password": "scylladb", + "keyspace": "scylladb", + "protocol_version": 271.0, + "load_balancing": map[string]any{ + "load_balancing_policy": "DCAwareRoundRobinPolicy", + "local_dc": "aws-us-west-2", + }, + "connection_timeout_millis": 271.0, + "request_timeout_millis": 271.0, + } + cassandraConfig, _ := extractCassandraConfig(config) + + assert.Equal(t, []string{"0.0.0.0", "255.255.255.255"}, cassandraConfig.hosts) + assert.Equal(t, "scylladb", cassandraConfig.username) + assert.Equal(t, "scylladb", cassandraConfig.password) + assert.Equal(t, "scylladb", cassandraConfig.keyspace) + assert.Equal(t, 271, cassandraConfig.protocolVersion) + assert.True(t, reflect.TypeOf(gocql.DCAwareRoundRobinPolicy("aws-us-west-2")) == reflect.TypeOf(cassandraConfig.loadBalancingPolicy)) + assert.Equal(t, int64(271), cassandraConfig.connectionTimeoutMillis) + assert.Equal(t, int64(271), cassandraConfig.requestTimeoutMillis) +} + +func TestGetFqTableName(t *testing.T) { + store := CassandraOnlineStore{ + project: "dummy_project", + clusterConfigs: &gocql.ClusterConfig{ + Keyspace: "scylladb", + }, + } + + fqTableName := store.getFqTableName("dummy_fv") + assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName) +} + +func TestGetCQLStatement(t *testing.T) { + store := CassandraOnlineStore{} + fqTableName := `"scylladb"."dummy_project_dummy_fv"` + + cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"}) + assert.Equal(t, + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`, + cqlStatement, + ) +} + +func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) { + store := CassandraOnlineStore{} + _, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"}) + assert.Error(t, err) +} diff --git a/go/internal/feast/onlinestore/onlinestore.go b/go/internal/feast/onlinestore/onlinestore.go index 2f30e16d67..5dcc8e5370 100644 --- a/go/internal/feast/onlinestore/onlinestore.go +++ b/go/internal/feast/onlinestore/onlinestore.go @@ -61,7 +61,10 @@ func NewOnlineStore(config *registry.RepoConfig) (OnlineStore, error) { } else if onlineStoreType == "redis" { onlineStore, err := NewRedisOnlineStore(config.Project, config, config.OnlineStore) return onlineStore, err + } else if onlineStoreType == "cassandra" || onlineStoreType == "scylladb" { + onlineStore, err := NewCassandraOnlineStore(config.Project, config, config.OnlineStore) + return onlineStore, err } else { - return nil, fmt.Errorf("%s online store type is currently not supported; only redis and sqlite are supported", onlineStoreType) + return nil, fmt.Errorf("%s online store type is currently not supported; only redis, scylladb, cassandra and sqlite are supported", onlineStoreType) } } diff --git a/go/internal/feast/onlinestore/redisonlinestore.go b/go/internal/feast/onlinestore/redisonlinestore.go index 3a57d0fcec..a3cc2a20aa 100644 --- a/go/internal/feast/onlinestore/redisonlinestore.go +++ b/go/internal/feast/onlinestore/redisonlinestore.go @@ -6,8 +6,8 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/feast-dev/feast/go/internal/feast/utils" "os" - "sort" "strconv" "strings" @@ -338,102 +338,10 @@ func (r *RedisOnlineStore) Destruct() { } func buildRedisKey(project string, entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { - serKey, err := serializeEntityKey(entityKey, entityKeySerializationVersion) + serKey, err := utils.SerializeEntityKey(entityKey, entityKeySerializationVersion) if err != nil { return nil, err } fullKey := append(*serKey, []byte(project)...) return &fullKey, nil } - -func serializeEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { - // Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. - - // Ensure that we have the right amount of join keys and entity values - if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { - return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) - } - - // Make sure that join keys are sorted so that we have consistent key building - m := make(map[string]*types.Value) - - for i := 0; i < len(entityKey.JoinKeys); i++ { - m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] - } - - keys := make([]string, 0, len(m)) - for k := range entityKey.JoinKeys { - keys = append(keys, entityKey.JoinKeys[k]) - } - sort.Strings(keys) - - // Build the key - length := 5 * len(keys) - bufferList := make([][]byte, length) - - for i := 0; i < len(keys); i++ { - offset := i * 2 - byteBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) - bufferList[offset] = byteBuffer - bufferList[offset+1] = []byte(keys[i]) - } - - for i := 0; i < len(keys); i++ { - offset := (2 * len(keys)) + (i * 3) - value := m[keys[i]].GetVal() - - valueBytes, valueTypeBytes, err := serializeValue(value, entityKeySerializationVersion) - if err != nil { - return valueBytes, err - } - - typeBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) - - lenBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) - - bufferList[offset+0] = typeBuffer - bufferList[offset+1] = lenBuffer - bufferList[offset+2] = *valueBytes - } - - // Convert from an array of byte arrays to a single byte array - var entityKeyBuffer []byte - for i := 0; i < len(bufferList); i++ { - entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) - } - - return &entityKeyBuffer, nil -} - -func serializeValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { - // TODO: Implement support for other types (at least the major types like ints, strings, bytes) - switch x := (value).(type) { - case *types.Value_StringVal: - valueString := []byte(x.StringVal) - return &valueString, types.ValueType_STRING, nil - case *types.Value_BytesVal: - return &x.BytesVal, types.ValueType_BYTES, nil - case *types.Value_Int32Val: - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) - return &valueBuffer, types.ValueType_INT32, nil - case *types.Value_Int64Val: - if entityKeySerializationVersion <= 1 { - // We unfortunately have to use 32 bit here for backward compatibility :( - valueBuffer := make([]byte, 4) - binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } else { - valueBuffer := make([]byte, 8) - binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) - return &valueBuffer, types.ValueType_INT64, nil - } - case nil: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - default: - return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) - } -} diff --git a/go/internal/feast/onlinestore/sqliteonlinestore.go b/go/internal/feast/onlinestore/sqliteonlinestore.go index 6c37258e74..95f95610e5 100644 --- a/go/internal/feast/onlinestore/sqliteonlinestore.go +++ b/go/internal/feast/onlinestore/sqliteonlinestore.go @@ -1,10 +1,9 @@ package onlinestore import ( - "crypto/sha1" "database/sql" - "encoding/hex" "errors" + "github.com/feast-dev/feast/go/internal/feast/utils" "strings" "sync" "time" @@ -71,12 +70,12 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. in_query := make([]string, len(entityKeys)) serialized_entities := make([]interface{}, len(entityKeys)) for i := 0; i < len(entityKeys); i++ { - serKey, err := serializeEntityKey(entityKeys[i], s.repoConfig.EntityKeySerializationVersion) + serKey, err := utils.SerializeEntityKey(entityKeys[i], s.repoConfig.EntityKeySerializationVersion) if err != nil { return nil, err } // TODO: fix this, string conversion is not safe - entityNameToEntityIndex[hashSerializedEntityKey(serKey)] = i + entityNameToEntityIndex[utils.HashSerializedEntityKey(serKey)] = i // for IN clause in read query in_query[i] = "?" serialized_entities[i] = *serKey @@ -109,7 +108,7 @@ func (s *SqliteOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types. if err := proto.Unmarshal(valueString, &value); err != nil { return nil, errors.New("error converting parsed value to types.Value") } - rowIdx := entityNameToEntityIndex[hashSerializedEntityKey(&entity_key)] + rowIdx := entityNameToEntityIndex[utils.HashSerializedEntityKey(&entity_key)] if results[rowIdx] == nil { results[rowIdx] = make([]FeatureData, featureCount) } @@ -152,13 +151,3 @@ func initializeConnection(db_path string) (*sql.DB, error) { } return db, nil } - -func hashSerializedEntityKey(serializedEntityKey *[]byte) string { - if serializedEntityKey == nil { - return "" - } - h := sha1.New() - h.Write(*serializedEntityKey) - sha1_hash := hex.EncodeToString(h.Sum(nil)) - return sha1_hash -} diff --git a/go/internal/feast/utils/key_utils.go b/go/internal/feast/utils/key_utils.go new file mode 100644 index 0000000000..7cf9459455 --- /dev/null +++ b/go/internal/feast/utils/key_utils.go @@ -0,0 +1,110 @@ +package utils + +import ( + "crypto/sha1" + "encoding/binary" + "encoding/hex" + "fmt" + "github.com/feast-dev/feast/go/protos/feast/types" + "sort" +) + +func HashSerializedEntityKey(serializedEntityKey *[]byte) string { + if serializedEntityKey == nil { + return "" + } + h := sha1.New() + h.Write(*serializedEntityKey) + return hex.EncodeToString(h.Sum(nil)) +} + +// SerializeEntityKey Serialize entity key to a bytestring so that it can be used as a lookup key in a hash table. +func SerializeEntityKey(entityKey *types.EntityKey, entityKeySerializationVersion int64) (*[]byte, error) { + // Ensure that we have the right amount of join keys and entity values + if len(entityKey.JoinKeys) != len(entityKey.EntityValues) { + return nil, fmt.Errorf("the amount of join key names and entity values don't match: %s vs %s", entityKey.JoinKeys, entityKey.EntityValues) + } + + // Make sure that join keys are sorted so that we have consistent key building + m := make(map[string]*types.Value) + + for i := 0; i < len(entityKey.JoinKeys); i++ { + m[entityKey.JoinKeys[i]] = entityKey.EntityValues[i] + } + + keys := make([]string, 0, len(m)) + for k := range entityKey.JoinKeys { + keys = append(keys, entityKey.JoinKeys[k]) + } + sort.Strings(keys) + + // Build the key + length := 5 * len(keys) + bufferList := make([][]byte, length) + + for i := 0; i < len(keys); i++ { + offset := i * 2 + byteBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(byteBuffer, uint32(types.ValueType_Enum_value["STRING"])) + bufferList[offset] = byteBuffer + bufferList[offset+1] = []byte(keys[i]) + } + + for i := 0; i < len(keys); i++ { + offset := (2 * len(keys)) + (i * 3) + value := m[keys[i]].GetVal() + + valueBytes, valueTypeBytes, err := serializeValue(value, entityKeySerializationVersion) + if err != nil { + return valueBytes, err + } + + typeBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(typeBuffer, uint32(valueTypeBytes)) + + lenBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(lenBuffer, uint32(len(*valueBytes))) + + bufferList[offset+0] = typeBuffer + bufferList[offset+1] = lenBuffer + bufferList[offset+2] = *valueBytes + } + + // Convert from an array of byte arrays to a single byte array + var entityKeyBuffer []byte + for i := 0; i < len(bufferList); i++ { + entityKeyBuffer = append(entityKeyBuffer, bufferList[i]...) + } + + return &entityKeyBuffer, nil +} + +func serializeValue(value interface{}, entityKeySerializationVersion int64) (*[]byte, types.ValueType_Enum, error) { + // TODO: Implement support for other types (at least the major types like ints, strings, bytes) + switch x := (value).(type) { + case *types.Value_StringVal: + valueString := []byte(x.StringVal) + return &valueString, types.ValueType_STRING, nil + case *types.Value_BytesVal: + return &x.BytesVal, types.ValueType_BYTES, nil + case *types.Value_Int32Val: + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int32Val)) + return &valueBuffer, types.ValueType_INT32, nil + case *types.Value_Int64Val: + if entityKeySerializationVersion <= 1 { + // We unfortunately have to use 32 bit here for backward compatibility :( + valueBuffer := make([]byte, 4) + binary.LittleEndian.PutUint32(valueBuffer, uint32(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } else { + valueBuffer := make([]byte, 8) + binary.LittleEndian.PutUint64(valueBuffer, uint64(x.Int64Val)) + return &valueBuffer, types.ValueType_INT64, nil + } + case nil: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + default: + return nil, types.ValueType_INVALID, fmt.Errorf("could not detect type for %v", x) + } +}