From effed8620953ff3c236933a5fd0f083dfb7250e0 Mon Sep 17 00:00:00 2001 From: William Alexander Date: Wed, 1 Jul 2020 14:43:08 -0400 Subject: [PATCH] YAML configuration support (#44) * Add basic YAML configuration support * Redo multi-document support and get tests running * Fix incorrect config filename in tests * Get tests passing * Add deprecation message to get-* commands * Remove set deprecation messages from get commands * Get tests working with deprecated commands * Don't apply global rate limit during Jails test * Add full support for new config format * Add changes to CLI * Update unit tests to use new Deprecated method names * Finish updating unit tests to use Deprecated method names * Remove wayward Println * Add JSON tags and flags to indicate if deprecated fields are in use * Add deprecated to main.go function names * Fix incorrect Kind in e2e jail config YAML * Add deletion tests for new CLI * Persist useDeprecated flags to Redis * Separate config files per test * Fix deprecated tests using wrong applyGuardianConfig() * Set GlobalRateLimit and GlobalSettings for deprecated Set* tests * Add GlobalRateLimit and GlobalSettings to more tests * Avoid running E2E tests in parallel * Update CLI help and fix get command bug Added a deprecated note to the CLI help for the deprecated CLI interface. Fixed Fetch*Configs() initializing slice with len instead of cap * Implement easy CR changes and add working stdin `apply` * Avoid using opaque structs as map keys * Make unit tests use paths instead of URLs * Fix get RateLimit not printing * Use distinct YAML spec keys to simplify decoding * Fix incorrect spec key for TestDeleteRateLimit * Cycle through available Redis DBs when testing * Remove deprecated RedisConfStore methods and fields and update unit tests * Readability/style changes for CR * Fix incorrect metric reporting method Co-authored-by: Will Alexander --- cmd/guardian-cli/main.go | 490 +++++++--- .../TestBlacklist/globalratelimitconfig.yml | 9 + .../TestBlacklist/globalsettingsconfig.yml | 6 + .../TestDeleteJail/globalratelimitconfig.yml | 9 + .../TestDeleteJail/globalsettingsconfig.yml | 6 + e2e/config/TestDeleteJail/jailconfig.yml | 25 + .../globalratelimitconfig.yml | 9 + .../globalsettingsconfig.yml | 6 + .../TestDeleteRateLimit/ratelimitconfig.yml | 23 + .../globalratelimitconfig.yml | 9 + .../globalsettingsconfig.yml | 6 + .../TestJails/globalratelimitconfig.yml | 9 + e2e/config/TestJails/globalsettingsconfig.yml | 6 + e2e/config/TestJails/jailconfig.yml | 25 + .../{ => TestJailsDeprecated}/jailconfig.yml | 4 +- .../TestRateLimit/globalratelimitconfig.yml | 9 + .../TestRateLimit/globalsettingsconfig.yml | 6 + e2e/config/TestRateLimit/ratelimitconfig.yml | 23 + .../TestRemoveJailDeprecated/jailconfig.yml | 15 + .../routeratelimitconfig.yml | 11 + .../routeratelimitconfig.yml | 11 + .../TestSetJailsDeprecated/jailconfig.yml | 15 + .../routeratelimitconfig.yml | 11 + .../TestWhitelist/globalratelimitconfig.yml | 9 + .../TestWhitelist/globalsettingsconfig.yml | 6 + e2e/config/routeratelimitconfig.yml | 11 - e2e/e2e_test.go | 580 +++++++++--- e2e/scripts/circleci-e2e-runner-docker.sh | 3 +- e2e/scripts/run-e2e.sh | 2 +- pkg/guardian/config.go | 117 +++ pkg/guardian/guardian_test.go | 32 +- pkg/guardian/jail.go | 7 +- pkg/guardian/redis_conf_store.go | 837 ++++++++---------- pkg/guardian/redis_conf_store_test.go | 543 ++++++++---- pkg/guardian/routeratelimit_test.go | 79 +- vendor/gopkg.in/yaml.v2/go.mod | 8 +- vendor/gopkg.in/yaml.v2/go.sum | 1 + 37 files changed, 2026 insertions(+), 952 deletions(-) create mode 100644 e2e/config/TestBlacklist/globalratelimitconfig.yml create mode 100644 e2e/config/TestBlacklist/globalsettingsconfig.yml create mode 100644 e2e/config/TestDeleteJail/globalratelimitconfig.yml create mode 100644 e2e/config/TestDeleteJail/globalsettingsconfig.yml create mode 100644 e2e/config/TestDeleteJail/jailconfig.yml create mode 100644 e2e/config/TestDeleteRateLimit/globalratelimitconfig.yml create mode 100644 e2e/config/TestDeleteRateLimit/globalsettingsconfig.yml create mode 100644 e2e/config/TestDeleteRateLimit/ratelimitconfig.yml create mode 100644 e2e/config/TestGlobalRateLimit/globalratelimitconfig.yml create mode 100644 e2e/config/TestGlobalRateLimit/globalsettingsconfig.yml create mode 100644 e2e/config/TestJails/globalratelimitconfig.yml create mode 100644 e2e/config/TestJails/globalsettingsconfig.yml create mode 100644 e2e/config/TestJails/jailconfig.yml rename e2e/config/{ => TestJailsDeprecated}/jailconfig.yml (60%) create mode 100644 e2e/config/TestRateLimit/globalratelimitconfig.yml create mode 100644 e2e/config/TestRateLimit/globalsettingsconfig.yml create mode 100644 e2e/config/TestRateLimit/ratelimitconfig.yml create mode 100644 e2e/config/TestRemoveJailDeprecated/jailconfig.yml create mode 100644 e2e/config/TestRemoveRouteRateLimitsDeprecated/routeratelimitconfig.yml create mode 100644 e2e/config/TestRouteRateLimitDeprecated/routeratelimitconfig.yml create mode 100644 e2e/config/TestSetJailsDeprecated/jailconfig.yml create mode 100644 e2e/config/TestSetRouteRateLimitsDeprecated/routeratelimitconfig.yml create mode 100644 e2e/config/TestWhitelist/globalratelimitconfig.yml create mode 100644 e2e/config/TestWhitelist/globalsettingsconfig.yml delete mode 100644 e2e/config/routeratelimitconfig.yml create mode 100644 pkg/guardian/config.go create mode 100644 vendor/gopkg.in/yaml.v2/go.sum diff --git a/cmd/guardian-cli/main.go b/cmd/guardian-cli/main.go index 652a1f4..edbb54d 100644 --- a/cmd/guardian-cli/main.go +++ b/cmd/guardian-cli/main.go @@ -2,9 +2,9 @@ package main import ( "fmt" + "io" "io/ioutil" "net" - "net/url" "os" "strings" @@ -21,6 +21,19 @@ func main() { logLevel := app.Flag("log-level", "log level.").Short('l').Default("error").OverrideDefaultFromEnvar("LOG_LEVEL").String() redisAddress := app.Flag("redis-address", "host:port.").Short('r').OverrideDefaultFromEnvar("REDIS_ADDRESS").Required().String() + // Configuration + applyCmd := app.Command("apply", "Apply configuration resources from a YAML file") + applyConfigFilePaths := applyCmd.Arg("config-file", "Path to configuration file (if omitted, configuration is read from stdin)").Strings() + + // Getting configuration data + getCmd := app.Command("get", "Get configuration resources of a certain kind") + getConfigKind := getCmd.Arg("kind", "kind of resource").Required().String() + + // Removing configuration data + deleteCmd := app.Command("delete", "Delete configuration resources") + deleteConfigKind := deleteCmd.Arg("kind", "kind of resource").Required().String() + deleteConfigName := deleteCmd.Arg("name", "name of resource").Required().String() + // Whitelisting addWhitelistCmd := app.Command("add-whitelist", "Add CIDRs to the IP Whitelist") addCidrStrings := addWhitelistCmd.Arg("cidr", "CIDR").Required().Strings() @@ -39,36 +52,36 @@ func main() { getBlacklistCmd := app.Command("get-blacklist", "Get blacklisted CIDRs") - // Rate limiting - setLimitCmd := app.Command("set-limit", "Sets the IP rate limit") + // Rate limiting (deprecated CLI) + setLimitCmd := app.Command("set-limit", "Sets the IP rate limit (deprecated)") limitCount := setLimitCmd.Arg("count", "limit count").Required().Uint64() limitDuration := setLimitCmd.Arg("duration", "limit duration").Required().Duration() limitEnabled := setLimitCmd.Arg("enabled", "limit enabled").Required().Bool() - getLimitCmd := app.Command("get-limit", "Gets the IP rate limit") + getLimitCmd := app.Command("get-limit", "Gets the IP rate limit (deprecated)") - // Route rate limitting - setRouteRateLimitsCmd := app.Command("set-route-rate-limits", "Sets rate limits for provided routes") + // Route rate limiting (deprecated CLI) + setRouteRateLimitsCmd := app.Command("set-route-rate-limits", "Sets rate limits for provided routes (deprecated)") configFilePath := setRouteRateLimitsCmd.Arg("route-rate-limit-config-file", "path to configuration file").Required().String() - removeRouteRateLimitsCmd := app.Command("remove-route-rate-limits", "Removes rate limits for provided routes") + removeRouteRateLimitsCmd := app.Command("remove-route-rate-limits", "Removes rate limits for provided routes (deprecated)") removeRouteRateLimitStrings := removeRouteRateLimitsCmd.Arg("routes", "Comma seperated list of routes to remove").Required().String() - getRouteRateLimitsCmd := app.Command("get-route-rate-limits", "Gets the IP rate limits for each route") + getRouteRateLimitsCmd := app.Command("get-route-rate-limits", "Gets the IP rate limits for each route (deprecated)") - // Jails - setJailsCmd := app.Command("set-jails", "Sets rate limits for provided routes") + // Jails (deprecated CLI) + setJailsCmd := app.Command("set-jails", "Sets rate limits for provided routes (deprecated)") jailsConfigFilePath := setJailsCmd.Arg("jail-config-file", "Path to configuration file").Required().String() - removeJailsCmd := app.Command("remove-jails", "Removes rate limits for provided routes") + removeJailsCmd := app.Command("remove-jails", "Removes rate limits for provided routes (deprecated)") removeJailsArgs := removeJailsCmd.Arg("jail-routes", "Comma separated list of jails to remove. Use the name of the route").Required().String() - getJailsCmd := app.Command("get-jails", "Lists all of the jails") + getJailsCmd := app.Command("get-jails", "Lists all of the jails (deprecated)") getPrisonersCmd := app.Command("get-prisoners", "List all prisoners") removePrisonersCmd := app.Command("remove-prisoners", "Removes prisoners from") prisoners := removePrisonersCmd.Arg("prisoners", "Comma separated list of ip address to remove").Required().String() - // Report Only - setReportOnlyCmd := app.Command("set-report-only", "Sets the report only flag") + // Report Only (deprecated CLI) + setReportOnlyCmd := app.Command("set-report-only", "Sets the report only flag (deprecated)") reportOnly := setReportOnlyCmd.Arg("report-only", "report only enabled").Required().Bool() - getReportOnlyCmd := app.Command("get-report-only", "Gets the report only flag") + getReportOnlyCmd := app.Command("get-report-only", "Gets the report only flag (deprecated)") selectedCmd := kingpin.MustParse(app.Parse(os.Args[1:])) redisOpts := &redis.Options{Addr: *redisAddress} @@ -86,24 +99,35 @@ func main() { logger.SetLevel(level) switch selectedCmd { + case applyCmd.FullCommand(): + err := applyConfigs(redisConfStore, *applyConfigFilePaths, logger) + if err != nil { + fatalerror(fmt.Errorf("error applying configuration: %v", err)) + } + case getCmd.FullCommand(): + err := getConfig(redisConfStore, *getConfigKind, logger) + if err != nil { + fatalerror(fmt.Errorf("error getting configuration: %v", err)) + } + case deleteCmd.FullCommand(): + err := deleteConfig(redisConfStore, *deleteConfigKind, *deleteConfigName, logger) + if err != nil { + fatalerror(fmt.Errorf("error deleting configuration: %v", err)) + } case addWhitelistCmd.FullCommand(): err := addWhitelist(redisConfStore, *addCidrStrings, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error adding CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error adding CIDRS: %v", err)) } - case removeWhitelistCmd.FullCommand(): err := removeWhitelist(redisConfStore, *removeCidrStrings, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error removing CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error removing CIDRS: %v", err)) } case getWhitelistCmd.FullCommand(): whitelist, err := getWhitelist(redisConfStore, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error listing CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error listing CIDRS: %v", err)) } for _, cidr := range whitelist { @@ -112,99 +136,99 @@ func main() { case addBlacklistCmd.FullCommand(): err := addBlacklist(redisConfStore, *addBlacklistCidrStrings, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error adding CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error adding CIDRS: %v", err)) } case removeBlacklistCmd.FullCommand(): err := removeBlacklist(redisConfStore, *removeBlacklistCidrStrings, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error removing CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error removing CIDRS: %v", err)) } case getBlacklistCmd.FullCommand(): blacklist, err := getBlacklist(redisConfStore, logger) if err != nil { - fmt.Fprintf(os.Stderr, "error listing CIDRS: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error listing CIDRS: %v", err)) } for _, cidr := range blacklist { fmt.Println(cidr.String()) } case setLimitCmd.FullCommand(): + fmt.Fprintf(os.Stderr, "%s is deprecated: apply a GlobalRateLimit config instead\n", setLimitCmd.FullCommand()) limit := guardian.Limit{Count: *limitCount, Duration: *limitDuration, Enabled: *limitEnabled} - err := setLimit(redisConfStore, limit) + err := setLimitDeprecated(redisConfStore, limit) if err != nil { - fmt.Fprintf(os.Stderr, "error setting limit: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error setting limit: %v", err)) } case getLimitCmd.FullCommand(): - limit, err := getLimit(redisConfStore) + fmt.Fprintf(os.Stderr, "%s is deprecated: get GlobalRateLimit instead\n", getLimitCmd.FullCommand()) + limit, err := getLimitDeprecated(redisConfStore) if err != nil { - fmt.Fprintf(os.Stderr, "error getting limit: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error getting limit: %v", err)) } fmt.Printf("%v\n", limit) + case setRouteRateLimitsCmd.FullCommand(): + fmt.Fprintf(os.Stderr, "%s is deprecated: apply a RateLimit config instead\n", setRouteRateLimitsCmd.FullCommand()) + err := setRouteRateLimitsDeprecated(redisConfStore, *configFilePath) + if err != nil { + fatalerror(fmt.Errorf("error setting route rate limits: %v", err)) + } case getRouteRateLimitsCmd.FullCommand(): - routeRateLimits, err := getRouteRateLimits(redisConfStore) + fmt.Fprintf(os.Stderr, "%s is deprecated: get RateLimit instead\n", getRouteRateLimitsCmd.FullCommand()) + routeRateLimits, err := getRouteRateLimitsDeprecated(redisConfStore) if err != nil { - fmt.Fprintf(os.Stderr, "error getting route rate limits: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error getting route rate limits: %v", err)) } - config := guardian.RouteRateLimitConfig{} - for url, limit := range routeRateLimits { - entry := guardian.RouteRateLimitConfigEntry{ - Route: url.EscapedPath(), + config := guardian.RouteRateLimitConfigDeprecated{} + for path, limit := range routeRateLimits { + entry := guardian.RouteRateLimitConfigEntryDeprecated{ + Route: path, Limit: limit, } - config.RouteRatelimits = append(config.RouteRatelimits, entry) + config.RouteRateLimits = append(config.RouteRateLimits, entry) } configYaml, err := yaml.Marshal(config) if err != nil { fatalerror(fmt.Errorf("error marshaling route limit yaml: %v", err)) } fmt.Println(string(configYaml)) - case setRouteRateLimitsCmd.FullCommand(): - err := setRouteRateLimits(redisConfStore, *configFilePath) - if err != nil { - fatalerror(fmt.Errorf("error setting route rate limits: %v", err)) - } case removeRouteRateLimitsCmd.FullCommand(): - err := removeRouteRateLimits(redisConfStore, *removeRouteRateLimitStrings) + err := removeRouteRateLimitsDeprecated(redisConfStore, *removeRouteRateLimitStrings) if err != nil { fatalerror(fmt.Errorf("error remove route rate limits: %v", err)) } case setReportOnlyCmd.FullCommand(): - err := setReportOnly(redisConfStore, *reportOnly) + fmt.Fprintf(os.Stderr, "%s is deprecated: apply a GlobalSettings config instead", setReportOnlyCmd.FullCommand()) + err := setReportOnlyDeprecated(redisConfStore, *reportOnly) if err != nil { - fmt.Fprintf(os.Stderr, "error setting report only flag: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error setting report only flag: %v", err)) } case getReportOnlyCmd.FullCommand(): - reportOnly, err := getReportOnly(redisConfStore) + fmt.Fprintf(os.Stderr, "%s is deprecated: get GlobalSettings instead\n", setReportOnlyCmd.FullCommand()) + reportOnly, err := getReportOnlyDeprecated(redisConfStore) if err != nil { - fmt.Fprintf(os.Stderr, "error getting report only flag: %v\n", err) - os.Exit(1) + fatalerror(fmt.Errorf("error getting report only flag: %v", err)) } fmt.Println(reportOnly) case setJailsCmd.FullCommand(): - err := setJails(redisConfStore, *jailsConfigFilePath) + fmt.Fprintf(os.Stderr, "%s is deprecated: apply a Jail config instead", setJailsCmd.FullCommand()) + err := setJailsDeprecated(redisConfStore, *jailsConfigFilePath) if err != nil { fatalerror(fmt.Errorf("error setting jails: %v", err)) } case removeJailsCmd.FullCommand(): - err := removeJails(redisConfStore, *removeJailsArgs) + err := removeJailsDeprecated(redisConfStore, *removeJailsArgs) if err != nil { fatalerror(err) } case getJailsCmd.FullCommand(): - jails, err := getJails(redisConfStore) - config := guardian.JailConfig{} - for u, j := range jails { - entry := guardian.JailConfigEntry{ - Route: u.EscapedPath(), - Jail: j, + fmt.Fprintf(os.Stderr, "%s is deprecated: get Jail instead\n", getJailsCmd.FullCommand()) + jails, err := getJailsDeprecated(redisConfStore) + config := guardian.JailConfigDeprecated{} + for path, jail := range jails { + entry := guardian.JailConfigEntryDeprecated{ + Route: path, + Jail: jail, } config.Jails = append(config.Jails, entry) } @@ -222,14 +246,174 @@ func main() { case getPrisonersCmd.FullCommand(): prisoners, err := getPrisoners(redisConfStore) if err != nil { - fatalerror(fmt.Errorf("error fetching prisoners: %v")) + fatalerror(fmt.Errorf("error fetching prisoners: %v", err)) } - prisonersJson, err := yaml.Marshal(prisoners) + prisonersJSON, err := yaml.Marshal(prisoners) if err != nil { fatalerror(fmt.Errorf("error marshaling prisoners: %v", err)) } - fmt.Println(string(prisonersJson)) + fmt.Println(string(prisonersJSON)) + } +} + +func applyConfigFromReader(store *guardian.RedisConfStore, r io.Reader, logger logrus.FieldLogger) error { + dec := yaml.NewDecoder(r) + for { + cfg := guardian.Config{} + err := dec.Decode(&cfg) + if err != nil { + if err == io.EOF { // No more YAML documents to read + break + } + return fmt.Errorf("error decoding yaml: %v", err) + } + switch cfg.Kind { + case guardian.GlobalRateLimitConfigKind: + if cfg.GlobalRateLimitSpec == nil { + return fmt.Errorf("Kind is %v but did not decode a corresponding spec", cfg.Kind) + } + cfg := guardian.GlobalRateLimitConfig{ + ConfigMetadata: cfg.ConfigMetadata, + Spec: *cfg.GlobalRateLimitSpec, + } + if err := applyGlobalRateLimitConfig(store, cfg); err != nil { + return err + } + case guardian.GlobalSettingsConfigKind: + if cfg.GlobalSettingsSpec == nil { + return fmt.Errorf("Kind is %v but did not decode a corresponding spec", cfg.Kind) + } + cfg := guardian.GlobalSettingsConfig{ + ConfigMetadata: cfg.ConfigMetadata, + Spec: *cfg.GlobalSettingsSpec, + } + if err := applyGlobalSettingsConfig(store, cfg); err != nil { + return err + } + case guardian.RateLimitConfigKind: + if cfg.RateLimitSpec == nil { + return fmt.Errorf("Kind is %v but did not decode a corresponding spec", cfg.Kind) + } + cfg := guardian.RateLimitConfig{ + ConfigMetadata: cfg.ConfigMetadata, + Spec: *cfg.RateLimitSpec, + } + if err := applyRateLimitConfig(store, cfg); err != nil { + return err + } + case guardian.JailConfigKind: + if cfg.JailSpec == nil { + return fmt.Errorf("Kind is %v but did not decode a corresponding spec", cfg.Kind) + } + cfg := guardian.JailConfig{ + ConfigMetadata: cfg.ConfigMetadata, + Spec: *cfg.JailSpec, + } + if err := applyJailConfig(store, cfg); err != nil { + return err + } + default: + return fmt.Errorf("unrecognized config file kind: %v", cfg.Kind) + } + } + return nil +} + +func applyConfigs(store *guardian.RedisConfStore, configFilePaths []string, logger logrus.FieldLogger) error { + if len(configFilePaths) == 0 { + err := applyConfigFromReader(store, os.Stdin, logger) + if err != nil { + return fmt.Errorf("error applying config from stdin: %v", err) + } + } + for _, configFilePath := range configFilePaths { + f, err := os.Open(configFilePath) + if err != nil { + return fmt.Errorf("error opening config file: %v", err) + } + defer f.Close() + err = applyConfigFromReader(store, f, logger) + if err != nil { + return fmt.Errorf("error applying config file %v: %v", configFilePath, err) + } + } + return nil +} + +func getConfig(store *guardian.RedisConfStore, configKind string, logger logrus.FieldLogger) error { + switch configKind { + case guardian.GlobalRateLimitConfigKind: + config, err := store.FetchGlobalRateLimitConfig() + if err != nil { + return fmt.Errorf("error getting global rate limit config: %v", err) + } + configYaml, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("error marshaling yaml: %v", err) + } + fmt.Println(string(configYaml)) + case guardian.GlobalSettingsConfigKind: + config, err := store.FetchGlobalSettingsConfig() + if err != nil { + return fmt.Errorf("error getting global settings config: %v", err) + } + configYaml, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("error marshaling yaml: %v", err) + } + fmt.Println(string(configYaml)) + case guardian.RateLimitConfigKind: + configs := store.FetchRateLimitConfigs() + for i, config := range configs { + if i > 0 { + fmt.Println("---") + } + configYaml, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("error marshaling yaml: %v", err) + } + fmt.Println(string(configYaml)) + } + case guardian.JailConfigKind: + configs := store.FetchJailConfigs() + for i, config := range configs { + if i > 0 { + fmt.Println("---") + } + configYaml, err := yaml.Marshal(config) + if err != nil { + return fmt.Errorf("error marshaling yaml: %v", err) + } + fmt.Println(string(configYaml)) + + } + default: + return fmt.Errorf("unrecognized config kind: %v", configKind) + } + return nil +} + +func deleteConfig(store *guardian.RedisConfStore, configKind string, configName string, logger logrus.FieldLogger) error { + switch configKind { + case guardian.RateLimitConfigKind: + err := store.DeleteRateLimitConfig(configName) + if err != nil { + return fmt.Errorf("error deleting rate limit config: %v", err) + } + case guardian.JailConfigKind: + err := store.DeleteJailConfig(configName) + if err != nil { + return fmt.Errorf("error deleting jail config: %v", err) + } + case guardian.GlobalRateLimitConfigKind: + fallthrough + case guardian.GlobalSettingsConfigKind: + return fmt.Errorf("config kind %v does not support deletion", configKind) + default: + return fmt.Errorf("unrecognized config kind: %v", configKind) + } + return nil } func addWhitelist(store *guardian.RedisConfStore, cidrStrings []string, logger logrus.FieldLogger) error { @@ -340,94 +524,180 @@ func convertCIDRStrings(cidrStrings []string) ([]net.IPNet, error) { return cidrs, nil } -func setLimit(store *guardian.RedisConfStore, limit guardian.Limit) error { - return store.SetLimit(limit) +func applyGlobalRateLimitConfig(store *guardian.RedisConfStore, config guardian.GlobalRateLimitConfig) error { + return store.ApplyGlobalRateLimitConfig(config) } -func getLimit(store *guardian.RedisConfStore) (guardian.Limit, error) { - return store.FetchLimit() +func applyGlobalSettingsConfig(store *guardian.RedisConfStore, config guardian.GlobalSettingsConfig) error { + return store.ApplyGlobalSettingsConfig(config) } -func setReportOnly(store *guardian.RedisConfStore, reportOnly bool) error { - return store.SetReportOnly(reportOnly) +func applyRateLimitConfig(store *guardian.RedisConfStore, config guardian.RateLimitConfig) error { + return store.ApplyRateLimitConfig(config) } -func getReportOnly(store *guardian.RedisConfStore) (bool, error) { - return store.FetchReportOnly() +func applyJailConfig(store *guardian.RedisConfStore, config guardian.JailConfig) error { + return store.ApplyJailConfig(config) } -func getRouteRateLimits(store *guardian.RedisConfStore) (map[url.URL]guardian.Limit, error) { - return store.FetchRouteRateLimits() +func setLimitDeprecated(store *guardian.RedisConfStore, limit guardian.Limit) error { + config := guardian.GlobalRateLimitConfig{ + ConfigMetadata: guardian.ConfigMetadata{ + Version: guardian.GlobalRateLimitConfigVersion, + Kind: guardian.GlobalRateLimitConfigKind, + Name: guardian.GlobalRateLimitConfigKind, + Description: "Metadata automatically created from deprecated CLI", + }, + Spec: guardian.GlobalRateLimitSpec{ + Limit: limit, + }, + } + return store.ApplyGlobalRateLimitConfig(config) } -func removeRouteRateLimits(store *guardian.RedisConfStore, routes string) error { - var urls []url.URL - for _, route := range strings.Split(routes, ",") { - unwantedURL, err := url.Parse(route) - if err != nil { - return fmt.Errorf("error parsing route: %v", err) +func getLimitDeprecated(store *guardian.RedisConfStore) (guardian.Limit, error) { + config, err := store.FetchGlobalRateLimitConfig() + if err != nil { + return guardian.Limit{}, err + } + return config.Spec.Limit, nil +} + +func setReportOnlyDeprecated(store *guardian.RedisConfStore, reportOnly bool) error { + config := guardian.GlobalSettingsConfig{ + ConfigMetadata: guardian.ConfigMetadata{ + Version: guardian.GlobalSettingsConfigVersion, + Kind: guardian.GlobalSettingsConfigKind, + Name: guardian.GlobalSettingsConfigKind, + Description: "Metadata automatically created from deprecated CLI", + }, + Spec: guardian.GlobalSettingsSpec{ + ReportOnly: reportOnly, + }, + } + return store.ApplyGlobalSettingsConfig(config) +} + +func getReportOnlyDeprecated(store *guardian.RedisConfStore) (bool, error) { + config, err := store.FetchGlobalSettingsConfig() + if err != nil { + return false, err + } + return config.Spec.ReportOnly, nil +} + +func getRouteRateLimitsDeprecated(store *guardian.RedisConfStore) (map[string]guardian.Limit, error) { + routeRateLimits := make(map[string]guardian.Limit) + for _, config := range store.FetchRateLimitConfigs() { + routeRateLimits[config.Name] = config.Spec.Limit + } + return routeRateLimits, nil +} + +func removeRouteRateLimitsDeprecated(store *guardian.RedisConfStore, routes string) error { + configsByPath := make(map[string]guardian.RateLimitConfig) + for _, config := range store.FetchRateLimitConfigs() { + configsByPath[config.Spec.Conditions.Path] = config + } + for _, path := range strings.Split(routes, ",") { + config, ok := configsByPath[path] + if !ok { + continue + } + if err := store.DeleteRateLimitConfig(config.Name); err != nil { + return err } - urls = append(urls, *unwantedURL) } - return store.RemoveRouteRateLimits(urls) + return nil } -func setRouteRateLimits(store *guardian.RedisConfStore, configFilePath string) error { - routeRateLimits := make(map[url.URL]guardian.Limit) +func setRouteRateLimitsDeprecated(store *guardian.RedisConfStore, configFilePath string) error { content, err := ioutil.ReadFile(configFilePath) if err != nil { return fmt.Errorf("error reading config file: %v", err) } - config := guardian.RouteRateLimitConfig{} + config := guardian.RouteRateLimitConfigDeprecated{} err = yaml.Unmarshal(content, &config) if err != nil { return fmt.Errorf("error unmarshaling yaml: %v", err) } - for _, routeRateLimitEntry := range config.RouteRatelimits { - configuredURL, err := url.Parse(routeRateLimitEntry.Route) - if err != nil { - return fmt.Errorf("error parsing route: %v", err) + for _, routeRateLimitEntry := range config.RouteRateLimits { + config := guardian.RateLimitConfig{ + ConfigMetadata: guardian.ConfigMetadata{ + Version: guardian.RateLimitConfigVersion, + Kind: guardian.RateLimitConfigKind, + Name: routeRateLimitEntry.Route, + Description: "Metadata automatically created from deprecated CLI", + }, + Spec: guardian.RateLimitSpec{ + Limit: routeRateLimitEntry.Limit, + Conditions: guardian.Conditions{ + Path: routeRateLimitEntry.Route, + }, + }, + } + if err := store.ApplyRateLimitConfig(config); err != nil { + return err } - routeRateLimits[*configuredURL] = routeRateLimitEntry.Limit } - return store.SetRouteRateLimits(routeRateLimits) + return nil } -func setJails(store *guardian.RedisConfStore, configFilePath string) error { - jails := make(map[url.URL]guardian.Jail) +func setJailsDeprecated(store *guardian.RedisConfStore, configFilePath string) error { content, err := ioutil.ReadFile(configFilePath) if err != nil { return fmt.Errorf("error reading config file: %v", err) } - config := guardian.JailConfig{} + config := guardian.JailConfigDeprecated{} err = yaml.Unmarshal(content, &config) if err != nil { return fmt.Errorf("error unmarshaling yaml: %v", err) } for _, jailEntry := range config.Jails { - configuredURL, err := url.Parse(jailEntry.Route) - if err != nil { - return fmt.Errorf("error parsing route: %v", err) + config := guardian.JailConfig{ + ConfigMetadata: guardian.ConfigMetadata{ + Version: guardian.JailConfigVersion, + Kind: guardian.JailConfigKind, + Name: jailEntry.Route, + Description: "Metadata automatically created from deprecated CLI", + }, + Spec: guardian.JailSpec{ + Jail: jailEntry.Jail, + Conditions: guardian.Conditions{ + Path: jailEntry.Route, + }, + }, + } + if err := store.ApplyJailConfig(config); err != nil { + return err } - jails[*configuredURL] = jailEntry.Jail } - return store.SetJails(jails) + return nil } -func removeJails(store *guardian.RedisConfStore, routes string) error { - var urls []url.URL - for _, route := range strings.Split(routes, ",") { - unwantedURL, err := url.Parse(route) - if err != nil { - return fmt.Errorf("error parsing route: %v", err) +func removeJailsDeprecated(store *guardian.RedisConfStore, routes string) error { + configsByPath := make(map[string]guardian.JailConfig) + for _, config := range store.FetchJailConfigs() { + configsByPath[config.Spec.Conditions.Path] = config + } + for _, path := range strings.Split(routes, ",") { + config, ok := configsByPath[path] + if !ok { + continue + } + if err := store.DeleteJailConfig(config.Name); err != nil { + return err } - urls = append(urls, *unwantedURL) } - return store.RemoveJails(urls) + return nil } -func getJails(store *guardian.RedisConfStore) (map[url.URL]guardian.Jail, error) { - return store.FetchJails() +func getJailsDeprecated(store *guardian.RedisConfStore) (map[string]guardian.Jail, error) { + jails := make(map[string]guardian.Jail) + for _, config := range store.FetchJailConfigs() { + jails[config.Name] = config.Spec.Jail + } + return jails, nil } func getPrisoners(store *guardian.RedisConfStore) ([]guardian.Prisoner, error) { diff --git a/e2e/config/TestBlacklist/globalratelimitconfig.yml b/e2e/config/TestBlacklist/globalratelimitconfig.yml new file mode 100644 index 0000000..05ef782 --- /dev/null +++ b/e2e/config/TestBlacklist/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1s + enabled: true diff --git a/e2e/config/TestBlacklist/globalsettingsconfig.yml b/e2e/config/TestBlacklist/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestBlacklist/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestDeleteJail/globalratelimitconfig.yml b/e2e/config/TestDeleteJail/globalratelimitconfig.yml new file mode 100644 index 0000000..35c7a89 --- /dev/null +++ b/e2e/config/TestDeleteJail/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: false diff --git a/e2e/config/TestDeleteJail/globalsettingsconfig.yml b/e2e/config/TestDeleteJail/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestDeleteJail/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestDeleteJail/jailconfig.yml b/e2e/config/TestDeleteJail/jailconfig.yml new file mode 100644 index 0000000..9f24016 --- /dev/null +++ b/e2e/config/TestDeleteJail/jailconfig.yml @@ -0,0 +1,25 @@ +version: "v0" +kind: Jail +name: "/foo/bar" +description: "/foo/bar" +jailSpec: + limit: + count: 10 + duration: 10s + enabled: true + conditions: + path: "/foo/bar" + banDuration: 30s # Keep this duration short as it's used in tests +--- +version: "v0" +kind: Jail +name: "/foo/baz" +description: "/foo/baz" +jailSpec: + limit: + count: 5 + duration: 1m + enabled: false + conditions: + path: "/foo/baz" + banDuration: 30s # Keep this duration short as it's used in tests diff --git a/e2e/config/TestDeleteRateLimit/globalratelimitconfig.yml b/e2e/config/TestDeleteRateLimit/globalratelimitconfig.yml new file mode 100644 index 0000000..5a32157 --- /dev/null +++ b/e2e/config/TestDeleteRateLimit/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: false \ No newline at end of file diff --git a/e2e/config/TestDeleteRateLimit/globalsettingsconfig.yml b/e2e/config/TestDeleteRateLimit/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestDeleteRateLimit/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestDeleteRateLimit/ratelimitconfig.yml b/e2e/config/TestDeleteRateLimit/ratelimitconfig.yml new file mode 100644 index 0000000..4bbab82 --- /dev/null +++ b/e2e/config/TestDeleteRateLimit/ratelimitconfig.yml @@ -0,0 +1,23 @@ +version: "v0" +kind: RateLimit +name: "/foo/bar" +description: "/foo/bar" +rateLimitSpec: + limit: + count: 10 + duration: 1m + enabled: true + conditions: + path: "/foo/bar" +--- +version: "v0" +kind: RateLimit +name: "/foo/baz" +description: "/foo/baz" +rateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: false + conditions: + path: "/foo/baz" \ No newline at end of file diff --git a/e2e/config/TestGlobalRateLimit/globalratelimitconfig.yml b/e2e/config/TestGlobalRateLimit/globalratelimitconfig.yml new file mode 100644 index 0000000..7716bf2 --- /dev/null +++ b/e2e/config/TestGlobalRateLimit/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: true diff --git a/e2e/config/TestGlobalRateLimit/globalsettingsconfig.yml b/e2e/config/TestGlobalRateLimit/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestGlobalRateLimit/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestJails/globalratelimitconfig.yml b/e2e/config/TestJails/globalratelimitconfig.yml new file mode 100644 index 0000000..35c7a89 --- /dev/null +++ b/e2e/config/TestJails/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: false diff --git a/e2e/config/TestJails/globalsettingsconfig.yml b/e2e/config/TestJails/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestJails/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestJails/jailconfig.yml b/e2e/config/TestJails/jailconfig.yml new file mode 100644 index 0000000..9f24016 --- /dev/null +++ b/e2e/config/TestJails/jailconfig.yml @@ -0,0 +1,25 @@ +version: "v0" +kind: Jail +name: "/foo/bar" +description: "/foo/bar" +jailSpec: + limit: + count: 10 + duration: 10s + enabled: true + conditions: + path: "/foo/bar" + banDuration: 30s # Keep this duration short as it's used in tests +--- +version: "v0" +kind: Jail +name: "/foo/baz" +description: "/foo/baz" +jailSpec: + limit: + count: 5 + duration: 1m + enabled: false + conditions: + path: "/foo/baz" + banDuration: 30s # Keep this duration short as it's used in tests diff --git a/e2e/config/jailconfig.yml b/e2e/config/TestJailsDeprecated/jailconfig.yml similarity index 60% rename from e2e/config/jailconfig.yml rename to e2e/config/TestJailsDeprecated/jailconfig.yml index 5dc50aa..6b31408 100644 --- a/e2e/config/jailconfig.yml +++ b/e2e/config/TestJailsDeprecated/jailconfig.yml @@ -5,11 +5,11 @@ jails: duration: 10s enabled: true count: 10 - ban_duration: 30s # Keep this duration short as it's used in tests + banDuration: 30s # Keep this duration short as it's used in tests - route: "/foo/baz" jail: limit: duration: 1m enabled: false count: 5 - ban_duration: 30s # Keep this duration short as it's used in tests + banDuration: 30s # Keep this duration short as it's used in tests \ No newline at end of file diff --git a/e2e/config/TestRateLimit/globalratelimitconfig.yml b/e2e/config/TestRateLimit/globalratelimitconfig.yml new file mode 100644 index 0000000..8676419 --- /dev/null +++ b/e2e/config/TestRateLimit/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 100 + duration: 1s + enabled: false diff --git a/e2e/config/TestRateLimit/globalsettingsconfig.yml b/e2e/config/TestRateLimit/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestRateLimit/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/TestRateLimit/ratelimitconfig.yml b/e2e/config/TestRateLimit/ratelimitconfig.yml new file mode 100644 index 0000000..4bbab82 --- /dev/null +++ b/e2e/config/TestRateLimit/ratelimitconfig.yml @@ -0,0 +1,23 @@ +version: "v0" +kind: RateLimit +name: "/foo/bar" +description: "/foo/bar" +rateLimitSpec: + limit: + count: 10 + duration: 1m + enabled: true + conditions: + path: "/foo/bar" +--- +version: "v0" +kind: RateLimit +name: "/foo/baz" +description: "/foo/baz" +rateLimitSpec: + limit: + count: 5 + duration: 1m + enabled: false + conditions: + path: "/foo/baz" \ No newline at end of file diff --git a/e2e/config/TestRemoveJailDeprecated/jailconfig.yml b/e2e/config/TestRemoveJailDeprecated/jailconfig.yml new file mode 100644 index 0000000..6b31408 --- /dev/null +++ b/e2e/config/TestRemoveJailDeprecated/jailconfig.yml @@ -0,0 +1,15 @@ +jails: + - route: "/foo/bar" + jail: + limit: + duration: 10s + enabled: true + count: 10 + banDuration: 30s # Keep this duration short as it's used in tests + - route: "/foo/baz" + jail: + limit: + duration: 1m + enabled: false + count: 5 + banDuration: 30s # Keep this duration short as it's used in tests \ No newline at end of file diff --git a/e2e/config/TestRemoveRouteRateLimitsDeprecated/routeratelimitconfig.yml b/e2e/config/TestRemoveRouteRateLimitsDeprecated/routeratelimitconfig.yml new file mode 100644 index 0000000..956a533 --- /dev/null +++ b/e2e/config/TestRemoveRouteRateLimitsDeprecated/routeratelimitconfig.yml @@ -0,0 +1,11 @@ +route_rate_limits: + - route: "/foo/bar" + limit: + duration: 1m + enabled: true + count: 10 + - route: "/foo/baz" + limit: + duration: 1m + enabled: false + count: 5 diff --git a/e2e/config/TestRouteRateLimitDeprecated/routeratelimitconfig.yml b/e2e/config/TestRouteRateLimitDeprecated/routeratelimitconfig.yml new file mode 100644 index 0000000..956a533 --- /dev/null +++ b/e2e/config/TestRouteRateLimitDeprecated/routeratelimitconfig.yml @@ -0,0 +1,11 @@ +route_rate_limits: + - route: "/foo/bar" + limit: + duration: 1m + enabled: true + count: 10 + - route: "/foo/baz" + limit: + duration: 1m + enabled: false + count: 5 diff --git a/e2e/config/TestSetJailsDeprecated/jailconfig.yml b/e2e/config/TestSetJailsDeprecated/jailconfig.yml new file mode 100644 index 0000000..6b31408 --- /dev/null +++ b/e2e/config/TestSetJailsDeprecated/jailconfig.yml @@ -0,0 +1,15 @@ +jails: + - route: "/foo/bar" + jail: + limit: + duration: 10s + enabled: true + count: 10 + banDuration: 30s # Keep this duration short as it's used in tests + - route: "/foo/baz" + jail: + limit: + duration: 1m + enabled: false + count: 5 + banDuration: 30s # Keep this duration short as it's used in tests \ No newline at end of file diff --git a/e2e/config/TestSetRouteRateLimitsDeprecated/routeratelimitconfig.yml b/e2e/config/TestSetRouteRateLimitsDeprecated/routeratelimitconfig.yml new file mode 100644 index 0000000..956a533 --- /dev/null +++ b/e2e/config/TestSetRouteRateLimitsDeprecated/routeratelimitconfig.yml @@ -0,0 +1,11 @@ +route_rate_limits: + - route: "/foo/bar" + limit: + duration: 1m + enabled: true + count: 10 + - route: "/foo/baz" + limit: + duration: 1m + enabled: false + count: 5 diff --git a/e2e/config/TestWhitelist/globalratelimitconfig.yml b/e2e/config/TestWhitelist/globalratelimitconfig.yml new file mode 100644 index 0000000..05ef782 --- /dev/null +++ b/e2e/config/TestWhitelist/globalratelimitconfig.yml @@ -0,0 +1,9 @@ +version: "v0" +kind: GlobalRateLimit +name: GlobalRateLimit +description: GlobalRateLimit +globalRateLimitSpec: + limit: + count: 5 + duration: 1s + enabled: true diff --git a/e2e/config/TestWhitelist/globalsettingsconfig.yml b/e2e/config/TestWhitelist/globalsettingsconfig.yml new file mode 100644 index 0000000..d3528ba --- /dev/null +++ b/e2e/config/TestWhitelist/globalsettingsconfig.yml @@ -0,0 +1,6 @@ +version: "v0" +kind: GlobalSettings +name: GlobalSettings +description: GlobalSettings +globalSettingsSpec: + reportOnly: false \ No newline at end of file diff --git a/e2e/config/routeratelimitconfig.yml b/e2e/config/routeratelimitconfig.yml deleted file mode 100644 index 13441e8..0000000 --- a/e2e/config/routeratelimitconfig.yml +++ /dev/null @@ -1,11 +0,0 @@ -route_rate_limits: -- route: "/foo/bar" - limit: - duration: 1m - enabled: true - count: 10 -- route: "/foo/baz" - limit: - duration: 1m - enabled: false - count: 5 diff --git a/e2e/e2e_test.go b/e2e/e2e_test.go index 17e75fa..151138b 100644 --- a/e2e/e2e_test.go +++ b/e2e/e2e_test.go @@ -4,6 +4,7 @@ import ( "flag" "fmt" "go/build" + "io" "io/ioutil" "net/http" "os" @@ -11,6 +12,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "testing" "time" @@ -24,50 +26,16 @@ import ( var redisAddr = flag.String("redis-addr", "localhost:6379", "redis address") var envoyAddr = flag.String("envoy-addr", "localhost:8080", "envoy address") -func TestRateLimit(t *testing.T) { - resetRedis(*redisAddr) - - config := guardianConfig{ - whitelist: []string{}, - blacklist: []string{}, - limitCount: 5, - limitDuration: time.Minute, - limitEnabled: true, - reportOnly: false, - } - applyGuardianConfig(t, *redisAddr, config) - - for i := 0; i < 10; i++ { - if len(os.Getenv("SYNC")) == 0 { - time.Sleep(100 * time.Millisecond) // helps prevents races due asynchronous rate limiting - } - - res := GET(t, "192.168.1.234", "/") - res.Body.Close() - - want := 200 - if i >= config.limitCount { - want = 429 - } - - if res.StatusCode != want { - t.Fatalf("wanted %v, got %v, iteration %v", want, res.StatusCode, i) - } - } -} - func TestWhitelist(t *testing.T) { resetRedis(*redisAddr) IP := "192.168.1.234" CIDR := fmt.Sprintf("%v/32", IP) config := guardianConfig{ - whitelist: []string{CIDR}, - blacklist: []string{}, - limitCount: 5, - limitDuration: time.Second, - limitEnabled: true, - reportOnly: false, + whitelist: []string{CIDR}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestWhitelist/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestWhitelist/globalsettingsconfig.yml", } applyGuardianConfig(t, *redisAddr, config) @@ -88,12 +56,10 @@ func TestBlacklist(t *testing.T) { IP := "192.168.1.234" CIDR := fmt.Sprintf("%v/32", IP) config := guardianConfig{ - whitelist: []string{}, - blacklist: []string{CIDR}, - limitCount: 5, - limitDuration: time.Second, - limitEnabled: true, - reportOnly: false, + whitelist: []string{}, + blacklist: []string{CIDR}, + globalRateLimitConfigPath: "./config/TestBlacklist/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestBlacklist/globalsettingsconfig.yml", } applyGuardianConfig(t, *redisAddr, config) @@ -108,42 +74,83 @@ func TestBlacklist(t *testing.T) { } } -func TestRouteRateLimit(t *testing.T) { +func TestGlobalRateLimit(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/routeratelimitconfig.yml" - config := guardianConfig{ - whitelist: []string{}, - blacklist: []string{}, - limitCount: 100, - limitDuration: time.Second, - limitEnabled: false, - reportOnly: false, - routeRateLimitConfigPath: configFilePath, + guardianConfig := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestGlobalRateLimit/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestGlobalRateLimit/globalsettingsconfig.yml", } - applyGuardianConfig(t, *redisAddr, config) + applyGuardianConfig(t, *redisAddr, guardianConfig) - rrlConfig := guardian.RouteRateLimitConfig{} - rrlConfigBytes, err := ioutil.ReadFile(configFilePath) + f, err := os.Open(guardianConfig.globalRateLimitConfigPath) if err != nil { - t.Fatalf("unable to read config file: %v", err) + t.Fatalf("error opening config file: %v", err) } - err = yaml.Unmarshal(rrlConfigBytes, &rrlConfig) + defer f.Close() + config := guardian.GlobalRateLimitConfig{} + err = yaml.NewDecoder(f).Decode(&config) if err != nil { - t.Fatalf("error unmarshaling expected result string: %v", err) + t.Fatalf("error decoding yaml: %v", err) } - for _, routeRateLimit := range rrlConfig.RouteRatelimits { - for i := uint64(0); i < routeRateLimit.Limit.Count+5; i++ { + for i := uint64(0); i < 10; i++ { + if len(os.Getenv("SYNC")) == 0 { + time.Sleep(100 * time.Millisecond) // helps prevents races due asynchronous rate limiting + } + + res := GET(t, "192.168.1.234", "/") + res.Body.Close() + + want := 200 + if i >= config.Spec.Limit.Count { + want = 429 + } + + if res.StatusCode != want { + t.Fatalf("wanted %v, got %v, iteration %v", want, res.StatusCode, i) + } + } +} + +func TestRateLimit(t *testing.T) { + resetRedis(*redisAddr) + + guardianConfig := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestRateLimit/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestRateLimit/globalsettingsconfig.yml", + rateLimitConfigPath: "./config/TestRateLimit/ratelimitconfig.yml", + } + applyGuardianConfig(t, *redisAddr, guardianConfig) + + f, err := os.Open(guardianConfig.rateLimitConfigPath) + if err != nil { + t.Fatalf("error opening config file: %v", err) + } + defer f.Close() + dec := yaml.NewDecoder(f) + for { + config := guardian.RateLimitConfig{} + err := dec.Decode(&config) + if err == io.EOF { + break + } else if err != nil { + t.Fatalf("error decoding yaml: %v", err) + } + for i := uint64(0); i < config.Spec.Limit.Count+5; i++ { if len(os.Getenv("SYNC")) == 0 { time.Sleep(100 * time.Millisecond) // helps prevents races due asynchronous rate limiting } - res := GET(t, "192.168.1.234", routeRateLimit.Route) + res := GET(t, "192.168.1.234", config.Spec.Conditions.Path) res.Body.Close() want := 200 - if i >= routeRateLimit.Limit.Count && routeRateLimit.Limit.Enabled { + if i >= config.Spec.Limit.Count && config.Spec.Limit.Enabled { want = 429 } @@ -152,91 +159,218 @@ func TestRouteRateLimit(t *testing.T) { } } } - } -func TestSetRouteRateLimits(t *testing.T) { +func TestJails(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/routeratelimitconfig.yml" - config := guardianConfig{ - routeRateLimitConfigPath: configFilePath, - } - applyGuardianConfig(t, *redisAddr, config) - getCmd := "get-route-rate-limits" - resStr := runGuardianCLI(t, *redisAddr, getCmd) - expectedResStr, err := ioutil.ReadFile(configFilePath) + whitelistedIP := "192.168.1.1" - res := guardian.RouteRateLimitConfig{} - expectedRes := guardian.RouteRateLimitConfig{} - err = yaml.Unmarshal([]byte(resStr), &res) - if err != nil { - t.Fatalf("error unmarshaling result string: %v", err) + guardianConfig := guardianConfig{ + whitelist: []string{whitelistedIP + "/32"}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestJails/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestJails/globalsettingsconfig.yml", + jailConfigPath: "./config/TestJails/jailconfig.yml", } - err = yaml.Unmarshal(expectedResStr, &expectedRes) + + applyGuardianConfig(t, *redisAddr, guardianConfig) + + f, err := os.Open(guardianConfig.jailConfigPath) if err != nil { - t.Fatalf("error unmarshaling expected result string: %v", err) + t.Fatalf("error opening config file: %v", err) } + defer f.Close() + dec := yaml.NewDecoder(f) - // Since the ordering of the slice returned from the cli can be different - // than the original config, we just want to verify that both configs contain - // the same entries in no particular order. - expectedResSet := make(map[string]guardian.Limit) - resSet := make(map[string]guardian.Limit) - for _, entry := range expectedRes.RouteRatelimits { - expectedResSet[entry.Route] = entry.Limit + // Assumes that any BanDuration in the Jail Config is greater than the time it takes + // to execute this particular test. + for { + config := guardian.JailConfig{} + err := dec.Decode(&config) + if err != nil { + if err == io.EOF { + break + } + t.Fatalf("error decoding yaml: %v", err) + } + banned := false + resetRedis(*redisAddr) + applyGuardianConfig(t, *redisAddr, guardianConfig) + for i := 0; uint64(i) <= config.Spec.Limit.Count; i++ { + if os.Getenv("SYNC") != "" { + time.Sleep(150 * time.Millisecond) // helps prevents races due asynchronous rate limiting + } + + res := GET(t, "192.168.1.43", config.Spec.Conditions.Path) + whitelistedRes := GET(t, whitelistedIP, config.Spec.Conditions.Path) + res.Body.Close() + whitelistedRes.Body.Close() + + want := 200 + if (uint64(i) >= config.Spec.Limit.Count && config.Spec.Limit.Enabled) || banned { + banned = true + want = 429 + } + + if res.StatusCode != want { + t.Fatalf("wanted %v, got %v, iteration %v, route: %v", want, res.StatusCode, i, config.Spec.Conditions.Path) + } + + if whitelistedRes.StatusCode != 200 { + t.Fatalf("whitelisted ip received unexpected status code: wanted %v, got %v, iteration %d, route: %v", 200, whitelistedRes.StatusCode, i, config.Spec.Conditions.Path) + } + } + if config.Spec.Limit.Enabled { + t.Logf("sleeping for banDuration: %v + 2 seconds to ensure the prisoner is removed", config.Spec.BanDuration) + // ensure that we sleep for an additional confUpdateInterval so that the configuration is updated + time.Sleep(config.Spec.BanDuration + (2 * time.Second)) + res := GET(t, "192.168.1.43", config.Spec.Conditions.Path) + if res.StatusCode != 200 { + t.Fatalf("prisoner was never removed, received unexpected status code: %d, %v", res.StatusCode, config.Spec.Jail) + } + } } - for _, entry := range res.RouteRatelimits { - resSet[entry.Route] = entry.Limit +} + +func TestDeleteRateLimit(t *testing.T) { + resetRedis(*redisAddr) + + config := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestDeleteRateLimit/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestDeleteRateLimit/globalsettingsconfig.yml", + rateLimitConfigPath: "./config/TestDeleteRateLimit/ratelimitconfig.yml", } + applyGuardianConfig(t, *redisAddr, config) + delCmd := "delete" + runGuardianCLI(t, *redisAddr, delCmd, "RateLimit", "/foo/bar") + runGuardianCLI(t, *redisAddr, delCmd, "RateLimit", "/foo/baz") - if !cmp.Equal(resSet, expectedResSet) { - t.Fatalf("expected: %v, received: %v", expectedResSet, resSet) + getCmd := "get" + resStr := runGuardianCLI(t, *redisAddr, getCmd, "RateLimit") + + if len(resStr) != 0 { + t.Fatalf("get RateLimit returned non-empty output %v", resStr) } } -func TestRemoveRouteRateLimits(t *testing.T) { +func TestDeleteJail(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/routeratelimitconfig.yml" + config := guardianConfig{ - routeRateLimitConfigPath: configFilePath, + whitelist: []string{}, + blacklist: []string{}, + globalRateLimitConfigPath: "./config/TestDeleteJail/globalratelimitconfig.yml", + globalSettingsConfigPath: "./config/TestDeleteJail/globalsettingsconfig.yml", + jailConfigPath: "./config/TestDeleteJail/jailconfig.yml", } applyGuardianConfig(t, *redisAddr, config) - rmCmd := "remove-route-rate-limits" - runGuardianCLI(t, *redisAddr, rmCmd, "/foo/bar,/foo/baz") + delCmd := "delete" + runGuardianCLI(t, *redisAddr, delCmd, "Jail", "/foo/bar") + runGuardianCLI(t, *redisAddr, delCmd, "Jail", "/foo/baz") - getCmd := "get-route-rate-limits" - resStr := runGuardianCLI(t, *redisAddr, getCmd) + getCmd := "get" + resStr := runGuardianCLI(t, *redisAddr, getCmd, "Jail") - res := guardian.RouteRateLimitConfig{} - err := yaml.Unmarshal([]byte(resStr), &res) + if len(resStr) != 0 { + t.Fatalf("get Jail returned non-empty output %v", resStr) + } +} + +func TestRateLimitDeprecated(t *testing.T) { + resetRedis(*redisAddr) + + config := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 5, + limitDurationDeprecated: time.Minute, + limitEnabledDeprecated: true, + reportOnlyDeprecated: false, + } + applyGuardianConfigDeprecated(t, *redisAddr, config) + + for i := 0; i < 10; i++ { + if len(os.Getenv("SYNC")) == 0 { + time.Sleep(100 * time.Millisecond) // helps prevents races due asynchronous rate limiting + } + + res := GET(t, "192.168.1.234", "/") + res.Body.Close() + + want := 200 + if i >= config.limitCountDeprecated { + want = 429 + } + + if res.StatusCode != want { + t.Fatalf("wanted %v, got %v, iteration %v", want, res.StatusCode, i) + } + } +} + +func TestRouteRateLimitDeprecated(t *testing.T) { + resetRedis(*redisAddr) + config := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 100, + limitDurationDeprecated: time.Second, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + routeRateLimitConfigPathDeprecated: "./config/TestRouteRateLimitDeprecated/routeratelimitconfig.yml", + } + applyGuardianConfigDeprecated(t, *redisAddr, config) + + rrlConfig := guardian.RouteRateLimitConfigDeprecated{} + b, err := ioutil.ReadFile(config.routeRateLimitConfigPathDeprecated) if err != nil { - t.Fatalf("error unmarshaling result string: %v", err) + t.Fatalf("unable to read config file: %v", err) + } + err = yaml.Unmarshal(b, &rrlConfig) + if err != nil { + t.Fatalf("error unmarshaling expected result string: %v", err) } - if len(res.RouteRatelimits) != 0 { - t.Fatalf("expected route rate limits to be empty after removing them") + for _, routeRateLimit := range rrlConfig.RouteRateLimits { + for i := uint64(0); i < routeRateLimit.Limit.Count+5; i++ { + if len(os.Getenv("SYNC")) == 0 { + time.Sleep(100 * time.Millisecond) // helps prevents races due asynchronous rate limiting + } + + res := GET(t, "192.168.1.234", routeRateLimit.Route) + res.Body.Close() + + want := 200 + if i >= routeRateLimit.Limit.Count && routeRateLimit.Limit.Enabled { + want = 429 + } + + if res.StatusCode != want { + t.Fatalf("wanted %v, got %v, iteration %v", want, res.StatusCode, i) + } + } } } -func TestJails(t *testing.T) { +func TestJailsDeprecated(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/jailconfig.yml" whitelistedIP := "192.168.1.1" config := guardianConfig{ whitelist: []string{whitelistedIP + "/32"}, blacklist: []string{}, - limitCount: 5, - limitDuration: time.Minute, - limitEnabled: false, - reportOnly: false, - routeRateLimitConfigPath: "", - jailConfigPath: configFilePath, + limitCountDeprecated: 5, + limitDurationDeprecated: time.Minute, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + jailConfigPathDeprecated: "./config/TestJailsDeprecated/jailconfig.yml", } - applyGuardianConfig(t, *redisAddr, config) - jailConfig := &guardian.JailConfig{} - jailConfigContents, err := ioutil.ReadFile(config.jailConfigPath) + applyGuardianConfigDeprecated(t, *redisAddr, config) + jailConfig := &guardian.JailConfigDeprecated{} + jailConfigContents, err := ioutil.ReadFile(config.jailConfigPathDeprecated) if err != nil { t.Fatalf("unable to read config file: %v", err) } @@ -250,7 +384,7 @@ func TestJails(t *testing.T) { for _, j := range jailConfig.Jails { banned := false resetRedis(*redisAddr) - applyGuardianConfig(t, *redisAddr, config) + applyGuardianConfigDeprecated(t, *redisAddr, config) for i := uint64(0); i < j.Jail.Limit.Count+1; i++ { if len(os.Getenv("SYNC")) == 0 { time.Sleep(150 * time.Millisecond) // helps prevents races due asynchronous rate limiting @@ -276,9 +410,9 @@ func TestJails(t *testing.T) { } } if j.Jail.Limit.Enabled { - t.Logf("sleeping for ban_duration: %v + 2 seconds to ensure the prisoner is removed", j.Jail.BanDuration) + t.Logf("sleeping for banDuration: %v + 2 seconds to ensure the prisoner is removed", j.Jail.BanDuration) time.Sleep(j.Jail.BanDuration) - time.Sleep(2 * time.Second) // ensure that we sleep for an additional confUpdateInterval so that the configuration is updated + time.Sleep(2 * time.Second) // ensure that we sleep for an additional confUpdateInterval so that the configuration is updated res := GET(t, "192.168.1.43", j.Route) if res.StatusCode != 200 { t.Fatalf("prisoner was never removed, received unexpected status code: %d, %v", res.StatusCode, j.Jail) @@ -287,20 +421,107 @@ func TestJails(t *testing.T) { } } -func TestSetJails(t *testing.T) { +func routeRateLimitsEqual(rrl1, rrl2 []guardian.RouteRateLimitConfigEntryDeprecated) bool { + if len(rrl1) != len(rrl2) { + return false + } + m := make(map[guardian.RouteRateLimitConfigEntryDeprecated]struct{}, len(rrl1)) + for i := range rrl1 { + m[rrl1[i]] = struct{}{} + } + for _, e := range rrl2 { + if _, ok := m[e]; !ok { + return false + } + } + return true +} + +func TestSetRouteRateLimitsDeprecated(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/jailconfig.yml" config := guardianConfig{ - jailConfigPath: configFilePath, + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 100, + limitDurationDeprecated: time.Second, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + routeRateLimitConfigPathDeprecated: "./config/TestSetRouteRateLimitsDeprecated/routeratelimitconfig.yml", + } + applyGuardianConfigDeprecated(t, *redisAddr, config) + getCmd := "get-route-rate-limits" + resStr := runGuardianCLI(t, *redisAddr, getCmd) + expectedResStr, err := ioutil.ReadFile(config.routeRateLimitConfigPathDeprecated) + + res := guardian.RouteRateLimitConfigDeprecated{} + expectedRes := guardian.RouteRateLimitConfigDeprecated{} + err = yaml.Unmarshal([]byte(resStr), &res) + if err != nil { + t.Fatalf("error unmarshaling result string: %v", err) + } + err = yaml.Unmarshal(expectedResStr, &expectedRes) + if err != nil { + t.Fatalf("error unmarshaling expected result string: %v", err) } - applyGuardianConfig(t, *redisAddr, config) - getCmd := "get-jails" + // Since the ordering of the slice returned from the cli can be different + // than the original config, we just want to verify that both configs contain + // the same entries in no particular order. + got := res.RouteRateLimits + expected := expectedRes.RouteRateLimits + if !routeRateLimitsEqual(got, expected) { + t.Fatalf("expected: %v, received: %v", expected, got) + } +} + +func TestRemoveRouteRateLimitsDeprecated(t *testing.T) { + resetRedis(*redisAddr) + config := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 100, + limitDurationDeprecated: time.Second, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + routeRateLimitConfigPathDeprecated: "./config/TestRemoveRouteRateLimitsDeprecated/routeratelimitconfig.yml", + } + applyGuardianConfigDeprecated(t, *redisAddr, config) + rmCmd := "remove-route-rate-limits" + runGuardianCLI(t, *redisAddr, rmCmd, "/foo/bar,/foo/baz") + + getCmd := "get-route-rate-limits" resStr := runGuardianCLI(t, *redisAddr, getCmd) - expectedResStr, err := ioutil.ReadFile(configFilePath) - res := guardian.JailConfig{} - expectedRes := guardian.JailConfig{} + res := guardian.RouteRateLimitConfigDeprecated{} + err := yaml.Unmarshal([]byte(resStr), &res) + if err != nil { + t.Fatalf("error unmarshaling result string: %v", err) + } + + if len(res.RouteRateLimits) != 0 { + t.Fatalf("expected route rate limits to be empty after removing them") + } +} + +func TestSetJailsDeprecated(t *testing.T) { + resetRedis(*redisAddr) + config := guardianConfig{ + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 100, + limitDurationDeprecated: time.Second, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + jailConfigPathDeprecated: "./config/TestSetJailsDeprecated/jailconfig.yml", + } + applyGuardianConfigDeprecated(t, *redisAddr, config) + + getCmd := "get-jails" + resStr := runGuardianCLI(t, *redisAddr, getCmd) + t.Logf("Got result: %v", resStr) + expectedResStr, err := ioutil.ReadFile(config.jailConfigPathDeprecated) + res := guardian.JailConfigDeprecated{} + expectedRes := guardian.JailConfigDeprecated{} err = yaml.Unmarshal([]byte(resStr), &res) if err != nil { t.Fatalf("error unmarshaling result string: %v", err) @@ -327,20 +548,26 @@ func TestSetJails(t *testing.T) { } } -func TestRemoveJail(t *testing.T) { +func TestRemoveJailDeprecated(t *testing.T) { resetRedis(*redisAddr) - configFilePath := "./config/jailconfig.yml" + config := guardianConfig{ - jailConfigPath: configFilePath, + whitelist: []string{}, + blacklist: []string{}, + limitCountDeprecated: 100, + limitDurationDeprecated: time.Second, + limitEnabledDeprecated: false, + reportOnlyDeprecated: false, + jailConfigPathDeprecated: "./config/TestRemoveJailDeprecated/jailconfig.yml", } - applyGuardianConfig(t, *redisAddr, config) + applyGuardianConfigDeprecated(t, *redisAddr, config) rmCmd := "remove-jails" runGuardianCLI(t, *redisAddr, rmCmd, "/foo/bar,/foo/baz") getCmd := "get-jails" resStr := runGuardianCLI(t, *redisAddr, getCmd) - res := guardian.JailConfig{} + res := guardian.JailConfigDeprecated{} err := yaml.Unmarshal([]byte(resStr), &res) if err != nil { t.Fatalf("error unmarshaling result string: %v", err) @@ -371,31 +598,58 @@ func GET(t *testing.T, sourceIP string, path string) *http.Response { return res } +type redisDBIndex struct { + sync.Mutex + Index int +} + +var currentRedisDBIndex = redisDBIndex{ + Index: 0, +} + func resetRedis(redisAddr string) { + currentRedisDBIndex.Lock() redisOpts := &redis.Options{ Addr: redisAddr, + DB: currentRedisDBIndex.Index, + } + currentRedisDBIndex.Index++ + maxDBIndex := 15 + if currentRedisDBIndex.Index > maxDBIndex { + currentRedisDBIndex.Index = 0 } + currentRedisDBIndex.Unlock() redis := redis.NewClient(redisOpts) redis.FlushAll() } type guardianConfig struct { - whitelist []string - blacklist []string - limitCount int - limitDuration time.Duration - limitEnabled bool - reportOnly bool - routeRateLimitConfigPath string - jailConfigPath string + whitelist []string + blacklist []string + globalRateLimitConfigPath string + globalSettingsConfigPath string + rateLimitConfigPath string + jailConfigPath string + // Fields associated with deprecated CLI + limitCountDeprecated int + limitDurationDeprecated time.Duration + limitEnabledDeprecated bool + reportOnlyDeprecated bool + routeRateLimitConfigPathDeprecated string + jailConfigPathDeprecated string } func applyGuardianConfig(t *testing.T, redisAddr string, c guardianConfig) { t.Helper() - runGuardianCLI(t, redisAddr, "set-limit", strconv.Itoa(c.limitCount), c.limitDuration.String(), strconv.FormatBool(c.limitEnabled)) - runGuardianCLI(t, redisAddr, "set-report-only", strconv.FormatBool(c.reportOnly)) + if len(c.globalRateLimitConfigPath) > 0 { + runGuardianCLI(t, redisAddr, "apply", c.globalRateLimitConfigPath) + } + + if len(c.globalSettingsConfigPath) > 0 { + runGuardianCLI(t, redisAddr, "apply", c.globalSettingsConfigPath) + } clearXList(t, redisAddr, "blacklist") clearXList(t, redisAddr, "whitelist") @@ -408,12 +662,48 @@ func applyGuardianConfig(t *testing.T, redisAddr string, c guardianConfig) { runGuardianCLI(t, redisAddr, "add-blacklist", strings.Join(c.blacklist, " ")) } - if len(c.routeRateLimitConfigPath) > 0 { - runGuardianCLI(t, redisAddr, "set-route-rate-limits", c.routeRateLimitConfigPath) + if len(c.rateLimitConfigPath) > 0 { + runGuardianCLI(t, redisAddr, "apply", c.rateLimitConfigPath) } if len(c.jailConfigPath) > 0 { - runGuardianCLI(t, redisAddr, "set-jails", c.jailConfigPath) + runGuardianCLI(t, redisAddr, "apply", c.jailConfigPath) + } + + time.Sleep(2 * time.Second) +} + +func applyGuardianConfigDeprecated(t *testing.T, redisAddr string, c guardianConfig) { + t.Helper() + + runGuardianCLI( + t, + redisAddr, + "set-limit", + strconv.Itoa(c.limitCountDeprecated), + c.limitDurationDeprecated.String(), + strconv.FormatBool(c.limitEnabledDeprecated), + ) + + runGuardianCLI(t, redisAddr, "set-report-only", strconv.FormatBool(c.reportOnlyDeprecated)) + + clearXList(t, redisAddr, "blacklist") + clearXList(t, redisAddr, "whitelist") + + if len(c.whitelist) > 0 { + runGuardianCLI(t, redisAddr, "add-whitelist", strings.Join(c.whitelist, " ")) + } + + if len(c.blacklist) > 0 { + runGuardianCLI(t, redisAddr, "add-blacklist", strings.Join(c.blacklist, " ")) + } + + if len(c.routeRateLimitConfigPathDeprecated) > 0 { + runGuardianCLI(t, redisAddr, "set-route-rate-limits", c.routeRateLimitConfigPathDeprecated) + } + + if len(c.jailConfigPathDeprecated) > 0 { + runGuardianCLI(t, redisAddr, "set-jails", c.jailConfigPathDeprecated) } time.Sleep(2 * time.Second) @@ -450,7 +740,7 @@ func runGuardianCLI(t *testing.T, redisAddr string, command string, args ...stri cmdArgs := append([]string{command, "-r", redisAddr}, args...) c := exec.Command(cliPath, cmdArgs...) - output, err := c.CombinedOutput() + output, err := c.Output() if err != nil { t.Fatalf("error running guardian-cli: %v %v", err, string(output)) } diff --git a/e2e/scripts/circleci-e2e-runner-docker.sh b/e2e/scripts/circleci-e2e-runner-docker.sh index d77dbe8..750dad1 100755 --- a/e2e/scripts/circleci-e2e-runner-docker.sh +++ b/e2e/scripts/circleci-e2e-runner-docker.sh @@ -3,4 +3,5 @@ cd /go/src/github.com/dollarshaveclub/guardian COMMIT="e2e" make cli -go test ./e2e -redis-addr="redis:6379" -envoy-addr="envoy:8080" \ No newline at end of file +# TODO: Refactor tests to be less brittle and not require sequential runs +go test ./e2e -p 1 -redis-addr="redis:6379" -envoy-addr="envoy:8080" diff --git a/e2e/scripts/run-e2e.sh b/e2e/scripts/run-e2e.sh index 73cf8d3..3cfec82 100755 --- a/e2e/scripts/run-e2e.sh +++ b/e2e/scripts/run-e2e.sh @@ -11,7 +11,7 @@ while [ $secs -ge 0 ]; do done printf "\nRunning tests...\n" -go test ./e2e/ +go test -p 1 ./e2e/ results=$? echo "Stopping async guardian..." diff --git a/pkg/guardian/config.go b/pkg/guardian/config.go new file mode 100644 index 0000000..06581dc --- /dev/null +++ b/pkg/guardian/config.go @@ -0,0 +1,117 @@ +package guardian + +// ConfigKind identifies a kind of configuration resource +type ConfigKind string + +// Config kinds +const ( + // GlobalRateLimitConfigKind identifies a Global Rate Limit config resource + GlobalRateLimitConfigKind = "GlobalRateLimit" + // RateLimitConfigKind identifies a Rate Limit config resource + RateLimitConfigKind = "RateLimit" + // JailConfigKind identifies a Jail config resource + JailConfigKind = "Jail" + // GlobalSettingsConfigKind identifies a Global Settings config resource + GlobalSettingsConfigKind = "GlobalSettings" +) + +// ConfigMetadata represents metadata associated with a configuration resource. +// Every configuration resource begins with this metadata. +type ConfigMetadata struct { + // Version identifies the version of the configuration resource format + Version string `yaml:"version" json:"version"` + // Kind identifies the kind of the configuration resource + Kind ConfigKind `yaml:"kind" json:"kind"` + // Name uniquely identifies a configuration resource within the resource's Kind + Name string `yaml:"name" json:"name"` + // Description is a description to add context to a resource + Description string `yaml:"description" json:"description"` +} + +// Conditions represents conditions required for a Limit to be applied to +// a Request. Currently, Guardian only filters requests based on URL path, +// via RedisConfStore.GetRouteRateLimit(url.URL) or .GetJail(url.URL) +type Conditions struct { + Path string `yaml:"path" json:"path"` +} + +// GlobalRateLimitSpec represents the specification for a GlobalRateLimitConfig +type GlobalRateLimitSpec struct { + Limit Limit `yaml:"limit" json:"limit"` +} + +// GlobalSettingsSpec represents the specification for a GlobalSettingsConfig +type GlobalSettingsSpec struct { + ReportOnly bool `yaml:"reportOnly" json:"reportOnly"` +} + +// RateLimitSpec represents the specification for a RateLimitConfig +type RateLimitSpec struct { + Limit Limit `yaml:"limit" json:"limit"` + Conditions Conditions `yaml:"conditions" json:"conditions"` +} + +// JailSpec represents the specification for a JailConfig +type JailSpec struct { + Jail `yaml:",inline" json:",inline"` + Conditions Conditions `yaml:"conditions" json:"conditions"` +} + +// GlobalRateLimitConfig represents a resource that configures the global rate limit +type GlobalRateLimitConfig struct { + ConfigMetadata `yaml:",inline" json:",inline"` + Spec GlobalRateLimitSpec `yaml:"globalRateLimitSpec" json:"globalRateLimitSpec"` +} + +// GlobalSettingsConfig represents a resource that configures global settings +type GlobalSettingsConfig struct { + ConfigMetadata `yaml:",inline" json:",inline"` + Spec GlobalSettingsSpec `yaml:"globalSettingsSpec" json:"globalSettingsSpec"` +} + +// RateLimitConfig represents a resource that configures a conditional rate limit +type RateLimitConfig struct { + ConfigMetadata `yaml:",inline" json:",inline"` + Spec RateLimitSpec `yaml:"rateLimitSpec" json:"rateLimitSpec"` +} + +// JailConfig represents a resource that configures a jail +type JailConfig struct { + ConfigMetadata `yaml:",inline" json:",inline"` + Spec JailSpec `yaml:"jailSpec" json:"jailSpec"` +} + +// Config represents a generic configuration file +type Config struct { + ConfigMetadata `yaml:",inline"` + GlobalRateLimitSpec *GlobalRateLimitSpec `yaml:"globalRateLimitSpec"` + GlobalSettingsSpec *GlobalSettingsSpec `yaml:"globalSettingsSpec"` + RateLimitSpec *RateLimitSpec `yaml:"rateLimitSpec"` + JailSpec *JailSpec `yaml:"jailSpec"` +} + +// JailConfigEntryDeprecated represents an entry in the jail configuration format +// associated with the deprecated CLI +type JailConfigEntryDeprecated struct { + Route string `yaml:"route"` + Jail Jail `yaml:"jail"` +} + +// JailConfigEntryDeprecated represents the jail configuration format associated +// with the deprecated CLI +type JailConfigDeprecated struct { + Jails []JailConfigEntryDeprecated `yaml:"jails"` +} + +// RouteRateLimitConfigEntryDeprecated represents an entry in the conditional +// rate limit configuration format associated with the deprecated CLI +type RouteRateLimitConfigEntryDeprecated struct { + Route string `yaml:"route"` + Limit Limit `yaml:"limit"` +} + +// RouteRateLimitConfigDeprecated represents the conditional rate limit +// configuration format associated with the deprecated CLI +type RouteRateLimitConfigDeprecated struct { + RouteRateLimits []RouteRateLimitConfigEntryDeprecated `yaml:"route_rate_limits"` +} diff --git a/pkg/guardian/guardian_test.go b/pkg/guardian/guardian_test.go index ddc21d1..d4e64e7 100644 --- a/pkg/guardian/guardian_test.go +++ b/pkg/guardian/guardian_test.go @@ -126,8 +126,36 @@ func TestBasicFunctionality(t *testing.T) { redisConfStore.AddWhitelistCidrs([]net.IPNet{ipStringToIPNet(t, whitelistedIP)}) redisConfStore.AddBlacklistCidrs([]net.IPNet{ipStringToIPNet(t, blacklistedIP)}) - redisConfStore.SetLimit(Limit{Count: 5, Duration: time.Minute, Enabled: true}) - redisConfStore.SetReportOnly(false) + + redisConfStore.ApplyGlobalRateLimitConfig( + GlobalRateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalRateLimitConfigVersion, + Kind: GlobalRateLimitConfigKind, + Name: GlobalRateLimitConfigKind, + }, + Spec: GlobalRateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Minute, + Enabled: true, + }, + }, + }, + ) + + redisConfStore.ApplyGlobalSettingsConfig( + GlobalSettingsConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalSettingsConfigVersion, + Kind: GlobalSettingsConfigKind, + Name: GlobalSettingsConfigKind, + }, + Spec: GlobalSettingsSpec{ + ReportOnly: false, + }, + }, + ) time.Sleep(2 * time.Second) // let conf changes take effect diff --git a/pkg/guardian/jail.go b/pkg/guardian/jail.go index 187f671..3014bc7 100644 --- a/pkg/guardian/jail.go +++ b/pkg/guardian/jail.go @@ -3,10 +3,11 @@ package guardian import ( "context" "fmt" - "github.com/pkg/errors" - "github.com/sirupsen/logrus" "net/url" "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" ) // Jail is a namesake of a similar concept from fail2ban @@ -15,7 +16,7 @@ import ( // If the Limit is reached, the the IP will be banned for the BanDuration. type Jail struct { Limit Limit `yaml:"limit"" json:"limit"` - BanDuration time.Duration `yaml:"ban_duration" json:"ban_duration"` + BanDuration time.Duration `yaml:"banDuration" json:"banDuration"` } func (j Jail) String() string { diff --git a/pkg/guardian/redis_conf_store.go b/pkg/guardian/redis_conf_store.go index 65a925e..18ba6bf 100644 --- a/pkg/guardian/redis_conf_store.go +++ b/pkg/guardian/redis_conf_store.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/url" - "strconv" "sync" "time" @@ -16,21 +15,25 @@ import ( "github.com/sirupsen/logrus" ) -const redisIPWhitelistKey = "guardian_conf:whitelist" -const redisIPBlacklistKey = "guardian_conf:blacklist" -const redisPrisonersKey = "guardian_conf:prisoners" -const redisPrisonersLockKey = "guardian_conf:prisoners_lock" -const redisLimitCountKey = "guardian_conf:limit_count" -const redisLimitDurationKey = "guardian_conf:limit_duration" -const redisLimitEnabledKey = "guardian_conf:limit_enabled" -const redisReportOnlyKey = "guardian_conf:reportOnly" -const redisRouteRateLimitsEnabledKey = "guardian_conf:route_limits:enabled" -const redisRouteRateLimitsDurationKey = "guardian_conf:route_limits:duration" -const redisRouteRateLimitsCountKey = "guardian_conf:route_limits:count" -const redisJailLimitsEnabledKey = "guardian_conf:jail:limits:enabled" -const redisJailLimitsDurationKey = "guardian_conf:jail:limits:duration" -const redisJailLimitsCountKey = "guardian_conf:jail:limits:count" -const redisJailBanDurationKey = "guardian_conf:jail:ban:duration" +// Config file format versions +const ( + GlobalRateLimitConfigVersion = "v0" + GlobalSettingsConfigVersion = "v0" + RateLimitConfigVersion = "v0" + JailConfigVersion = "v0" +) + +// Redis keys +const ( + redisIPWhitelistKey = "guardian_conf:whitelist" + redisIPBlacklistKey = "guardian_conf:blacklist" + redisPrisonersKey = "guardian_conf:prisoners" + redisPrisonersLockKey = "guardian_conf:prisoners_lock" + redisGlobalRateLimitConfigKey = "guardian_conf:global_rate_limit" + redisGlobalSettingsConfigKey = "guardian_conf:global_settings" + redisRateLimitsConfigKey = "guardian_conf:rate_limits" + redisJailsConfigKey = "guardian_conf:jails" +) // NewRedisConfStore creates a new RedisConfStore func NewRedisConfStore(redis *redis.Client, defaultWhitelist []net.IPNet, defaultBlacklist []net.IPNet, defaultLimit Limit, defaultReportOnly, initConfig bool, maxPrisonerCacheSize uint16, logger logrus.FieldLogger, mr MetricReporter) (*RedisConfStore, error) { @@ -53,12 +56,34 @@ func NewRedisConfStore(redis *redis.Client, defaultWhitelist []net.IPNet, defaul } defaultConf := conf{ - whitelist: defaultWhitelist, - blacklist: defaultBlacklist, - limit: defaultLimit, - reportOnly: defaultReportOnly, - routeRateLimits: make(map[url.URL]Limit), - jails: make(map[url.URL]Jail), + whitelist: defaultWhitelist, + blacklist: defaultBlacklist, + globalRateLimit: GlobalRateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalRateLimitConfigVersion, + Kind: GlobalRateLimitConfigKind, + Name: GlobalRateLimitConfigKind, + Description: "", + }, + Spec: GlobalRateLimitSpec{ + Limit: defaultLimit, + }, + }, + globalSettings: GlobalSettingsConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalSettingsConfigVersion, + Kind: GlobalSettingsConfigKind, + Name: GlobalSettingsConfigKind, + Description: "", + }, + Spec: GlobalSettingsSpec{ + ReportOnly: defaultReportOnly, + }, + }, + rateLimitsByName: make(map[string]RateLimitConfig), + rateLimitsByPath: make(map[string]RateLimitConfig), + jailsByName: make(map[string]JailConfig), + jailsByPath: make(map[string]JailConfig), // There are a couple limitations in redis that require us to need to keep track of prisoners and manage when they should expire. // 1. Redis Hashmaps and Sets to not support setting TTL of individual keys // 2. We could just rely upon unique keys in the global keyspace and do prefix lookups with the KEYS function, however, these lookups would be O(N). @@ -86,13 +111,15 @@ type RedisConfStore struct { } type conf struct { - whitelist []net.IPNet - blacklist []net.IPNet - limit Limit - reportOnly bool - routeRateLimits map[url.URL]Limit - jails map[url.URL]Jail - prisoners *prisonersCache + whitelist []net.IPNet + blacklist []net.IPNet + prisoners *prisonersCache + globalSettings GlobalSettingsConfig + globalRateLimit GlobalRateLimitConfig + rateLimitsByName map[string]RateLimitConfig + rateLimitsByPath map[string]RateLimitConfig + jailsByName map[string]JailConfig + jailsByPath map[string]JailConfig } type lockingConf struct { @@ -100,28 +127,6 @@ type lockingConf struct { conf } -// JailConfigEntry models an entry in a config file for adding a jail -type JailConfigEntry struct { - Route string `yaml:"route"` - Jail Jail `yaml:"jail"` -} - -// JailConfig -type JailConfig struct { - Jails []JailConfigEntry `yaml:"jails"` -} - -// RouteRateLimitConfigEntry models an entry in a config file for adding route rate limits. -type RouteRateLimitConfigEntry struct { - Route string `yaml:"route"` - Limit Limit `yaml:"limit"` -} - -// RouteRateLimitConfig models the config file for adding route rate limits. -type RouteRateLimitConfig struct { - RouteRatelimits []RouteRateLimitConfigEntry `yaml:"route_rate_limits"` -} - func (rs *RedisConfStore) GetWhitelist() []net.IPNet { rs.conf.RLock() defer rs.conf.RUnlock() @@ -287,229 +292,148 @@ func (rs *RedisConfStore) RemoveBlacklistCidrs(cidrs []net.IPNet) error { return nil } -// SetJails configures all of the jails -func (rs *RedisConfStore) SetJails(jails map[url.URL]Jail) error { - for url, jail := range jails { - route := url.EscapedPath() - limitCountStr := strconv.FormatUint(jail.Limit.Count, 10) - limitDurationStr := jail.Limit.Duration.String() - limitEnabledStr := strconv.FormatBool(jail.Limit.Enabled) - jailBanDuration := jail.BanDuration.String() - - pipe := rs.redis.TxPipeline() - pipe.HSet(redisJailLimitsCountKey, route, limitCountStr) - pipe.HSet(redisJailLimitsDurationKey, route, limitDurationStr) - pipe.HSet(redisJailLimitsEnabledKey, route, limitEnabledStr) - pipe.HSet(redisJailBanDurationKey, route, jailBanDuration) - _, err := pipe.Exec() - if err != nil { - return err - } +func (rs *RedisConfStore) ApplyJailConfig(cfg JailConfig) error { + b, err := json.Marshal(cfg) + if err != nil { + return fmt.Errorf("error marshaling json: %v", err) } - return nil -} + pipe := rs.redis.TxPipeline() + pipe.HSet(redisJailsConfigKey, cfg.Name, string(b)) -func (rs *RedisConfStore) RemoveJails(urls []url.URL) error { - for _, url := range urls { - route := url.EscapedPath() - rs.logger.Debugf("Sending HDel for key %v field %v", redisJailLimitsCountKey, route) - rs.logger.Debugf("Sending HDel for key %v field %v", redisJailLimitsDurationKey, route) - rs.logger.Debugf("Sending HDel for key %v field %v", redisJailLimitsEnabledKey, route) - rs.logger.Debugf("Sending HDel for key %v field %v", redisJailBanDurationKey, route) - - pipe := rs.redis.TxPipeline() - pipe.HDel(redisJailLimitsCountKey, route) - pipe.HDel(redisJailLimitsDurationKey, route) - pipe.HDel(redisJailLimitsEnabledKey, route) - pipe.HDel(redisJailBanDurationKey, route) - _, err := pipe.Exec() - if err != nil { - return err - } + _, err = pipe.Exec() + if err != nil { + return err } return nil } +func (rs *RedisConfStore) DeleteJailConfig(name string) error { + rs.logger.Debugf("Sending HDel for key %v field %v", redisJailsConfigKey, name) + return rs.redis.HDel(redisJailsConfigKey, name).Err() +} + func (rs *RedisConfStore) GetJail(url url.URL) Jail { rs.conf.RLock() defer rs.conf.RUnlock() - jail, _ := rs.conf.jails[url] - return jail + conf, ok := rs.conf.jailsByPath[url.Path] + if !ok { + return Jail{} + } + return conf.Spec.Jail } -func (rs *RedisConfStore) FetchJail(url url.URL) (Jail, error) { +func (rs *RedisConfStore) FetchJailConfigs() []JailConfig { c := rs.pipelinedFetchConf() - count, _ := c.jailLimitCounts[url] - duration, _ := c.jailLimitDurations[url] - enabled, _ := c.jailLimitsEnabled[url] - banDuration, _ := c.jailBanDuration[url] - if count == nil || duration == nil || enabled == nil || banDuration == nil { - return Jail{}, fmt.Errorf("jail not found") - } - return Jail{ - Limit: Limit{ - Count: *count, - Duration: *duration, - Enabled: *enabled, - }, - BanDuration: *banDuration, - }, nil + cfgs := make([]JailConfig, 0, len(c.jailsByName)) + for _, cfg := range c.jailsByName { + cfgs = append(cfgs, cfg) + } + return cfgs } -func (rs *RedisConfStore) FetchJails() (map[url.URL]Jail, error) { +func (rs *RedisConfStore) FetchJailConfig(name string) (JailConfig, error) { c := rs.pipelinedFetchConf() - jails := make(map[url.URL]Jail) - - for url, count := range c.jailLimitCounts { - duration, _ := c.jailLimitDurations[url] - enabled, _ := c.jailLimitsEnabled[url] - banDuration, _ := c.jailBanDuration[url] - if count != nil && duration != nil && enabled != nil && banDuration != nil { - jails[url] = Jail{ - Limit: Limit{ - Duration: *duration, - Enabled: *enabled, - Count: *count, - }, - BanDuration: *banDuration, - } - } + cfg, ok := c.jailsByName[name] + if !ok { + return JailConfig{}, fmt.Errorf("no jail exists with name %v", name) } - return jails, nil + return cfg, nil } -// SetRouteRateLimits set the limit for each route. -// If the route limit is already defined in the store, it will be overwritten. -func (rs *RedisConfStore) SetRouteRateLimits(routeRateLimits map[url.URL]Limit) error { - for url, limit := range routeRateLimits { - route := url.EscapedPath() - limitCountStr := strconv.FormatUint(limit.Count, 10) - limitDurationStr := limit.Duration.String() - limitEnabledStr := strconv.FormatBool(limit.Enabled) - - pipe := rs.redis.TxPipeline() - pipe.HSet(redisRouteRateLimitsCountKey, route, limitCountStr) - pipe.HSet(redisRouteRateLimitsDurationKey, route, limitDurationStr) - pipe.HSet(redisRouteRateLimitsEnabledKey, route, limitEnabledStr) - _, err := pipe.Exec() - if err != nil { - return err - } +func (rs *RedisConfStore) ApplyRateLimitConfig(cfg RateLimitConfig) error { + b, err := json.Marshal(cfg) + if err != nil { + return fmt.Errorf("error marshaling json: %v", err) } - return nil + pipe := rs.redis.TxPipeline() + pipe.HSet(redisRateLimitsConfigKey, cfg.Name, string(b)) + _, err = pipe.Exec() + return err } -// RemoveRouteRateLimits will iterate through the slice and delete the limit for each route. -// If one of the routes has not be set in the conf store, it will be treated by Redis as an empty hash and it will effectively be a no-op. -// This function will continue to iterate through the slice and delete the remaining routes contained in the slice. -func (rs *RedisConfStore) RemoveRouteRateLimits(urls []url.URL) error { - for _, url := range urls { - route := url.EscapedPath() - rs.logger.Debugf("Sending HDel for key %v field %v", redisRouteRateLimitsCountKey, route) - rs.logger.Debugf("Sending HDel for key %v field %v", redisRouteRateLimitsDurationKey, route) - rs.logger.Debugf("Sending HDel for key %v field %v", redisRouteRateLimitsEnabledKey, route) - - pipe := rs.redis.TxPipeline() - pipe.HDel(redisRouteRateLimitsCountKey, route) - pipe.HDel(redisRouteRateLimitsDurationKey, route) - pipe.HDel(redisRouteRateLimitsEnabledKey, route) - _, err := pipe.Exec() - if err != nil { - return err - } - } - return nil +func (rs *RedisConfStore) DeleteRateLimitConfig(name string) error { + rs.logger.Debugf("Sending HDel for key %v field %v", redisJailsConfigKey, name) + return rs.redis.HDel(redisRateLimitsConfigKey, name).Err() } // GetRouteRateLimit gets a route limit from the local cache. func (rs *RedisConfStore) GetRouteRateLimit(url url.URL) Limit { rs.conf.RLock() defer rs.conf.RUnlock() - limit, _ := rs.conf.routeRateLimits[url] - return limit + conf, ok := rs.conf.rateLimitsByPath[url.Path] + if !ok { + return Limit{} + } + return conf.Spec.Limit } -// FetchRouteRateLimit fetches the route limit from the conf store. -func (rs *RedisConfStore) FetchRouteRateLimit(url url.URL) (Limit, error) { +func (rs *RedisConfStore) FetchRateLimitConfigs() []RateLimitConfig { c := rs.pipelinedFetchConf() - count, _ := c.routeLimitCounts[url] - duration, _ := c.routeLimitDurations[url] - enabled, _ := c.routeRateLimitsEnabled[url] - if count == nil || duration == nil || enabled == nil { - return Limit{}, fmt.Errorf("route limit not found") + cfgs := make([]RateLimitConfig, 0, len(c.rateLimitsByName)) + for _, cfg := range c.rateLimitsByName { + cfgs = append(cfgs, cfg) } - return Limit{Count: *count, Duration: *duration, Enabled: *enabled}, nil + return cfgs } -// FetchRouteRateLimits fetches all of the route rate limits from the conf store. -func (rs *RedisConfStore) FetchRouteRateLimits() (map[url.URL]Limit, error) { +func (rs *RedisConfStore) FetchRateLimitConfig(name string) (RateLimitConfig, error) { c := rs.pipelinedFetchConf() - res := make(map[url.URL]Limit) - - for url, count := range c.routeLimitCounts { - duration, _ := c.routeLimitDurations[url] - enabled, _ := c.routeRateLimitsEnabled[url] - if duration != nil && enabled != nil && count != nil { - res[url] = Limit{ - Count: *count, - Duration: *duration, - Enabled: *enabled, - } - } + cfg, ok := c.rateLimitsByName[name] + if !ok { + return RateLimitConfig{}, fmt.Errorf("no rate limit exists with name %v", name) } - return res, nil + return cfg, nil +} + +func (rs *RedisConfStore) ApplyGlobalRateLimitConfig(cfg GlobalRateLimitConfig) error { + b, err := json.Marshal(cfg) + if err != nil { + return fmt.Errorf("error marshaling json: %v", err) + } + pipe := rs.redis.TxPipeline() + pipe.Set(redisGlobalRateLimitConfigKey, string(b), 0) + _, err = pipe.Exec() + return err } func (rs *RedisConfStore) GetLimit() Limit { rs.conf.RLock() defer rs.conf.RUnlock() - - return rs.conf.limit + return rs.conf.globalRateLimit.Spec.Limit } -func (rs *RedisConfStore) FetchLimit() (Limit, error) { +func (rs *RedisConfStore) FetchGlobalRateLimitConfig() (GlobalRateLimitConfig, error) { c := rs.pipelinedFetchConf() - if c.limitCount == nil || c.limitDuration == nil || c.limitEnabled == nil { - return Limit{}, fmt.Errorf("error fetching limit") + if c.globalRateLimit == nil { + return GlobalRateLimitConfig{}, fmt.Errorf("error fetching global rate limit config") } - - return Limit{Count: *c.limitCount, Duration: *c.limitDuration, Enabled: *c.limitEnabled}, nil + return *c.globalRateLimit, nil } -func (rs *RedisConfStore) SetLimit(limit Limit) error { - limitCountStr := strconv.FormatUint(limit.Count, 10) - limitDurationStr := limit.Duration.String() - limitEnabledStr := strconv.FormatBool(limit.Enabled) - +func (rs *RedisConfStore) ApplyGlobalSettingsConfig(cfg GlobalSettingsConfig) error { + b, err := json.Marshal(cfg) + if err != nil { + return fmt.Errorf("error marshaling json: %v", err) + } pipe := rs.redis.TxPipeline() - pipe.Set(redisLimitCountKey, limitCountStr, 0) - pipe.Set(redisLimitDurationKey, limitDurationStr, 0) - pipe.Set(redisLimitEnabledKey, limitEnabledStr, 0) - - _, err := pipe.Exec() + pipe.Set(redisGlobalSettingsConfigKey, string(b), 0) + _, err = pipe.Exec() return err } func (rs *RedisConfStore) GetReportOnly() bool { rs.conf.RLock() defer rs.conf.RUnlock() - - return rs.conf.reportOnly + return rs.conf.globalSettings.Spec.ReportOnly } -func (rs *RedisConfStore) FetchReportOnly() (bool, error) { +func (rs *RedisConfStore) FetchGlobalSettingsConfig() (GlobalSettingsConfig, error) { c := rs.pipelinedFetchConf() - if c.reportOnly == nil { - return false, fmt.Errorf("error fetching report only flag") + if c.globalSettings == nil { + return GlobalSettingsConfig{}, fmt.Errorf("error fetching global settings config") } - - return *c.reportOnly, nil -} - -func (rs *RedisConfStore) SetReportOnly(reportOnly bool) error { - reportOnlyStr := strconv.FormatBool(reportOnly) - return rs.redis.Set(redisReportOnlyKey, reportOnlyStr, 0).Err() + return *c.globalSettings, nil } func (rs *RedisConfStore) RunSync(updateInterval time.Duration, stop <-chan struct{}) { @@ -545,31 +469,33 @@ func (rs *RedisConfStore) init() error { } } - if rs.redis.Get(redisLimitEnabledKey).Err() == redis.Nil || - rs.redis.Get(redisLimitDurationKey).Err() == redis.Nil || - rs.redis.Get(redisLimitCountKey).Err() == redis.Nil { - rs.logger.Debug("Initializing limit") - if err := rs.SetLimit(rs.conf.limit); err != nil { - return errors.Wrap(err, "error initializing limit") + if rs.redis.Get(redisGlobalRateLimitConfigKey).Err() == redis.Nil { + if err := rs.ApplyGlobalRateLimitConfig(rs.conf.globalRateLimit); err != nil { + return errors.Wrap(err, "error initializing global settings") } } - if rs.redis.Get(redisReportOnlyKey).Err() == redis.Nil { - rs.logger.Debug("Initializing report only") - if err := rs.SetReportOnly(rs.conf.reportOnly); err != nil { - return errors.Wrap(err, "error initializing report only") + if rs.redis.Get(redisGlobalSettingsConfigKey).Err() == redis.Nil { + if err := rs.ApplyGlobalSettingsConfig(rs.conf.globalSettings); err != nil { + return errors.Wrap(err, "error initializing global settings") } } - if rs.redis.Get(redisRouteRateLimitsEnabledKey).Err() == redis.Nil || - rs.redis.Get(redisRouteRateLimitsCountKey).Err() == redis.Nil || - rs.redis.Get(redisRouteRateLimitsDurationKey).Err() == redis.Nil { - rs.logger.Debug("Initializing route rate limits") - if err := rs.SetRouteRateLimits(rs.conf.routeRateLimits); err != nil { - return errors.Wrap(err, "error initializing route rate limits") + if rs.redis.Get(redisRateLimitsConfigKey).Err() == redis.Nil { + for _, config := range rs.conf.rateLimitsByName { + if err := rs.ApplyRateLimitConfig(config); err != nil { + return errors.Wrap(err, "error initializing rate limit") + } } } + if rs.redis.Get(redisJailsConfigKey).Err() == redis.Nil { + for _, config := range rs.conf.jailsByName { + if err := rs.ApplyJailConfig(config); err != nil { + return errors.Wrap(err, "error initializing jail") + } + } + } rs.logger.Debug("Success initializing conf") return nil } @@ -592,74 +518,51 @@ func (rs *RedisConfStore) UpdateCachedConf() { rs.conf.blacklist = fetched.blacklist } - if fetched.limitCount != nil && - fetched.limitDuration != nil && - fetched.limitEnabled != nil { - rs.conf.limit.Count = *fetched.limitCount - rs.conf.limit.Duration = *fetched.limitDuration - rs.conf.limit.Enabled = *fetched.limitEnabled + if fetched.globalRateLimit != nil { + rs.conf.globalRateLimit = *fetched.globalRateLimit + rs.reporter.CurrentGlobalLimit(rs.conf.globalRateLimit.Spec.Limit) } - if fetched.reportOnly != nil { - rs.conf.reportOnly = *fetched.reportOnly + if fetched.globalSettings != nil { + rs.conf.globalSettings = *fetched.globalSettings + rs.reporter.CurrentReportOnlyMode(rs.conf.globalSettings.Spec.ReportOnly) } - rs.conf.routeRateLimits = make(map[url.URL]Limit, len(fetched.routeLimitCounts)) - for url, count := range fetched.routeLimitCounts { - duration, _ := fetched.routeLimitDurations[url] - enabled, _ := fetched.routeRateLimitsEnabled[url] - if duration != nil && enabled != nil && count != nil { - l := Limit{ - Count: *count, - Duration: *duration, - Enabled: *enabled, - } - rs.conf.routeRateLimits[url] = l - rs.reporter.CurrentRouteLimit(url.Path, l) - } + rs.conf.rateLimitsByName = make(map[string]RateLimitConfig) + rs.conf.rateLimitsByPath = make(map[string]RateLimitConfig) + + for name, config := range fetched.rateLimitsByName { + rs.conf.rateLimitsByName[name] = config + path := config.Spec.Conditions.Path + rs.conf.rateLimitsByPath[path] = config + rs.reporter.CurrentRouteLimit(path, config.Spec.Limit) } - rs.conf.jails = make(map[url.URL]Jail, len(fetched.jailLimitCounts)) - for url, count := range fetched.jailLimitCounts { - duration, _ := fetched.jailLimitDurations[url] - enabled, _ := fetched.jailLimitsEnabled[url] - banDuration, _ := fetched.jailBanDuration[url] - if duration != nil && enabled != nil && count != nil && banDuration != nil { - j := Jail{ - Limit: Limit{ - Count: *count, - Duration: *duration, - Enabled: *enabled, - }, - BanDuration: *banDuration, - } - rs.conf.jails[url] = j - rs.reporter.CurrentRouteJail(url.Path, j) - } + rs.conf.jailsByName = make(map[string]JailConfig, len(fetched.jailsByName)) + rs.conf.jailsByPath = make(map[string]JailConfig, len(fetched.jailsByName)) + + for name, config := range fetched.jailsByName { + rs.conf.jailsByName[name] = config + path := config.Spec.Conditions.Path + rs.conf.jailsByPath[path] = config + rs.reporter.CurrentRouteJail(path, config.Spec.Jail) } - rs.reporter.CurrentGlobalLimit(rs.conf.limit) rs.reporter.CurrentWhitelist(rs.conf.whitelist) rs.reporter.CurrentBlacklist(rs.conf.blacklist) - rs.reporter.CurrentReportOnlyMode(rs.conf.reportOnly) rs.reporter.CurrentPrisoners(rs.conf.prisoners.length()) rs.logger.Debug("Updated conf") } type fetchConf struct { - whitelist []net.IPNet - blacklist []net.IPNet - limitCount *uint64 - limitDuration *time.Duration - limitEnabled *bool - reportOnly *bool - routeLimitDurations map[url.URL]*time.Duration - routeLimitCounts map[url.URL]*uint64 - routeRateLimitsEnabled map[url.URL]*bool - jailLimitDurations map[url.URL]*time.Duration - jailLimitCounts map[url.URL]*uint64 - jailLimitsEnabled map[url.URL]*bool - jailBanDuration map[url.URL]*time.Duration + whitelist []net.IPNet + blacklist []net.IPNet + globalRateLimit *GlobalRateLimitConfig + globalSettings *GlobalSettingsConfig + rateLimitsByName map[string]RateLimitConfig + rateLimitsByPath map[string]RateLimitConfig + jailsByName map[string]JailConfig + jailsByPath map[string]JailConfig } func (rs *RedisConfStore) obtainRedisPrisonersLock() (*redislock.Lock, error) { @@ -682,238 +585,204 @@ func (rs *RedisConfStore) releaseRedisPrisonersLock(lock *redislock.Lock, start rs.reporter.RedisReleaseLock(time.Since(start), err != nil) } -func (rs *RedisConfStore) pipelinedFetchConf() fetchConf { - newConf := fetchConf{ - routeLimitDurations: make(map[url.URL]*time.Duration), - routeLimitCounts: make(map[url.URL]*uint64), - routeRateLimitsEnabled: make(map[url.URL]*bool), - jailLimitDurations: make(map[url.URL]*time.Duration), - jailLimitCounts: make(map[url.URL]*uint64), - jailLimitsEnabled: make(map[url.URL]*bool), - jailBanDuration: make(map[url.URL]*time.Duration), +func (fc *fetchConf) setWhitelist(cmd *redis.StringSliceCmd, logger logrus.FieldLogger) { + whitelistStrs, err := cmd.Result() + if err != nil { + logger.WithError(err).Warnf("error send HKEYS for key %v", redisIPWhitelistKey) + return } + fc.whitelist = IPNetsFromStrings(whitelistStrs, logger) +} - rs.logger.Debugf("Sending HKEYS for key %v", redisIPWhitelistKey) - rs.logger.Debugf("Sending HKEYS for key %v", redisIPBlacklistKey) - rs.logger.Debugf("Sending GET for key %v", redisLimitCountKey) - rs.logger.Debugf("Sending GET for key %v", redisLimitDurationKey) - rs.logger.Debugf("Sending GET for key %v", redisLimitEnabledKey) - rs.logger.Debugf("Sending GET for key %v", redisReportOnlyKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisRouteRateLimitsCountKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisRouteRateLimitsDurationKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisRouteRateLimitsEnabledKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisJailLimitsCountKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisJailLimitsDurationKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisJailLimitsEnabledKey) - rs.logger.Debugf("Sending HGETALL for key %v", redisJailBanDurationKey) - - pipe := rs.redis.Pipeline() - defer pipe.Close() - whitelistKeysCmd := pipe.HKeys(redisIPWhitelistKey) - blacklistKeysCmd := pipe.HKeys(redisIPBlacklistKey) - limitCountCmd := pipe.Get(redisLimitCountKey) - limitDurationCmd := pipe.Get(redisLimitDurationKey) - limitEnabledCmd := pipe.Get(redisLimitEnabledKey) - reportOnlyCmd := pipe.Get(redisReportOnlyKey) - routeRateLimitsCountCmd := pipe.HGetAll(redisRouteRateLimitsCountKey) - routeRateLimitsDurationCmd := pipe.HGetAll(redisRouteRateLimitsDurationKey) - routeRateLimitsEnabledCmd := pipe.HGetAll(redisRouteRateLimitsEnabledKey) - jailLimitCountCmd := pipe.HGetAll(redisJailLimitsCountKey) - jailLimitDurationCmd := pipe.HGetAll(redisJailLimitsDurationKey) - jailLimitEnabledCmd := pipe.HGetAll(redisJailLimitsEnabledKey) - jailBanDurationCmd := pipe.HGetAll(redisJailBanDurationKey) - - pipe.Exec() - - if whitelistStrs, err := whitelistKeysCmd.Result(); err == nil { - newConf.whitelist = IPNetsFromStrings(whitelistStrs, rs.logger) - } else { - rs.logger.WithError(err).Warnf("error send HKEYS for key %v", redisIPWhitelistKey) +func (fc *fetchConf) setBlacklist(cmd *redis.StringSliceCmd, logger logrus.FieldLogger) { + blacklistStrs, err := cmd.Result() + if err != nil { + logger.WithError(err).Warnf("error send HKEYS for key %v", redisIPWhitelistKey) + return } + fc.blacklist = IPNetsFromStrings(blacklistStrs, logger) +} - if blacklistStrs, err := blacklistKeysCmd.Result(); err == nil { - newConf.blacklist = IPNetsFromStrings(blacklistStrs, rs.logger) - } else { - rs.logger.WithError(err).Warnf("error send HKEYS for key %v", redisIPWhitelistKey) +func (fc *fetchConf) setGlobalRateLimit(cmd *redis.StringCmd, logger logrus.FieldLogger) { + b, err := cmd.Bytes() + if err != nil { + if err != redis.Nil { + logger.WithError(err).Warnf("error sending GET for key %v", redisGlobalRateLimitConfigKey) + } + return } - - if limitCount, err := limitCountCmd.Uint64(); err == nil { - newConf.limitCount = &limitCount - } else { - rs.logger.WithError(err).Warnf("error sending GET for key %v", redisLimitCountKey) + cfg := GlobalRateLimitConfig{} + if err := json.Unmarshal(b, &cfg); err != nil { + logger.WithError(err).Warnf("error unmarshaling json for key %v", redisGlobalRateLimitConfigKey) + return } - - if limitDurationStr, err := limitDurationCmd.Result(); err == nil { - limitDuration, err := time.ParseDuration(limitDurationStr) - if err != nil { - rs.logger.WithError(err).Warnf("error parsing limit duration") - } else { - newConf.limitDuration = &limitDuration - } - } else { - rs.logger.WithError(err).Errorf("error sending GET for key %v", redisLimitDurationKey) + if cfg.Version != GlobalRateLimitConfigVersion { + logger.Warnf( + "stored global rate limit config version %v does not match current version %v; skipping", + cfg.Version, + GlobalRateLimitConfigVersion, + ) + return } + fc.globalRateLimit = &cfg +} - if limitEnabledStr, err := limitEnabledCmd.Result(); err == nil { - limitEnabled, err := strconv.ParseBool(limitEnabledStr) - if err != nil { - rs.logger.WithError(err).Warnf("error parsing limit enabled") - } else { - newConf.limitEnabled = &limitEnabled +func (fc *fetchConf) setGlobalSettings(cmd *redis.StringCmd, logger logrus.FieldLogger) { + b, err := cmd.Bytes() + if err != nil { + if err != redis.Nil { + logger.WithError(err).Warnf("error sending GET for key %v", redisGlobalSettingsConfigKey) } - } else { - rs.logger.WithError(err).Errorf("error sending GET for key %v", redisLimitEnabledKey) + return + } + cfg := GlobalSettingsConfig{} + if err := json.Unmarshal(b, &cfg); err != nil { + logger.WithError(err).Warnf("error unmarshaling json for key %v", redisGlobalSettingsConfigKey) + return + } + if cfg.Version != GlobalSettingsConfigVersion { + logger.Warnf( + "stored global settings config version %v does not match current version %v; skipping", + cfg.Version, + GlobalSettingsConfigVersion, + ) + return } + fc.globalSettings = &cfg +} - if reportOnlyStr, err := reportOnlyCmd.Result(); err == nil { - reportOnly, err := strconv.ParseBool(reportOnlyStr) - if err != nil { - rs.logger.WithError(err).Warnf("error parsing report only") - } else { - newConf.reportOnly = &reportOnly - } - } else { - rs.logger.WithError(err).Warnf("error sending GET for key %v", redisReportOnlyKey) - } - - if routeRateLimitsCounts, err := routeRateLimitsCountCmd.Result(); err == nil { - for route, countStr := range routeRateLimitsCounts { - parsedURL, urlParseErr := url.Parse(route) - count, intParseErr := strconv.ParseUint(countStr, 10, 64) - if urlParseErr != nil || intParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(intParseErr).Warnf("error parsing route limit duration for %v", route) - } else { - newConf.routeLimitCounts[*parsedURL] = &count - } - } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisRouteRateLimitsCountKey) - } - - if routeRateLimitsDurations, err := routeRateLimitsDurationCmd.Result(); err == nil { - for route, durationStr := range routeRateLimitsDurations { - parsedURL, urlParseErr := url.Parse(route) - duration, durationParseErr := time.ParseDuration(durationStr) - if urlParseErr != nil || durationParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(durationParseErr).Warnf("error parsing route limit duration for %v", route) - } else { - newConf.routeLimitDurations[*parsedURL] = &duration - } +func (fc *fetchConf) setRateLimits(cmd *redis.StringStringMapCmd, logger logrus.FieldLogger) { + rateLimits, err := cmd.Result() + if err != nil { + if err != redis.Nil { + logger.WithError(err).Warnf("error sending GET for key %v", redisRateLimitsConfigKey) } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisRouteRateLimitsDurationKey) - } - - if routeRateLimitsEnabled, err := routeRateLimitsEnabledCmd.Result(); err == nil { - for route, enabled := range routeRateLimitsEnabled { - parsedURL, urlParseErr := url.Parse(route) - enabled, boolParseErr := strconv.ParseBool(enabled) - if urlParseErr != nil || boolParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(boolParseErr).Warnf("error parsing route limit enabled for %v", route) - } else { - newConf.routeRateLimitsEnabled[*parsedURL] = &enabled - } + return + } + for name, configString := range rateLimits { + cfg := RateLimitConfig{} + if err := json.Unmarshal([]byte(configString), &cfg); err != nil { + logger.WithError(err).Warnf("error unmarshaling json for key %v value %v", redisRateLimitsConfigKey, name) + continue } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisRouteRateLimitsEnabledKey) - } - - if jailLimitCounts, err := jailLimitCountCmd.Result(); err == nil { - for route, countStr := range jailLimitCounts { - parsedURL, urlParseErr := url.Parse(route) - count, intParseErr := strconv.ParseUint(countStr, 10, 64) - if urlParseErr != nil || intParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(intParseErr).Warnf("error parsing jail limit count for %v", route) - } else { - newConf.jailLimitCounts[*parsedURL] = &count - } + if cfg.Version != RateLimitConfigVersion { + logger.Warnf( + "stored rate limit config version %v does not match current version %v; skipping", + cfg.Version, + RateLimitConfigVersion, + ) + continue } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisJailLimitsCountKey) - } - - if jailLimitDurations, err := jailLimitDurationCmd.Result(); err == nil { - for route, durationStr := range jailLimitDurations { - parsedURL, urlParseErr := url.Parse(route) - duration, durationParseErr := time.ParseDuration(durationStr) - if urlParseErr != nil || durationParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(durationParseErr).Warnf("error parsing jail limit duration for %v", route) - } else { - newConf.jailLimitDurations[*parsedURL] = &duration - } + fc.rateLimitsByName[name] = cfg + fc.rateLimitsByPath[cfg.Spec.Conditions.Path] = cfg + } +} + +func (fc *fetchConf) setJails(cmd *redis.StringStringMapCmd, logger logrus.FieldLogger) { + jails, err := cmd.Result() + if err != nil { + if err != redis.Nil { + logger.WithError(err).Warnf("error sending GET for key %v", redisJailsConfigKey) } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisJailLimitsDurationKey) - } - - if jailLimitsEnabled, err := jailLimitEnabledCmd.Result(); err == nil { - for route, enabled := range jailLimitsEnabled { - parsedURL, urlParseErr := url.Parse(route) - enabled, boolParseErr := strconv.ParseBool(enabled) - if urlParseErr != nil || boolParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(boolParseErr).Warnf("error parsing route limit enabled for %v", route) - } else { - newConf.jailLimitsEnabled[*parsedURL] = &enabled - } + return + } + for name, configString := range jails { + cfg := JailConfig{} + if err := json.Unmarshal([]byte(configString), &cfg); err != nil { + logger.WithError(err).Warnf("error unmarshaling json for key %v value %v", redisJailsConfigKey, name) + continue } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisJailLimitsEnabledKey) - } - - if jailBanDurations, err := jailBanDurationCmd.Result(); err == nil { - for route, durationStr := range jailBanDurations { - parsedURL, urlParseErr := url.Parse(route) - duration, durationParseErr := time.ParseDuration(durationStr) - if urlParseErr != nil || durationParseErr != nil { - rs.logger.WithError(urlParseErr).WithError(durationParseErr).Warnf("error parsing jail ban duration for %v", route) - } else { - newConf.jailBanDuration[*parsedURL] = &duration - } + if cfg.Version != JailConfigVersion { + logger.Warnf( + "stored jail config version %v does not match current version %v; skipping", + cfg.Version, + JailConfigVersion, + ) + continue } - } else { - rs.logger.WithError(err).Warnf("error sending HGETALL for key %v", redisJailLimitsEnabledKey) + fc.jailsByName[name] = cfg + fc.jailsByPath[cfg.Spec.Conditions.Path] = cfg } +} +func (rs *RedisConfStore) fetchPrisoners() { lock, err := rs.obtainRedisPrisonersLock() + if err != nil { rs.logger.Errorf("error obtaining lock in pipelined fetch: %v", err) - return newConf + return } defer rs.releaseRedisPrisonersLock(lock, time.Now().UTC()) expiredPrisoners := []string{} - prisonersCmd := rs.redis.HGetAll(redisPrisonersKey) + cmd := rs.redis.HGetAll(redisPrisonersKey) // Note: In order to match the rest of the configuration, we purposely purge the prisoners regardless of whether we // can connect or update the data in Redis. This way, Guardian continues to "fail open" rs.conf.prisoners.purge() - if prisoners, err := prisonersCmd.Result(); err == nil { - for ip, prisonerJson := range prisoners { - var prisoner Prisoner - err := json.Unmarshal([]byte(prisonerJson), &prisoner) - if err != nil { - rs.logger.WithError(err).Warnf("unable to unmarshal json: %v", err) - continue - } - if time.Now().UTC().Before(prisoner.Expiry) { - rs.logger.Debugf("adding %v to prisoners\n", prisoner.IP.String()) - rs.conf.prisoners.addPrisonerFromStore(prisoner) - } else { - rs.logger.Debugf("removing %v from prisoners\n", prisoner.IP.String()) - expiredPrisoners = append(expiredPrisoners, ip) - } - } - } else { + prisoners, err := cmd.Result() + if err != nil { rs.logger.Errorf("error getting prisoners from redis: %v", err) + return } - - if len(expiredPrisoners) > 0 { - removeExpiredPrisonersCmd := rs.redis.HDel(redisPrisonersKey, expiredPrisoners...) - if n, err := removeExpiredPrisonersCmd.Result(); err == nil { - rs.logger.Debugf("removed %d expired prisoners: %v", n, expiredPrisoners) - } else { - rs.logger.Errorf("error removing expired prisoners: %v", err) + for ip, prisonerJSON := range prisoners { + var prisoner Prisoner + err := json.Unmarshal([]byte(prisonerJSON), &prisoner) + if err != nil { + rs.logger.WithError(err).Warnf("unable to unmarshal json: %v", err) + continue + } + if time.Now().UTC().Before(prisoner.Expiry) { + rs.logger.Debugf("adding %v to prisoners\n", prisoner.IP.String()) + rs.conf.prisoners.addPrisonerFromStore(prisoner) + continue } + rs.logger.Debugf("removing %v from prisoners\n", prisoner.IP.String()) + expiredPrisoners = append(expiredPrisoners, ip) } + if len(expiredPrisoners) == 0 { + return + } + removeExpiredPrisonersCmd := rs.redis.HDel(redisPrisonersKey, expiredPrisoners...) + n, err := removeExpiredPrisonersCmd.Result() + if err != nil { + rs.logger.Errorf("error removing expired prisoners: %v", err) + return + } + rs.logger.Debugf("removed %d expired prisoners: %v", n, expiredPrisoners) +} + +func (rs *RedisConfStore) pipelinedFetchConf() fetchConf { + newConf := fetchConf{ + rateLimitsByName: make(map[string]RateLimitConfig), + rateLimitsByPath: make(map[string]RateLimitConfig), + jailsByName: make(map[string]JailConfig), + jailsByPath: make(map[string]JailConfig), + } + + rs.logger.Debugf("Sending GET for key %v", redisGlobalRateLimitConfigKey) + rs.logger.Debugf("Sending GET for key %v", redisGlobalSettingsConfigKey) + rs.logger.Debugf("Sending HGETALL for key %v", redisRateLimitsConfigKey) + rs.logger.Debugf("Sending HGETALL for key %v", redisJailsConfigKey) + rs.logger.Debugf("Sending HKEYS for key %v", redisIPWhitelistKey) + rs.logger.Debugf("Sending HKEYS for key %v", redisIPBlacklistKey) + + pipe := rs.redis.Pipeline() + defer pipe.Close() + whitelistKeysCmd := pipe.HKeys(redisIPWhitelistKey) + blacklistKeysCmd := pipe.HKeys(redisIPBlacklistKey) + globalRateLimitCmd := pipe.Get(redisGlobalRateLimitConfigKey) + globalSettingsCmd := pipe.Get(redisGlobalSettingsConfigKey) + rateLimitsCmd := pipe.HGetAll(redisRateLimitsConfigKey) + jailsCmd := pipe.HGetAll(redisJailsConfigKey) + pipe.Exec() + + newConf.setWhitelist(whitelistKeysCmd, rs.logger) + newConf.setBlacklist(blacklistKeysCmd, rs.logger) + newConf.setGlobalRateLimit(globalRateLimitCmd, rs.logger) + newConf.setGlobalSettings(globalSettingsCmd, rs.logger) + newConf.setRateLimits(rateLimitsCmd, rs.logger) + newConf.setJails(jailsCmd, rs.logger) + rs.fetchPrisoners() return newConf } diff --git a/pkg/guardian/redis_conf_store_test.go b/pkg/guardian/redis_conf_store_test.go index 26a0efe..68b93b4 100644 --- a/pkg/guardian/redis_conf_store_test.go +++ b/pkg/guardian/redis_conf_store_test.go @@ -81,24 +81,66 @@ func TestConfStoreFetchesSets(t *testing.T) { expectedWhitelist := parseCIDRs([]string{"10.0.0.1/8"}) expectedBlacklist := parseCIDRs([]string{"12.0.0.1/8"}) - expectedLimit := Limit{Count: 20, Duration: time.Second, Enabled: true} - expectedReportOnly := true - fooBarURL, _ := url.Parse("/foo/bar") - expectedRouteRateLimits := map[url.URL]Limit{ - *fooBarURL: Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + expectedGlobalRateLimit := GlobalRateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalRateLimitConfigVersion, + Kind: GlobalRateLimitConfigKind, + Name: GlobalRateLimitConfigKind, }, - } - expectedJails := map[url.URL]Jail{ - *fooBarURL: { + Spec: GlobalRateLimitSpec{ Limit: Limit{ - Count: 10, - Duration: time.Minute, + Count: 20, + Duration: time.Second, Enabled: true, }, - BanDuration: time.Hour, + }, + } + expectedGlobalSettings := GlobalSettingsConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalSettingsConfigVersion, + Kind: GlobalSettingsConfigKind, + Name: GlobalSettingsConfigKind, + }, + Spec: GlobalSettingsSpec{ + ReportOnly: true, + }, + } + expectedRateLimits := []RateLimitConfig{ + RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: "/foo/bar", + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + Conditions: Conditions{ + Path: "/foo/bar", + }, + }, + }, + } + expectedJails := []JailConfig{ + JailConfig{ + ConfigMetadata: ConfigMetadata{ + Version: JailConfigVersion, + Kind: JailConfigKind, + Name: "/foo/bar", + }, + Spec: JailSpec{ + Jail: Jail{ + Limit: Limit{ + Count: 10, + Duration: time.Minute, + Enabled: true, + }, + BanDuration: time.Hour, + }, + }, }, } @@ -110,20 +152,24 @@ func TestConfStoreFetchesSets(t *testing.T) { t.Fatalf("got error: %v", err) } - if err := c.SetLimit(expectedLimit); err != nil { + if err := c.ApplyGlobalRateLimitConfig(expectedGlobalRateLimit); err != nil { t.Fatalf("got error: %v", err) } - if err := c.SetReportOnly(expectedReportOnly); err != nil { + if err := c.ApplyGlobalSettingsConfig(expectedGlobalSettings); err != nil { t.Fatalf("got error: %v", err) } - if err := c.SetRouteRateLimits(expectedRouteRateLimits); err != nil { - t.Fatalf("got error: %v", err) + for _, config := range expectedRateLimits { + if err := c.ApplyRateLimitConfig(config); err != nil { + t.Fatalf("got error: %v", err) + } } - if err := c.SetJails(expectedJails); err != nil { - t.Fatalf("got error: %v", err) + for _, config := range expectedJails { + if err := c.ApplyJailConfig(config); err != nil { + t.Fatalf("got error: %v", err) + } } gotWhitelist, err := c.FetchWhitelist() @@ -136,25 +182,18 @@ func TestConfStoreFetchesSets(t *testing.T) { t.Fatalf("got error: %v", err) } - gotLimit, err := c.FetchLimit() + gotGlobalRateLimit, err := c.FetchGlobalRateLimitConfig() if err != nil { t.Fatalf("got error: %v", err) } - gotReportOnly, err := c.FetchReportOnly() + gotGlobalSettings, err := c.FetchGlobalSettingsConfig() if err != nil { t.Fatalf("got error: %v", err) } - gotRouteRateLimits, err := c.FetchRouteRateLimits() - if err != nil { - t.Fatalf("got error: %v", err) - } - - gotJails, err := c.FetchJails() - if err != nil { - t.Fatalf("got error: %v", err) - } + gotRateLimits := c.FetchRateLimitConfigs() + gotJails := c.FetchJailConfigs() if !cmp.Equal(gotWhitelist, expectedWhitelist) { t.Errorf("expected: %v received: %v", expectedWhitelist, gotWhitelist) @@ -164,16 +203,16 @@ func TestConfStoreFetchesSets(t *testing.T) { t.Errorf("expected: %v received: %v", expectedBlacklist, gotBlacklist) } - if gotLimit != expectedLimit { - t.Errorf("expected: %v received: %v", expectedLimit, gotLimit) + if gotGlobalRateLimit != expectedGlobalRateLimit { + t.Errorf("expected: %v received: %v", expectedGlobalRateLimit, gotGlobalRateLimit) } - if gotReportOnly != expectedReportOnly { - t.Errorf("expected: %v received: %v", expectedReportOnly, gotReportOnly) + if gotGlobalSettings != expectedGlobalSettings { + t.Errorf("expected: %v received: %v", expectedGlobalSettings, gotGlobalSettings) } - if !cmp.Equal(gotRouteRateLimits, expectedRouteRateLimits) { - t.Errorf("expected: %v received: %v", expectedRouteRateLimits, gotRouteRateLimits) + if !cmp.Equal(gotRateLimits, expectedRateLimits) { + t.Errorf("expected: %v received: %v", expectedRateLimits, gotRateLimits) } if !cmp.Equal(gotJails, expectedJails) { @@ -187,8 +226,30 @@ func TestConfStoreUpdateCacheConf(t *testing.T) { expectedWhitelist := parseCIDRs([]string{"10.0.0.1/8"}) expectedBlacklist := parseCIDRs([]string{"12.0.0.1/8"}) - expectedLimit := Limit{Count: 20, Duration: time.Second, Enabled: true} - expectedReportOnly := true + expectedGlobalRateLimit := GlobalRateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalRateLimitConfigVersion, + Kind: GlobalRateLimitConfigKind, + Name: GlobalRateLimitConfigKind, + }, + Spec: GlobalRateLimitSpec{ + Limit: Limit{ + Count: 20, + Duration: time.Second, + Enabled: true, + }, + }, + } + expectedGlobalSettings := GlobalSettingsConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalSettingsConfigVersion, + Kind: GlobalSettingsConfigKind, + Name: GlobalSettingsConfigKind, + }, + Spec: GlobalSettingsSpec{ + ReportOnly: true, + }, + } if err := c.AddWhitelistCidrs(expectedWhitelist); err != nil { t.Fatalf("got error: %v", err) @@ -198,11 +259,11 @@ func TestConfStoreUpdateCacheConf(t *testing.T) { t.Fatalf("got error: %v", err) } - if err := c.SetLimit(expectedLimit); err != nil { + if err := c.ApplyGlobalRateLimitConfig(expectedGlobalRateLimit); err != nil { t.Fatalf("got error: %v", err) } - if err := c.SetReportOnly(expectedReportOnly); err != nil { + if err := c.ApplyGlobalSettingsConfig(expectedGlobalSettings); err != nil { t.Fatalf("got error: %v", err) } @@ -221,12 +282,12 @@ func TestConfStoreUpdateCacheConf(t *testing.T) { t.Errorf("expected: %v received: %v", expectedBlacklist, gotBlacklist) } - if gotLimit != expectedLimit { - t.Errorf("expected: %v received: %v", expectedLimit, gotLimit) + if gotLimit != expectedGlobalRateLimit.Spec.Limit { + t.Errorf("expected: %v received: %v", expectedGlobalRateLimit.Spec.Limit, gotLimit) } - if gotReportOnly != expectedReportOnly { - t.Errorf("expected: %v received: %v", expectedReportOnly, gotReportOnly) + if gotReportOnly != expectedGlobalSettings.Spec.ReportOnly { + t.Errorf("expected: %v received: %v", expectedGlobalSettings.Spec.ReportOnly, gotReportOnly) } } @@ -236,8 +297,30 @@ func TestConfStoreRunUpdatesCache(t *testing.T) { expectedWhitelist := parseCIDRs([]string{"10.1.1.1/8"}) expectedBlacklist := parseCIDRs([]string{"11.1.1.1/8"}) - expectedLimit := Limit{Count: 40, Duration: time.Minute, Enabled: true} - expectedReportOnly := true + expectedGlobalRateLimit := GlobalRateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalRateLimitConfigVersion, + Kind: GlobalRateLimitConfigKind, + Name: GlobalRateLimitConfigKind, + }, + Spec: GlobalRateLimitSpec{ + Limit: Limit{ + Count: 40, + Duration: time.Minute, + Enabled: true, + }, + }, + } + expectedGlobalSettings := GlobalSettingsConfig{ + ConfigMetadata: ConfigMetadata{ + Version: GlobalSettingsConfigVersion, + Kind: GlobalSettingsConfigKind, + Name: GlobalSettingsConfigKind, + }, + Spec: GlobalSettingsSpec{ + ReportOnly: true, + }, + } if err := c.AddWhitelistCidrs(expectedWhitelist); err != nil { t.Fatalf("got error: %v", err) @@ -247,11 +330,11 @@ func TestConfStoreRunUpdatesCache(t *testing.T) { t.Fatalf("got error: %v", err) } - if err := c.SetLimit(expectedLimit); err != nil { + if err := c.ApplyGlobalRateLimitConfig(expectedGlobalRateLimit); err != nil { t.Fatalf("got error: %v", err) } - if err := c.SetReportOnly(expectedReportOnly); err != nil { + if err := c.ApplyGlobalSettingsConfig(expectedGlobalSettings); err != nil { t.Fatalf("got error: %v", err) } @@ -278,12 +361,12 @@ func TestConfStoreRunUpdatesCache(t *testing.T) { t.Errorf("expected: %v received: %v", expectedWhitelist, gotWhitelist) } - if gotLimit != expectedLimit { - t.Errorf("expected: %v received: %v", expectedLimit, gotLimit) + if gotLimit != expectedGlobalRateLimit.Spec.Limit { + t.Errorf("expected: %v received: %v", expectedGlobalRateLimit.Spec.Limit, gotLimit) } - if gotReportOnly != expectedReportOnly { - t.Errorf("expected: %v received: %v", expectedReportOnly, gotReportOnly) + if gotReportOnly != expectedGlobalSettings.Spec.ReportOnly { + t.Errorf("expected: %v received: %v", expectedGlobalSettings.Spec.ReportOnly, gotReportOnly) } } @@ -338,55 +421,80 @@ func TestConfStoreRemoveBlacklistCidr(t *testing.T) { func TestConfStoreAddRemoveRouteRateLimits(t *testing.T) { c, s := newTestConfStore(t) defer s.Close() - fooBarURL, _ := url.Parse("/foo/bar") - fooBarLimit := Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + + fooBarRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: "/foo/bar", + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + Conditions: Conditions{ + Path: "/foo/bar", + }, + }, } - fooBazURL, _ := url.Parse("/foo/baz") - fooBazLimit := Limit{ - Count: 3, - Duration: time.Second, - Enabled: false, + fooBazRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: "/foo/baz", + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 3, + Duration: time.Second, + Enabled: false, + }, + Conditions: Conditions{ + Path: "/foo/baz", + }, + }, } - routeRateLimits := map[url.URL]Limit{ - *fooBarURL: fooBarLimit, - *fooBazURL: fooBazLimit, + rateLimits := []RateLimitConfig{ + fooBarRateLimit, + fooBazRateLimit, } - if err := c.SetRouteRateLimits(routeRateLimits); err != nil { - t.Fatalf("got error: %v", err) + for _, config := range rateLimits { + if err := c.ApplyRateLimitConfig(config); err != nil { + t.Fatalf("got error: %v", err) + } } - got, err := c.FetchRouteRateLimit(*fooBarURL) + config, err := c.FetchRateLimitConfig(fooBarRateLimit.Name) + if err != nil { t.Fatalf("got error: %v", err) } - if !cmp.Equal(got, fooBarLimit) { - t.Errorf("expected: %v, received: %v", fooBarLimit, got) + if !cmp.Equal(config, fooBarRateLimit) { + t.Errorf("expected: %v, received: %v", fooBarRateLimit, config) } // Ensure configuration cache is updated after a confSyncInterval c.UpdateCachedConf() + fooBarURL, _ := url.Parse(fooBarRateLimit.Spec.Conditions.Path) cachedItem := c.GetRouteRateLimit(*fooBarURL) - if !cmp.Equal(cachedItem, fooBarLimit) { - t.Errorf("expected: %v, received: %v", fooBarLimit, cachedItem) + if !cmp.Equal(cachedItem, fooBarRateLimit.Spec.Limit) { + t.Errorf("expected: %v, received: %v", fooBarRateLimit.Spec.Limit, cachedItem) } - var urls []url.URL - urls = append(urls, *fooBarURL) - if err := c.RemoveRouteRateLimits(urls); err != nil { + if err := c.DeleteRateLimitConfig(fooBarRateLimit.Name); err != nil { t.Fatalf("got error: %v", err) } // Expect an error since we removed the limits for this route - got, err = c.FetchRouteRateLimit(*fooBarURL) + rateLimit, err := c.FetchRateLimitConfig(fooBarRateLimit.Name) if err == nil { - t.Fatalf("expected error fetching route limit which didn't exist") + t.Fatalf("found rate limit which shouldn't exist") } // Ensure configuration cache is updated after a confSyncInterval @@ -396,101 +504,150 @@ func TestConfStoreAddRemoveRouteRateLimits(t *testing.T) { t.Errorf("expected: %v, received: %v", Limit{}, cachedItem) } - got, err = c.FetchRouteRateLimit(*fooBazURL) + rateLimit, err = c.FetchRateLimitConfig(fooBazRateLimit.Name) if err != nil { t.Fatalf("got error: %v", err) } - - if !cmp.Equal(got, fooBazLimit) { - t.Errorf("expected: %v, received: %v", fooBazLimit, got) + if !cmp.Equal(rateLimit, fooBazRateLimit) { + t.Errorf("expected: %v, received: %v", fooBazRateLimit, config) } } func TestConfStoreSetExistingRoute(t *testing.T) { c, s := newTestConfStore(t) defer s.Close() - fooBarURL, _ := url.Parse("/foo/bar") - originalRouteRateLimit := map[url.URL]Limit{ - *fooBarURL: Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + + fooBarPath := "/foo/bar" + originalRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: fooBarPath, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, }, } - if err := c.SetRouteRateLimits(originalRouteRateLimit); err != nil { + if err := c.ApplyRateLimitConfig(originalRateLimit); err != nil { t.Fatalf("got error: %v", err) } - newLimit := Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + newRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: fooBarPath, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, + }, } - newRouteRateLimit := map[url.URL]Limit{ - *fooBarURL: newLimit, - } - if err := c.SetRouteRateLimits(newRouteRateLimit); err != nil { + if err := c.ApplyRateLimitConfig(newRateLimit); err != nil { t.Fatalf("got error: %v", err) } - got, err := c.FetchRouteRateLimit(*fooBarURL) + got, err := c.FetchRateLimitConfig(fooBarPath) if err != nil { t.Fatalf("got error: %v", err) } - if !cmp.Equal(got, newLimit) { - t.Errorf("expected: %v, received: %v", newLimit, got) + if !cmp.Equal(newRateLimit, got) { + t.Errorf("expected: %v, received: %v", newRateLimit, got) } } func TestConfStoreRemoveNonexistentRoute(t *testing.T) { c, s := newTestConfStore(t) defer s.Close() - fooBarURL, _ := url.Parse("/foo/bar") - fooBarLimit := Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + fooBarPath := "/foo/bar" + + fooBarRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: fooBarPath, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, + }, } - fooBazURL, _ := url.Parse("/foo/baz") - fooBazLimit := Limit{ - Count: 3, - Duration: time.Second, - Enabled: false, + fooBazPath := "/foo/baz" + fooBazRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: fooBazPath, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 3, + Duration: time.Second, + Enabled: false, + }, + Conditions: Conditions{ + Path: fooBazPath, + }, + }, } - routeRateLimits := map[url.URL]Limit{ - *fooBarURL: fooBarLimit, - *fooBazURL: fooBazLimit, + rateLimits := []RateLimitConfig{ + fooBarRateLimit, + fooBazRateLimit, } - if err := c.SetRouteRateLimits(routeRateLimits); err != nil { - t.Fatalf("got error: %v", err) + for _, config := range rateLimits { + if err := c.ApplyRateLimitConfig(config); err != nil { + t.Fatalf("got error: %v", err) + + } } - var urls []url.URL - nonExistentURL, _ := url.Parse("/foo/foo") - urls = append(urls, *nonExistentURL, *fooBarURL) - if err := c.RemoveRouteRateLimits(urls); err != nil { - t.Fatalf("got error: %v", err) + nonExistentPath := "/foo/foo" + names := []string{nonExistentPath, fooBarPath} + + for _, name := range names { + if err := c.DeleteRateLimitConfig(name); err != nil { + t.Fatalf("got error: %v", err) + } } // Expect an error since we removed the limits for this route - got, err := c.FetchRouteRateLimit(*fooBarURL) + got, err := c.FetchRateLimitConfig(fooBarRateLimit.Name) if err == nil { t.Fatalf("expected error fetching route limit which didn't exist") } - got, err = c.FetchRouteRateLimit(*fooBazURL) + got, err = c.FetchRateLimitConfig(fooBazRateLimit.Name) if err != nil { t.Fatalf("got error: %v", err) } - if !cmp.Equal(got, fooBazLimit) { - t.Errorf("expected: %v, received: %v", fooBazLimit, got) + if !cmp.Equal(got, fooBazRateLimit) { + t.Errorf("expected: %v, received: %v", fooBazRateLimit, got) } } @@ -498,36 +655,62 @@ func TestConfStoreAddRemoveJails(t *testing.T) { c, s := newTestConfStore(t) defer s.Close() - fooBarURL, _ := url.Parse("/foo/bar") - fooBarJail := Jail{ - Limit: Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + fooBarPath := "/foo/bar" + fooBarJail := JailConfig{ + ConfigMetadata: ConfigMetadata{ + Version: JailConfigVersion, + Kind: JailConfigKind, + Name: fooBarPath, + }, + Spec: JailSpec{ + Jail: Jail{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + BanDuration: time.Hour, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, }, - BanDuration: time.Hour, } - fooBazURL, _ := url.Parse("/foo/baz") - fooBazJail := Jail{ - Limit: Limit{ - Count: 3, - Duration: time.Second, - Enabled: false, + fooBazPath := "/foo/baz" + fooBazJail := JailConfig{ + ConfigMetadata: ConfigMetadata{ + Version: JailConfigVersion, + Kind: JailConfigKind, + Name: fooBazPath, + }, + Spec: JailSpec{ + Jail: Jail{ + Limit: Limit{ + Count: 3, + Duration: time.Second, + Enabled: false, + }, + BanDuration: time.Hour, + }, + Conditions: Conditions{ + Path: fooBazPath, + }, }, - BanDuration: time.Hour, } - jails := map[url.URL]Jail{ - *fooBarURL: fooBarJail, - *fooBazURL: fooBazJail, + jails := []JailConfig{ + fooBarJail, + fooBazJail, } - if err := c.SetJails(jails); err != nil { - t.Fatalf("got error: %v", err) + for _, config := range jails { + if err := c.ApplyJailConfig(config); err != nil { + t.Fatalf("got error: %v", err) + } } - got, err := c.FetchJail(*fooBarURL) + got, err := c.FetchJailConfig(fooBarJail.Name) if err != nil { t.Fatalf("got error: %v", err) } @@ -536,21 +719,20 @@ func TestConfStoreAddRemoveJails(t *testing.T) { t.Errorf("expected: %v, received: %v", fooBarJail, got) } + fooBarURL, _ := url.Parse(fooBarPath) // Ensure configuration cache is updated after a confSyncInterval c.UpdateCachedConf() cachedItem := c.GetJail(*fooBarURL) - if !cmp.Equal(cachedItem, fooBarJail) { + if !cmp.Equal(cachedItem, fooBarJail.Spec.Jail) { t.Errorf("expected: %v, received: %v", fooBarJail, cachedItem) } - var urls []url.URL - urls = append(urls, *fooBarURL) - if err := c.RemoveJails(urls); err != nil { + if err := c.DeleteJailConfig(fooBarJail.Name); err != nil { t.Fatalf("got error: %v", err) } // Expect an error since we removed the limits for this route - got, err = c.FetchJail(*fooBarURL) + got, err = c.FetchJailConfig(fooBarJail.Name) if err == nil { t.Fatalf("expected error fetching route limit which didn't exist") } @@ -562,7 +744,7 @@ func TestConfStoreAddRemoveJails(t *testing.T) { t.Errorf("expected: %v, received: %v", Jail{}, cachedItem) } - got, err = c.FetchJail(*fooBazURL) + got, err = c.FetchJailConfig(fooBazJail.Name) if err != nil { t.Fatalf("got error: %v", err) } @@ -576,39 +758,58 @@ func TestConfStoreSetExistingJail(t *testing.T) { c, s := newTestConfStore(t) defer s.Close() - fooBarURL, _ := url.Parse("/foo/bar") - fooBarJail := Jail{ - Limit: Limit{ - Count: 5, - Duration: time.Second, - Enabled: true, + fooBarPath := "/foo/bar" + fooBarJail := JailConfig{ + ConfigMetadata: ConfigMetadata{ + Version: JailConfigVersion, + Kind: JailConfigKind, + Name: fooBarPath, + }, + Spec: JailSpec{ + Jail: Jail{ + Limit: Limit{ + Count: 5, + Duration: time.Second, + Enabled: true, + }, + BanDuration: time.Hour, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, }, - BanDuration: time.Hour, } - jails := map[url.URL]Jail{*fooBarURL: fooBarJail} - - if err := c.SetJails(jails); err != nil { + if err := c.ApplyJailConfig(fooBarJail); err != nil { t.Fatalf("got error: %v", err) } - newJail := Jail{ - Limit: Limit{ - Count: 100, - Duration: time.Minute, - Enabled: false, + newJail := JailConfig{ + ConfigMetadata: ConfigMetadata{ + Version: JailConfigVersion, + Kind: JailConfigKind, + Name: fooBarPath, + }, + Spec: JailSpec{ + Jail: Jail{ + Limit: Limit{ + Count: 100, + Duration: time.Minute, + Enabled: false, + }, + BanDuration: time.Minute, + }, + Conditions: Conditions{ + Path: fooBarPath, + }, }, - BanDuration: time.Minute, } - newJails := map[url.URL]Jail{ - *fooBarURL: newJail, - } - if err := c.SetJails(newJails); err != nil { + if err := c.ApplyJailConfig(newJail); err != nil { t.Fatalf("got error: %v", err) } - got, err := c.FetchJail(*fooBarURL) + got, err := c.FetchJailConfig(fooBarJail.Name) if err != nil { t.Fatalf("got error: %v", err) } diff --git a/pkg/guardian/routeratelimit_test.go b/pkg/guardian/routeratelimit_test.go index 9222ae1..087d5b5 100644 --- a/pkg/guardian/routeratelimit_test.go +++ b/pkg/guardian/routeratelimit_test.go @@ -1,21 +1,35 @@ package guardian import ( - "net/url" "reflect" "testing" "time" ) func TestRouteLimitProvider(t *testing.T) { - fooBarRouteLimit := Limit{Count: 2, Duration: time.Minute, Enabled: true} - route := url.URL{Path: "/foo/bar"} - routeLimits := map[url.URL]Limit{route: fooBarRouteLimit} + route := "/foo/bar" + fooBarRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: route, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 2, + Duration: time.Minute, + Enabled: true, + }, + Conditions: Conditions{ + Path: route, + }, + }, + } globalLimit := Limit{Count: 2, Duration: time.Minute, Enabled: true} cs, s := newTestConfStoreWithDefaults(t, nil, nil, globalLimit, false) defer s.Close() - cs.SetRouteRateLimits(routeLimits) + cs.ApplyRateLimitConfig(fooBarRateLimit) cs.UpdateCachedConf() tests := []struct { @@ -26,7 +40,7 @@ func TestRouteLimitProvider(t *testing.T) { { name: "route with limit", req: Request{Path: "/foo/bar"}, - wantLimit: fooBarRouteLimit, + wantLimit: fooBarRateLimit.Spec.Limit, }, { name: "sub route without limit", @@ -51,30 +65,59 @@ func TestRouteLimitProvider(t *testing.T) { } func TestRouteLimitProviderUpdates(t *testing.T) { - fooBarRouteLimit := Limit{Count: 2, Duration: time.Minute, Enabled: true} - route := url.URL{Path: "/foo/bar"} - routeLimits := map[url.URL]Limit{route: fooBarRouteLimit} + route := "/foo/bar" + fooBarRateLimit := RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: route, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 2, + Duration: time.Minute, + Enabled: true, + }, + Conditions: Conditions{ + Path: route, + }, + }, + } globalLimit := Limit{Count: 2, Duration: time.Minute, Enabled: true} cs, s := newTestConfStoreWithDefaults(t, nil, nil, globalLimit, false) defer s.Close() - cs.SetRouteRateLimits(routeLimits) + cs.ApplyRateLimitConfig(fooBarRateLimit) cs.UpdateCachedConf() rlp := NewRouteRateLimitProvider(cs, TestingLogger) gotLimit := rlp.GetLimit(Request{Path: "/foo/bar"}) - if !reflect.DeepEqual(gotLimit, fooBarRouteLimit) { - t.Errorf("GetLimit() = %v, want %v", gotLimit, fooBarRouteLimit) + if !reflect.DeepEqual(gotLimit, fooBarRateLimit.Spec.Limit) { + t.Errorf("GetLimit() = %v, want %v", gotLimit, fooBarRateLimit.Spec.Limit) } - fooBarRouteLimit = Limit{Count: 43, Duration: time.Minute, Enabled: true} - - newRouteLimits := map[url.URL]Limit{route: fooBarRouteLimit} - cs.SetRouteRateLimits(newRouteLimits) + fooBarRateLimit = RateLimitConfig{ + ConfigMetadata: ConfigMetadata{ + Version: RateLimitConfigVersion, + Kind: RateLimitConfigKind, + Name: route, + }, + Spec: RateLimitSpec{ + Limit: Limit{ + Count: 43, + Duration: time.Minute, + Enabled: true, + }, + Conditions: Conditions{ + Path: route, + }, + }, + } + cs.ApplyRateLimitConfig(fooBarRateLimit) cs.UpdateCachedConf() gotLimit = rlp.GetLimit(Request{Path: "/foo/bar"}) - if !reflect.DeepEqual(gotLimit, fooBarRouteLimit) { - t.Errorf("GetLimit() = %v, want %v", gotLimit, fooBarRouteLimit) + if !reflect.DeepEqual(gotLimit, fooBarRateLimit.Spec.Limit) { + t.Errorf("GetLimit() = %v, want %v", gotLimit, fooBarRateLimit.Spec.Limit) } } diff --git a/vendor/gopkg.in/yaml.v2/go.mod b/vendor/gopkg.in/yaml.v2/go.mod index 1934e87..643be14 100644 --- a/vendor/gopkg.in/yaml.v2/go.mod +++ b/vendor/gopkg.in/yaml.v2/go.mod @@ -1,5 +1,5 @@ -module "gopkg.in/yaml.v2" +module gopkg.in/yaml.v2 -require ( - "gopkg.in/check.v1" v0.0.0-20161208181325-20d25e280405 -) +go 1.14 + +require gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 diff --git a/vendor/gopkg.in/yaml.v2/go.sum b/vendor/gopkg.in/yaml.v2/go.sum new file mode 100644 index 0000000..bfc2806 --- /dev/null +++ b/vendor/gopkg.in/yaml.v2/go.sum @@ -0,0 +1 @@ +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=