diff --git a/pkg/networkservice/chains/client/client.go b/pkg/networkservice/chains/client/client.go index a8467383d..07d5a8410 100644 --- a/pkg/networkservice/chains/client/client.go +++ b/pkg/networkservice/chains/client/client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2022 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -44,6 +44,7 @@ func NewClient(ctx context.Context, clientOpts ...Option) networkservice.Network authorizeClient: null.NewClient(), healClient: null.NewClient(), refreshClient: refresh.NewClient(ctx), + reselectFunc: begin.DefaultReselectFunc, } for _, opt := range clientOpts { opt(opts) @@ -53,7 +54,7 @@ func NewClient(ctx context.Context, clientOpts ...Option) networkservice.Network append( []networkservice.NetworkServiceClient{ updatepath.NewClient(opts.name), - begin.NewClient(), + begin.NewClient(begin.WithReselectFunc(opts.reselectFunc)), metadata.NewClient(), opts.refreshClient, clienturl.NewClient(opts.clientURL), diff --git a/pkg/networkservice/chains/client/options.go b/pkg/networkservice/chains/client/options.go index 1bea59244..d3f0f4564 100644 --- a/pkg/networkservice/chains/client/options.go +++ b/pkg/networkservice/chains/client/options.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -23,6 +23,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "google.golang.org/grpc" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" "github.com/networkservicemesh/sdk/pkg/networkservice/common/null" ) @@ -36,6 +37,7 @@ type clientOptions struct { healClient networkservice.NetworkServiceClient dialOptions []grpc.DialOption dialTimeout time.Duration + reselectFunc begin.ReselectFunc } // Option modifies default client chain values. @@ -109,3 +111,10 @@ func WithoutRefresh() Option { c.refreshClient = null.NewClient() } } + +// WithReselectFunc sets a function for changing request parameters on reselect +func WithReselectFunc(f func(*networkservice.NetworkServiceRequest)) Option { + return func(c *clientOptions) { + c.reselectFunc = f + } +} diff --git a/pkg/networkservice/chains/nsmgr/heal_test.go b/pkg/networkservice/chains/nsmgr/heal_test.go index ac54b5b22..01bc6c3cc 100644 --- a/pkg/networkservice/chains/nsmgr/heal_test.go +++ b/pkg/networkservice/chains/nsmgr/heal_test.go @@ -48,8 +48,10 @@ import ( ) const ( - tick = 10 * time.Millisecond - timeout = 10 * time.Second + tick = 10 * time.Millisecond + timeout = 10 * time.Second + labelKey = "key" + labelValue = "value" ) func TestNSMGR_HealEndpoint(t *testing.T) { @@ -911,3 +913,44 @@ func TestNSMGR_RefreshFailed_ControlPlaneBroken(t *testing.T) { require.NoError(t, err) require.Equal(t, 2, counter.Requests()) } + +func TestNSMGRHealEndpoint_CustomReselectFunc(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + + defer cancel() + domain := sandbox.NewBuilder(ctx, t). + SetNodesCount(1). + SetNSMgrProxySupplier(nil). + SetRegistryProxySupplier(nil). + Build() + + nsReg, err := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken).Register(ctx, defaultRegistryService(t.Name())) + require.NoError(t, err) + + nseReg := defaultRegistryEndpoint(nsReg.Name) + nse := domain.Nodes[0].NewEndpoint(ctx, nseReg, sandbox.GenerateTestToken) + + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, nsclient.WithHealClient(heal.NewClient(ctx)), + nsclient.WithReselectFunc( + func(request *networkservice.NetworkServiceRequest) { + request.Connection.Labels = make(map[string]string) + request.Connection.Labels[labelKey] = labelValue + request.Connection.NetworkServiceEndpointName = "" + })) + + request := defaultRequest(nsReg.Name) + _, err = nsc.Request(ctx, request.Clone()) + require.NoError(t, err) + + nse.Cancel() + + nseReg2 := defaultRegistryEndpoint(nsReg.Name) + nseReg2.Name += "-2" + domain.Nodes[0].NewEndpoint(ctx, nseReg2, sandbox.GenerateTestToken) + + require.Eventually(t, func() bool { + resp, err := nsc.Request(ctx, request.Clone()) + return err == nil && resp.Labels[labelKey] == labelValue + }, timeout, tick) +} diff --git a/pkg/networkservice/common/begin/client.go b/pkg/networkservice/common/begin/client.go index 3980cd24f..658b63f82 100644 --- a/pkg/networkservice/common/begin/client.go +++ b/pkg/networkservice/common/begin/client.go @@ -1,4 +1,4 @@ -// Copyright (c) 2021-2023 Cisco and/or its affiliates. +// Copyright (c) 2021-2024 Cisco and/or its affiliates. // // SPDX-License-Identifier: Apache-2.0 // @@ -32,11 +32,21 @@ import ( type beginClient struct { genericsync.Map[string, *eventFactoryClient] + reselectFunc ReselectFunc } // NewClient - creates a new begin chain element -func NewClient() networkservice.NetworkServiceClient { - return &beginClient{} +func NewClient(opts ...ClientOption) networkservice.NetworkServiceClient { + o := &clientOption{ + reselectFunc: DefaultReselectFunc, + } + for _, opt := range opts { + opt(o) + } + + return &beginClient{ + reselectFunc: o.reselectFunc, + } } func (b *beginClient) Request(ctx context.Context, request *networkservice.NetworkServiceRequest, opts ...grpc.CallOption) (conn *networkservice.Connection, err error) { @@ -54,6 +64,7 @@ func (b *beginClient) Request(ctx context.Context, request *networkservice.Netwo func() { b.Delete(request.GetRequestConnection().GetId()) }, + b.reselectFunc, opts..., ), ) diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index ae92164e9..bafa58ec6 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -40,6 +40,18 @@ const ( var _ connectionState = zero +// ReselectFunc - function for changing request parameters on reselect +type ReselectFunc func(request *networkservice.NetworkServiceRequest) + +// DefaultReselectFunc - default ReselectFunc +var DefaultReselectFunc ReselectFunc = func(request *networkservice.NetworkServiceRequest) { + if request.GetConnection() != nil { + request.GetConnection().Mechanism = nil + request.GetConnection().NetworkServiceEndpointName = "" + request.GetConnection().State = networkservice.State_RESELECT_REQUESTED + } +} + // EventFactory - allows firing off a Request or Close event from midchain type EventFactory interface { Request(opts ...Option) <-chan error @@ -56,12 +68,14 @@ type eventFactoryClient struct { opts []grpc.CallOption client networkservice.NetworkServiceClient afterCloseFunc func() + reselectFunc ReselectFunc } -func newEventFactoryClient(ctx context.Context, afterClose func(), opts ...grpc.CallOption) *eventFactoryClient { +func newEventFactoryClient(ctx context.Context, afterClose func(), reselectFunc func(*networkservice.NetworkServiceRequest), opts ...grpc.CallOption) *eventFactoryClient { f := &eventFactoryClient{ client: next.Client(ctx), initialCtxFunc: postpone.Context(ctx), + reselectFunc: reselectFunc, opts: opts, } f.updateContext(ctx) @@ -103,11 +117,7 @@ func (f *eventFactoryClient) Request(opts ...Option) <-chan error { if o.reselect { ctx, cancel := f.ctxFunc() _, _ = f.client.Close(ctx, request.GetConnection(), f.opts...) - if request.GetConnection() != nil { - request.GetConnection().Mechanism = nil - request.GetConnection().NetworkServiceEndpointName = "" - request.GetConnection().State = networkservice.State_RESELECT_REQUESTED - } + f.reselectFunc(request) cancel() } ctx, cancel := f.ctxFunc() diff --git a/pkg/networkservice/common/begin/options.go b/pkg/networkservice/common/begin/options.go index abd68afbb..fe6c3e1e2 100644 --- a/pkg/networkservice/common/begin/options.go +++ b/pkg/networkservice/common/begin/options.go @@ -25,9 +25,16 @@ type option struct { reselect bool } -// Option - event option +type clientOption struct { + reselectFunc ReselectFunc +} + +// Option - event factory option type Option func(*option) +// ClientOption - begin client option +type ClientOption func(*clientOption) + // CancelContext - optionally provide a context that, when canceled will preclude the event from running func CancelContext(cancelCtx context.Context) Option { return func(o *option) { @@ -41,3 +48,10 @@ func WithReselect() Option { o.reselect = true } } + +// WithReselectFunc - sets a function for changing request parameters on reselect +func WithReselectFunc(reselectFunc ReselectFunc) ClientOption { + return func(o *clientOption) { + o.reselectFunc = reselectFunc + } +}