diff --git a/pkg/registry/common/querycache/cache.go b/pkg/registry/common/querycache/ns_cache.go similarity index 65% rename from pkg/registry/common/querycache/cache.go rename to pkg/registry/common/querycache/ns_cache.go index f9558dc42..dc06ce306 100644 --- a/pkg/registry/common/querycache/cache.go +++ b/pkg/registry/common/querycache/ns_cache.go @@ -1,6 +1,4 @@ -// Copyright (c) 2021 Doc.ai and/or its affiliates. -// -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -24,19 +22,20 @@ import ( "time" "github.com/edwarnicke/genericsync" + "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/tools/clock" ) -type cache struct { +type nsCache struct { expireTimeout time.Duration entries genericsync.Map[string, *cacheEntry] clockTime clock.Clock } -func newCache(ctx context.Context, opts ...Option) *cache { - c := &cache{ +func newNSCache(ctx context.Context, opts ...NSCacheOption) *nsCache { + c := &nsCache{ expireTimeout: time.Minute, clockTime: clock.FromContext(ctx), } @@ -60,7 +59,6 @@ func newCache(ctx context.Context, opts ...Option) *cache { if c.clockTime.Until(e.expirationTime) < 0 { e.cleanup() } - return true }) } @@ -70,51 +68,53 @@ func newCache(ctx context.Context, opts ...Option) *cache { return c } -func (c *cache) LoadOrStore(key string, nse *registry.NetworkServiceEndpoint, cancel context.CancelFunc) (*cacheEntry, bool) { +func (c *nsCache) LoadOrStore(value *registry.NetworkService, cancel context.CancelFunc) (*cacheEntry, bool) { var once sync.Once - return c.entries.LoadOrStore(key, &cacheEntry{ - nse: nse, + + entry, ok := c.entries.LoadOrStore(value.GetName(), &cacheEntry{ + value: value, expirationTime: c.clockTime.Now().Add(c.expireTimeout), cleanup: func() { once.Do(func() { - c.entries.Delete(key) + c.entries.Delete(value.GetName()) cancel() }) - }, - }) -} - -func (c *cache) Load(key string) (*registry.NetworkServiceEndpoint, bool) { - e, ok := c.entries.Load(key) - if !ok { - return nil, false - } + }}) - e.lock.Lock() - defer e.lock.Unlock() + return entry, ok +} - if c.clockTime.Until(e.expirationTime) < 0 { - e.cleanup() - return nil, false +func (c *nsCache) Load(ctx context.Context, query *registry.NetworkService) *registry.NetworkService { + entry, ok := c.entries.Load(query.Name) + if ok { + entry.lock.Lock() + defer entry.lock.Unlock() + if c.clockTime.Until(entry.expirationTime) < 0 { + entry.cleanup() + } else { + entry.expirationTime = c.clockTime.Now().Add(c.expireTimeout) + ns, ok := entry.value.(*registry.NetworkService) + if ok { + return ns + } + } } - e.expirationTime = c.clockTime.Now().Add(c.expireTimeout) - - return e.nse, true + return nil } type cacheEntry struct { - nse *registry.NetworkServiceEndpoint + value interface{} expirationTime time.Time lock sync.Mutex cleanup func() } -func (e *cacheEntry) Update(nse *registry.NetworkServiceEndpoint) { +func (e *cacheEntry) Update(value interface{}) { e.lock.Lock() defer e.lock.Unlock() - e.nse = nse + e.value = value } func (e *cacheEntry) Cleanup() { diff --git a/pkg/registry/common/querycache/ns_client.go b/pkg/registry/common/querycache/ns_client.go new file mode 100644 index 000000000..c3ab9f582 --- /dev/null +++ b/pkg/registry/common/querycache/ns_client.go @@ -0,0 +1,126 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package querycache adds possibility to cache Find queries +package querycache + +import ( + "context" + + "github.com/golang/protobuf/ptypes/empty" + "google.golang.org/grpc" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/registry/core/streamchannel" +) + +type queryCacheNSClient struct { + ctx context.Context + cache *nsCache +} + +// NewNetworkServiceClient creates new querycache NS registry client that caches all resolved NSs +func NewNetworkServiceClient(ctx context.Context, opts ...NSCacheOption) registry.NetworkServiceRegistryClient { + return &queryCacheNSClient{ + ctx: ctx, + cache: newNSCache(ctx, opts...), + } +} + +func (q *queryCacheNSClient) Register(ctx context.Context, nse *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, nse, opts...) +} + +func (q *queryCacheNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + if query.Watch { + return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) + } + + if client, ok := q.findInCache(ctx, query); ok { + return client, nil + } + + client, err := next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) + if err != nil { + return nil, err + } + + nses := registry.ReadNetworkServiceList(client) + + resultCh := make(chan *registry.NetworkServiceResponse, len(nses)) + for _, nse := range nses { + resultCh <- ®istry.NetworkServiceResponse{NetworkService: nse} + q.storeInCache(ctx, nse.Clone(), opts...) + } + close(resultCh) + + return streamchannel.NewNetworkServiceFindClient(ctx, resultCh), nil +} + +func (q *queryCacheNSClient) findInCache(ctx context.Context, query *registry.NetworkServiceQuery) (registry.NetworkServiceRegistry_FindClient, bool) { + ns := q.cache.Load(ctx, query.NetworkService) + if ns == nil { + return nil, false + } + + resultCh := make(chan *registry.NetworkServiceResponse, 1) + resultCh <- ®istry.NetworkServiceResponse{NetworkService: ns.Clone()} + close(resultCh) + + return streamchannel.NewNetworkServiceFindClient(ctx, resultCh), true +} + +func (q *queryCacheNSClient) storeInCache(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) { + nsQuery := ®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{ + Name: ns.Name, + }, + } + + findCtx, cancel := context.WithCancel(q.ctx) + entry, loaded := q.cache.LoadOrStore(ns, cancel) + if loaded { + cancel() + return + } + + go func() { + defer entry.Cleanup() + + nsQuery.Watch = true + stream, err := next.NetworkServiceRegistryClient(ctx).Find(findCtx, nsQuery, opts...) + if err != nil { + return + } + + for nsResp, err := stream.Recv(); err == nil; nsResp, err = stream.Recv() { + if nsResp.NetworkService.Name != nsQuery.NetworkService.Name { + continue + } + if nsResp.Deleted { + break + } + + entry.Update(nsResp.NetworkService) + } + }() +} + +func (q *queryCacheNSClient) Unregister(ctx context.Context, in *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, in, opts...) +} diff --git a/pkg/registry/common/querycache/ns_client_test.go b/pkg/registry/common/querycache/ns_client_test.go new file mode 100644 index 000000000..303de22c6 --- /dev/null +++ b/pkg/registry/common/querycache/ns_client_test.go @@ -0,0 +1,197 @@ +// Copyright (c) 2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycache_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/golang/protobuf/ptypes/empty" + "github.com/networkservicemesh/api/pkg/api/registry" + "github.com/pkg/errors" + "github.com/stretchr/testify/require" + "go.uber.org/goleak" + "google.golang.org/grpc" + + "github.com/networkservicemesh/sdk/pkg/registry/common/memory" + "github.com/networkservicemesh/sdk/pkg/registry/common/querycache" + "github.com/networkservicemesh/sdk/pkg/registry/core/adapters" + "github.com/networkservicemesh/sdk/pkg/registry/core/next" + "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" +) + +const ( + payload1 = "ethernet" + payload2 = "ip" +) + +func testNSQuery(nsName string) *registry.NetworkServiceQuery { + return ®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{ + Name: nsName, + }, + } +} + +func Test_QueryCacheClient_ShouldCacheNSs(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + mem := memory.NewNetworkServiceRegistryServer() + + failureClient := new(failureNSClient) + c := next.NewNetworkServiceRegistryClient( + querycache.NewNetworkServiceClient(ctx, querycache.WithNSExpireTimeout(expireTimeout)), + failureClient, + adapters.NetworkServiceServerToClient(mem), + ) + + reg, err := mem.Register(ctx, ®istry.NetworkService{ + Name: name, + Payload: payload1, + }) + require.NoError(t, err) + + // Goroutines should be cleaned up on ns unregister + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // 1. Find from memory + atomic.StoreInt32(&failureClient.shouldFail, 0) + + stream, err := c.Find(ctx, testNSQuery("")) + require.NoError(t, err) + nsResp, err := stream.Recv() + require.NoError(t, err) + require.Equal(t, name, nsResp.NetworkService.Name) + + // 2. Find from cache + atomic.StoreInt32(&failureClient.shouldFail, 1) + + stream, err = c.Find(ctx, testNSQuery(name)) + require.NoError(t, err) + nsResp, err = stream.Recv() + require.NoError(t, err) + require.Equal(t, name, nsResp.NetworkService.Name) + + // 3. Update NS in memory + reg.Payload = payload2 + reg, err = mem.Register(ctx, reg) + require.NoError(t, err) + + require.Eventually(t, func() bool { + if stream, err = c.Find(ctx, testNSQuery(name)); err != nil { + return false + } + if nsResp, err = stream.Recv(); err != nil { + return false + } + return name == nsResp.NetworkService.Name && payload2 == nsResp.NetworkService.Payload + }, testWait, testTick) + + // 4. Delete ns from memory + _, err = mem.Unregister(ctx, reg) + require.NoError(t, err) + + require.Eventually(t, func() bool { + _, err = c.Find(ctx, testNSQuery(name)) + return err != nil + }, testWait, testTick) +} + +func Test_QueryCacheClient_ShouldCleanUpNSOnTimeout(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) + + mem := memory.NewNetworkServiceRegistryServer() + + failureClient := new(failureNSClient) + c := next.NewNetworkServiceRegistryClient( + querycache.NewNetworkServiceClient(ctx, querycache.WithNSExpireTimeout(expireTimeout)), + failureClient, + adapters.NetworkServiceServerToClient(mem), + ) + + _, err := mem.Register(ctx, ®istry.NetworkService{ + Name: name, + }) + require.NoError(t, err) + + // Goroutines should be cleaned up on cache entry expiration + t.Cleanup(func() { goleak.VerifyNone(t) }) + + // 1. Find from memory + atomic.StoreInt32(&failureClient.shouldFail, 0) + + stream, err := c.Find(ctx, testNSQuery("")) + require.NoError(t, err) + + _, err = stream.Recv() + require.NoError(t, err) + + // 2. Find from cache + atomic.StoreInt32(&failureClient.shouldFail, 1) + + require.Eventually(t, func() bool { + if stream, err = c.Find(ctx, testNSQuery(name)); err == nil { + _, err = stream.Recv() + } + return err == nil + }, testWait, testTick) + + // 3. Keep finding from cache to prevent expiration + for start := clockMock.Now(); clockMock.Since(start) < 2*expireTimeout; clockMock.Add(expireTimeout / 3) { + stream, err = c.Find(ctx, testNSQuery(name)) + require.NoError(t, err) + + _, err = stream.Recv() + require.NoError(t, err) + } + + // 4. Wait for the expire to happen + clockMock.Add(expireTimeout) + + _, err = c.Find(ctx, testNSQuery(name)) + require.Errorf(t, err, "find error") +} + +type failureNSClient struct { + shouldFail int32 +} + +func (c *failureNSClient) Register(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*registry.NetworkService, error) { + return next.NetworkServiceRegistryClient(ctx).Register(ctx, ns, opts...) +} + +func (c *failureNSClient) Find(ctx context.Context, query *registry.NetworkServiceQuery, opts ...grpc.CallOption) (registry.NetworkServiceRegistry_FindClient, error) { + if atomic.LoadInt32(&c.shouldFail) == 1 && !query.Watch { + return nil, errors.New("find error") + } + return next.NetworkServiceRegistryClient(ctx).Find(ctx, query, opts...) +} + +func (c *failureNSClient) Unregister(ctx context.Context, ns *registry.NetworkService, opts ...grpc.CallOption) (*empty.Empty, error) { + return next.NetworkServiceRegistryClient(ctx).Unregister(ctx, ns, opts...) +} diff --git a/pkg/registry/common/querycache/nse_cache.go b/pkg/registry/common/querycache/nse_cache.go new file mode 100644 index 000000000..a3fc36040 --- /dev/null +++ b/pkg/registry/common/querycache/nse_cache.go @@ -0,0 +1,141 @@ +// Copyright (c) 2021 Doc.ai and/or its affiliates. +// +// Copyright (c) 2023-2024 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package querycache + +import ( + "context" + "sync" + "time" + + "github.com/edwarnicke/genericsync" + + "github.com/networkservicemesh/api/pkg/api/registry" + + "github.com/networkservicemesh/sdk/pkg/tools/clock" +) + +type nseCache struct { + expireTimeout time.Duration + entries genericsync.Map[string, *cacheEntry] + clockTime clock.Clock +} + +func newNSECache(ctx context.Context, opts ...NSECacheOption) *nseCache { + c := &nseCache{ + expireTimeout: time.Minute, + clockTime: clock.FromContext(ctx), + } + + for _, opt := range opts { + opt(c) + } + + ticker := c.clockTime.Ticker(c.expireTimeout) + go func() { + for { + select { + case <-ctx.Done(): + ticker.Stop() + return + case <-ticker.C(): + c.entries.Range(func(_ string, e *cacheEntry) bool { + e.lock.Lock() + defer e.lock.Unlock() + + if c.clockTime.Until(e.expirationTime) < 0 { + e.cleanup() + } + return true + }) + } + } + }() + + return c +} + +func (c *nseCache) LoadOrStore(value *registry.NetworkServiceEndpoint, cancel context.CancelFunc) (*cacheEntry, bool) { + var once sync.Once + + entry, ok := c.entries.LoadOrStore(value.GetName(), &cacheEntry{ + value: value, + expirationTime: c.clockTime.Now().Add(c.expireTimeout), + cleanup: func() { + once.Do(func() { + c.entries.Delete(value.GetName()) + cancel() + }) + }}) + + return entry, ok +} + +func (c *nseCache) add(entry *cacheEntry, values []*registry.NetworkServiceEndpoint) []*registry.NetworkServiceEndpoint { + entry.lock.Lock() + defer entry.lock.Unlock() + if c.clockTime.Until(entry.expirationTime) < 0 { + entry.cleanup() + } else { + entry.expirationTime = c.clockTime.Now().Add(c.expireTimeout) + nse, ok := entry.value.(*registry.NetworkServiceEndpoint) + if ok { + values = append(values, nse) + } + } + + return values +} + +// Checks if a is a subset of b +func subset(a, b []string) bool { + set := make(map[string]struct{}) + for _, value := range a { + set[value] = struct{}{} + } + + for _, value := range b { + if _, found := set[value]; !found { + return false + } + } + + return true +} + +func (c *nseCache) Load(ctx context.Context, query *registry.NetworkServiceEndpointQuery) []*registry.NetworkServiceEndpoint { + values := make([]*registry.NetworkServiceEndpoint, 0) + + if query.NetworkServiceEndpoint.Name != "" { + entry, ok := c.entries.Load(query.NetworkServiceEndpoint.Name) + if ok { + values = c.add(entry, values) + } + return values + } + + c.entries.Range(func(key string, entry *cacheEntry) bool { + nse, ok := entry.value.(*registry.NetworkServiceEndpoint) + if ok && subset(query.NetworkServiceEndpoint.NetworkServiceNames, nse.NetworkServiceNames) { + values = c.add(entry, values) + } + return true + }) + + return values +} diff --git a/pkg/registry/common/querycache/nse_client.go b/pkg/registry/common/querycache/nse_client.go index efc37f204..f441d257c 100644 --- a/pkg/registry/common/querycache/nse_client.go +++ b/pkg/registry/common/querycache/nse_client.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -31,14 +33,14 @@ import ( type queryCacheNSEClient struct { ctx context.Context - cache *cache + cache *nseCache } -// NewClient creates new querycache NSE registry client that caches all resolved NSEs -func NewClient(ctx context.Context, opts ...Option) registry.NetworkServiceEndpointRegistryClient { +// NewNetworkServiceEndpointClient creates new querycache NSE registry client that caches all resolved NSEs +func NewNetworkServiceEndpointClient(ctx context.Context, opts ...NSECacheOption) registry.NetworkServiceEndpointRegistryClient { return &queryCacheNSEClient{ ctx: ctx, - cache: newCache(ctx, opts...), + cache: newNSECache(ctx, opts...), } } @@ -51,7 +53,7 @@ func (q *queryCacheNSEClient) Find(ctx context.Context, query *registry.NetworkS return next.NetworkServiceEndpointRegistryClient(ctx).Find(ctx, query, opts...) } - if client, ok := q.findInCache(ctx, query.String()); ok { + if client, ok := q.findInCache(ctx, query); ok { return client, nil } @@ -72,31 +74,24 @@ func (q *queryCacheNSEClient) Find(ctx context.Context, query *registry.NetworkS return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), nil } -func (q *queryCacheNSEClient) findInCache(ctx context.Context, key string) (registry.NetworkServiceEndpointRegistry_FindClient, bool) { - nse, ok := q.cache.Load(key) - if !ok { +func (q *queryCacheNSEClient) findInCache(ctx context.Context, query *registry.NetworkServiceEndpointQuery) (registry.NetworkServiceEndpointRegistry_FindClient, bool) { + nses := q.cache.Load(ctx, query) + if len(nses) == 0 { return nil, false } - resultCh := make(chan *registry.NetworkServiceEndpointResponse, 1) - resultCh <- ®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: nse.Clone()} + resultCh := make(chan *registry.NetworkServiceEndpointResponse, len(nses)) + for _, nse := range nses { + resultCh <- ®istry.NetworkServiceEndpointResponse{NetworkServiceEndpoint: nse.Clone()} + } close(resultCh) return streamchannel.NewNetworkServiceEndpointFindClient(ctx, resultCh), true } func (q *queryCacheNSEClient) storeInCache(ctx context.Context, nse *registry.NetworkServiceEndpoint, opts ...grpc.CallOption) { - nseQuery := ®istry.NetworkServiceEndpointQuery{ - NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ - Name: nse.Name, - }, - } - - key := nseQuery.String() - findCtx, cancel := context.WithCancel(q.ctx) - - entry, loaded := q.cache.LoadOrStore(key, nse, cancel) + entry, loaded := q.cache.LoadOrStore(nse, cancel) if loaded { cancel() return @@ -105,7 +100,12 @@ func (q *queryCacheNSEClient) storeInCache(ctx context.Context, nse *registry.Ne go func() { defer entry.Cleanup() - nseQuery.Watch = true + nseQuery := ®istry.NetworkServiceEndpointQuery{ + NetworkServiceEndpoint: ®istry.NetworkServiceEndpoint{ + Name: nse.Name, + }, + Watch: true, + } stream, err := next.NetworkServiceEndpointRegistryClient(ctx).Find(findCtx, nseQuery, opts...) if err != nil { diff --git a/pkg/registry/common/querycache/nse_client_test.go b/pkg/registry/common/querycache/nse_client_test.go index 516a911ba..700ffa52c 100644 --- a/pkg/registry/common/querycache/nse_client_test.go +++ b/pkg/registry/common/querycache/nse_client_test.go @@ -1,5 +1,7 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -64,7 +66,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), + querycache.NewNetworkServiceEndpointClient(ctx, querycache.WithNSEExpireTimeout(expireTimeout)), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) @@ -86,26 +88,21 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { nseResp, err := stream.Recv() require.NoError(t, err) - require.Equal(t, name, nseResp.NetworkServiceEndpoint.Name) require.Equal(t, url1, nseResp.NetworkServiceEndpoint.Url) // 2. Find from cache atomic.StoreInt32(&failureClient.shouldFail, 1) - require.Eventually(t, func() bool { - if stream, err = c.Find(ctx, testNSEQuery(name)); err != nil { - return false - } - if nseResp, err = stream.Recv(); err != nil { - return false - } - return name == nseResp.NetworkServiceEndpoint.Name && url1 == nseResp.NetworkServiceEndpoint.Url - }, testWait, testTick) + stream, err = c.Find(ctx, testNSEQuery(name)) + require.NoError(t, err) + nseResp, err = stream.Recv() + require.NoError(t, err) + require.Equal(t, name, nseResp.NetworkServiceEndpoint.Name) + require.Equal(t, url1, nseResp.NetworkServiceEndpoint.Url) // 3. Update NSE in memory reg.Url = url2 - reg, err = mem.Register(ctx, reg) require.NoError(t, err) @@ -129,7 +126,7 @@ func Test_QueryCacheClient_ShouldCacheNSEs(t *testing.T) { }, testWait, testTick) } -func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { +func Test_QueryCacheClient_ShouldCleanUpNSEOnTimeout(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) ctx, cancel := context.WithCancel(context.Background()) @@ -142,7 +139,7 @@ func Test_QueryCacheClient_ShouldCleanUpOnTimeout(t *testing.T) { failureClient := new(failureNSEClient) c := next.NewNetworkServiceEndpointRegistryClient( - querycache.NewClient(ctx, querycache.WithExpireTimeout(expireTimeout)), + querycache.NewNetworkServiceEndpointClient(ctx, querycache.WithNSEExpireTimeout(expireTimeout)), failureClient, adapters.NetworkServiceEndpointServerToClient(mem), ) diff --git a/pkg/registry/common/querycache/option.go b/pkg/registry/common/querycache/option.go index f70220737..55338a20d 100644 --- a/pkg/registry/common/querycache/option.go +++ b/pkg/registry/common/querycache/option.go @@ -1,5 +1,7 @@ // Copyright (c) 2021 Doc.ai and/or its affiliates. // +// Copyright (c) 2024 Cisco and/or its affiliates. +// // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,12 +20,22 @@ package querycache import "time" -// Option is an option for cache -type Option func(c *cache) +// NSCacheOption is an option for NS cache +type NSCacheOption func(c *nsCache) + +// NSECacheOption is an option for NSE cache +type NSECacheOption func(c *nseCache) + +// WithNSExpireTimeout sets NS cache expire timeout +func WithNSExpireTimeout(expireTimeout time.Duration) NSCacheOption { + return func(c *nsCache) { + c.expireTimeout = expireTimeout + } +} -// WithExpireTimeout sets cache expire timeout -func WithExpireTimeout(expireTimeout time.Duration) Option { - return func(c *cache) { +// WithNSEExpireTimeout sets NSE cache expire timeout +func WithNSEExpireTimeout(expireTimeout time.Duration) NSECacheOption { + return func(c *nseCache) { c.expireTimeout = expireTimeout } }