Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor postgresql scaler config #6262

Merged
merged 1 commit into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 63 additions & 95 deletions pkg/scalers/postgresql_scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -42,12 +41,46 @@ type postgreSQLScaler struct {
}

type postgreSQLMetadata struct {
targetQueryValue float64
activationTargetQueryValue float64
connection string
query string
TargetQueryValue float64 `keda:"name=targetQueryValue, order=triggerMetadata, optional"`
ActivationTargetQueryValue float64 `keda:"name=activationTargetQueryValue, order=triggerMetadata, optional"`
Connection string `keda:"name=connection, order=authParams;resolvedEnv, optional"`
Query string `keda:"name=query, order=triggerMetadata"`
triggerIndex int
azureAuthContext azureAuthContext

Host string `keda:"name=host, order=authParams;triggerMetadata, optional"`
Port string `keda:"name=port, order=authParams;triggerMetadata, optional"`
UserName string `keda:"name=userName, order=authParams;triggerMetadata, optional"`
DBName string `keda:"name=dbName, order=authParams;triggerMetadata, optional"`
SslMode string `keda:"name=sslmode, order=authParams;triggerMetadata, optional"`

Password string `keda:"name=password, order=authParams;resolvedEnv, optional"`
}

func (p *postgreSQLMetadata) Validate() error {
if p.Connection == "" {
wozniakjan marked this conversation as resolved.
Show resolved Hide resolved
if p.Host == "" {
return fmt.Errorf("no host given")
}

if p.Port == "" {
return fmt.Errorf("no port given")
}

if p.UserName == "" {
return fmt.Errorf("no userName given")
}

if p.DBName == "" {
return fmt.Errorf("no dbName given")
}

if p.SslMode == "" {
return fmt.Errorf("no sslmode given")
}
}

return nil
}

type azureAuthContext struct {
Expand Down Expand Up @@ -83,66 +116,26 @@ func NewPostgreSQLScaler(ctx context.Context, config *scalersconfig.ScalerConfig
}

func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerConfig) (*postgreSQLMetadata, kedav1alpha1.AuthPodIdentity, error) {
meta := postgreSQLMetadata{}

meta := &postgreSQLMetadata{}
authPodIdentity := kedav1alpha1.AuthPodIdentity{}

if val, ok := config.TriggerMetadata["query"]; ok {
meta.query = val
} else {
return nil, authPodIdentity, fmt.Errorf("no query given")
}

if val, ok := config.TriggerMetadata["targetQueryValue"]; ok {
targetQueryValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("queryValue parsing error %w", err)
}
meta.targetQueryValue = targetQueryValue
} else {
if config.AsMetricSource {
meta.targetQueryValue = 0
} else {
return nil, authPodIdentity, fmt.Errorf("no targetQueryValue given")
}
meta.triggerIndex = config.TriggerIndex
if err := config.TypedConfig(meta); err != nil {
return nil, authPodIdentity, fmt.Errorf("error parsing postgresql metadata: %w", err)
}

meta.activationTargetQueryValue = 0
if val, ok := config.TriggerMetadata["activationTargetQueryValue"]; ok {
activationTargetQueryValue, err := strconv.ParseFloat(val, 64)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("activationTargetQueryValue parsing error %w", err)
}
meta.activationTargetQueryValue = activationTargetQueryValue
if !config.AsMetricSource && meta.TargetQueryValue == 0 {
wozniakjan marked this conversation as resolved.
Show resolved Hide resolved
return nil, authPodIdentity, fmt.Errorf("no targetQueryValue given")
}

switch config.PodIdentity.Provider {
case "", kedav1alpha1.PodIdentityProviderNone:
switch {
case config.AuthParams["connection"] != "":
meta.connection = config.AuthParams["connection"]
case config.TriggerMetadata["connectionFromEnv"] != "":
meta.connection = config.ResolvedEnv[config.TriggerMetadata["connectionFromEnv"]]
default:
params, err := buildConnArray(config)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("failed to parse fields related to the connection")
}

var password string
if config.AuthParams["password"] != "" {
password = config.AuthParams["password"]
} else if config.TriggerMetadata["passwordFromEnv"] != "" {
password = config.ResolvedEnv[config.TriggerMetadata["passwordFromEnv"]]
}
params = append(params, "password="+escapePostgreConnectionParameter(password))
meta.connection = strings.Join(params, " ")
if meta.Connection == "" {
params := buildConnArray(meta)
params = append(params, "password="+escapePostgreConnectionParameter(meta.Password))
meta.Connection = strings.Join(params, " ")
}
case kedav1alpha1.PodIdentityProviderAzureWorkload:
params, err := buildConnArray(config)
if err != nil {
return nil, authPodIdentity, fmt.Errorf("failed to parse fields related to the connection")
}
params := buildConnArray(meta)

cred, err := azure.NewChainedCredential(logger, config.PodIdentity)
if err != nil {
Expand All @@ -152,59 +145,34 @@ func parsePostgreSQLMetadata(logger logr.Logger, config *scalersconfig.ScalerCon
authPodIdentity = kedav1alpha1.AuthPodIdentity{Provider: config.PodIdentity.Provider}

params = append(params, "%PASSWORD%")
meta.connection = strings.Join(params, " ")
meta.Connection = strings.Join(params, " ")
}
meta.triggerIndex = config.TriggerIndex

return &meta, authPodIdentity, nil
return meta, authPodIdentity, nil
}

func buildConnArray(config *scalersconfig.ScalerConfig) ([]string, error) {
func buildConnArray(meta *postgreSQLMetadata) []string {
var params []string
params = append(params, "host="+escapePostgreConnectionParameter(meta.Host))
params = append(params, "port="+escapePostgreConnectionParameter(meta.Port))
params = append(params, "user="+escapePostgreConnectionParameter(meta.UserName))
params = append(params, "dbname="+escapePostgreConnectionParameter(meta.DBName))
params = append(params, "sslmode="+escapePostgreConnectionParameter(meta.SslMode))

host, err := GetFromAuthOrMeta(config, "host")
if err != nil {
return nil, err
}

port, err := GetFromAuthOrMeta(config, "port")
if err != nil {
return nil, err
}

userName, err := GetFromAuthOrMeta(config, "userName")
if err != nil {
return nil, err
}

dbName, err := GetFromAuthOrMeta(config, "dbName")
if err != nil {
return nil, err
}

sslmode, err := GetFromAuthOrMeta(config, "sslmode")
if err != nil {
return nil, err
}
params = append(params, "host="+escapePostgreConnectionParameter(host))
params = append(params, "port="+escapePostgreConnectionParameter(port))
params = append(params, "user="+escapePostgreConnectionParameter(userName))
params = append(params, "dbname="+escapePostgreConnectionParameter(dbName))
params = append(params, "sslmode="+escapePostgreConnectionParameter(sslmode))

return params, nil
return params
}

func getConnection(ctx context.Context, meta *postgreSQLMetadata, podIdentity kedav1alpha1.AuthPodIdentity, logger logr.Logger) (*sql.DB, error) {
connectionString := meta.connection
connectionString := meta.Connection

if podIdentity.Provider == kedav1alpha1.PodIdentityProviderAzureWorkload {
accessToken, err := getAzureAccessToken(ctx, meta, azureDatabasePostgresResource)
if err != nil {
return nil, err
}
newPasswordField := "password=" + escapePostgreConnectionParameter(accessToken)
connectionString = passwordConnPattern.ReplaceAllString(meta.connection, newPasswordField)
connectionString = passwordConnPattern.ReplaceAllString(meta.Connection, newPasswordField)
}

db, err := sql.Open("pgx", connectionString)
Expand Down Expand Up @@ -245,7 +213,7 @@ func (s *postgreSQLScaler) getActiveNumber(ctx context.Context) (float64, error)
}
}

err := s.connection.QueryRowContext(ctx, s.metadata.query).Scan(&id)
err := s.connection.QueryRowContext(ctx, s.metadata.Query).Scan(&id)
if err != nil {
s.logger.Error(err, fmt.Sprintf("could not query postgreSQL: %s", err))
return 0, fmt.Errorf("could not query postgreSQL: %w", err)
Expand All @@ -259,7 +227,7 @@ func (s *postgreSQLScaler) GetMetricSpecForScaling(context.Context) []v2.MetricS
Metric: v2.MetricIdentifier{
Name: GenerateMetricNameWithIndex(s.metadata.triggerIndex, kedautil.NormalizeString("postgresql")),
},
Target: GetMetricTargetMili(s.metricType, s.metadata.targetQueryValue),
Target: GetMetricTargetMili(s.metricType, s.metadata.TargetQueryValue),
}
metricSpec := v2.MetricSpec{
External: externalMetric, Type: externalMetricType,
Expand All @@ -276,7 +244,7 @@ func (s *postgreSQLScaler) GetMetricsAndActivity(ctx context.Context, metricName

metric := GenerateMetricInMili(metricName, num)

return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.activationTargetQueryValue, nil
return []external_metrics.ExternalMetricValue{metric}, num > s.metadata.ActivationTargetQueryValue, nil
}

func escapePostgreConnectionParameter(str string) string {
Expand Down
8 changes: 4 additions & 4 deletions pkg/scalers/postgresql_scaler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ func TestPosgresSQLConnectionStringGeneration(t *testing.T) {
t.Fatal("Could not parse metadata:", err)
}

if meta.connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.connection)
if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}
Expand All @@ -104,8 +104,8 @@ func TestPodIdentityAzureWorkloadPosgresSQLConnectionStringGeneration(t *testing
t.Fatal("Could not parse metadata:", err)
}

if meta.connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.connection)
if meta.Connection != testData.connectionString {
t.Errorf("Error generating connectionString, expected '%s' and get '%s'", testData.connectionString, meta.Connection)
}
}
}
Expand Down
Loading