From a38df8c32bf8876ff652cb91db03338c2c526172 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Fri, 25 Oct 2024 20:40:19 +0200 Subject: [PATCH 1/2] Refactor cassandra scaler Signed-off-by: rickbrouwer --- pkg/scalers/cassandra_scaler.go | 292 ++++++++------------- pkg/scalers/cassandra_scaler_test.go | 368 +++++++++++++++++++-------- 2 files changed, 367 insertions(+), 293 deletions(-) diff --git a/pkg/scalers/cassandra_scaler.go b/pkg/scalers/cassandra_scaler.go index 6e8705d2d8d..24abfc59494 100644 --- a/pkg/scalers/cassandra_scaler.go +++ b/pkg/scalers/cassandra_scaler.go @@ -18,53 +18,66 @@ import ( kedautil "github.com/kedacore/keda/v2/pkg/util" ) -// cassandraScaler exposes a data pointer to CassandraMetadata and gocql.Session connection. type cassandraScaler struct { metricType v2.MetricTargetType - metadata *CassandraMetadata + metadata cassandraMetadata session *gocql.Session logger logr.Logger } -// CassandraMetadata defines metadata used by KEDA to query a Cassandra table. -type CassandraMetadata struct { - username string - password string - enableTLS bool - cert string - key string - ca string - clusterIPAddress string - port int - consistency gocql.Consistency - protocolVersion int - keyspace string - query string - targetQueryValue int64 - activationTargetQueryValue int64 - triggerIndex int +type cassandraMetadata struct { + Username string `keda:"name=username, order=triggerMetadata"` + Password string `keda:"name=password, order=authParams"` + TLS string `keda:"name=tls, order=authParams, enum=enable;disable, default=disable, optional"` + Cert string `keda:"name=cert, order=authParams, optional"` + Key string `keda:"name=key, order=authParams, optional"` + CA string `keda:"name=ca, order=authParams, optional"` + ClusterIPAddress string `keda:"name=clusterIPAddress, order=triggerMetadata"` + Port int `keda:"name=port, order=triggerMetadata, optional"` + Consistency string `keda:"name=consistency, order=triggerMetadata, default=one"` + ProtocolVersion int `keda:"name=protocolVersion, order=triggerMetadata, default=4"` + Keyspace string `keda:"name=keyspace, order=triggerMetadata"` + Query string `keda:"name=query, order=triggerMetadata"` + TargetQueryValue int64 `keda:"name=targetQueryValue, order=triggerMetadata"` + ActivationTargetQueryValue int64 `keda:"name=activationTargetQueryValue, order=triggerMetadata, default=0"` + TriggerIndex int } const ( - tlsEnable = "enable" - tlsDisable = "disable" + tlsEnable = "enable" ) -// NewCassandraScaler creates a new Cassandra scaler. +func (m *cassandraMetadata) Validate() error { + // Handle port in ClusterIPAddress + splitVal := strings.Split(m.ClusterIPAddress, ":") + if len(splitVal) == 2 { + if port, err := strconv.Atoi(splitVal[1]); err == nil { + m.Port = port + return nil + } + } + + if m.Port == 0 { + return fmt.Errorf("no port given") + } + + m.ClusterIPAddress = net.JoinHostPort(m.ClusterIPAddress, fmt.Sprintf("%d", m.Port)) + return nil +} + +// NewCassandraScaler creates a new Cassandra scaler func NewCassandraScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { metricType, err := GetMetricTargetType(config) if err != nil { return nil, fmt.Errorf("error getting scaler metric type: %w", err) } - logger := InitializeLogger(config, "cassandra_scaler") - meta, err := parseCassandraMetadata(config) if err != nil { return nil, fmt.Errorf("error parsing cassandra metadata: %w", err) } - session, err := newCassandraSession(meta, logger) + session, err := newCassandraSession(meta, InitializeLogger(config, "cassandra_scaler")) if err != nil { return nil, fmt.Errorf("error establishing cassandra session: %w", err) } @@ -73,108 +86,32 @@ func NewCassandraScaler(config *scalersconfig.ScalerConfig) (Scaler, error) { metricType: metricType, metadata: meta, session: session, - logger: logger, + logger: InitializeLogger(config, "cassandra_scaler"), }, nil } -// parseCassandraMetadata parses the metadata and returns a CassandraMetadata or an error if the ScalerConfig is invalid. -func parseCassandraMetadata(config *scalersconfig.ScalerConfig) (*CassandraMetadata, error) { - meta := &CassandraMetadata{} - var err error - - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, fmt.Errorf("no query given") - } - - if val, ok := config.TriggerMetadata["targetQueryValue"]; ok { - targetQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, fmt.Errorf("targetQueryValue parsing error %w", err) - } - meta.targetQueryValue = targetQueryValue - } else { - if config.AsMetricSource { - meta.targetQueryValue = 0 - } else { - return nil, fmt.Errorf("no targetQueryValue given") - } - } - - meta.activationTargetQueryValue = 0 - if val, ok := config.TriggerMetadata["activationTargetQueryValue"]; ok { - activationTargetQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, fmt.Errorf("activationTargetQueryValue parsing error %w", err) - } - meta.activationTargetQueryValue = activationTargetQueryValue - } - - if val, ok := config.TriggerMetadata["username"]; ok { - meta.username = val - } else { - return nil, fmt.Errorf("no username given") - } - - if val, ok := config.TriggerMetadata["port"]; ok { - port, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("port parsing error %w", err) - } - meta.port = port - } - - if val, ok := config.TriggerMetadata["clusterIPAddress"]; ok { - splitval := strings.Split(val, ":") - port := splitval[len(splitval)-1] - - _, err := strconv.Atoi(port) - switch { - case err == nil: - meta.clusterIPAddress = val - case meta.port > 0: - meta.clusterIPAddress = net.JoinHostPort(val, fmt.Sprintf("%d", meta.port)) - default: - return nil, fmt.Errorf("no port given") - } - } else { - return nil, fmt.Errorf("no cluster IP address given") - } - - if val, ok := config.TriggerMetadata["protocolVersion"]; ok { - protocolVersion, err := strconv.Atoi(val) - if err != nil { - return nil, fmt.Errorf("protocolVersion parsing error %w", err) - } - meta.protocolVersion = protocolVersion - } else { - meta.protocolVersion = 4 +func parseCassandraMetadata(config *scalersconfig.ScalerConfig) (cassandraMetadata, error) { + meta := cassandraMetadata{} + err := config.TypedConfig(&meta) + if err != nil { + return meta, fmt.Errorf("error parsing cassandra metadata: %w", err) } - if val, ok := config.TriggerMetadata["consistency"]; ok { - meta.consistency = gocql.ParseConsistency(val) - } else { - meta.consistency = gocql.One + if config.AsMetricSource { + meta.TargetQueryValue = 0 } - if val, ok := config.TriggerMetadata["keyspace"]; ok { - meta.keyspace = val - } else { - return nil, fmt.Errorf("no keyspace given") - } - if val, ok := config.AuthParams["password"]; ok { - meta.password = val - } else { - return nil, fmt.Errorf("no password given") + err = meta.Validate() + if err != nil { + return meta, err } - if err = parseCassandraTLS(config, meta); err != nil { + err = parseCassandraTLS(&meta) + if err != nil { return meta, err } - meta.triggerIndex = config.TriggerIndex - + meta.TriggerIndex = config.TriggerIndex return meta, nil } @@ -182,8 +119,8 @@ func createTempFile(prefix string, content string) (string, error) { tempCassandraDir := fmt.Sprintf("%s%c%s", os.TempDir(), os.PathSeparator, "cassandra") err := os.MkdirAll(tempCassandraDir, 0700) if err != nil { - return "", fmt.Errorf(`error creating temporary directory: %s. Error: %w - Note, when running in a container a writable /tmp/cassandra emptyDir must be mounted. Refer to documentation`, tempCassandraDir, err) + return "", fmt.Errorf(`error creating temporary directory: %s. Error: %w + Note, when running in a container a writable /tmp/cassandra emptyDir must be mounted. Refer to documentation`, tempCassandraDir, err) } f, err := os.CreateTemp(tempCassandraDir, prefix+"-*.pem") @@ -200,72 +137,52 @@ func createTempFile(prefix string, content string) (string, error) { return f.Name(), nil } -func parseCassandraTLS(config *scalersconfig.ScalerConfig, meta *CassandraMetadata) error { - meta.enableTLS = false - if val, ok := config.AuthParams["tls"]; ok { - val = strings.TrimSpace(val) - if val == tlsEnable { - certGiven := config.AuthParams["cert"] != "" - keyGiven := config.AuthParams["key"] != "" - caCertGiven := config.AuthParams["ca"] != "" - if certGiven && !keyGiven { - return errors.New("no key given") - } - if keyGiven && !certGiven { - return errors.New("no cert given") - } - if !keyGiven && !certGiven { - return errors.New("no cert/key given") - } +func parseCassandraTLS(meta *cassandraMetadata) error { + if meta.TLS == tlsEnable { + if meta.Cert == "" || meta.Key == "" { + return errors.New("both cert and key are required when TLS is enabled") + } - certFilePath, err := createTempFile("cert", config.AuthParams["cert"]) - if err != nil { - // handle error - return errors.New("Error creating cert file: " + err.Error()) - } + // Create temp files for certs + certFilePath, err := createTempFile("cert", meta.Cert) + if err != nil { + return fmt.Errorf("error creating cert file: %w", err) + } + meta.Cert = certFilePath - keyFilePath, err := createTempFile("key", config.AuthParams["key"]) - if err != nil { - // handle error - return errors.New("Error creating key file: " + err.Error()) - } + keyFilePath, err := createTempFile("key", meta.Key) + if err != nil { + return fmt.Errorf("error creating key file: %w", err) + } + meta.Key = keyFilePath - meta.cert = certFilePath - meta.key = keyFilePath - meta.ca = config.AuthParams["ca"] - if !caCertGiven { - meta.ca = "" - } else { - caCertFilePath, err := createTempFile("caCert", config.AuthParams["ca"]) - meta.ca = caCertFilePath - if err != nil { - // handle error - return errors.New("Error creating ca file: " + err.Error()) - } + // If CA cert is given, make also file + if meta.CA != "" { + caCertFilePath, err := createTempFile("caCert", meta.CA) + if err != nil { + return fmt.Errorf("error creating ca file: %w", err) } - meta.enableTLS = true - } else if val != tlsDisable { - return fmt.Errorf("err incorrect value for TLS given: %s", val) + meta.CA = caCertFilePath } } return nil } -// newCassandraSession returns a new Cassandra session for the provided CassandraMetadata. -func newCassandraSession(meta *CassandraMetadata, logger logr.Logger) (*gocql.Session, error) { - cluster := gocql.NewCluster(meta.clusterIPAddress) - cluster.ProtoVersion = meta.protocolVersion - cluster.Consistency = meta.consistency +// newCassandraSession returns a new Cassandra session for the provided CassandraMetadata +func newCassandraSession(meta cassandraMetadata, logger logr.Logger) (*gocql.Session, error) { + cluster := gocql.NewCluster(meta.ClusterIPAddress) + cluster.ProtoVersion = meta.ProtocolVersion + cluster.Consistency = gocql.ParseConsistency(meta.Consistency) cluster.Authenticator = gocql.PasswordAuthenticator{ - Username: meta.username, - Password: meta.password, + Username: meta.Username, + Password: meta.Password, } - if meta.enableTLS { + if meta.TLS == tlsEnable { cluster.SslOpts = &gocql.SslOptions{ - CertPath: meta.cert, - KeyPath: meta.key, - CaPath: meta.ca, + CertPath: meta.Cert, + KeyPath: meta.Key, + CaPath: meta.CA, } } @@ -278,22 +195,19 @@ func newCassandraSession(meta *CassandraMetadata, logger logr.Logger) (*gocql.Se return session, nil } -// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler. +// GetMetricSpecForScaling returns the MetricSpec for the Horizontal Pod Autoscaler func (s *cassandraScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("cassandra-%s", s.metadata.keyspace))), + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, kedautil.NormalizeString(fmt.Sprintf("cassandra-%s", s.metadata.Keyspace))), }, - Target: GetMetricTarget(s.metricType, s.metadata.targetQueryValue), - } - metricSpec := v2.MetricSpec{ - External: externalMetric, Type: externalMetricType, + Target: GetMetricTarget(s.metricType, s.metadata.TargetQueryValue), } - + metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType} return []v2.MetricSpec{metricSpec} } -// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric. +// GetMetricsAndActivity returns a value for a supported metric or an error if there is a problem getting the metric func (s *cassandraScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { num, err := s.GetQueryResult(ctx) if err != nil { @@ -301,38 +215,36 @@ func (s *cassandraScaler) GetMetricsAndActivity(ctx context.Context, metricName } metric := GenerateMetricInMili(metricName, float64(num)) - - return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetQueryValue, nil + return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetQueryValue, nil } -// GetQueryResult returns the result of the scaler query. +// GetQueryResult returns the result of the scaler query func (s *cassandraScaler) GetQueryResult(ctx context.Context) (int64, error) { var value int64 - if err := s.session.Query(s.metadata.query).WithContext(ctx).Scan(&value); err != nil { + if err := s.session.Query(s.metadata.Query).WithContext(ctx).Scan(&value); err != nil { if err != gocql.ErrNotFound { s.logger.Error(err, "query failed") return 0, err } } - return value, nil } -// Close closes the Cassandra session connection. +// Close closes the Cassandra session connection func (s *cassandraScaler) Close(_ context.Context) error { // clean up any temporary files - if strings.TrimSpace(s.metadata.cert) != "" { - if err := os.Remove(s.metadata.cert); err != nil { + if s.metadata.Cert != "" { + if err := os.Remove(s.metadata.Cert); err != nil { return err } } - if strings.TrimSpace(s.metadata.key) != "" { - if err := os.Remove(s.metadata.key); err != nil { + if s.metadata.Key != "" { + if err := os.Remove(s.metadata.Key); err != nil { return err } } - if strings.TrimSpace(s.metadata.ca) != "" { - if err := os.Remove(s.metadata.ca); err != nil { + if s.metadata.CA != "" { + if err := os.Remove(s.metadata.CA); err != nil { return err } } diff --git a/pkg/scalers/cassandra_scaler_test.go b/pkg/scalers/cassandra_scaler_test.go index 39930946a56..d2e892b8c32 100644 --- a/pkg/scalers/cassandra_scaler_test.go +++ b/pkg/scalers/cassandra_scaler_test.go @@ -2,156 +2,318 @@ package scalers import ( "context" - "fmt" "os" "testing" "github.com/go-logr/logr" "github.com/gocql/gocql" + "github.com/stretchr/testify/assert" + v2 "k8s.io/api/autoscaling/v2" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) type parseCassandraMetadataTestData struct { + name string metadata map[string]string - isError bool authParams map[string]string + isError bool } type parseCassandraTLSTestData struct { + name string authParams map[string]string isError bool - enableTLS bool + tlsEnabled bool } type cassandraMetricIdentifier struct { + name string metadataTestData *parseCassandraMetadataTestData triggerIndex int - name string + metricName string } var testCassandraMetadata = []parseCassandraMetadataTestData{ - // nothing passed - {map[string]string{}, true, map[string]string{}}, - // everything is passed in verbatim - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "clusterIPAddress": "cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // metricName is generated from keyspace - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no query passed - {map[string]string{"targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no targetQueryValue passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no username passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no port passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no clusterIPAddress passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no keyspace passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "TriggerIndex": "0"}, true, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, - // no password passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"}, true, map[string]string{}}, - // fix issue[4110] passed - {map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "port": "9042", "clusterIPAddress": "https://cassandra.test", "keyspace": "test_keyspace", "TriggerIndex": "0"}, false, map[string]string{"password": "Y2Fzc2FuZHJhCg=="}}, + { + name: "nothing passed", + metadata: map[string]string{}, + authParams: map[string]string{}, + isError: true, + }, + { + name: "everything passed verbatim", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "clusterIPAddress": "cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, + { + name: "metricName from keyspace", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, + { + name: "no query", + metadata: map[string]string{ + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no targetQueryValue", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no username", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no port", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no clusterIPAddress", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no keyspace", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: true, + }, + { + name: "no password", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{}, + isError: true, + }, + { + name: "with https prefix", + metadata: map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "port": "9042", + "clusterIPAddress": "https://cassandra.test", + "keyspace": "test_keyspace", + }, + authParams: map[string]string{"password": "Y2Fzc2FuZHJhCg=="}, + isError: false, + }, } var tlsAuthParamsTestData = []parseCassandraTLSTestData{ - // success, TLS cert/key - {map[string]string{"tls": "enable", "cert": "ceert", "key": "keey", "password": "Y2Fzc2FuZHJhCg=="}, false, true}, - // failure, TLS missing cert - {map[string]string{"tls": "enable", "key": "keey", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, - // failure, TLS missing key - {map[string]string{"tls": "enable", "cert": "ceert", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, - // failure, TLS invalid - {map[string]string{"tls": "yes", "cert": "ceert", "key": "keeey", "password": "Y2Fzc2FuZHJhCg=="}, true, false}, + { + name: "success with cert/key", + authParams: map[string]string{ + "tls": "enable", + "cert": "test-cert", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: false, + tlsEnabled: true, + }, + { + name: "failure missing cert", + authParams: map[string]string{ + "tls": "enable", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, + { + name: "failure missing key", + authParams: map[string]string{ + "tls": "enable", + "cert": "test-cert", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, + { + name: "failure invalid tls value", + authParams: map[string]string{ + "tls": "yes", + "cert": "test-cert", + "key": "test-key", + "password": "Y2Fzc2FuZHJhCg==", + }, + isError: true, + tlsEnabled: false, + }, } var cassandraMetricIdentifiers = []cassandraMetricIdentifier{ - {&testCassandraMetadata[1], 0, "s0-cassandra-test_keyspace"}, - {&testCassandraMetadata[2], 1, "s1-cassandra-test_keyspace"}, + { + name: "everything passed verbatim", + metadataTestData: &testCassandraMetadata[1], + triggerIndex: 0, + metricName: "s0-cassandra-test_keyspace", + }, + { + name: "metricName from keyspace", + metadataTestData: &testCassandraMetadata[2], + triggerIndex: 1, + metricName: "s1-cassandra-test_keyspace", + }, +} + +var successMetaData = map[string]string{ + "query": "SELECT COUNT(*) FROM test_keyspace.test_table;", + "targetQueryValue": "1", + "username": "cassandra", + "clusterIPAddress": "cassandra.test:9042", + "keyspace": "test_keyspace", } func TestCassandraParseMetadata(t *testing.T) { - testCaseNum := 1 for _, testData := range testCassandraMetadata { - _, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) - if err != nil && !testData.isError { - t.Errorf("Expected success but got error for unit test # %v", testCaseNum) - } - if testData.isError && err == nil { - t.Errorf("Expected error but got success for unit test # %v", testCaseNum) - } - testCaseNum++ + t.Run(testData.name, func(t *testing.T) { + _, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadata, + AuthParams: testData.authParams, + }) + if err != nil && !testData.isError { + t.Error("Expected success but got error", err) + } + if testData.isError && err == nil { + t.Error("Expected error but got success") + } + }) } } func TestCassandraGetMetricSpecForScaling(t *testing.T) { for _, testData := range cassandraMetricIdentifiers { - meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex, AuthParams: testData.metadataTestData.authParams}) - if err != nil { - t.Fatal("Could not parse metadata:", err) - } - cluster := gocql.NewCluster(meta.clusterIPAddress) - session, _ := cluster.CreateSession() - mockCassandraScaler := cassandraScaler{"", meta, session, logr.Discard()} - - metricSpec := mockCassandraScaler.GetMetricSpecForScaling(context.Background()) - metricName := metricSpec[0].External.Metric.Name - if metricName != testData.name { - t.Errorf("Wrong External metric source name: %s, expected: %s", metricName, testData.name) - } - } -} + t.Run(testData.name, func(t *testing.T) { + meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: testData.metadataTestData.metadata, + TriggerIndex: testData.triggerIndex, + AuthParams: testData.metadataTestData.authParams, + }) + if err != nil { + t.Fatal("Could not parse metadata:", err) + } + mockCassandraScaler := cassandraScaler{ + metricType: v2.AverageValueMetricType, + metadata: meta, + session: &gocql.Session{}, + logger: logr.Discard(), + } -func assertCertContents(testData parseCassandraTLSTestData, meta *CassandraMetadata, prop string) error { - if testData.authParams[prop] != "" { - var path string - switch prop { - case "cert": - path = meta.cert - case "key": - path = meta.key - } - data, err := os.ReadFile(path) - if err != nil { - return fmt.Errorf("expected to find '%v' file at %v", prop, path) - } - contents := string(data) - if contents != testData.authParams[prop] { - return fmt.Errorf("expected value: '%v' but got '%v'", testData.authParams[prop], contents) - } + metricSpec := mockCassandraScaler.GetMetricSpecForScaling(context.Background()) + metricName := metricSpec[0].External.Metric.Name + assert.Equal(t, testData.metricName, metricName) + }) } - return nil } -var successMetaData = map[string]string{"query": "SELECT COUNT(*) FROM test_keyspace.test_table;", "targetQueryValue": "1", "username": "cassandra", "clusterIPAddress": "cassandra.test:9042", "keyspace": "test_keyspace", "TriggerIndex": "0"} - func TestParseCassandraTLS(t *testing.T) { for _, testData := range tlsAuthParamsTestData { - meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{TriggerMetadata: successMetaData, AuthParams: testData.authParams}) - - if err != nil && !testData.isError { - t.Error("Expected success but got error", err) - } - if testData.isError && err == nil { - t.Error("Expected error but got success") - } - if meta.enableTLS != testData.enableTLS { - t.Errorf("Expected enableTLS to be set to %v but got %v\n", testData.enableTLS, meta.enableTLS) - } - if meta.enableTLS { - if meta.cert != testData.authParams["cert"] { - err := assertCertContents(testData, meta, "cert") - if err != nil { - t.Errorf(err.Error()) - } - } - if meta.key != testData.authParams["key"] { - err := assertCertContents(testData, meta, "key") - if err != nil { - t.Errorf(err.Error()) + t.Run(testData.name, func(t *testing.T) { + meta, err := parseCassandraMetadata(&scalersconfig.ScalerConfig{ + TriggerMetadata: successMetaData, + AuthParams: testData.authParams, + }) + + if testData.isError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, testData.tlsEnabled, meta.TLS == "enable") + + if meta.TLS == "enable" { + // Verify cert contents + if testData.authParams["cert"] != "" { + data, err := os.ReadFile(meta.Cert) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["cert"], string(data)) + // Cleanup + defer os.Remove(meta.Cert) + } + + // Verify key contents + if testData.authParams["key"] != "" { + data, err := os.ReadFile(meta.Key) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["key"], string(data)) + // Cleanup + defer os.Remove(meta.Key) + } + + // Verify CA contents if present + if testData.authParams["ca"] != "" { + data, err := os.ReadFile(meta.CA) + assert.NoError(t, err) + assert.Equal(t, testData.authParams["ca"], string(data)) + // Cleanup + defer os.Remove(meta.CA) + } } } - } + }) } } From e8996a85e6be1fdde52166a20b49f46b8a33fb65 Mon Sep 17 00:00:00 2001 From: rickbrouwer Date: Tue, 29 Oct 2024 10:22:14 +0100 Subject: [PATCH 2/2] Update feedback Signed-off-by: rickbrouwer --- pkg/scalers/cassandra_scaler.go | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pkg/scalers/cassandra_scaler.go b/pkg/scalers/cassandra_scaler.go index 24abfc59494..b41dddb9dec 100644 --- a/pkg/scalers/cassandra_scaler.go +++ b/pkg/scalers/cassandra_scaler.go @@ -48,6 +48,10 @@ const ( ) func (m *cassandraMetadata) Validate() error { + if m.TLS == tlsEnable && (m.Cert == "" || m.Key == "") { + return errors.New("both cert and key are required when TLS is enabled") + } + // Handle port in ClusterIPAddress splitVal := strings.Split(m.ClusterIPAddress, ":") if len(splitVal) == 2 { @@ -101,11 +105,6 @@ func parseCassandraMetadata(config *scalersconfig.ScalerConfig) (cassandraMetada meta.TargetQueryValue = 0 } - err = meta.Validate() - if err != nil { - return meta, err - } - err = parseCassandraTLS(&meta) if err != nil { return meta, err @@ -139,10 +138,6 @@ func createTempFile(prefix string, content string) (string, error) { func parseCassandraTLS(meta *cassandraMetadata) error { if meta.TLS == tlsEnable { - if meta.Cert == "" || meta.Key == "" { - return errors.New("both cert and key are required when TLS is enabled") - } - // Create temp files for certs certFilePath, err := createTempFile("cert", meta.Cert) if err != nil {