diff --git a/pkg/registry/common/memory/ns_server.go b/pkg/registry/common/memory/ns_server.go index 4db145d8a..9fbf22420 100644 --- a/pkg/registry/common/memory/ns_server.go +++ b/pkg/registry/common/memory/ns_server.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2021 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco and/or its affiliates. +// Copyright (c) 2023-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -37,7 +37,7 @@ import ( type memoryNSServer struct { networkServices genericsync.Map[string, *registry.NetworkService] executor serialize.Executor - eventChannels map[string]chan *registry.NetworkService + eventChannels map[string]chan *registry.NetworkServiceResponse eventChannelSize int } @@ -45,7 +45,7 @@ type memoryNSServer struct { func NewNetworkServiceRegistryServer(options ...Option) registry.NetworkServiceRegistryServer { s := &memoryNSServer{ eventChannelSize: defaultEventChannelSize, - eventChannels: make(map[string]chan *registry.NetworkService), + eventChannels: make(map[string]chan *registry.NetworkServiceResponse), } for _, o := range options { o.apply(s) @@ -65,12 +65,12 @@ func (s *memoryNSServer) Register(ctx context.Context, ns *registry.NetworkServi s.networkServices.Store(r.Name, r.Clone()) - s.sendEvent(r) + s.sendEvent(®istry.NetworkServiceResponse{NetworkService: r}) return r, nil } -func (s *memoryNSServer) sendEvent(event *registry.NetworkService) { +func (s *memoryNSServer) sendEvent(event *registry.NetworkServiceResponse) { event = event.Clone() s.executor.AsyncExec(func() { for _, ch := range s.eventChannels { @@ -93,13 +93,13 @@ func (s *memoryNSServer) Find(query *registry.NetworkServiceQuery, server regist return next.NetworkServiceRegistryServer(server.Context()).Find(query, server) } - eventCh := make(chan *registry.NetworkService, s.eventChannelSize) + eventCh := make(chan *registry.NetworkServiceResponse, s.eventChannelSize) id := uuid.New().String() s.executor.AsyncExec(func() { s.eventChannels[id] = eventCh for _, entity := range s.allMatches(query) { - eventCh <- entity + eventCh <- ®istry.NetworkServiceResponse{NetworkService: entity} } }) defer s.closeEventChannel(id, eventCh) @@ -123,7 +123,7 @@ func (s *memoryNSServer) allMatches(query *registry.NetworkServiceQuery) (matche return matches } -func (s *memoryNSServer) closeEventChannel(id string, eventCh <-chan *registry.NetworkService) { +func (s *memoryNSServer) closeEventChannel(id string, eventCh <-chan *registry.NetworkServiceResponse) { ctx, cancel := context.WithCancel(context.Background()) s.executor.AsyncExec(func() { @@ -143,22 +143,18 @@ func (s *memoryNSServer) closeEventChannel(id string, eventCh <-chan *registry.N func (s *memoryNSServer) receiveEvent( query *registry.NetworkServiceQuery, server registry.NetworkServiceRegistry_FindServer, - eventCh <-chan *registry.NetworkService, + eventCh <-chan *registry.NetworkServiceResponse, ) error { select { case <-server.Context().Done(): return errors.WithStack(io.EOF) case event := <-eventCh: - if matchutils.MatchNetworkServices(query.NetworkService, event) { - nse := ®istry.NetworkServiceResponse{ - NetworkService: event, - } - - if err := server.Send(nse); err != nil { + if matchutils.MatchNetworkServices(query.NetworkService, event.NetworkService) { + if err := server.Send(event); err != nil { if server.Context().Err() != nil { return errors.WithStack(io.EOF) } - return errors.Wrapf(err, "NetworkServiceRegistry find server failed to send a response %s", nse.String()) + return errors.Wrapf(err, "NetworkServiceRegistry find server failed to send a response %s", event.String()) } } return nil @@ -166,7 +162,9 @@ func (s *memoryNSServer) receiveEvent( } func (s *memoryNSServer) Unregister(ctx context.Context, ns *registry.NetworkService) (*empty.Empty, error) { - s.networkServices.Delete(ns.Name) - + if unregisterNS, ok := s.networkServices.LoadAndDelete(ns.GetName()); ok { + unregisterNS = unregisterNS.Clone() + s.sendEvent(®istry.NetworkServiceResponse{NetworkService: unregisterNS, Deleted: true}) + } return next.NetworkServiceRegistryServer(ctx).Unregister(ctx, ns) } diff --git a/pkg/registry/common/memory/ns_server_test.go b/pkg/registry/common/memory/ns_server_test.go index e5aea9860..32d107fa2 100644 --- a/pkg/registry/common/memory/ns_server_test.go +++ b/pkg/registry/common/memory/ns_server_test.go @@ -1,6 +1,6 @@ // Copyright (c) 2020-2022 Doc.ai and/or its affiliates. // -// Copyright (c) 2023 Cisco Systems, Inc. +// Copyright (c) 2023-2024 Cisco Systems, Inc. // // SPDX-License-Identifier: Apache-2.0 // @@ -251,6 +251,42 @@ func TestNetworkServiceRegistryServer_ShouldReceiveAllRegisters(t *testing.T) { wgWait(ctx, t, &wg) } +func TestNetworkServiceRegistryServer_DeleteEvent(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + s := memory.NewNetworkServiceRegistryServer() + + ns, err := s.Register(ctx, ®istry.NetworkService{Name: "ns"}) + require.NoError(t, err) + + findCtx, findCancel := context.WithCancel(ctx) + defer findCancel() + + ch := make(chan *registry.NetworkServiceResponse, 2) + go func() { + defer close(ch) + findErr := s.Find(®istry.NetworkServiceQuery{ + NetworkService: ®istry.NetworkService{Name: "ns"}, + Watch: true, + }, streamchannel.NewNetworkServiceFindServer(findCtx, ch)) + require.NoError(t, findErr) + }() + + nsResp := <-ch + require.False(t, nsResp.Deleted) + + _, err = s.Unregister(ctx, ns) + require.NoError(t, err) + + // Read unregister event + nsResp, err = readNSResponse(findCtx, ch) + require.NoError(t, err) + require.True(t, nsResp.Deleted) +} + func readNSResponse(ctx context.Context, ch <-chan *registry.NetworkServiceResponse) (*registry.NetworkServiceResponse, error) { select { case <-ctx.Done():