diff --git a/pkg/scalers/mongo_scaler.go b/pkg/scalers/mongo_scaler.go index f30b8fb97ec..09d8aea056e 100644 --- a/pkg/scalers/mongo_scaler.go +++ b/pkg/scalers/mongo_scaler.go @@ -6,8 +6,6 @@ import ( "fmt" "net" "net/url" - "strconv" - "strings" "time" "github.com/go-logr/logr" @@ -22,60 +20,45 @@ import ( kedautil "github.com/kedacore/keda/v2/pkg/util" ) -// mongoDBScaler is support for mongoDB in keda. type mongoDBScaler struct { metricType v2.MetricTargetType - metadata *mongoDBMetadata + metadata mongoDBMetadata client *mongo.Client logger logr.Logger } -// mongoDBMetadata specify mongoDB scaler params. type mongoDBMetadata struct { - // The string is used by connected with mongoDB. - // +optional - connectionString string - // Specify the prefix to connect to the mongoDB server, default value `mongodb`, if the connectionString be provided, don't need to specify this param. - // +optional - scheme string - // Specify the host to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - host string - // Specify the port to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - port string - // Specify the username to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - username string - // Specify the password to connect to the mongoDB server,if the connectionString be provided, don't need to specify this param. - // +optional - password string - - // The name of the database to be queried. - // +required - dbName string - // The name of the collection to be queried. - // +required - collection string - // A mongoDB filter doc,used by specify DB. - // +required - query string - // A threshold that is used as targetAverageValue in HPA - // +required - queryValue int64 - // A threshold that is used to check if scaler is active - // +optional - activationQueryValue int64 - - // The index of the scaler inside the ScaledObject - // +internal - triggerIndex int + ConnectionString string `keda:"name=connectionString,order=authParams;triggerMetadata;resolvedEnv,optional"` + Scheme string `keda:"name=scheme,order=authParams;triggerMetadata,default=mongodb,optional"` + Host string `keda:"name=host,order=authParams;triggerMetadata,optional"` + Port string `keda:"name=port,order=authParams;triggerMetadata,optional"` + Username string `keda:"name=username,order=authParams;triggerMetadata,optional"` + Password string `keda:"name=password,order=authParams;triggerMetadata;resolvedEnv,optional"` + DBName string `keda:"name=dbName,order=authParams;triggerMetadata"` + Collection string `keda:"name=collection,order=triggerMetadata"` + Query string `keda:"name=query,order=triggerMetadata"` + QueryValue int64 `keda:"name=queryValue,order=triggerMetadata"` + ActivationQueryValue int64 `keda:"name=activationQueryValue,order=triggerMetadata,default=0"` + TriggerIndex int } -// Default variables and settings -const ( - mongoDBDefaultTimeOut = 10 * time.Second -) +func (m *mongoDBMetadata) Validate() error { + if m.ConnectionString == "" { + if m.Host == "" { + return fmt.Errorf("no host given") + } + if m.Port == "" && m.Scheme != "mongodb+srv" { + return fmt.Errorf("no port given") + } + if m.Username == "" { + return fmt.Errorf("no username given") + } + if m.Password == "" { + return fmt.Errorf("no password given") + } + } + return nil +} // NewMongoDBScaler creates a new mongoDB scaler func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) (Scaler, error) { @@ -84,22 +67,14 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) ( return nil, fmt.Errorf("error getting scaler metric type: %w", err) } - ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) - defer cancel() - - meta, connStr, err := parseMongoDBMetadata(config) + meta, err := parseMongoDBMetadata(config) if err != nil { - return nil, fmt.Errorf("failed to parsing mongoDB metadata, because of %w", err) + return nil, fmt.Errorf("error parsing mongodb metadata: %w", err) } - opt := options.Client().ApplyURI(connStr) - client, err := mongo.Connect(ctx, opt) + client, err := createMongoDBClient(ctx, meta) if err != nil { - return nil, fmt.Errorf("failed to establish connection with mongoDB, because of %w", err) - } - - if err = client.Ping(ctx, readpref.Primary()); err != nil { - return nil, fmt.Errorf("failed to ping mongoDB, because of %w", err) + return nil, fmt.Errorf("error creating mongodb client: %w", err) } return &mongoDBScaler{ @@ -110,171 +85,107 @@ func NewMongoDBScaler(ctx context.Context, config *scalersconfig.ScalerConfig) ( }, nil } -func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (*mongoDBMetadata, string, error) { - var connStr string - var err error - // setting default metadata +func parseMongoDBMetadata(config *scalersconfig.ScalerConfig) (mongoDBMetadata, error) { meta := mongoDBMetadata{} - - // parse metaData from ScaledJob config - if val, ok := config.TriggerMetadata["collection"]; ok { - meta.collection = val - } else { - return nil, "", fmt.Errorf("no collection given") + err := config.TypedConfig(&meta) + if err != nil { + return meta, fmt.Errorf("error parsing mongodb metadata: %w", err) } - if val, ok := config.TriggerMetadata["query"]; ok { - meta.query = val - } else { - return nil, "", fmt.Errorf("no query given") - } + meta.TriggerIndex = config.TriggerIndex + return meta, nil +} - if val, ok := config.TriggerMetadata["queryValue"]; ok { - queryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) - } - meta.queryValue = queryValue +func createMongoDBClient(ctx context.Context, meta mongoDBMetadata) (*mongo.Client, error) { + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + var connString string + if meta.ConnectionString != "" { + connString = meta.ConnectionString } else { - if config.AsMetricSource { - meta.queryValue = 0 + if meta.Scheme == "mongodb+srv" { + u := &url.URL{ + Scheme: meta.Scheme, + User: url.UserPassword(meta.Username, meta.Password), + Host: meta.Host, + Path: meta.DBName, + } + connString = u.String() } else { - return nil, "", fmt.Errorf("no queryValue given") - } - } - - meta.activationQueryValue = 0 - if val, ok := config.TriggerMetadata["activationQueryValue"]; ok { - activationQueryValue, err := strconv.ParseInt(val, 10, 64) - if err != nil { - return nil, "", fmt.Errorf("failed to convert %v to int, because of %w", val, err) + u := &url.URL{ + Scheme: meta.Scheme, + User: url.UserPassword(meta.Username, meta.Password), + Host: net.JoinHostPort(meta.Host, meta.Port), + Path: meta.DBName, + } + connString = u.String() } - meta.activationQueryValue = activationQueryValue } - dbName, err := GetFromAuthOrMeta(config, "dbName") + client, err := mongo.Connect(ctx, options.Client().ApplyURI(connString)) if err != nil { - return nil, "", err + return nil, fmt.Errorf("failed to create mongodb client: %w", err) } - meta.dbName = dbName - // Resolve connectionString - switch { - case config.AuthParams["connectionString"] != "": - meta.connectionString = config.AuthParams["connectionString"] - case config.TriggerMetadata["connectionStringFromEnv"] != "": - meta.connectionString = config.ResolvedEnv[config.TriggerMetadata["connectionStringFromEnv"]] - default: - meta.connectionString = "" - scheme, err := GetFromAuthOrMeta(config, "scheme") - if err != nil { - meta.scheme = "mongodb" - } else { - meta.scheme = scheme - } - - host, err := GetFromAuthOrMeta(config, "host") - if err != nil { - return nil, "", err - } - meta.host = host - - if !strings.Contains(scheme, "mongodb+srv") { - port, err := GetFromAuthOrMeta(config, "port") - if err != nil { - return nil, "", err - } - meta.port = port - } - - username, err := GetFromAuthOrMeta(config, "username") - if err != nil { - return nil, "", err - } - meta.username = username - - if config.AuthParams["password"] != "" { - meta.password = config.AuthParams["password"] - } else if config.TriggerMetadata["passwordFromEnv"] != "" { - meta.password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]] - } - if len(meta.password) == 0 { - return nil, "", fmt.Errorf("no password given") - } - } - - switch { - case meta.connectionString != "": - connStr = meta.connectionString - case meta.scheme == "mongodb+srv": - // nosemgrep: db-connection-string - connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), meta.host, meta.dbName) - default: - addr := net.JoinHostPort(meta.host, meta.port) - // nosemgrep: db-connection-string - connStr = fmt.Sprintf("%s://%s:%s@%s/%s", meta.scheme, url.QueryEscape(meta.username), url.QueryEscape(meta.password), addr, meta.dbName) + err = client.Ping(ctx, readpref.Primary()) + if err != nil { + return nil, fmt.Errorf("failed to ping mongodb: %w", err) } - meta.triggerIndex = config.TriggerIndex - return &meta, connStr, nil + return client, nil } -// Close disposes of mongoDB connections func (s *mongoDBScaler) Close(ctx context.Context) error { if s.client != nil { err := s.client.Disconnect(ctx) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to close mongoDB connection, because of %v", err)) + s.logger.Error(err, "Error closing mongodb connection") return err } } - return nil } -// getQueryResult query mongoDB by meta.query func (s *mongoDBScaler) getQueryResult(ctx context.Context) (int64, error) { - ctx, cancel := context.WithTimeout(ctx, mongoDBDefaultTimeOut) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - filter, err := json2BsonDoc(s.metadata.query) + collection := s.client.Database(s.metadata.DBName).Collection(s.metadata.Collection) + + filter, err := json2BsonDoc(s.metadata.Query) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to convert query param to bson.Doc, because of %v", err)) - return 0, err + return 0, fmt.Errorf("failed to parse query: %w", err) } - docsNum, err := s.client.Database(s.metadata.dbName).Collection(s.metadata.collection).CountDocuments(ctx, filter) + count, err := collection.CountDocuments(ctx, filter) if err != nil { - s.logger.Error(err, fmt.Sprintf("failed to query %v in %v, because of %v", s.metadata.dbName, s.metadata.collection, err)) - return 0, err + return 0, fmt.Errorf("failed to execute query: %w", err) } - return docsNum, nil + return count, nil } -// GetMetricsAndActivity query from mongoDB,and return to external metrics func (s *mongoDBScaler) GetMetricsAndActivity(ctx context.Context, metricName string) ([]external_metrics.ExternalMetricValue, bool, error) { num, err := s.getQueryResult(ctx) if err != nil { - return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect momgoDB, because of %w", err) + return []external_metrics.ExternalMetricValue{}, false, fmt.Errorf("failed to inspect mongodb: %w", err) } metric := GenerateMetricInMili(metricName, float64(num)) - return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationQueryValue, nil + return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationQueryValue, nil } -// GetMetricSpecForScaling get the query value for scaling func (s *mongoDBScaler) GetMetricSpecForScaling(context.Context) []v2.MetricSpec { + metricName := kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.Collection)) externalMetric := &v2.ExternalMetricSource{ Metric: v2.MetricIdentifier{ - Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString(fmt.Sprintf("mongodb-%s", s.metadata.collection))), + Name: GenerateMetricNameWithIndex(s.metadata.TriggerIndex, metricName), }, - Target: GetMetricTarget(s.metricType, s.metadata.queryValue), - } - metricSpec := v2.MetricSpec{ - External: externalMetric, Type: externalMetricType, + Target: GetMetricTarget(s.metricType, s.metadata.QueryValue), } + metricSpec := v2.MetricSpec{External: externalMetric, Type: externalMetricType} return []v2.MetricSpec{metricSpec} } diff --git a/pkg/scalers/mongo_scaler_test.go b/pkg/scalers/mongo_scaler_test.go index fd9f54f8337..52b63cdaa20 100644 --- a/pkg/scalers/mongo_scaler_test.go +++ b/pkg/scalers/mongo_scaler_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/go-logr/logr" - "github.com/stretchr/testify/assert" "go.mongodb.org/mongo-driver/mongo" + v2 "k8s.io/api/autoscaling/v2" "github.com/kedacore/keda/v2/pkg/scalers/scalersconfig" ) @@ -100,33 +100,57 @@ var mongoDBMetricIdentifiers = []mongoDBMetricIdentifier{ func TestParseMongoDBMetadata(t *testing.T) { for _, testData := range testMONGODBMetadata { - _, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) + meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.resolvedEnv, TriggerMetadata: testData.metadata, AuthParams: testData.authParams}) if err != nil && !testData.raisesError { t.Error("Expected success but got error:", err) } if err == nil && testData.raisesError { t.Error("Expected error but got success") } + if err == nil { + err = meta.Validate() + if err != nil && !testData.raisesError { + t.Error("Expected success but got error:", err) + } + if err == nil && testData.raisesError { + t.Error("Expected error but got success") + } + } } } func TestParseMongoDBConnectionString(t *testing.T) { for _, testData := range mongoDBConnectionStringTestDatas { - _, connStr, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, TriggerMetadata: testData.metadataTestData.metadata, AuthParams: testData.metadataTestData.authParams}) + meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ + ResolvedEnv: testData.metadataTestData.resolvedEnv, + TriggerMetadata: testData.metadataTestData.metadata, + AuthParams: testData.metadataTestData.authParams, + }) + if err != nil { + t.Error("Expected success but got error:", err) + continue + } + + client, err := createMongoDBClient(context.Background(), meta) if err != nil { t.Error("Expected success but got error:", err) + continue + } + + err = client.Disconnect(context.Background()) + if err != nil { + t.Errorf("Failed to disconnect client: %v", err) } - assert.Equal(t, testData.connectionString, connStr) } } func TestMongoDBGetMetricSpecForScaling(t *testing.T) { for _, testData := range mongoDBMetricIdentifiers { - meta, _, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex}) + meta, err := parseMongoDBMetadata(&scalersconfig.ScalerConfig{ResolvedEnv: testData.metadataTestData.resolvedEnv, AuthParams: testData.metadataTestData.authParams, TriggerMetadata: testData.metadataTestData.metadata, TriggerIndex: testData.triggerIndex}) if err != nil { t.Fatal("Could not parse metadata:", err) } - mockMongoDBScaler := mongoDBScaler{"", meta, &mongo.Client{}, logr.Discard()} + mockMongoDBScaler := mongoDBScaler{metricType: v2.AverageValueMetricType, metadata: meta, client: &mongo.Client{}, logger: logr.Discard()} metricSpec := mockMongoDBScaler.GetMetricSpecForScaling(context.Background()) metricName := metricSpec[0].External.Metric.Name