diff --git a/config/interceptor/role.yaml b/config/interceptor/role.yaml index 51c0d076..8cc9a7a2 100644 --- a/config/interceptor/role.yaml +++ b/config/interceptor/role.yaml @@ -12,6 +12,14 @@ rules: - get - list - watch +- apiGroups: + - "" + resources: + - services + verbs: + - get + - list + - watch - apiGroups: - http.keda.sh resources: diff --git a/interceptor/main.go b/interceptor/main.go index cb0cf634..5a67887a 100644 --- a/interceptor/main.go +++ b/interceptor/main.go @@ -17,6 +17,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" + k8sinformers "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/log/zap" @@ -42,6 +43,7 @@ var ( // +kubebuilder:rbac:groups=http.keda.sh,resources=httpscaledobjects,verbs=get;list;watch // +kubebuilder:rbac:groups="",resources=endpoints,verbs=get;list;watch +// +kubebuilder:rbac:groups="",resources=services,verbs=get;list;watch func main() { timeoutCfg := config.MustParseTimeouts() @@ -85,11 +87,10 @@ func main() { setupLog.Error(err, "creating new Kubernetes ClientSet") os.Exit(1) } - endpointsCache := k8s.NewInformerBackedEndpointsCache( - ctrl.Log, - cl, - time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS), - ) + + k8sSharedInformerFactory := k8sinformers.NewSharedInformerFactory(cl, time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS)) + svcCache := k8s.NewInformerBackedServiceCache(ctrl.Log, cl, k8sSharedInformerFactory) + endpointsCache := k8s.NewInformerBackedEndpointsCache(ctrl.Log, cl, time.Millisecond*time.Duration(servingCfg.EndpointsCachePollIntervalMS)) if err != nil { setupLog.Error(err, "creating new endpoints cache") os.Exit(1) @@ -123,6 +124,7 @@ func main() { setupLog.Info("starting the endpoints cache") endpointsCache.Start(ctx) + k8sSharedInformerFactory.Start(ctx.Done()) return nil }) @@ -173,10 +175,11 @@ func main() { eg.Go(func() error { proxyTLSConfig := map[string]string{"certificatePath": servingCfg.TLSCertPath, "keyPath": servingCfg.TLSKeyPath, "certstorePaths": servingCfg.TLSCertStorePaths} proxyTLSPort := servingCfg.TLSPort + k8sSharedInformerFactory.WaitForCacheSync(ctx.Done()) setupLog.Info("starting the proxy server with TLS enabled", "port", proxyTLSPort) - if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) { + if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcCache, timeoutCfg, proxyTLSPort, proxyTLSEnabled, proxyTLSConfig); !util.IsIgnoredErr(err) { setupLog.Error(err, "tls proxy server failed") return err } @@ -186,9 +189,11 @@ func main() { // start a proxy server without TLS. eg.Go(func() error { + k8sSharedInformerFactory.WaitForCacheSync(ctx.Done()) setupLog.Info("starting the proxy server with TLS disabled", "port", proxyPort) - if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, endpointsCache, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) { + k8sSharedInformerFactory.WaitForCacheSync(ctx.Done()) + if err := runProxyServer(ctx, ctrl.Log, queues, waitFunc, routingTable, svcCache, timeoutCfg, proxyPort, false, nil); !util.IsIgnoredErr(err) { setupLog.Error(err, "proxy server failed") return err } @@ -369,7 +374,7 @@ func runProxyServer( q queue.Counter, waitFunc forwardWaitFunc, routingTable routing.Table, - endpointsCache k8s.EndpointsCache, + svcCache k8s.ServiceCache, timeouts *config.Timeouts, port int, tlsEnabled bool, @@ -417,7 +422,7 @@ func runProxyServer( routingTable, probeHandler, upstreamHandler, - endpointsCache, + svcCache, tlsEnabled, ) rootHandler = middleware.NewLogging( diff --git a/interceptor/main_test.go b/interceptor/main_test.go index 67e93f1e..1809c0cd 100644 --- a/interceptor/main_test.go +++ b/interceptor/main_test.go @@ -63,7 +63,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) { // server routingTable := routingtest.NewTable() routingTable.Memory[host] = httpso - endpointsCache := k8s.NewFakeEndpointsCache() + svcCache := k8s.NewFakeServiceCache() timeouts := &config.Timeouts{} waiterCh := make(chan struct{}) @@ -78,7 +78,7 @@ func TestRunProxyServerCountMiddleware(t *testing.T) { q, waitFunc, routingTable, - endpointsCache, + svcCache, timeouts, port, false, @@ -196,7 +196,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) { // server routingTable := routingtest.NewTable() routingTable.Memory[host] = httpso - endpointsCache := k8s.NewFakeEndpointsCache() + svcCache := k8s.NewFakeServiceCache() timeouts := &config.Timeouts{} waiterCh := make(chan struct{}) @@ -212,7 +212,7 @@ func TestRunProxyServerWithTLSCountMiddleware(t *testing.T) { q, waitFunc, routingTable, - endpointsCache, + svcCache, timeouts, port, true, @@ -343,7 +343,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) { // server routingTable := routingtest.NewTable() routingTable.Memory[host] = httpso - endpointsCache := k8s.NewFakeEndpointsCache() + svcCache := k8s.NewFakeServiceCache() timeouts := &config.Timeouts{} waiterCh := make(chan struct{}) @@ -359,7 +359,7 @@ func TestRunProxyServerWithMultipleCertsTLSCountMiddleware(t *testing.T) { q, waitFunc, routingTable, - endpointsCache, + svcCache, timeouts, port, true, diff --git a/interceptor/middleware/routing.go b/interceptor/middleware/routing.go index 339ed2d9..69c6bcd5 100644 --- a/interceptor/middleware/routing.go +++ b/interceptor/middleware/routing.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "fmt" "net/http" "net/url" @@ -22,16 +23,16 @@ type Routing struct { routingTable routing.Table probeHandler http.Handler upstreamHandler http.Handler - endpointsCache k8s.EndpointsCache + svcCache k8s.ServiceCache tlsEnabled bool } -func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, endpointsCache k8s.EndpointsCache, tlsEnabled bool) *Routing { +func NewRouting(routingTable routing.Table, probeHandler http.Handler, upstreamHandler http.Handler, svcCache k8s.ServiceCache, tlsEnabled bool) *Routing { return &Routing{ routingTable: routingTable, probeHandler: probeHandler, upstreamHandler: upstreamHandler, - endpointsCache: endpointsCache, + svcCache: svcCache, tlsEnabled: tlsEnabled, } } @@ -55,7 +56,7 @@ func (rm *Routing) ServeHTTP(w http.ResponseWriter, r *http.Request) { } r = r.WithContext(util.ContextWithHTTPSO(r.Context(), httpso)) - stream, err := rm.streamFromHTTPSO(httpso) + stream, err := rm.streamFromHTTPSO(r.Context(), httpso) if err != nil { sh := handler.NewStatic(http.StatusInternalServerError, err) sh.ServeHTTP(w, r) @@ -67,29 +68,27 @@ func (rm *Routing) ServeHTTP(w http.ResponseWriter, r *http.Request) { rm.upstreamHandler.ServeHTTP(w, r) } -func (rm *Routing) getPort(httpso *httpv1alpha1.HTTPScaledObject) (int32, error) { +func (rm *Routing) getPort(ctx context.Context, httpso *httpv1alpha1.HTTPScaledObject) (int32, error) { if httpso.Spec.ScaleTargetRef.Port != 0 { return httpso.Spec.ScaleTargetRef.Port, nil } if httpso.Spec.ScaleTargetRef.PortName == "" { - return 0, fmt.Errorf("must specify either port or portName") + return 0, fmt.Errorf(`must specify either "port" or "portName"`) } - endpoints, err := rm.endpointsCache.Get(httpso.GetNamespace(), httpso.Spec.ScaleTargetRef.Service) + svc, err := rm.svcCache.Get(ctx, httpso.GetNamespace(), httpso.Spec.ScaleTargetRef.Service) if err != nil { - return 0, fmt.Errorf("failed to get Endpoints: %w", err) + return 0, fmt.Errorf("failed to get Service: %w", err) } - for _, subset := range endpoints.Subsets { - for _, port := range subset.Ports { - if port.Name == httpso.Spec.ScaleTargetRef.PortName { - return port.Port, nil - } + for _, port := range svc.Spec.Ports { + if port.Name == httpso.Spec.ScaleTargetRef.PortName { + return port.Port, nil } } - return 0, fmt.Errorf("portName %s not found in Endpoints", httpso.Spec.ScaleTargetRef.PortName) + return 0, fmt.Errorf("portName %q not found in Service", httpso.Spec.ScaleTargetRef.PortName) } -func (rm *Routing) streamFromHTTPSO(httpso *httpv1alpha1.HTTPScaledObject) (*url.URL, error) { - port, err := rm.getPort(httpso) +func (rm *Routing) streamFromHTTPSO(ctx context.Context, httpso *httpv1alpha1.HTTPScaledObject) (*url.URL, error) { + port, err := rm.getPort(ctx, httpso) if err != nil { return nil, fmt.Errorf("failed to get port: %w", err) } diff --git a/interceptor/middleware/routing_test.go b/interceptor/middleware/routing_test.go index b4a82f3e..b57f2321 100644 --- a/interceptor/middleware/routing_test.go +++ b/interceptor/middleware/routing_test.go @@ -25,9 +25,9 @@ var _ = Describe("RoutingMiddleware", func() { emptyHandler := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {}) probeHandler.Handle("/probe", emptyHandler) upstreamHandler.Handle("/upstream", emptyHandler) - endpointsCache := k8s.NewFakeEndpointsCache() + svcCache := k8s.NewFakeServiceCache() - rm := NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false) + rm := NewRouting(routingTable, probeHandler, upstreamHandler, svcCache, false) Expect(rm).NotTo(BeNil()) Expect(rm.routingTable).To(Equal(routingTable)) Expect(rm.probeHandler).To(Equal(probeHandler)) @@ -44,7 +44,7 @@ var _ = Describe("RoutingMiddleware", func() { var ( upstreamHandler *http.ServeMux probeHandler *http.ServeMux - endpointsCache *k8s.FakeEndpointsCache + svcCache *k8s.FakeServiceCache routingTable *routingtest.Table routingMiddleware *Routing w *httptest.ResponseRecorder @@ -76,18 +76,16 @@ var _ = Describe("RoutingMiddleware", func() { }, }, } - endpoints = corev1.Endpoints{ + svc = &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ Name: "keda-svc", Namespace: "default", }, - Subsets: []corev1.EndpointSubset{ - { - Ports: []corev1.EndpointPort{ - { - Name: "http", - Port: 80, - }, + Spec: corev1.ServiceSpec{ + Ports: []corev1.ServicePort{ + { + Name: "http", + Port: 80, }, }, }, @@ -98,8 +96,8 @@ var _ = Describe("RoutingMiddleware", func() { upstreamHandler = http.NewServeMux() probeHandler = http.NewServeMux() routingTable = routingtest.NewTable() - endpointsCache = k8s.NewFakeEndpointsCache() - routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, endpointsCache, false) + svcCache = k8s.NewFakeServiceCache() + routingMiddleware = NewRouting(routingTable, probeHandler, upstreamHandler, svcCache, false) w = httptest.NewRecorder() @@ -141,7 +139,7 @@ var _ = Describe("RoutingMiddleware", func() { When("route is found with portName", func() { It("routes to the upstream handler", func() { - endpointsCache.Set(endpoints) + svcCache.Add(*svc) var ( sc = http.StatusTeapot st = http.StatusText(sc) diff --git a/interceptor/proxy_handlers_integration_test.go b/interceptor/proxy_handlers_integration_test.go index c443b3c7..8898b275 100644 --- a/interceptor/proxy_handlers_integration_test.go +++ b/interceptor/proxy_handlers_integration_test.go @@ -281,6 +281,7 @@ func newHarness( }, ) + svcCache := k8s.NewFakeServiceCache() endpCache := k8s.NewFakeEndpointsCache() waitFunc := newWorkloadReplicasForwardWaitFunc( logr.Discard(), @@ -308,7 +309,7 @@ func newHarness( respHeaderTimeout: time.Second, }, &tls.Config{}), - endpCache, + svcCache, false, ) diff --git a/pkg/k8s/svc_cache.go b/pkg/k8s/svc_cache.go new file mode 100644 index 00000000..2069e5e7 --- /dev/null +++ b/pkg/k8s/svc_cache.go @@ -0,0 +1,77 @@ +package k8s + +import ( + "context" + "fmt" + "sync" + + "github.com/go-logr/logr" + v1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/client-go/informers" + "k8s.io/client-go/kubernetes" + listerv1 "k8s.io/client-go/listers/core/v1" +) + +// ServiceCache is an interface for caching service objects +type ServiceCache interface { + // Get gets a service with the given namespace and name from the cache + // If the service doesn't exist in the cache, it will be fetched from the API server + Get(ctx context.Context, namespace, name string) (*v1.Service, error) +} + +// InformerBackedServicesCache is a cache of services backed by a shared informer +type InformerBackedServicesCache struct { + lggr logr.Logger + cl kubernetes.Interface + svcLister listerv1.ServiceLister +} + +// FakeServiceCache is a fake implementation of a ServiceCache for testing +type FakeServiceCache struct { + current map[string]v1.Service + mut sync.RWMutex +} + +// NewInformerBackedServiceCache creates a new InformerBackedServicesCache +func NewInformerBackedServiceCache(lggr logr.Logger, cl kubernetes.Interface, factory informers.SharedInformerFactory) *InformerBackedServicesCache { + return &InformerBackedServicesCache{ + lggr: lggr.WithName("InformerBackedServicesCache"), + cl: cl, + svcLister: factory.Core().V1().Services().Lister(), + } +} + +// Get gets a service with the given namespace and name from the cache and as a fallback from the API server +func (c *InformerBackedServicesCache) Get(ctx context.Context, namespace, name string) (*v1.Service, error) { + svc, err := c.svcLister.Services(namespace).Get(name) + if err == nil { + c.lggr.V(1).Info("Service found in cache", "namespace", namespace, "name", name) + return svc, nil + } + c.lggr.V(1).Info("Service not found in cache, fetching from API server", "namespace", namespace, "name", name, "error", err) + return c.cl.CoreV1().Services(namespace).Get(ctx, name, metav1.GetOptions{}) +} + +// NewFakeServiceCache creates a new FakeServiceCache +func NewFakeServiceCache() *FakeServiceCache { + return &FakeServiceCache{current: make(map[string]v1.Service)} +} + +// Get gets a service with the given namespace and name from the cache +func (c *FakeServiceCache) Get(ctx context.Context, namespace, name string) (*v1.Service, error) { + c.mut.RLock() + defer c.mut.RUnlock() + svc, ok := c.current[key(namespace, name)] + if !ok { + return nil, fmt.Errorf("service not found") + } + return &svc, nil +} + +// Add adds a service to the cache +func (c *FakeServiceCache) Add(svc v1.Service) { + c.mut.Lock() + defer c.mut.Unlock() + c.current[key(svc.Namespace, svc.Name)] = svc +} diff --git a/tests/checks/internal_service_port_name/internal_service_port_name_test.go b/tests/checks/internal_service_port_name/internal_service_port_name_test.go index dd3c44b4..eb8db1de 100644 --- a/tests/checks/internal_service_port_name/internal_service_port_name_test.go +++ b/tests/checks/internal_service_port_name/internal_service_port_name_test.go @@ -6,6 +6,7 @@ package internal_service_port_name_test import ( "fmt" "testing" + "time" "github.com/stretchr/testify/assert" "k8s.io/client-go/kubernetes" @@ -150,6 +151,7 @@ func TestCheck(t *testing.T) { func testScaleOut(t *testing.T, kc *kubernetes.Clientset, data templateData) { t.Log("--- testing scale out ---") + time.Sleep(5 * time.Second) KubectlApplyWithTemplate(t, data, "loadJobTemplate", loadJobTemplate) assert.True(t, WaitForDeploymentReplicaReadyCount(t, kc, deploymentName, testNamespace, maxReplicaCount, 6, 10),