diff --git a/api/v1beta2/ratelimitpolicy_types.go b/api/v1beta2/ratelimitpolicy_types.go index 28f93f30f..2840a71f4 100644 --- a/api/v1beta2/ratelimitpolicy_types.go +++ b/api/v1beta2/ratelimitpolicy_types.go @@ -25,7 +25,6 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" gatewayapiv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" - gatewayapiv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" ) // EDIT THIS FILE! THIS IS SCAFFOLDING FOR YOU TO OWN! @@ -40,13 +39,14 @@ import ( // +kubebuilder:validation:MaxLength=253 type ContextSelector string -// +kubebuilder:validation:Enum:=eq;neq;startswith;incl;excl;matches +// +kubebuilder:validation:Enum:=eq;neq;startswith;endswith;incl;excl;matches type WhenConditionOperator string const ( EqualOperator WhenConditionOperator = "eq" NotEqualOperator WhenConditionOperator = "neq" StartsWithOperator WhenConditionOperator = "startswith" + EndsWithOperator WhenConditionOperator = "endswith" IncludeOperator WhenConditionOperator = "incl" ExcludeOperator WhenConditionOperator = "excl" MatchesOperator WhenConditionOperator = "matches" @@ -83,20 +83,6 @@ type WhenCondition struct { Value string `json:"value"` } -// RouteSelector defines semantics for matching an HTTP request based on conditions -// https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec -type RouteSelector struct { - // Hostnames defines a set of hostname that should match against the HTTP Host header to select a HTTPRoute to process the request - // https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec - // +optional - Hostnames []gatewayapiv1beta1.Hostname `json:"hostnames,omitempty"` - - // Matches define conditions used for matching the rule against incoming HTTP requests. - // https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec - // +optional - Matches []gatewayapiv1beta1.HTTPRouteMatch `json:"matches,omitempty"` -} - // Limit represents a complete rate limit configuration type Limit struct { // RouteSelectors defines semantics for matching an HTTP request based on conditions @@ -197,6 +183,15 @@ func (r *RateLimitPolicy) Validate() error { return fmt.Errorf("invalid targetRef.Namespace %s. Currently only supporting references to the same namespace", *r.Spec.TargetRef.Namespace) } + // prevents usage of routeSelectors in a gateway RLP + if r.Spec.TargetRef.Kind == gatewayapiv1alpha2.Kind("Gateway") { + for _, limit := range r.Spec.Limits { + if len(limit.RouteSelectors) > 0 { + return fmt.Errorf("route selectors not supported when targetting a Gateway") + } + } + } + return nil } diff --git a/api/v1beta2/route_selectors.go b/api/v1beta2/route_selectors.go new file mode 100644 index 000000000..f896a2cdd --- /dev/null +++ b/api/v1beta2/route_selectors.go @@ -0,0 +1,52 @@ +package v1beta2 + +import ( + gatewayapiv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + orderedmap "github.com/elliotchance/orderedmap/v2" + + "github.com/kuadrant/kuadrant-operator/pkg/common" +) + +// RouteSelector defines semantics for matching an HTTP request based on conditions +// https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec +type RouteSelector struct { + // Hostnames defines a set of hostname that should match against the HTTP Host header to select a HTTPRoute to process the request + // https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec + // +optional + Hostnames []gatewayapiv1beta1.Hostname `json:"hostnames,omitempty"` + + // Matches define conditions used for matching the rule against incoming HTTP requests. + // https://gateway-api.sigs.k8s.io/v1alpha2/references/spec/#gateway.networking.k8s.io/v1beta1.HTTPRouteSpec + // +optional + Matches []gatewayapiv1beta1.HTTPRouteMatch `json:"matches,omitempty"` +} + +// SelectRules returns, from a HTTPRoute, all HTTPRouteRules that either specify no HTTRouteMatches or that contain at +// least one HTTRouteMatch whose statements expressly include (partially or totally) the statements of at least one of +// the matches of the selector. If the selector does not specify any matches, then all HTTPRouteRules are selected. +// +// Additionally, if the selector specifies a non-empty list of hostnames, a non-empty intersection between the literal +// hostnames of the selector and set of hostnames specified in the HTTPRoute must exist. Otherwise, the function +// returns nil. +func (s *RouteSelector) SelectRules(route *gatewayapiv1beta1.HTTPRoute) (rules []gatewayapiv1beta1.HTTPRouteRule) { + rulesIndices := orderedmap.NewOrderedMap[int, gatewayapiv1beta1.HTTPRouteRule]() + if len(s.Hostnames) > 0 && !common.Intersect(s.Hostnames, route.Spec.Hostnames) { + return nil + } + if len(s.Matches) == 0 { + return route.Spec.Rules + } + for _, routeSelectorMatch := range s.Matches { + for idx, rule := range route.Spec.Rules { + rs := common.HTTPRouteRuleSelector{HTTPRouteMatch: &routeSelectorMatch} + if rs.Selects(rule) { + rulesIndices.Set(idx, rule) + } + } + } + for el := rulesIndices.Front(); el != nil; el = el.Next() { + rules = append(rules, el.Value) + } + return +} diff --git a/api/v1beta2/route_selectors_test.go b/api/v1beta2/route_selectors_test.go new file mode 100644 index 000000000..347dbe20e --- /dev/null +++ b/api/v1beta2/route_selectors_test.go @@ -0,0 +1,211 @@ +//go:build unit + +package v1beta2 + +import ( + "fmt" + "reflect" + "testing" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + gatewayapiv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" + + "github.com/kuadrant/kuadrant-operator/pkg/common" +) + +func TestRouteSelectors(t *testing.T) { + gatewayHostnames := []gatewayapiv1beta1.Hostname{ + "*.toystore.com", + } + + gateway := &gatewayapiv1beta1.Gateway{ + ObjectMeta: metav1.ObjectMeta{ + Name: "my-gateway", + }, + } + + for _, hostname := range gatewayHostnames { + gateway.Spec.Listeners = append(gateway.Spec.Listeners, gatewayapiv1beta1.Listener{Hostname: &hostname}) + } + + route := &gatewayapiv1beta1.HTTPRoute{ + Spec: gatewayapiv1beta1.HTTPRouteSpec{ + CommonRouteSpec: gatewayapiv1beta1.CommonRouteSpec{ + ParentRefs: []gatewayapiv1beta1.ParentReference{ + { + Name: gatewayapiv1beta1.ObjectName(gateway.Name), + }, + }, + }, + Hostnames: []gatewayapiv1beta1.Hostname{"api.toystore.com"}, + Rules: []gatewayapiv1beta1.HTTPRouteRule{ + { + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + // get /toys* + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethod("GET")}[0], + }, + // post /toys* + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethod("POST")}[0], + }, + }, + }, + { + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + // /assets* + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/assets"}[0], + }, + }, + }, + }, + }, + }, + } + + testCases := []struct { + name string + routeSelector RouteSelector + route *gatewayapiv1beta1.HTTPRoute + expected []gatewayapiv1beta1.HTTPRouteRule + }{ + { + name: "empty route selector selects all HTTPRouteRules", + routeSelector: RouteSelector{}, + route: route, + expected: route.Spec.Rules, + }, + { + name: "route selector selects the HTTPRouteRules whose set of HTTPRouteMatch is a perfect match", + routeSelector: RouteSelector{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/assets"}[0], + }, + }, + }, + }, + route: route, + expected: []gatewayapiv1beta1.HTTPRouteRule{route.Spec.Rules[1]}, + }, + { + name: "route selector selects the HTTPRouteRules whose set of HTTPRouteMatch contains at least one match", + routeSelector: RouteSelector{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethod("POST")}[0], + }, + }, + }, + route: route, + expected: []gatewayapiv1beta1.HTTPRouteRule{route.Spec.Rules[0]}, + }, + { + name: "route selector with missing part of a HTTPRouteMatch still selects the HTTPRouteRules that match", + routeSelector: RouteSelector{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + }, + }, + }, + route: route, + expected: []gatewayapiv1beta1.HTTPRouteRule{route.Spec.Rules[0]}, + }, + { + name: "route selector selects no HTTPRouteRule when no criterion matches", + routeSelector: RouteSelector{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchExact}[0], + Value: &[]string{"/toy"}[0], + }, + }, + }, + }, + route: route, + expected: nil, + }, + { + name: "route selector selects the HTTPRouteRules whose HTTPRoute's hostnames match the selector", + routeSelector: RouteSelector{ + Hostnames: []gatewayapiv1beta1.Hostname{"api.toystore.com"}, + }, + route: route, + expected: route.Spec.Rules, + }, + { + name: "route selector selects the HTTPRouteRules whose HTTPRoute's hostnames match the selector additionally to other criteria", + routeSelector: RouteSelector{ + Hostnames: []gatewayapiv1beta1.Hostname{"api.toystore.com"}, + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + }, + }, + }, + route: route, + expected: []gatewayapiv1beta1.HTTPRouteRule{route.Spec.Rules[0]}, + }, + { + name: "route selector does not select HTTPRouteRules whose HTTPRoute's hostnames do not match the selector", + routeSelector: RouteSelector{ + Hostnames: []gatewayapiv1beta1.Hostname{"www.toystore.com"}, + }, + route: route, + expected: nil, + }, + { + name: "route selector does not select HTTPRouteRules whose HTTPRoute's hostnames do not match the selector even when other criteria match", + routeSelector: RouteSelector{ + Hostnames: []gatewayapiv1beta1.Hostname{"www.toystore.com"}, + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + }, + }, + }, + route: route, + expected: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rules := tc.routeSelector.SelectRules(tc.route) + rulesToStringSlice := func(rules []gatewayapiv1beta1.HTTPRouteRule) []string { + return common.Map(common.Map(rules, common.HTTPRouteRuleToString), func(r string) string { return fmt.Sprintf("{%s}", r) }) + } + if !reflect.DeepEqual(rules, tc.expected) { + t.Errorf("expected %v, got %v", rulesToStringSlice(tc.expected), rulesToStringSlice(rules)) + } + }) + } +} diff --git a/bundle/manifests/kuadrant.io_ratelimitpolicies.yaml b/bundle/manifests/kuadrant.io_ratelimitpolicies.yaml index 2eeebd93a..af8a4c3dc 100644 --- a/bundle/manifests/kuadrant.io_ratelimitpolicies.yaml +++ b/bundle/manifests/kuadrant.io_ratelimitpolicies.yaml @@ -270,6 +270,7 @@ spec: - eq - neq - startswith + - endswith - incl - excl - matches diff --git a/config/crd/bases/kuadrant.io_ratelimitpolicies.yaml b/config/crd/bases/kuadrant.io_ratelimitpolicies.yaml index be4b708de..1c398ed41 100644 --- a/config/crd/bases/kuadrant.io_ratelimitpolicies.yaml +++ b/config/crd/bases/kuadrant.io_ratelimitpolicies.yaml @@ -269,6 +269,7 @@ spec: - eq - neq - startswith + - endswith - incl - excl - matches diff --git a/controllers/ratelimitpolicy_controller_test.go b/controllers/ratelimitpolicy_controller_test.go index 3933e154d..d25d0025c 100644 --- a/controllers/ratelimitpolicy_controller_test.go +++ b/controllers/ratelimitpolicy_controller_test.go @@ -13,6 +13,7 @@ import ( istioclientgoextensionv1alpha1 "istio.io/client-go/pkg/apis/extensions/v1alpha1" istioclientnetworkingv1alpha3 "istio.io/client-go/pkg/apis/networking/v1alpha3" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" "k8s.io/apimachinery/pkg/api/meta" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" @@ -52,17 +53,7 @@ func testBuildBasicGateway(gwName, ns string) *gatewayapiv1beta1.Gateway { } } -func testBuildBasicHttpRoute(routeName, gwName, ns string, hostnamesStrSlice []string) *gatewayapiv1beta1.HTTPRoute { - tmpMatchPathPrefix := gatewayapiv1beta1.PathMatchPathPrefix - tmpMatchValue := "/toy" - tmpMatchMethod := gatewayapiv1beta1.HTTPMethod("GET") - gwNamespace := gatewayapiv1beta1.Namespace(ns) - - var hostnames []gatewayapiv1beta1.Hostname - for _, str := range hostnamesStrSlice { - hostnames = append(hostnames, gatewayapiv1beta1.Hostname(str)) - } - +func testBuildBasicHttpRoute(routeName, gwName, ns string, hostnames []string) *gatewayapiv1beta1.HTTPRoute { return &gatewayapiv1beta1.HTTPRoute{ TypeMeta: metav1.TypeMeta{ Kind: "HTTPRoute", @@ -78,78 +69,20 @@ func testBuildBasicHttpRoute(routeName, gwName, ns string, hostnamesStrSlice []s ParentRefs: []gatewayapiv1beta1.ParentReference{ { Name: gatewayapiv1beta1.ObjectName(gwName), - Namespace: &gwNamespace, + Namespace: common.Ptr(gatewayapiv1beta1.Namespace(ns)), }, }, }, - Hostnames: hostnames, + Hostnames: common.Map(hostnames, func(hostname string) gatewayapiv1beta1.Hostname { return gatewayapiv1beta1.Hostname(hostname) }), Rules: []gatewayapiv1beta1.HTTPRouteRule{ { Matches: []gatewayapiv1beta1.HTTPRouteMatch{ { Path: &gatewayapiv1beta1.HTTPPathMatch{ - Type: &tmpMatchPathPrefix, - Value: &tmpMatchValue, + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/toy"), }, - Method: &tmpMatchMethod, - }, - }, - }, - }, - }, - } -} - -func testBuildBasicRoutePolicy(policyName, ns, routeName string) *kuadrantv1beta2.RateLimitPolicy { - return &kuadrantv1beta2.RateLimitPolicy{ - TypeMeta: metav1.TypeMeta{ - Kind: "RateLimitPolicy", - APIVersion: kuadrantv1beta2.GroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: policyName, - Namespace: ns, - }, - Spec: kuadrantv1beta2.RateLimitPolicySpec{ - TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ - Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), - Kind: "HTTPRoute", - Name: gatewayapiv1beta1.ObjectName(routeName), - }, - Limits: map[string]kuadrantv1beta2.Limit{ - "l1": { - Rates: []kuadrantv1beta2.Rate{ - { - Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), - }, - }, - }, - }, - }, - } -} - -func testBuildGatewayPolicy(policyName, ns, gwName string) *kuadrantv1beta2.RateLimitPolicy { - return &kuadrantv1beta2.RateLimitPolicy{ - TypeMeta: metav1.TypeMeta{ - Kind: "RateLimitPolicy", - APIVersion: kuadrantv1beta2.GroupVersion.String(), - }, - ObjectMeta: metav1.ObjectMeta{ - Name: policyName, - Namespace: ns, - }, - Spec: kuadrantv1beta2.RateLimitPolicySpec{ - TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ - Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), - Kind: "Gateway", - Name: gatewayapiv1beta1.ObjectName(gwName), - }, - Limits: map[string]kuadrantv1beta2.Limit{ - "l1": { - Rates: []kuadrantv1beta2.Rate{ - { - Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), + Method: common.Ptr(gatewayapiv1beta1.HTTPMethod("GET")), }, }, }, @@ -190,48 +123,64 @@ var _ = Describe("RateLimitPolicy controller", func() { }, 15*time.Second, 5*time.Second).Should(BeTrue()) ApplyKuadrantCR(testNamespace) + + // Check Limitador Status is Ready + Eventually(func() bool { + limitador := &limitadorv1alpha1.Limitador{} + err := k8sClient.Get(context.Background(), client.ObjectKey{Name: common.LimitadorName, Namespace: testNamespace}, limitador) + if err != nil { + return false + } + if !meta.IsStatusConditionTrue(limitador.Status.Conditions, "Ready") { + return false + } + return true + }, time.Minute, 5*time.Second).Should(BeTrue()) } BeforeEach(beforeEachCallback) AfterEach(DeleteNamespaceCallback(&testNamespace)) - Context("Basic: RLP targeting HTTPRoute", func() { - It("check created resources", func() { - // Check Limitador Status is Ready - Eventually(func() bool { - limitador := &limitadorv1alpha1.Limitador{} - err := k8sClient.Get(context.Background(), client.ObjectKey{Name: common.LimitadorName, Namespace: testNamespace}, limitador) - if err != nil { - return false - } - if !meta.IsStatusConditionTrue(limitador.Status.Conditions, "Ready") { - return false - } - return true - }, time.Minute, 5*time.Second).Should(BeTrue()) - + Context("RLP targeting HTTPRoute", func() { + It("Creates all the resources for a basic HTTPRoute and RateLimitPolicy", func() { + // create httproute httpRoute := testBuildBasicHttpRoute(routeName, gwName, testNamespace, []string{"*.example.com"}) err := k8sClient.Create(context.Background(), httpRoute) Expect(err).ToNot(HaveOccurred()) - rlp := testBuildBasicRoutePolicy(rlpName, testNamespace, routeName) - rlpKey := client.ObjectKey{Name: rlpName, Namespace: testNamespace} + // create ratelimitpolicy + rlp := &kuadrantv1beta2.RateLimitPolicy{ + TypeMeta: metav1.TypeMeta{ + Kind: "RateLimitPolicy", + APIVersion: kuadrantv1beta2.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: rlpName, + Namespace: testNamespace, + }, + Spec: kuadrantv1beta2.RateLimitPolicySpec{ + TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ + Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), + Kind: "HTTPRoute", + Name: gatewayapiv1beta1.ObjectName(routeName), + }, + Limits: map[string]kuadrantv1beta2.Limit{ + "l1": { + Rates: []kuadrantv1beta2.Rate{ + { + Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), + }, + }, + }, + }, + }, + } err = k8sClient.Create(context.Background(), rlp) Expect(err).ToNot(HaveOccurred()) // Check RLP status is available - Eventually(func() bool { - existingRLP := &kuadrantv1beta2.RateLimitPolicy{} - err := k8sClient.Get(context.Background(), rlpKey, existingRLP) - if err != nil { - return false - } - if !meta.IsStatusConditionTrue(existingRLP.Status.Conditions, "Available") { - return false - } - - return true - }, time.Minute, 5*time.Second).Should(BeTrue()) + rlpKey := client.ObjectKeyFromObject(rlp) + Eventually(testRLPIsAvailable(rlpKey), time.Minute, 5*time.Second).Should(BeTrue()) // Check HTTPRoute direct back reference routeKey := client.ObjectKey{Name: routeName, Namespace: testNamespace} @@ -277,10 +226,8 @@ var _ = Describe("RateLimitPolicy controller", func() { FailureMode: wasm.FailureModeDeny, RateLimitPolicies: []wasm.RateLimitPolicy{ { - Name: "*.example.com", - Domain: common.MarshallNamespace(client.ObjectKeyFromObject(gateway), "*.example.com"), - Service: common.KuadrantRateLimitClusterName, - Hostnames: []string{"*.example.com"}, + Name: rlpKey.String(), + Domain: fmt.Sprintf("%s/%s#%s", testNamespace, gwName, "*.example.com"), Rules: []wasm.Rule{ { Conditions: []wasm.Condition{ @@ -309,6 +256,8 @@ var _ = Describe("RateLimitPolicy controller", func() { }, }, }, + Hostnames: []string{"*.example.com"}, + Service: common.KuadrantRateLimitClusterName, }, }, })) @@ -325,41 +274,259 @@ var _ = Describe("RateLimitPolicy controller", func() { Expect(existingGateway.GetAnnotations()).To(HaveKeyWithValue( common.RateLimitPoliciesBackRefAnnotation, string(serialized))) }) + + It("Creates the correct WasmPlugin for a complex HTTPRoute and a RateLimitPolicy", func() { + // create httproute + httpRoute := testBuildBasicHttpRoute(routeName, gwName, testNamespace, []string{"*.toystore.acme.com", "api.toystore.io"}) + httpRoute.Spec.Rules = []gatewayapiv1beta1.HTTPRouteRule{ + { + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { // get /toys* + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/toys"), + }, + Method: common.Ptr(gatewayapiv1beta1.HTTPMethod("GET")), + }, + { // post /toys* + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/toys"), + }, + Method: common.Ptr(gatewayapiv1beta1.HTTPMethod("POST")), + }, + }, + }, + { + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { // /assets* + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/assets"), + }, + }, + }, + }, + } + err := k8sClient.Create(context.Background(), httpRoute) + Expect(err).ToNot(HaveOccurred()) + + // create ratelimitpolicy + rlp := &kuadrantv1beta2.RateLimitPolicy{ + TypeMeta: metav1.TypeMeta{ + Kind: "RateLimitPolicy", + APIVersion: kuadrantv1beta2.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: rlpName, + Namespace: testNamespace, + }, + Spec: kuadrantv1beta2.RateLimitPolicySpec{ + TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ + Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), + Kind: "HTTPRoute", + Name: gatewayapiv1beta1.ObjectName(routeName), + }, + Limits: map[string]kuadrantv1beta2.Limit{ + "toys": { + Rates: []kuadrantv1beta2.Rate{ + {Limit: 50, Duration: 1, Unit: kuadrantv1beta2.TimeUnit("minute")}, + }, + Counters: []kuadrantv1beta2.ContextSelector{"auth.identity.username"}, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { // selects the 1st HTTPRouteRule (i.e. get|post /toys*) for one of the hostnames + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/toys"), + }, + }, + }, + Hostnames: []gatewayapiv1beta1.Hostname{"*.toystore.acme.com"}, + }, + }, + When: []kuadrantv1beta2.WhenCondition{ + { + Selector: "auth.identity.group", + Operator: kuadrantv1beta2.WhenConditionOperator("neq"), + Value: "admin", + }, + }, + }, + "assets": { + Rates: []kuadrantv1beta2.Rate{ + {Limit: 5, Duration: 1, Unit: kuadrantv1beta2.TimeUnit("minute")}, + {Limit: 100, Duration: 12, Unit: kuadrantv1beta2.TimeUnit("hour")}, + }, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { // selects the 2nd HTTPRouteRule (i.e. /assets*) for all hostnames + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: common.Ptr(gatewayapiv1beta1.PathMatchPathPrefix), + Value: common.Ptr("/assets"), + }, + }, + }, + }, + }, + }, + }, + }, + } + err = k8sClient.Create(context.Background(), rlp) + Expect(err).ToNot(HaveOccurred()) + + // Check RLP status is available + rlpKey := client.ObjectKeyFromObject(rlp) + Eventually(testRLPIsAvailable(rlpKey), time.Minute, 5*time.Second).Should(BeTrue()) + + // Check wasm plugin + wpName := fmt.Sprintf("kuadrant-%s", gwName) + wasmPluginKey := client.ObjectKey{Name: wpName, Namespace: testNamespace} + existingWasmPlugin := &istioclientgoextensionv1alpha1.WasmPlugin{} + err = k8sClient.Get(context.Background(), wasmPluginKey, existingWasmPlugin) + // must exist + Expect(err).ToNot(HaveOccurred()) + existingWASMConfig, err := rlptools.WASMPluginFromStruct(existingWasmPlugin.Spec.PluginConfig) + Expect(err).ToNot(HaveOccurred()) + Expect(existingWASMConfig.FailureMode).To(Equal(wasm.FailureModeDeny)) + Expect(existingWASMConfig.RateLimitPolicies).To(HaveLen(1)) + wasmRLP := existingWASMConfig.RateLimitPolicies[0] + Expect(wasmRLP.Name).To(Equal(rlpKey.String())) + Expect(wasmRLP.Domain).To(Equal(fmt.Sprintf("%s/%s#%s", testNamespace, gwName, "*.toystore.acme.com"))) + Expect(wasmRLP.Rules).To(ContainElement(wasm.Rule{ // rule to activate the 'toys' limit defintion + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toys", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + { + Selector: "request.host", + Operator: wasm.PatternOperator(kuadrantv1beta2.EndsWithOperator), + Value: ".toystore.acme.com", + }, + { + Selector: "auth.identity.group", + Operator: wasm.PatternOperator(kuadrantv1beta2.NotEqualOperator), + Value: "admin", + }, + }, + }, + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toys", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "POST", + }, + { + Selector: "request.host", + Operator: wasm.PatternOperator(kuadrantv1beta2.EndsWithOperator), + Value: ".toystore.acme.com", + }, + { + Selector: "auth.identity.group", + Operator: wasm.PatternOperator(kuadrantv1beta2.NotEqualOperator), + Value: "admin", + }, + }, + }, + }, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: fmt.Sprintf("%s/%s/toys", testNamespace, rlpName), + Value: "1", + }, + }, + { + Selector: &wasm.SelectorSpec{ + Selector: kuadrantv1beta2.ContextSelector("auth.identity.username"), + }, + }, + }, + })) + Expect(wasmRLP.Rules).To(ContainElement(wasm.Rule{ // rule to activate the 'assets' limit defintion + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/assets", + }, + }, + }, + }, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: fmt.Sprintf("%s/%s/assets", testNamespace, rlpName), + Value: "1", + }, + }, + }, + })) + Expect(wasmRLP.Hostnames).To(Equal([]string{"*.toystore.acme.com", "api.toystore.io"})) + Expect(wasmRLP.Service).To(Equal(common.KuadrantRateLimitClusterName)) + }) }) - Context("Basic: RLP targeting Gateway", func() { - It("check created resources", func() { - // Check Limitador Status is Ready - Eventually(func() bool { - limitador := &limitadorv1alpha1.Limitador{} - err := k8sClient.Get(context.Background(), client.ObjectKey{Name: common.LimitadorName, Namespace: testNamespace}, limitador) - if err != nil { - return false - } - if !meta.IsStatusConditionTrue(limitador.Status.Conditions, "Ready") { - return false - } - return true - }, time.Minute, 5*time.Second).Should(BeTrue()) - - rlp := testBuildGatewayPolicy(rlpName, testNamespace, gwName) - rlpKey := client.ObjectKey{Name: rlpName, Namespace: testNamespace} - err := k8sClient.Create(context.Background(), rlp) + Context("RLP targeting Gateway", func() { + It("Creates all the resources for a basic Gateway and RateLimitPolicy", func() { + // create httproute + httpRoute := testBuildBasicHttpRoute(routeName, gwName, testNamespace, []string{"*.example.com"}) + err := k8sClient.Create(context.Background(), httpRoute) + Expect(err).ToNot(HaveOccurred()) + + // create ratelimitpolicy + rlp := &kuadrantv1beta2.RateLimitPolicy{ + TypeMeta: metav1.TypeMeta{ + Kind: "RateLimitPolicy", + APIVersion: kuadrantv1beta2.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: rlpName, + Namespace: testNamespace, + }, + Spec: kuadrantv1beta2.RateLimitPolicySpec{ + TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ + Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), + Kind: "Gateway", + Name: gatewayapiv1beta1.ObjectName(gwName), + }, + Limits: map[string]kuadrantv1beta2.Limit{ + "l1": { + Rates: []kuadrantv1beta2.Rate{ + { + Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), + }, + }, + }, + }, + }, + } + err = k8sClient.Create(context.Background(), rlp) Expect(err).ToNot(HaveOccurred()) // Check RLP status is available - Eventually(func() bool { - existingRLP := &kuadrantv1beta2.RateLimitPolicy{} - err := k8sClient.Get(context.Background(), rlpKey, existingRLP) - if err != nil { - return false - } - if !meta.IsStatusConditionTrue(existingRLP.Status.Conditions, "Available") { - return false - } - - return true - }, time.Minute, 5*time.Second).Should(BeTrue()) + rlpKey := client.ObjectKey{Name: rlpName, Namespace: testNamespace} + Eventually(testRLPIsAvailable(rlpKey), time.Minute, 5*time.Second).Should(BeTrue()) // Check Gateway direct back reference gwKey := client.ObjectKeyFromObject(gateway) @@ -405,13 +572,26 @@ var _ = Describe("RateLimitPolicy controller", func() { FailureMode: wasm.FailureModeDeny, RateLimitPolicies: []wasm.RateLimitPolicy{ { - Name: "*", - Domain: common.MarshallNamespace(client.ObjectKeyFromObject(gateway), "*"), - Service: common.KuadrantRateLimitClusterName, - Hostnames: []string{"*"}, + Name: rlpKey.String(), + Domain: fmt.Sprintf("%s/%s#%s", testNamespace, gwName, "*"), Rules: []wasm.Rule{ { - Conditions: nil, + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toy", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + }, + }, + }, Data: []wasm.DataItem{ { Static: &wasm.StaticSpec{ @@ -422,6 +602,8 @@ var _ = Describe("RateLimitPolicy controller", func() { }, }, }, + Hostnames: []string{"*"}, + Service: common.KuadrantRateLimitClusterName, }, }, })) @@ -433,8 +615,106 @@ var _ = Describe("RateLimitPolicy controller", func() { refs := []client.ObjectKey{rlpKey} serialized, err := json.Marshal(refs) Expect(err).ToNot(HaveOccurred()) + Expect(existingGateway.GetAnnotations()).To(HaveKeyWithValue(common.RateLimitPoliciesBackRefAnnotation, string(serialized))) + }) + + It("Creates all the resources for a basic Gateway and RateLimitPolicy when missing a HTTPRoute attached to the Gateway", func() { + // create ratelimitpolicy + rlp := &kuadrantv1beta2.RateLimitPolicy{ + TypeMeta: metav1.TypeMeta{ + Kind: "RateLimitPolicy", + APIVersion: kuadrantv1beta2.GroupVersion.String(), + }, + ObjectMeta: metav1.ObjectMeta{ + Name: rlpName, + Namespace: testNamespace, + }, + Spec: kuadrantv1beta2.RateLimitPolicySpec{ + TargetRef: gatewayapiv1alpha2.PolicyTargetReference{ + Group: gatewayapiv1beta1.Group("gateway.networking.k8s.io"), + Kind: "Gateway", + Name: gatewayapiv1beta1.ObjectName(gwName), + }, + Limits: map[string]kuadrantv1beta2.Limit{ + "l1": { + Rates: []kuadrantv1beta2.Rate{ + { + Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), + }, + }, + }, + }, + }, + } + err := k8sClient.Create(context.Background(), rlp) + Expect(err).ToNot(HaveOccurred()) + + // Check RLP status is available + rlpKey := client.ObjectKey{Name: rlpName, Namespace: testNamespace} + Eventually(testRLPIsAvailable(rlpKey), time.Minute, 5*time.Second).Should(BeTrue()) + + // Check Gateway direct back reference + gwKey := client.ObjectKeyFromObject(gateway) + existingGateway := &gatewayapiv1beta1.Gateway{} + err = k8sClient.Get(context.Background(), gwKey, existingGateway) + // must exist + Expect(err).ToNot(HaveOccurred()) Expect(existingGateway.GetAnnotations()).To(HaveKeyWithValue( - common.RateLimitPoliciesBackRefAnnotation, string(serialized))) + common.RateLimitPolicyBackRefAnnotation, client.ObjectKeyFromObject(rlp).String())) + + // check limits + limitadorKey := client.ObjectKey{Name: common.LimitadorName, Namespace: testNamespace} + existingLimitador := &limitadorv1alpha1.Limitador{} + err = k8sClient.Get(context.Background(), limitadorKey, existingLimitador) + // must exist + Expect(err).ToNot(HaveOccurred()) + Expect(existingLimitador.Spec.Limits).To(ContainElements(limitadorv1alpha1.RateLimit{ + MaxValue: 1, + Seconds: 3 * 60, + Namespace: common.MarshallNamespace(client.ObjectKeyFromObject(gateway), "*"), + Conditions: []string{fmt.Sprintf("%s/%s/l1 == \"1\"", testNamespace, rlpName)}, + Variables: []string{}, + })) + + // Check envoy filter + efName := fmt.Sprintf("kuadrant-ratelimiting-cluster-%s", gwName) + efKey := client.ObjectKey{Name: efName, Namespace: testNamespace} + existingEF := &istioclientnetworkingv1alpha3.EnvoyFilter{} + err = k8sClient.Get(context.Background(), efKey, existingEF) + // must exist + Expect(err).ToNot(HaveOccurred()) + + // Check wasm plugin + wpName := fmt.Sprintf("kuadrant-%s", gwName) + wasmPluginKey := client.ObjectKey{Name: wpName, Namespace: testNamespace} + existingWasmPlugin := &istioclientgoextensionv1alpha1.WasmPlugin{} + // must not exist + err = k8sClient.Get(context.Background(), wasmPluginKey, existingWasmPlugin) + Expect(apierrors.IsNotFound(err)).To(BeTrue()) + + // Check gateway back references + err = k8sClient.Get(context.Background(), gwKey, existingGateway) + // must exist + Expect(err).ToNot(HaveOccurred()) + refs := []client.ObjectKey{rlpKey} + serialized, err := json.Marshal(refs) + Expect(err).ToNot(HaveOccurred()) + Expect(existingGateway.GetAnnotations()).To(HaveKeyWithValue(common.RateLimitPoliciesBackRefAnnotation, string(serialized))) }) }) }) + +func testRLPIsAvailable(rlpKey client.ObjectKey) func() bool { + return func() bool { + existingRLP := &kuadrantv1beta2.RateLimitPolicy{} + err := k8sClient.Get(context.Background(), rlpKey, existingRLP) + if err != nil { + return false + } + if !meta.IsStatusConditionTrue(existingRLP.Status.Conditions, "Available") { + return false + } + + return true + } +} diff --git a/controllers/ratelimitpolicy_limits.go b/controllers/ratelimitpolicy_limits.go index f753c78e7..900cd8db9 100644 --- a/controllers/ratelimitpolicy_limits.go +++ b/controllers/ratelimitpolicy_limits.go @@ -155,7 +155,7 @@ func (r *RateLimitPolicyReconciler) gatewayLimits(ctx context.Context, limits["*"] = append(limits["*"], rlptools.ReadLimitsFromRLP(gwRLP)...) } else { for _, gwHostname := range gw.Hostnames() { - limits[gwHostname] = append(limits[gwHostname], rlptools.ReadLimitsFromRLP(gwRLP)...) + limits[string(gwHostname)] = append(limits[string(gwHostname)], rlptools.ReadLimitsFromRLP(gwRLP)...) } } } diff --git a/controllers/ratelimitpolicy_wasm_plugins.go b/controllers/ratelimitpolicy_wasm_plugins.go index ef393cc19..22f191bbb 100644 --- a/controllers/ratelimitpolicy_wasm_plugins.go +++ b/controllers/ratelimitpolicy_wasm_plugins.go @@ -120,55 +120,73 @@ func (r *RateLimitPolicyReconciler) gatewayWASMPlugin(ctx context.Context, gw co } // returns nil when there is no rate limit policy to apply -func (r *RateLimitPolicyReconciler) wasmPluginConfig(ctx context.Context, - gw common.GatewayWrapper, rlpRefs []client.ObjectKey) (*wasm.WASMPlugin, error) { +func (r *RateLimitPolicyReconciler) wasmPluginConfig(ctx context.Context, gw common.GatewayWrapper, rlpRefs []client.ObjectKey) (*wasm.WASMPlugin, error) { logger, _ := logr.FromContext(ctx) + logger = logger.WithName("wasmPluginConfig").WithValues("gateway", gw.Key()) - routeRLPList := make([]*kuadrantv1beta2.RateLimitPolicy, 0) - var gwRLP *kuadrantv1beta2.RateLimitPolicy + type store struct { + rlp kuadrantv1beta2.RateLimitPolicy + route gatewayapiv1beta1.HTTPRoute + skip bool + } + rlps := make(map[string]*store, len(rlpRefs)) + routeKeys := make(map[string]struct{}, 0) + var gwRLPKey string + + // store all rlps and find the one that targets the gateway (if there is one) for _, rlpKey := range rlpRefs { rlp := &kuadrantv1beta2.RateLimitPolicy{} err := r.Client().Get(ctx, rlpKey, rlp) - logger.V(1).Info("wasmPluginConfig", "get rlp", rlpKey, "err", err) + logger.V(1).Info("get rlp", "ratelimitpolicy", rlpKey, "err", err) if err != nil { return nil, err } + // target ref is a HTTPRoute if common.IsTargetRefHTTPRoute(rlp.Spec.TargetRef) { - routeRLPList = append(routeRLPList, rlp) - } else if common.IsTargetRefGateway(rlp.Spec.TargetRef) { - if gwRLP == nil { - gwRLP = rlp - } else { - return nil, fmt.Errorf("wasmPluginConfig: multiple gateway RLP found and only one expected. rlp keys: %v", rlpRefs) + route, err := r.FetchValidHTTPRoute(ctx, rlp.TargetKey()) + if err != nil { + return nil, err } + rlps[rlpKey.String()] = &store{rlp: *rlp, route: *route} + routeKeys[client.ObjectKeyFromObject(route).String()] = struct{}{} + continue } - } - - wasmRulesByDomain := make(rlptools.WasmRulesByDomain) - if gwRLP != nil { - if len(gw.Hostnames()) == 0 { - // wildcard domain - wasmRulesByDomain["*"] = append(wasmRulesByDomain["*"], rlptools.WasmRules(gwRLP, nil)...) - } else { - for _, gwHostname := range gw.Hostnames() { - wasmRulesByDomain[gwHostname] = append(wasmRulesByDomain[gwHostname], rlptools.WasmRules(gwRLP, nil)...) - } + // target ref is a Gateway + if rlps[rlpKey.String()] != nil { + return nil, fmt.Errorf("wasmPluginConfig: multiple gateway RLP found and only one expected. rlp keys: %v", rlpRefs) } + gwRLPKey = rlpKey.String() + rlps[gwRLPKey] = &store{rlp: *rlp} } - for _, httpRouteRLP := range routeRLPList { - httpRoute, err := r.FetchValidHTTPRoute(ctx, httpRouteRLP.TargetKey()) - if err != nil { - return nil, err - } + gwHostnames := gw.Hostnames() + if len(gwHostnames) == 0 { + gwHostnames = []gatewayapiv1beta1.Hostname{"*"} + } - // gateways limits merged with the route level limits - mergedGatewayActions := mergeRules(httpRouteRLP, gwRLP, httpRoute) - // routeLimits referenced by multiple hostnames - for _, hostname := range httpRoute.Spec.Hostnames { - wasmRulesByDomain[string(hostname)] = append(wasmRulesByDomain[string(hostname)], mergedGatewayActions...) + // if there is a gateway rlp, fake a single httproute with all rules from all httproutes accepted by the gateway, + // that do not have a rlp of its own, so we can generate wasm rules for those cases + if gwRLPKey != "" { + rules := make([]gatewayapiv1beta1.HTTPRouteRule, 0) + for _, route := range r.FetchAcceptedGatewayHTTPRoutes(ctx, rlps[gwRLPKey].rlp.TargetKey()) { + // skip routes that have a rlp of its own + if _, found := routeKeys[client.ObjectKeyFromObject(&route).String()]; found { + continue + } + rules = append(rules, route.Spec.Rules...) + } + if len(rules) == 0 { + logger.V(1).Info("no httproutes attached to the targeted gateway, skipping wasm config for the gateway rlp", "ratelimitpolicy", gwRLPKey) + rlps[gwRLPKey].skip = true + } else { + rlps[gwRLPKey].route = gatewayapiv1beta1.HTTPRoute{ + Spec: gatewayapiv1beta1.HTTPRouteSpec{ + Hostnames: gwHostnames, + Rules: rules, + }, + } } } @@ -177,29 +195,40 @@ func (r *RateLimitPolicyReconciler) wasmPluginConfig(ctx context.Context, RateLimitPolicies: make([]wasm.RateLimitPolicy, 0), } - // One RateLimitPolicy per domain - for domain, rules := range wasmRulesByDomain { - rateLimitPolicy := wasm.RateLimitPolicy{ - Name: domain, - Domain: common.MarshallNamespace(gw.Key(), domain), - Service: common.KuadrantRateLimitClusterName, - Hostnames: []string{domain}, - Rules: rules, + for _, rlpKey := range rlpRefs { + s := rlps[rlpKey.String()] + if s.skip { + continue } - wasmPlugin.RateLimitPolicies = append(wasmPlugin.RateLimitPolicies, rateLimitPolicy) - } + rlp := s.rlp + route := s.route + + // narrow the list of hostnames specified in the route so we don't generate wasm rules that only apply to other gateways + // this is a no-op for the gateway rlp + hostnames := common.FilterValidSubdomains(gwHostnames, route.Spec.Hostnames) + if len(hostnames) == 0 { // it should only happen when the route specifies no hostnames + hostnames = gwHostnames + } + route.Spec.Hostnames = hostnames - return wasmPlugin, nil -} + rules := rlptools.WasmRules(&rlp, &route) + if len(rules) == 0 { + continue // no need to add the policy if there are no rules; a rlp can return no rules if all its limits fail to match any route rule + } -// merge operations currently implemented with list append operation -func mergeRules(routeRLP *kuadrantv1beta2.RateLimitPolicy, gwRLP *kuadrantv1beta2.RateLimitPolicy, route *gatewayapiv1beta1.HTTPRoute) []wasm.Rule { - routeRules := rlptools.WasmRules(routeRLP, route) + wasmPlugin.RateLimitPolicies = append(wasmPlugin.RateLimitPolicies, wasm.RateLimitPolicy{ + Name: rlpKey.String(), + Domain: common.MarshallNamespace(gw.Key(), string(hostnames[0])), // TODO(guicassolato): https://github.com/Kuadrant/kuadrant-operator/issues/201. Meanwhile, we are using the first hostname so it matches at least one set of limit definitions in the Limitador CR + Rules: rules, + Hostnames: common.HostnamesToStrings(hostnames), // we might be listing more hostnames than needed due to route selectors hostnames possibly being more restrictive + Service: common.KuadrantRateLimitClusterName, + }) + } - if gwRLP == nil { - return routeRules + // avoid building a wasm plugin config if there are no rules to apply + if len(wasmPlugin.RateLimitPolicies) == 0 { + return nil, nil } - // add gateway level actions - return append(routeRules, rlptools.WasmRules(gwRLP, nil)...) + return wasmPlugin, nil } diff --git a/go.mod b/go.mod index 9405b3d00..e1865d4e8 100644 --- a/go.mod +++ b/go.mod @@ -54,6 +54,7 @@ require ( github.com/decred/dcrd/dcrec/secp256k1/v4 v4.0.0-20210816181553-5444fa50b93d // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/eko/gocache v1.2.0 // indirect + github.com/elliotchance/orderedmap/v2 v2.2.0 // indirect github.com/emicklei/go-restful/v3 v3.9.0 // indirect github.com/envoyproxy/go-control-plane v0.10.3 // indirect github.com/envoyproxy/protoc-gen-validate v0.9.1 // indirect diff --git a/go.sum b/go.sum index 4285abe27..465166fb9 100644 --- a/go.sum +++ b/go.sum @@ -215,6 +215,8 @@ github.com/eko/gocache v1.2.0 h1:SCtTs65qMXjhdtu62yHPCQuzdMkQjP+fQmkNrVutkRw= github.com/eko/gocache v1.2.0/go.mod h1:6u8/2bnr+nOf87mRXWS710rqNNZUECF4CGsPNnsoJ78= github.com/elazarl/goproxy v0.0.0-20170405201442-c4fc26588b6e/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= github.com/elazarl/goproxy v0.0.0-20180725130230-947c36da3153/go.mod h1:/Zj4wYkgs4iZTTu3o/KG3Itv/qCCa8VVMlb3i9OVuzc= +github.com/elliotchance/orderedmap/v2 v2.2.0 h1:7/2iwO98kYT4XkOjA9mBEIwvi4KpGB4cyHeOFOnj4Vk= +github.com/elliotchance/orderedmap/v2 v2.2.0/go.mod h1:85lZyVbpGaGvHvnKa7Qhx7zncAdBIBq6u56Hb1PRU5Q= github.com/emicklei/go-restful v0.0.0-20170410110728-ff4f55a20633/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE= diff --git a/pkg/common/common.go b/pkg/common/common.go index 7c0e24bd7..e036a3ee3 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -47,6 +47,10 @@ type KuadrantPolicy interface { GetRulesHostnames() []string } +func Ptr[T any](t T) *T { + return &t +} + // FetchEnv fetches the value of the environment variable with the specified key, // or returns the default value if the variable is not found or has an empty value. // If an error occurs during the lookup, the function returns the default value. @@ -87,7 +91,7 @@ func NamespacedNameToObjectKey(namespacedName, defaultNamespace string) client.O // Contains checks if the given target string is present in the slice of strings 'slice'. // It returns true if the target string is found in the slice, false otherwise. -func Contains(slice []string, target string) bool { +func Contains[T comparable](slice []T, target T) bool { for idx := range slice { if slice[idx] == target { return true @@ -96,6 +100,44 @@ func Contains(slice []string, target string) bool { return false } +// SameElements checks if the two slices contain the exact same elements. Order does not matter. +func SameElements[T comparable](s1, s2 []T) bool { + if len(s1) != len(s2) { + return false + } + for _, v := range s1 { + if !Contains(s2, v) { + return false + } + } + return true +} + +func Intersect[T comparable](slice1, slice2 []T) bool { + for _, item := range slice1 { + if Contains(slice2, item) { + return true + } + } + return false +} + +func Intersection[T comparable](slice1, slice2 []T) []T { + smallerSlice := slice1 + largerSlice := slice2 + if len(slice1) > len(slice2) { + smallerSlice = slice2 + largerSlice = slice1 + } + var result []T + for _, item := range smallerSlice { + if Contains(largerSlice, item) { + result = append(result, item) + } + } + return result +} + func Find[T any](slice []T, match func(T) bool) (*T, bool) { for _, item := range slice { if match(item) { @@ -114,6 +156,17 @@ func Map[T, U any](slice []T, f func(T) U) []U { return arr } +// Filter filters the input slice using the given predicate function and returns a new slice with the results. +func Filter[T any](slice []T, f func(T) bool) []T { + arr := make([]T, 0) + for _, e := range slice { + if f(e) { + arr = append(arr, e) + } + } + return arr +} + // SliceCopy copies the elements from the input slice into the output slice, and returns the output slice. func SliceCopy[T any](s1 []T) []T { s2 := make([]T, len(s1)) @@ -192,11 +245,9 @@ func UnMarshallObjectKey(keyStr string) (client.ObjectKey, error) { // HostnamesToStrings converts []gatewayapi_v1alpha2.Hostname to []string func HostnamesToStrings(hostnames []gatewayapiv1beta1.Hostname) []string { - hosts := make([]string, len(hostnames)) - for i, h := range hostnames { - hosts[i] = string(h) - } - return hosts + return Map(hostnames, func(hostname gatewayapiv1beta1.Hostname) string { + return string(hostname) + }) } // ValidSubdomains returns (true, "") when every single subdomains item @@ -221,3 +272,16 @@ func ValidSubdomains(domains, subdomains []string) (bool, string) { } return true, "" } + +// FilterValidSubdomains returns every subdomain that is a subset of at least one of the (super) domains specified in the first argument. +func FilterValidSubdomains(domains, subdomains []gatewayapiv1beta1.Hostname) []gatewayapiv1beta1.Hostname { + arr := make([]gatewayapiv1beta1.Hostname, 0) + for _, subsubdomain := range subdomains { + if _, found := Find(domains, func(domain gatewayapiv1beta1.Hostname) bool { + return Name(subsubdomain).SubsetOf(Name(domain)) + }); found { + arr = append(arr, subsubdomain) + } + } + return arr +} diff --git a/pkg/common/common_test.go b/pkg/common/common_test.go index cb093479b..8b181a790 100644 --- a/pkg/common/common_test.go +++ b/pkg/common/common_test.go @@ -250,6 +250,274 @@ func TestContains(t *testing.T) { } } +func TestContainsWithInts(t *testing.T) { + testCases := []struct { + name string + slice []int + target int + expected bool + }{ + { + name: "when slice has one target item then return true", + slice: []int{1}, + target: 1, + expected: true, + }, + { + name: "when slice is empty then return false", + slice: []int{}, + target: 2, + expected: false, + }, + { + name: "when target is in a slice then return true", + slice: []int{1, 2, 3}, + target: 2, + expected: true, + }, + { + name: "when no target in a slice then return false", + slice: []int{1, 2, 3}, + target: 4, + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if Contains(tc.slice, tc.target) != tc.expected { + t.Errorf("when slice=%v and target=%d, expected=%v, but got=%v", tc.slice, tc.target, tc.expected, !tc.expected) + } + }) + } +} + +func TestSameElements(t *testing.T) { + testCases := []struct { + name string + slice1 []string + slice2 []string + expected bool + }{ + { + name: "when slice1 and slice2 contain the same elements then return true", + slice1: []string{"test-gw1", "test-gw2", "test-gw3"}, + slice2: []string{"test-gw1", "test-gw2", "test-gw3"}, + expected: true, + }, + { + name: "when slice1 and slice2 contain unique elements then return false", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{"test-gw1", "test-gw3"}, + expected: false, + }, + { + name: "when both slices are empty then return true", + slice1: []string{}, + slice2: []string{}, + expected: true, + }, + { + name: "when both slices are nil then return true", + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if SameElements(tc.slice1, tc.slice2) != tc.expected { + t.Errorf("when slice1=%v and slice2=%v, expected=%v, but got=%v", tc.slice1, tc.slice2, tc.expected, !tc.expected) + } + }) + } +} + +func TestIntersect(t *testing.T) { + testCases := []struct { + name string + slice1 []string + slice2 []string + expected bool + }{ + { + name: "when slice1 and slice2 have one common item then return true", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{"test-gw1", "test-gw3", "test-gw4"}, + expected: true, + }, + { + name: "when slice1 and slice2 have no common item then return false", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{"test-gw3", "test-gw4"}, + expected: false, + }, + { + name: "when slice1 is empty then return false", + slice1: []string{}, + slice2: []string{"test-gw3", "test-gw4"}, + expected: false, + }, + { + name: "when slice2 is empty then return false", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{}, + expected: false, + }, + { + name: "when both slices are empty then return false", + slice1: []string{}, + slice2: []string{}, + expected: false, + }, + { + name: "when slice1 is nil then return false", + slice2: []string{"test-gw3", "test-gw4"}, + expected: false, + }, + { + name: "when slice2 is nil then return false", + slice1: []string{"test-gw1", "test-gw2"}, + expected: false, + }, + { + name: "when both slices are nil then return false", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if Intersect(tc.slice1, tc.slice2) != tc.expected { + t.Errorf("when slice1=%v and slice2=%v, expected=%v, but got=%v", tc.slice1, tc.slice2, tc.expected, !tc.expected) + } + }) + } +} + +func TestIntersectWithInts(t *testing.T) { + testCases := []struct { + name string + slice1 []int + slice2 []int + expected bool + }{ + { + name: "when slice1 and slice2 have one common item then return true", + slice1: []int{1, 2}, + slice2: []int{1, 3, 4}, + expected: true, + }, + { + name: "when slice1 and slice2 have no common item then return false", + slice1: []int{1, 2}, + slice2: []int{3, 4}, + expected: false, + }, + { + name: "when slice1 is empty then return false", + slice1: []int{}, + slice2: []int{3, 4}, + expected: false, + }, + { + name: "when slice2 is empty then return false", + slice1: []int{1, 2}, + slice2: []int{}, + expected: false, + }, + { + name: "when both slices are empty then return false", + slice1: []int{}, + slice2: []int{}, + expected: false, + }, + { + name: "when slice1 is nil then return false", + slice2: []int{3, 4}, + expected: false, + }, + { + name: "when slice2 is nil then return false", + slice1: []int{1, 2}, + expected: false, + }, + { + name: "when both slices are nil then return false", + expected: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if Intersect(tc.slice1, tc.slice2) != tc.expected { + t.Errorf("when slice1=%v and slice2=%v, expected=%v, but got=%v", tc.slice1, tc.slice2, tc.expected, !tc.expected) + } + }) + } +} + +func TestIntersection(t *testing.T) { + testCases := []struct { + name string + slice1 []string + slice2 []string + expected []string + }{ + { + name: "when slice1 and slice2 have one common item then return that item", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{"test-gw1", "test-gw3", "test-gw4"}, + expected: []string{"test-gw1"}, + }, + { + name: "when slice1 and slice2 have no common item then return nil", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{"test-gw3", "test-gw4"}, + expected: nil, + }, + { + name: "when slice1 is empty then return nil", + slice1: []string{}, + slice2: []string{"test-gw3", "test-gw4"}, + expected: nil, + }, + { + name: "when slice2 is empty then return nil", + slice1: []string{"test-gw1", "test-gw2"}, + slice2: []string{}, + expected: nil, + }, + { + name: "when both slices are empty then return nil", + slice1: []string{}, + slice2: []string{}, + expected: nil, + }, + { + name: "when slice1 is nil then return nil", + slice2: []string{"test-gw3", "test-gw4"}, + expected: nil, + }, + { + name: "when slice2 is nil then return nil", + slice1: []string{"test-gw1", "test-gw2"}, + expected: nil, + }, + { + name: "when both slices are nil then return nil", + expected: nil, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := Intersection(tc.slice1, tc.slice2); !reflect.DeepEqual(r, tc.expected) { + t.Errorf("expected=%v; got=%v", tc.expected, r) + } + }) + } +} + func TestMap(t *testing.T) { slice1 := []int{1, 2, 3, 4} f1 := func(x int) int { return x + 1 } @@ -652,3 +920,51 @@ func TestHostnamesToStrings(t *testing.T) { }) } } + +func TestFilterValidSubdomains(t *testing.T) { + testCases := []struct { + name string + domains []gatewayapiv1beta1.Hostname + subdomains []gatewayapiv1beta1.Hostname + expected []gatewayapiv1beta1.Hostname + }{ + { + name: "when all subdomains are valid", + domains: []gatewayapiv1beta1.Hostname{"my-app.apps.io", "*.acme.com"}, + subdomains: []gatewayapiv1beta1.Hostname{"toystore.acme.com", "my-app.apps.io", "carstore.acme.com"}, + expected: []gatewayapiv1beta1.Hostname{"toystore.acme.com", "my-app.apps.io", "carstore.acme.com"}, + }, + { + name: "when some subdomains are valid and some are not", + domains: []gatewayapiv1beta1.Hostname{"my-app.apps.io", "*.acme.com"}, + subdomains: []gatewayapiv1beta1.Hostname{"toystore.acme.com", "my-app.apps.io", "other-app.apps.io"}, + expected: []gatewayapiv1beta1.Hostname{"toystore.acme.com", "my-app.apps.io"}, + }, + { + name: "when none of subdomains are valid", + domains: []gatewayapiv1beta1.Hostname{"my-app.apps.io", "*.acme.com"}, + subdomains: []gatewayapiv1beta1.Hostname{"other-app.apps.io"}, + expected: []gatewayapiv1beta1.Hostname{}, + }, + { + name: "when the set of super domains is empty", + domains: []gatewayapiv1beta1.Hostname{}, + subdomains: []gatewayapiv1beta1.Hostname{"toystore.acme.com"}, + expected: []gatewayapiv1beta1.Hostname{}, + }, + { + name: "when the set of subdomains is empty", + domains: []gatewayapiv1beta1.Hostname{"my-app.apps.io", "*.acme.com"}, + subdomains: []gatewayapiv1beta1.Hostname{}, + expected: []gatewayapiv1beta1.Hostname{}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := FilterValidSubdomains(tc.domains, tc.subdomains); !reflect.DeepEqual(r, tc.expected) { + t.Errorf("expected=%v; got=%v", tc.expected, r) + } + }) + } +} diff --git a/pkg/common/gatewayapi_utils.go b/pkg/common/gatewayapi_utils.go index a527f64d9..c13ba6964 100644 --- a/pkg/common/gatewayapi_utils.go +++ b/pkg/common/gatewayapi_utils.go @@ -91,6 +91,117 @@ func RulesFromHTTPRoute(route *gatewayapiv1beta1.HTTPRoute) []HTTPRouteRule { return rules } +type HTTPRouteRuleSelector struct { + *gatewayapiv1beta1.HTTPRouteMatch +} + +func (s *HTTPRouteRuleSelector) Selects(rule gatewayapiv1beta1.HTTPRouteRule) bool { + if s.HTTPRouteMatch == nil { + return true + } + + _, found := Find(rule.Matches, func(ruleMatch gatewayapiv1beta1.HTTPRouteMatch) bool { + // path + if s.Path != nil && !reflect.DeepEqual(s.Path, ruleMatch.Path) { + return false + } + + // method + if s.Method != nil && !reflect.DeepEqual(s.Method, ruleMatch.Method) { + return false + } + + // headers + for _, header := range s.Headers { + if _, found := Find(ruleMatch.Headers, func(otherHeader gatewayapiv1beta1.HTTPHeaderMatch) bool { + return reflect.DeepEqual(header, otherHeader) + }); !found { + return false + } + } + + // query params + for _, param := range s.QueryParams { + if _, found := Find(ruleMatch.QueryParams, func(otherParam gatewayapiv1beta1.HTTPQueryParamMatch) bool { + return reflect.DeepEqual(param, otherParam) + }); !found { + return false + } + } + + return true + }) + + return found +} + +// HTTPRouteRuleToString prints the matches of a HTTPRouteRule as string +func HTTPRouteRuleToString(rule gatewayapiv1beta1.HTTPRouteRule) string { + matches := Map(rule.Matches, HTTPRouteMatchToString) + return fmt.Sprintf("{matches:[%s]}", strings.Join(matches, ",")) +} + +func HTTPRouteMatchToString(match gatewayapiv1beta1.HTTPRouteMatch) string { + var patterns []string + if method := match.Method; method != nil { + patterns = append(patterns, fmt.Sprintf("method:%v", HTTPMethodToString(method))) + } + if path := match.Path; path != nil { + patterns = append(patterns, fmt.Sprintf("path:%s", HTTPPathMatchToString(path))) + } + if len(match.QueryParams) > 0 { + queryParams := Map(match.QueryParams, HTTPQueryParamMatchToString) + patterns = append(patterns, fmt.Sprintf("queryParams:[%s]", strings.Join(queryParams, ","))) + } + if len(match.Headers) > 0 { + headers := Map(match.Headers, HTTPHeaderMatchToString) + patterns = append(patterns, fmt.Sprintf("headers:[%s]", strings.Join(headers, ","))) + } + return fmt.Sprintf("{%s}", strings.Join(patterns, ",")) +} + +func HTTPPathMatchToString(path *gatewayapiv1beta1.HTTPPathMatch) string { + if path == nil { + return "*" + } + if path.Type != nil { + switch *path.Type { + case gatewayapiv1beta1.PathMatchExact: + return fmt.Sprintf("%s", *path.Value) + case gatewayapiv1beta1.PathMatchRegularExpression: + return fmt.Sprintf("~/%s/", *path.Value) + } + } + return fmt.Sprintf("%s*", *path.Value) +} + +func HTTPHeaderMatchToString(header gatewayapiv1beta1.HTTPHeaderMatch) string { + if header.Type != nil { + switch *header.Type { + case gatewayapiv1beta1.HeaderMatchRegularExpression: + return fmt.Sprintf("{%s:~/%s/}", header.Name, header.Value) + } + } + return fmt.Sprintf("{%s:%s}", header.Name, header.Value) +} + +func HTTPQueryParamMatchToString(queryParam gatewayapiv1beta1.HTTPQueryParamMatch) string { + if queryParam.Type != nil { + switch *queryParam.Type { + case gatewayapiv1beta1.QueryParamMatchRegularExpression: + return fmt.Sprintf("{%s:~/%s/}", queryParam.Name, queryParam.Value) + } + } + return fmt.Sprintf("{%s:%s}", queryParam.Name, queryParam.Value) +} + +func HTTPMethodToString(method *gatewayapiv1beta1.HTTPMethod) string { + if method == nil { + return "*" + } + return string(*method) +} + func GetNamespaceFromPolicyTargetRef(ctx context.Context, cli client.Client, policy KuadrantPolicy) (string, error) { targetRef := policy.GetTargetRef() gwNamespacedName := types.NamespacedName{Namespace: string(GetDefaultIfNil(targetRef.Namespace, policy.GetWrappedNamespace())), Name: string(targetRef.Name)} @@ -374,15 +485,15 @@ func (g GatewayWrapper) DeletePolicy(policyKey client.ObjectKey) bool { } // Hostnames builds a list of hostnames from the listeners. -func (g GatewayWrapper) Hostnames() []string { - hostnames := make([]string, 0) +func (g GatewayWrapper) Hostnames() []gatewayapiv1beta1.Hostname { + hostnames := make([]gatewayapiv1beta1.Hostname, 0) if g.Gateway == nil { return hostnames } for idx := range g.Spec.Listeners { if g.Spec.Listeners[idx].Hostname != nil { - hostnames = append(hostnames, string(*g.Spec.Listeners[idx].Hostname)) + hostnames = append(hostnames, *g.Spec.Listeners[idx].Hostname) } } diff --git a/pkg/common/gatewayapi_utils_test.go b/pkg/common/gatewayapi_utils_test.go index d17ef8025..1f0668c22 100644 --- a/pkg/common/gatewayapi_utils_test.go +++ b/pkg/common/gatewayapi_utils_test.go @@ -241,6 +241,443 @@ func TestRulesFromHTTPRoute(t *testing.T) { } } +func TestHTTPRouteRuleSelectorSelects(t *testing.T) { + testCases := []struct { + name string + selector HTTPRouteRuleSelector + rule gatewayapiv1beta1.HTTPRouteRule + expected bool + }{ + { + name: "when the httproutrule contains the exact match then return true", + selector: HTTPRouteRuleSelector{ + HTTPRouteMatch: &gatewayapiv1beta1.HTTPRouteMatch{ + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + rule: gatewayapiv1beta1.HTTPRouteRule{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "when the httproutrule contains the exact match and more then return true", + selector: HTTPRouteRuleSelector{ + HTTPRouteMatch: &gatewayapiv1beta1.HTTPRouteMatch{ + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + }, + }, + rule: gatewayapiv1beta1.HTTPRouteRule{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "when the httproutrule contains all the matching headers and more then return true", + selector: HTTPRouteRuleSelector{ + HTTPRouteMatch: &gatewayapiv1beta1.HTTPRouteMatch{ + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + rule: gatewayapiv1beta1.HTTPRouteRule{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchRegularExpression}[0], + Name: "someotherheader", + Value: "someregex.*", + }, + }, + }, + }, + }, + expected: true, + }, + { + name: "when the httproutrule contains an inexact match then return false", + selector: HTTPRouteRuleSelector{ + HTTPRouteMatch: &gatewayapiv1beta1.HTTPRouteMatch{ + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + rule: gatewayapiv1beta1.HTTPRouteRule{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodPost}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + }, + expected: false, + }, + { + name: "when the httproutrule is empty then return false", + rule: gatewayapiv1beta1.HTTPRouteRule{}, + selector: HTTPRouteRuleSelector{ + HTTPRouteMatch: &gatewayapiv1beta1.HTTPRouteMatch{ + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + }, + }, + expected: false, + }, + { + name: "when the selector is empty then return true", + selector: HTTPRouteRuleSelector{}, + rule: gatewayapiv1beta1.HTTPRouteRule{ + Matches: []gatewayapiv1beta1.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + Headers: []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "someheader", + Value: "somevalue", + }, + }, + }, + }, + }, + expected: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := tc.selector.Selects(tc.rule); r != tc.expected { + expectedStr := "" + resultStr := "not" + if !tc.expected { + expectedStr = "not" + resultStr = "" + } + t.Error("expected selector", HTTPRouteMatchToString(*tc.selector.HTTPRouteMatch), expectedStr, "to select rule", HTTPRouteRuleToString(tc.rule), "but it does", resultStr) + } + }) + } +} + +func TestHTTPPathMatchToString(t *testing.T) { + testCases := []struct { + name string + input *gatewayapiv1beta1.HTTPPathMatch + expected string + }{ + { + name: "exact path match", + input: &[]gatewayapiv1beta1.HTTPPathMatch{ + { + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchExact}[0], + Value: &[]string{"/foo"}[0], + }, + }[0], + expected: "/foo", + }, + { + name: "regex path match", + input: &[]gatewayapiv1beta1.HTTPPathMatch{ + { + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchRegularExpression}[0], + Value: &[]string{"^\\/foo.*"}[0], + }, + }[0], + expected: "~/^\\/foo.*/", + }, + { + name: "path prefix match", + input: &[]gatewayapiv1beta1.HTTPPathMatch{ + { + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/foo"}[0], + }, + }[0], + expected: "/foo*", + }, + { + name: "path match with default type", + input: &[]gatewayapiv1beta1.HTTPPathMatch{ + { + Value: &[]string{"/foo"}[0], + }, + }[0], + expected: "/foo*", + }, + { + name: "nil path match", + input: nil, + expected: "*", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := HTTPPathMatchToString(tc.input); r != tc.expected { + t.Errorf("expected: %s, got: %s", tc.expected, r) + } + }) + } +} + +func TestHTTPHeaderMatchToString(t *testing.T) { + testCases := []struct { + name string + input gatewayapiv1beta1.HTTPHeaderMatch + expected string + }{ + { + name: "exact header match", + input: gatewayapiv1beta1.HTTPHeaderMatch{ + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchExact}[0], + Name: "some-header", + Value: "foo", + }, + expected: "{some-header:foo}", + }, + { + name: "regex header match", + input: gatewayapiv1beta1.HTTPHeaderMatch{ + Type: &[]gatewayapiv1beta1.HeaderMatchType{gatewayapiv1beta1.HeaderMatchRegularExpression}[0], + Name: "some-header", + Value: "^foo.*", + }, + expected: "{some-header:~/^foo.*/}", + }, + { + name: "header match with default type", + input: gatewayapiv1beta1.HTTPHeaderMatch{ + Name: "some-header", + Value: "foo", + }, + expected: "{some-header:foo}", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := HTTPHeaderMatchToString(tc.input); r != tc.expected { + t.Errorf("expected: %s, got: %s", tc.expected, r) + } + }) + } +} + +func TestHTTPQueryParamMatchToString(t *testing.T) { + testCases := []struct { + name string + input gatewayapiv1beta1.HTTPQueryParamMatch + expected string + }{ + { + name: "exact query param match", + input: gatewayapiv1beta1.HTTPQueryParamMatch{ + Type: &[]gatewayapiv1beta1.QueryParamMatchType{gatewayapiv1beta1.QueryParamMatchExact}[0], + Name: "some-param", + Value: "foo", + }, + expected: "{some-param:foo}", + }, + { + name: "regex query param match", + input: gatewayapiv1beta1.HTTPQueryParamMatch{ + Type: &[]gatewayapiv1beta1.QueryParamMatchType{gatewayapiv1beta1.QueryParamMatchRegularExpression}[0], + Name: "some-param", + Value: "^foo.*", + }, + expected: "{some-param:~/^foo.*/}", + }, + { + name: "query param match with default type", + input: gatewayapiv1beta1.HTTPQueryParamMatch{ + Name: "some-param", + Value: "foo", + }, + expected: "{some-param:foo}", + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if r := HTTPQueryParamMatchToString(tc.input); r != tc.expected { + t.Errorf("expected: %s, got: %s", tc.expected, r) + } + }) + } +} + +func TestHTTPMethodToString(t *testing.T) { + testCases := []struct { + input *gatewayapiv1beta1.HTTPMethod + expected string + }{ + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + expected: "GET", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodHead}[0], + expected: "HEAD", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodPost}[0], + expected: "POST", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodPut}[0], + expected: "PUT", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodPatch}[0], + expected: "PATCH", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodDelete}[0], + expected: "DELETE", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodConnect}[0], + expected: "CONNECT", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodOptions}[0], + expected: "OPTIONS", + }, + { + input: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodTrace}[0], + expected: "TRACE", + }, + { + input: nil, + expected: "*", + }, + } + for _, tc := range testCases { + if r := HTTPMethodToString(tc.input); r != tc.expected { + t.Errorf("expected: %s, got: %s", tc.expected, r) + } + } +} + +func TestHTTPRouteMatchToString(t *testing.T) { + match := gatewayapiv1beta1.HTTPRouteMatch{ + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchExact}[0], + Value: &[]string{"/foo"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + QueryParams: []gatewayapiv1beta1.HTTPQueryParamMatch{ + { + Type: &[]gatewayapiv1beta1.QueryParamMatchType{gatewayapiv1beta1.QueryParamMatchRegularExpression}[0], + Name: "page", + Value: "\\d+", + }, + }, + } + + expected := "{method:GET,path:/foo,queryParams:[{page:~/\\d+/}]}" + + if r := HTTPRouteMatchToString(match); r != expected { + t.Errorf("expected: %s, got: %s", expected, r) + } + + match.Headers = []gatewayapiv1beta1.HTTPHeaderMatch{ + { + Name: "x-foo", + Value: "bar", + }, + } + + expected = "{method:GET,path:/foo,queryParams:[{page:~/\\d+/}],headers:[{x-foo:bar}]}" + + if r := HTTPRouteMatchToString(match); r != expected { + t.Errorf("expected: %s, got: %s", expected, r) + } +} + +func TestHTTPRouteRuleToString(t *testing.T) { + rule := gatewayapiv1beta1.HTTPRouteRule{} + + expected := "{matches:[]}" + + if r := HTTPRouteRuleToString(rule); r != expected { + t.Errorf("expected: %s, got: %s", expected, r) + } + + rule.Matches = []gatewayapiv1beta1.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchExact}[0], + Value: &[]string{"/foo"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethodGet}[0], + QueryParams: []gatewayapiv1beta1.HTTPQueryParamMatch{ + { + Type: &[]gatewayapiv1beta1.QueryParamMatchType{gatewayapiv1beta1.QueryParamMatchRegularExpression}[0], + Name: "page", + Value: "\\d+", + }, + }, + }, + } + + expected = "{matches:[{method:GET,path:/foo,queryParams:[{page:~/\\d+/}]}]}" + + if r := HTTPRouteRuleToString(rule); r != expected { + t.Errorf("expected: %s, got: %s", expected, r) + } +} + func TestGatewaysMissingPolicyRef(t *testing.T) { gwList := &gatewayapiv1beta1.GatewayList{ Items: []gatewayapiv1beta1.Gateway{ diff --git a/pkg/reconcilers/targetref_reconciler.go b/pkg/reconcilers/targetref_reconciler.go index 124114194..a4bbc3e78 100644 --- a/pkg/reconcilers/targetref_reconciler.go +++ b/pkg/reconcilers/targetref_reconciler.go @@ -94,6 +94,35 @@ func (r *TargetRefReconciler) FetchValidTargetRef(ctx context.Context, targetRef return nil, fmt.Errorf("FetchValidTargetRef: targetRef (%v) to unknown network resource", targetRef) } +// FetchAcceptedGatewayHTTPRoutes returns the list of HTTPRoutes that have been accepted as children of a gateway. +func (r *TargetRefReconciler) FetchAcceptedGatewayHTTPRoutes(ctx context.Context, gwKey client.ObjectKey) (routes []gatewayapiv1beta1.HTTPRoute) { + logger, _ := logr.FromContext(ctx) + logger = logger.WithName("FetchAcceptedGatewayHTTPRoutes").WithValues("gateway", gwKey) + + routeList := &gatewayapiv1beta1.HTTPRouteList{} + err := r.Client().List(ctx, routeList) + if err != nil { + logger.V(1).Info("failed to list httproutes", "err", err) + return + } + + for _, route := range routeList.Items { + routeParentStatus, found := common.Find(route.Status.RouteStatus.Parents, func(p gatewayapiv1beta1.RouteParentStatus) bool { + return *p.ParentRef.Kind == ("Gateway") && + ((p.ParentRef.Namespace == nil && route.GetNamespace() == gwKey.Namespace) || string(*p.ParentRef.Namespace) == gwKey.Namespace) && + string(p.ParentRef.Name) == gwKey.Name + }) + if found && meta.IsStatusConditionTrue(routeParentStatus.Conditions, "Accepted") { + logger.V(1).Info("found route attached to gateway", "httproute", client.ObjectKeyFromObject(&route)) + routes = append(routes, route) + continue + } + logger.V(1).Info("skipping route, not attached to gateway", "httproute", client.ObjectKeyFromObject(&route)) + } + + return +} + // TargetedGatewayKeys returns the list of gateways that are being referenced from the target. func (r *TargetRefReconciler) TargetedGatewayKeys(ctx context.Context, targetNetworkObject client.Object) []client.ObjectKey { switch obj := targetNetworkObject.(type) { diff --git a/pkg/rlptools/wasm_utils.go b/pkg/rlptools/wasm_utils.go index 8f66100b9..ca6b28d98 100644 --- a/pkg/rlptools/wasm_utils.go +++ b/pkg/rlptools/wasm_utils.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "reflect" + "strings" _struct "github.com/golang/protobuf/ptypes/struct" istioclientgoextensionv1alpha1 "istio.io/client-go/pkg/apis/extensions/v1alpha1" @@ -20,7 +21,9 @@ var ( WASMFilterImageURL = common.FetchEnv("RELATED_IMAGE_WASMSHIM", "oci://quay.io/kuadrant/wasm-shim:latest") ) -// WasmRules computes WASM rules from the policy and the targeted Route (which can be nil when a gateway is targeted) +// WasmRules computes WASM rules from the policy and the targeted route. +// It returns an empty list of wasm rules if the policy specifies no limits or if all limits specified in the policy +// fail to match any route rule according to the limits route selectors. func WasmRules(rlp *kuadrantv1beta2.RateLimitPolicy, route *gatewayapiv1beta1.HTTPRoute) []wasm.Rule { rules := make([]wasm.Rule, 0) if rlp == nil { @@ -30,93 +33,138 @@ func WasmRules(rlp *kuadrantv1beta2.RateLimitPolicy, route *gatewayapiv1beta1.HT for limitName, limit := range rlp.Spec.Limits { // 1 RLP limit <---> 1 WASM rule limitFullName := FullLimitName(rlp, limitName) - rule := ruleFromLimit(limitFullName, &limit, route) - - rules = append(rules, rule) + rule, err := ruleFromLimit(limitFullName, &limit, route) + if err == nil { + rules = append(rules, rule) + } } return rules } -func ruleFromLimit(limitFullName string, limit *kuadrantv1beta2.Limit, route *gatewayapiv1beta1.HTTPRoute) wasm.Rule { - if limit == nil { - return wasm.Rule{} +func ruleFromLimit(limitFullName string, limit *kuadrantv1beta2.Limit, route *gatewayapiv1beta1.HTTPRoute) (wasm.Rule, error) { + rule := wasm.Rule{} + + if conditions, err := conditionsFromLimit(limit, route); err != nil { + return rule, err + } else { + rule.Conditions = conditions } - return wasm.Rule{ - Conditions: conditionsFromLimit(limit, route), - Data: dataFromLimt(limitFullName, limit), + if data := dataFromLimt(limitFullName, limit); data != nil { + rule.Data = data } + + return rule, nil } -func conditionsFromLimit(limit *kuadrantv1beta2.Limit, route *gatewayapiv1beta1.HTTPRoute) []wasm.Condition { +func conditionsFromLimit(limit *kuadrantv1beta2.Limit, route *gatewayapiv1beta1.HTTPRoute) ([]wasm.Condition, error) { if limit == nil { - return make([]wasm.Condition, 0) + return nil, errors.New("limit should not be nil") } - // TODO(eastizle): review this implementation. This is a first naive implementation. - // The conditions must always be a subset of the route's matching rules. - - conditions := make([]wasm.Condition, 0) + routeConditions := make([]wasm.Condition, 0) - for routeSelectorIdx := range limit.RouteSelectors { - // TODO(eastizle): what if there are only Hostnames (i.e. empty "matches" list) - for matchIdx := range limit.RouteSelectors[routeSelectorIdx].Matches { - condition := wasm.Condition{ - AllOf: patternExpresionsFromMatch(&limit.RouteSelectors[routeSelectorIdx].Matches[matchIdx]), + if len(limit.RouteSelectors) > 0 { + // build conditions from the rules selected by the route selectors + for _, routeSelector := range limit.RouteSelectors { + hostnamesForConditions := hostnamesForConditions(route, &routeSelector) + for _, rule := range routeSelector.SelectRules(route) { + routeConditions = append(routeConditions, conditionsFromRule(rule, hostnamesForConditions)...) } - - // merge hostnames expression in the same condition - for _, hostname := range limit.RouteSelectors[routeSelectorIdx].Hostnames { - condition.AllOf = append(condition.AllOf, patternExpresionFromHostname(hostname)) - } - - conditions = append(conditions, condition) + } + if len(routeConditions) == 0 { + return nil, errors.New("cannot match any route rules, check for invalid route selectors in the policy") + } + } else { + // build conditions from all rules if no route selectors are defined + for _, rule := range route.Spec.Rules { + routeConditions = append(routeConditions, conditionsFromRule(rule, hostnamesForConditions(route, nil))...) } } - if len(conditions) == 0 { - conditions = append(conditions, conditionsFromRoute(route)...) + if len(limit.When) == 0 { + if len(routeConditions) == 0 { + return nil, nil + } + return routeConditions, nil } - // merge when expression in the same condition - // must be done after adding route level conditions when no route selector are available - // prevent conditions only filled with "when" definitions - for whenIdx := range limit.When { - for idx := range conditions { - conditions[idx].AllOf = append(conditions[idx].AllOf, patternExpresionFromWhen(limit.When[whenIdx])) + if len(routeConditions) > 0 { + // merge the 'when' conditions into each route level one + mergedConditions := make([]wasm.Condition, len(routeConditions)) + for _, when := range limit.When { + for idx := range routeConditions { + mergedCondition := routeConditions[idx] + mergedCondition.AllOf = append(mergedCondition.AllOf, patternExpresionFromWhen(when)) + mergedConditions[idx] = mergedCondition + } } + return mergedConditions, nil } - return conditions + // build conditions only from the 'when' field + whenConditions := make([]wasm.Condition, len(limit.When)) + for idx, when := range limit.When { + whenConditions[idx] = wasm.Condition{AllOf: []wasm.PatternExpression{patternExpresionFromWhen(when)}} + } + return whenConditions, nil } -func conditionsFromRoute(route *gatewayapiv1beta1.HTTPRoute) []wasm.Condition { - if route == nil { - return make([]wasm.Condition, 0) - } +// hostnamesForConditions allows avoiding building conditions for hostnames that are excluded by the selector +// or when the hostname is irrelevant (i.e. matches all hostnames) +func hostnamesForConditions(route *gatewayapiv1beta1.HTTPRoute, routeSelector *kuadrantv1beta2.RouteSelector) []gatewayapiv1beta1.Hostname { + hostnames := route.Spec.Hostnames - conditions := make([]wasm.Condition, 0) + if routeSelector != nil && len(routeSelector.Hostnames) > 0 { + hostnames = common.Intersection(routeSelector.Hostnames, hostnames) + } - for ruleIdx := range route.Spec.Rules { - // One condition per match - for matchIdx := range route.Spec.Rules[ruleIdx].Matches { - conditions = append(conditions, wasm.Condition{ - AllOf: patternExpresionsFromMatch(&route.Spec.Rules[ruleIdx].Matches[matchIdx]), - }) - } + if common.SameElements(hostnames, route.Spec.Hostnames) { + return []gatewayapiv1beta1.Hostname{"*"} } - return conditions + return hostnames } -func patternExpresionsFromMatch(match *gatewayapiv1beta1.HTTPRouteMatch) []wasm.PatternExpression { - // TODO(eastizle): only paths and methods implemented +// conditionsFromRule builds a list of conditions from a rule and a list of hostnames +// each combination of a rule match and hostname yields one condition +// rules that specify no explicit match are assumed to match all request (i.e. implicit catch-all rule) +// empty list of hostnames yields a condition without a hostname pattern expression +func conditionsFromRule(rule gatewayapiv1beta1.HTTPRouteRule, hostnames []gatewayapiv1beta1.Hostname) (conditions []wasm.Condition) { + if len(rule.Matches) == 0 { + for _, hostname := range hostnames { + if hostname == "*" { + continue + } + condition := wasm.Condition{AllOf: []wasm.PatternExpression{patternExpresionFromHostname(hostname)}} + conditions = append(conditions, condition) + } + return + } + + for _, match := range rule.Matches { + condition := wasm.Condition{AllOf: patternExpresionsFromMatch(match)} + + if len(hostnames) > 0 { + for _, hostname := range hostnames { + if hostname == "*" { + conditions = append(conditions, condition) + continue + } + mergedCondition := condition + mergedCondition.AllOf = append(mergedCondition.AllOf, patternExpresionFromHostname(hostname)) + conditions = append(conditions, mergedCondition) + } + continue + } - if match == nil { - return make([]wasm.PatternExpression, 0) + conditions = append(conditions, condition) } + return +} +func patternExpresionsFromMatch(match gatewayapiv1beta1.HTTPRouteMatch) []wasm.PatternExpression { expressions := make([]wasm.PatternExpression, 0) if match.Path != nil { @@ -127,6 +175,8 @@ func patternExpresionsFromMatch(match *gatewayapiv1beta1.HTTPRouteMatch) []wasm. expressions = append(expressions, patternExpresionFromMethod(*match.Method)) } + // TODO(eastizle): only paths and methods implemented + return expressions } @@ -162,29 +212,33 @@ func patternExpresionFromMethod(method gatewayapiv1beta1.HTTPMethod) wasm.Patter } } -func patternExpresionFromWhen(when kuadrantv1beta2.WhenCondition) wasm.PatternExpression { +func patternExpresionFromHostname(hostname gatewayapiv1beta1.Hostname) wasm.PatternExpression { + value := string(hostname) + operator := "eq" + if strings.HasPrefix(value, "*.") { + operator = "endswith" + value = value[1:] + } return wasm.PatternExpression{ - Selector: when.Selector, - Operator: wasm.PatternOperator(when.Operator), - Value: when.Value, + Selector: "request.host", + Operator: wasm.PatternOperator(operator), + Value: string(value), } } -func patternExpresionFromHostname(hostname gatewayapiv1beta1.Hostname) wasm.PatternExpression { +func patternExpresionFromWhen(when kuadrantv1beta2.WhenCondition) wasm.PatternExpression { return wasm.PatternExpression{ - Selector: "request.host", - Operator: "eq", - Value: string(hostname), + Selector: when.Selector, + Operator: wasm.PatternOperator(when.Operator), + Value: when.Value, } } -func dataFromLimt(limitFullName string, limit *kuadrantv1beta2.Limit) []wasm.DataItem { +func dataFromLimt(limitFullName string, limit *kuadrantv1beta2.Limit) (data []wasm.DataItem) { if limit == nil { - return make([]wasm.DataItem, 0) + return } - data := make([]wasm.DataItem, 0) - // static key representing the limit data = append(data, wasm.DataItem{Static: &wasm.StaticSpec{Key: limitFullName, Value: "1"}}) diff --git a/pkg/rlptools/wasm_utils_test.go b/pkg/rlptools/wasm_utils_test.go index 573308caf..fc6b02dd4 100644 --- a/pkg/rlptools/wasm_utils_test.go +++ b/pkg/rlptools/wasm_utils_test.go @@ -3,13 +3,13 @@ package rlptools import ( - "reflect" "testing" + "github.com/google/go-cmp/cmp" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + gatewayapiv1alpha2 "sigs.k8s.io/gateway-api/apis/v1alpha2" gatewayapiv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" - "github.com/google/go-cmp/cmp" kuadrantv1beta2 "github.com/kuadrant/kuadrant-operator/api/v1beta2" "github.com/kuadrant/kuadrant-operator/pkg/rlptools/wasm" ) @@ -19,7 +19,10 @@ import ( func TestWasmRules(t *testing.T) { httpRoute := &gatewayapiv1beta1.HTTPRoute{ Spec: gatewayapiv1beta1.HTTPRouteSpec{ - Hostnames: []gatewayapiv1beta1.Hostname{"*.example.com"}, + Hostnames: []gatewayapiv1beta1.Hostname{ + "*.example.com", + "*.apps.example.internal", + }, Rules: []gatewayapiv1beta1.HTTPRouteRule{ { Matches: []gatewayapiv1beta1.HTTPRouteMatch{ @@ -36,59 +39,300 @@ func TestWasmRules(t *testing.T) { }, } - t.Run("minimal RLP", func(subT *testing.T) { - rlp := &kuadrantv1beta2.RateLimitPolicy{ - TypeMeta: metav1.TypeMeta{ - Kind: "RateLimitPolicy", APIVersion: kuadrantv1beta2.GroupVersion.String()}, - ObjectMeta: metav1.ObjectMeta{Name: "rlpA", Namespace: "nsA"}, + catchAllHTTPRoute := &gatewayapiv1beta1.HTTPRoute{ + Spec: gatewayapiv1beta1.HTTPRouteSpec{ + Hostnames: []gatewayapiv1beta1.Hostname{"*"}, + }, + } + + rlp := func(name string, limits map[string]kuadrantv1beta2.Limit) *kuadrantv1beta2.RateLimitPolicy { + return &kuadrantv1beta2.RateLimitPolicy{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: "my-app", + }, Spec: kuadrantv1beta2.RateLimitPolicySpec{ - Limits: map[string]kuadrantv1beta2.Limit{ - "l1": kuadrantv1beta2.Limit{ - Rates: []kuadrantv1beta2.Rate{ - { - Limit: 1, Duration: 3, Unit: kuadrantv1beta2.TimeUnit("minute"), + Limits: limits, + }, + } + } + + // a simple 50rps counter, for convinience, to be used in tests + counter50rps := kuadrantv1beta2.Rate{ + Limit: 50, + Duration: 1, + Unit: kuadrantv1beta2.TimeUnit("second"), + } + + testCases := []struct { + name string + rlp *kuadrantv1beta2.RateLimitPolicy + route *gatewayapiv1beta1.HTTPRoute + expectedRules []wasm.Rule + }{ + { + name: "minimal RLP", + rlp: rlp("minimal", map[string]kuadrantv1beta2.Limit{ + "50rps": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + }, + }), + route: httpRoute, + expectedRules: []wasm.Rule{ + { + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toy", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + }, + }, + }, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: "my-app/minimal/50rps", + Value: "1", }, }, }, }, }, - } - - expectedRule := wasm.Rule{ - Conditions: []wasm.Condition{ + }, + { + name: "RLP with route selector based on hostname", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps-for-selected-hostnames": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { + Hostnames: []gatewayapiv1beta1.Hostname{ + "*.example.com", + "myapp.apps.example.com", // ignored + }, + }, + }, + }, + }), + route: httpRoute, + expectedRules: []wasm.Rule{ { - AllOf: []wasm.PatternExpression{ + Conditions: []wasm.Condition{ { - Selector: "request.url_path", - Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), - Value: "/toy", + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toy", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + { + Selector: "request.host", + Operator: wasm.PatternOperator(kuadrantv1beta2.EndsWithOperator), + Value: ".example.com", + }, + }, }, + }, + Data: []wasm.DataItem{ { - Selector: "request.method", - Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), - Value: "GET", + Static: &wasm.StaticSpec{ + Key: "my-app/my-rlp/50rps-for-selected-hostnames", + Value: "1", + }, }, }, }, }, - Data: []wasm.DataItem{ + }, + { + name: "RLP with route selector based on http route matches (full match)", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps-for-selected-route": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { + Matches: []gatewayapiv1alpha2.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethod("GET")}[0], + }, + }, + }, + }, + }, + }), + route: httpRoute, + expectedRules: []wasm.Rule{ { - Static: &wasm.StaticSpec{ - Key: "nsA/rlpA/l1", - Value: "1", + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toy", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + }, + }, + }, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: "my-app/my-rlp/50rps-for-selected-route", + Value: "1", + }, + }, }, }, }, - } - - rules := WasmRules(rlp, httpRoute) - if len(rules) != 1 { - subT.Errorf("expected 1 rule, got (%d)", len(rules)) - } + }, + { + name: "RLP with route selector based on http route matches (partial match)", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps-for-selected-path": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { + Matches: []gatewayapiv1alpha2.HTTPRouteMatch{ + { + Path: &gatewayapiv1beta1.HTTPPathMatch{ + Type: &[]gatewayapiv1beta1.PathMatchType{gatewayapiv1beta1.PathMatchPathPrefix}[0], + Value: &[]string{"/toy"}[0], + }, + }, + }, + }, + }, + }, + }), + route: httpRoute, + expectedRules: []wasm.Rule{ + { + Conditions: []wasm.Condition{ + { + AllOf: []wasm.PatternExpression{ + { + Selector: "request.url_path", + Operator: wasm.PatternOperator(kuadrantv1beta2.StartsWithOperator), + Value: "/toy", + }, + { + Selector: "request.method", + Operator: wasm.PatternOperator(kuadrantv1beta2.EqualOperator), + Value: "GET", + }, + }, + }, + }, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: "my-app/my-rlp/50rps-for-selected-path", + Value: "1", + }, + }, + }, + }, + }, + }, + { + name: "RLP with mismatching route selectors", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps-for-non-existent-route": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + RouteSelectors: []kuadrantv1beta2.RouteSelector{ + { + Matches: []gatewayapiv1alpha2.HTTPRouteMatch{ + { + Method: &[]gatewayapiv1beta1.HTTPMethod{gatewayapiv1beta1.HTTPMethod("POST")}[0], + }, + }, + }, + }, + }, + }), + route: httpRoute, + expectedRules: []wasm.Rule{}, + }, + { + name: "HTTPRouteRules without rule matches", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + }, + }), + route: catchAllHTTPRoute, + expectedRules: []wasm.Rule{ + { + Conditions: nil, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: "my-app/my-rlp/50rps", + Value: "1", + }, + }, + }, + }, + }, + }, + { + name: "RLP with counter qualifier", + rlp: rlp("my-rlp", map[string]kuadrantv1beta2.Limit{ + "50rps-per-username": { + Rates: []kuadrantv1beta2.Rate{counter50rps}, + Counters: []kuadrantv1beta2.ContextSelector{"auth.identity.username"}, + }, + }), + route: catchAllHTTPRoute, + expectedRules: []wasm.Rule{ + { + Conditions: nil, + Data: []wasm.DataItem{ + { + Static: &wasm.StaticSpec{ + Key: "my-app/my-rlp/50rps-per-username", + Value: "1", + }, + }, + { + Selector: &wasm.SelectorSpec{ + Selector: "auth.identity.username", + }, + }, + }, + }, + }, + }, + } - if !reflect.DeepEqual(rules[0], expectedRule) { - diff := cmp.Diff(rules[0], expectedRule) - subT.Errorf("expected rule not found: %s", diff) - } - }) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + computedRules := WasmRules(tc.rlp, tc.route) + if diff := cmp.Diff(tc.expectedRules, computedRules); diff != "" { + t.Errorf("unexpected wasm rules (-want +got):\n%s", diff) + } + }) + } }