diff --git a/pkg/networkservice/common/begin/event_factory.go b/pkg/networkservice/common/begin/event_factory.go index 6a8b00bc7..ae92164e9 100644 --- a/pkg/networkservice/common/begin/event_factory.go +++ b/pkg/networkservice/common/begin/event_factory.go @@ -18,7 +18,6 @@ package begin import ( "context" - "time" "github.com/edwarnicke/serialize" "github.com/networkservicemesh/api/pkg/api/networkservice" @@ -159,16 +158,14 @@ type eventFactoryServer struct { ctxFunc func() (context.Context, context.CancelFunc) request *networkservice.NetworkServiceRequest returnedConnection *networkservice.Connection - contextTimeout time.Duration afterCloseFunc func() server networkservice.NetworkServiceServer } -func newEventFactoryServer(ctx context.Context, contextTimeout time.Duration, afterClose func()) *eventFactoryServer { +func newEventFactoryServer(ctx context.Context, afterClose func()) *eventFactoryServer { f := &eventFactoryServer{ server: next.Server(ctx), initialCtxFunc: postpone.Context(ctx), - contextTimeout: contextTimeout, } f.updateContext(ctx) @@ -206,12 +203,7 @@ func (f *eventFactoryServer) Request(opts ...Option) <-chan error { default: ctx, cancel := f.ctxFunc() defer cancel() - - extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout) - defer cancel() - - extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx) - conn, err := f.server.Request(extendedCtx, f.request) + conn, err := f.server.Request(ctx, f.request) if err == nil && f.request != nil { f.request.Connection = conn } @@ -239,12 +231,7 @@ func (f *eventFactoryServer) Close(opts ...Option) <-chan error { default: ctx, cancel := f.ctxFunc() defer cancel() - - extendedCtx, cancel := context.WithTimeout(context.Background(), f.contextTimeout) - defer cancel() - - extendedCtx = extend.WithValuesFromContext(extendedCtx, ctx) - _, err := f.server.Close(extendedCtx, f.request.GetConnection()) + _, err := f.server.Close(ctx, f.request.GetConnection()) f.afterCloseFunc() ch <- err } diff --git a/pkg/networkservice/common/begin/event_factory_server_test.go b/pkg/networkservice/common/begin/event_factory_server_test.go index 3a105998b..3292ec8de 100644 --- a/pkg/networkservice/common/begin/event_factory_server_test.go +++ b/pkg/networkservice/common/begin/event_factory_server_test.go @@ -33,6 +33,7 @@ import ( "github.com/networkservicemesh/sdk/pkg/networkservice/core/chain" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" "github.com/networkservicemesh/sdk/pkg/tools/clock" + "github.com/networkservicemesh/sdk/pkg/tools/clockmock" ) // This test reproduces the situation when refresh changes the eventFactory context @@ -137,12 +138,18 @@ func TestContextTimeout_Server(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - contextTimeout := time.Second * 2 + // Add clockMock to the context + clockMock := clockmock.New(ctx) + ctx = clock.WithClock(ctx, clockMock) + + ctx, cancel = context.WithDeadline(ctx, clockMock.Now().Add(time.Second*3)) + defer cancel() + eventFactoryServ := &eventFactoryServer{} server := chain.NewNetworkServiceServer( - begin.NewServer(begin.WithContextTimeout(contextTimeout)), + begin.NewServer(), eventFactoryServ, - &delayedNSEServer{t: t, contextTimeout: contextTimeout}, + &delayedNSEServer{t: t, clock: clockMock}, ) // Do Request @@ -221,8 +228,8 @@ func (f *failedNSEServer) Close(ctx context.Context, conn *networkservice.Connec type delayedNSEServer struct { t *testing.T + clock *clockmock.Mock initialTimeout time.Duration - contextTimeout time.Duration } func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { @@ -238,10 +245,10 @@ func (d *delayedNSEServer) Request(ctx context.Context, request *networkservice. d.initialTimeout = timeout } // All requests timeout must be equal the first - require.Less(d.t, (d.initialTimeout - timeout).Abs(), time.Second) + require.Equal(d.t, d.initialTimeout, timeout) // Add delay - time.Sleep(timeout / 2) + d.clock.Add(timeout / 2) return next.Server(ctx).Request(ctx, request) } @@ -249,9 +256,9 @@ func (d *delayedNSEServer) Close(ctx context.Context, conn *networkservice.Conne require.Greater(d.t, d.initialTimeout, time.Duration(0)) deadline, _ := ctx.Deadline() - timeout := time.Until(deadline) + clockTime := clock.FromContext(ctx) - require.Less(d.t, (d.contextTimeout - timeout).Abs(), time.Second) + require.Equal(d.t, d.initialTimeout, clockTime.Until(deadline)) return next.Server(ctx).Close(ctx, conn) } diff --git a/pkg/networkservice/common/begin/options.go b/pkg/networkservice/common/begin/options.go index 509acc7aa..abd68afbb 100644 --- a/pkg/networkservice/common/begin/options.go +++ b/pkg/networkservice/common/begin/options.go @@ -18,13 +18,11 @@ package begin import ( "context" - "time" ) type option struct { - cancelCtx context.Context - reselect bool - contextTimeout time.Duration + cancelCtx context.Context + reselect bool } // Option - event option @@ -43,10 +41,3 @@ func WithReselect() Option { o.reselect = true } } - -// WithContextTimeout - set a custom timeout for a context in begin.Close -func WithContextTimeout(timeout time.Duration) Option { - return func(o *option) { - o.contextTimeout = timeout - } -} diff --git a/pkg/networkservice/common/begin/server.go b/pkg/networkservice/common/begin/server.go index a057a9299..f2a9b09e6 100644 --- a/pkg/networkservice/common/begin/server.go +++ b/pkg/networkservice/common/begin/server.go @@ -18,14 +18,12 @@ package begin import ( "context" - "time" "github.com/edwarnicke/genericsync" "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/pkg/errors" "google.golang.org/protobuf/types/known/emptypb" - "github.com/networkservicemesh/sdk/pkg/tools/extend" "github.com/networkservicemesh/sdk/pkg/tools/log" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" @@ -33,30 +31,14 @@ import ( type beginServer struct { genericsync.Map[string, *eventFactoryServer] - contextTimeout time.Duration } // NewServer - creates a new begin chain element -func NewServer(opts ...Option) networkservice.NetworkServiceServer { - o := &option{ - cancelCtx: context.Background(), - reselect: false, - contextTimeout: time.Minute, - } - - for _, opt := range opts { - opt(o) - } - - return &beginServer{ - contextTimeout: o.contextTimeout, - } +func NewServer() networkservice.NetworkServiceServer { + return &beginServer{} } -func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - var conn *networkservice.Connection - var err error - +func (b *beginServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (conn *networkservice.Connection, err error) { // No connection.ID, no service if request.GetConnection().GetId() == "" { return nil, errors.New("request.EventFactory.Id must not be zero valued") @@ -68,14 +50,12 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer, _ := b.LoadOrStore(request.GetConnection().GetId(), newEventFactoryServer( ctx, - b.contextTimeout, func() { b.Delete(request.GetRequestConnection().GetId()) }, ), ) - select { - case <-eventFactoryServer.executor.AsyncExec(func() { + <-eventFactoryServer.executor.AsyncExec(func() { currentEventFactoryServer, _ := b.Load(request.GetConnection().GetId()) if currentEventFactoryServer != eventFactoryServer { log.FromContext(ctx).Debug("recalling begin.Request because currentEventFactoryServer != eventFactoryServer") @@ -88,12 +68,8 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer.request != nil && eventFactoryServer.request.Connection != nil { log.FromContext(ctx).Info("Closing connection due to RESELECT_REQUESTED state") - closeCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) - defer cancel() - eventFactoryCtx, eventFactoryCtxCancel := eventFactoryServer.ctxFunc() - closeCtx = extend.WithValuesFromContext(closeCtx, eventFactoryCtx) - _, closeErr := next.Server(closeCtx).Close(closeCtx, eventFactoryServer.request.Connection) + _, closeErr := next.Server(eventFactoryCtx).Close(eventFactoryCtx, eventFactoryServer.request.Connection) if closeErr != nil { log.FromContext(ctx).Errorf("Can't close old connection: %v", closeErr) } @@ -101,11 +77,8 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryCtxCancel() } - extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) - extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer)) - defer cancel() - - conn, err = next.Server(extendedCtx).Request(extendedCtx, request) + withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) + conn, err = next.Server(withEventFactoryCtx).Request(withEventFactoryCtx, request) if err != nil { if eventFactoryServer.state != established { eventFactoryServer.state = closed @@ -120,48 +93,33 @@ func (b *beginServer) Request(ctx context.Context, request *networkservice.Netwo eventFactoryServer.returnedConnection = conn.Clone() eventFactoryServer.updateContext(ctx) - }): - case <-ctx.Done(): - return nil, ctx.Err() - } - + }) return conn, err } -func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (*emptypb.Empty, error) { - var err error - connID := conn.GetId() +func (b *beginServer) Close(ctx context.Context, conn *networkservice.Connection) (emp *emptypb.Empty, err error) { // If some other EventFactory is already in the ctx... we are already running in an executor, and can just execute normally if fromContext(ctx) != nil { return next.Server(ctx).Close(ctx, conn) } - eventFactoryServer, ok := b.Load(connID) + eventFactoryServer, ok := b.Load(conn.GetId()) if !ok { // If we don't have a connection to Close, just let it be return &emptypb.Empty{}, nil } - - select { - case <-eventFactoryServer.executor.AsyncExec(func() { + <-eventFactoryServer.executor.AsyncExec(func() { if eventFactoryServer.state != established || eventFactoryServer.request == nil { return } - currentServerClient, _ := b.Load(connID) + currentServerClient, _ := b.Load(conn.GetId()) if currentServerClient != eventFactoryServer { return } - extendedCtx, cancel := context.WithTimeout(context.Background(), b.contextTimeout) - extendedCtx = extend.WithValuesFromContext(extendedCtx, withEventFactory(ctx, eventFactoryServer)) - defer cancel() - // Always close with the last valid EventFactory we got conn = eventFactoryServer.request.Connection - _, err = next.Server(extendedCtx).Close(extendedCtx, conn) + withEventFactoryCtx := withEventFactory(ctx, eventFactoryServer) + emp, err = next.Server(withEventFactoryCtx).Close(withEventFactoryCtx, conn) eventFactoryServer.afterCloseFunc() - }): - return &emptypb.Empty{}, err - case <-ctx.Done(): - b.Delete(connID) - return nil, ctx.Err() - } + }) + return &emptypb.Empty{}, err } diff --git a/pkg/networkservice/common/begin/server_test.go b/pkg/networkservice/common/begin/server_test.go deleted file mode 100644 index 70c3142f0..000000000 --- a/pkg/networkservice/common/begin/server_test.go +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright (c) 2024 Cisco and/or its affiliates. -// -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at: -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package begin_test - -import ( - "context" - "sync/atomic" - "testing" - "time" - - "github.com/golang/protobuf/ptypes/empty" - "github.com/networkservicemesh/api/pkg/api/networkservice" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" - "google.golang.org/protobuf/types/known/emptypb" - - "github.com/networkservicemesh/sdk/pkg/networkservice/common/begin" - "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" -) - -const ( - waitTime = time.Second -) - -type waitServer struct { - requestDone atomic.Int32 - closeDone atomic.Int32 -} - -func (s *waitServer) Request(ctx context.Context, request *networkservice.NetworkServiceRequest) (*networkservice.Connection, error) { - afterCh := time.After(time.Second) - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-afterCh: - s.requestDone.Add(1) - } - return next.Server(ctx).Request(ctx, request) -} - -func (s *waitServer) Close(ctx context.Context, connection *networkservice.Connection) (*empty.Empty, error) { - afterCh := time.After(time.Second) - select { - case <-ctx.Done(): - return &emptypb.Empty{}, nil - case <-afterCh: - s.closeDone.Add(1) - } - return next.Server(ctx).Close(ctx, connection) -} - -func TestBeginWorksWithSmallTimeout(t *testing.T) { - t.Cleanup(func() { - goleak.VerifyNone(t) - }) - requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - - waitSrv := &waitServer{} - server := next.NewNetworkServiceServer( - begin.NewServer(), - waitSrv, - ) - - request := testRequest("id") - _, err := server.Request(requestCtx, request) - require.EqualError(t, err, context.DeadlineExceeded.Error()) - require.Equal(t, int32(0), waitSrv.requestDone.Load()) - require.Eventually(t, func() bool { - return waitSrv.requestDone.Load() == 1 - }, waitTime*2, time.Millisecond*500) - - closeCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - _, err = server.Close(closeCtx, request.Connection) - require.EqualError(t, err, context.DeadlineExceeded.Error()) - require.Equal(t, int32(0), waitSrv.closeDone.Load()) - require.Eventually(t, func() bool { - return waitSrv.closeDone.Load() == 1 - }, waitTime*2, time.Millisecond*500) -} - -func TestBeginHasExtendedTimeoutOnReselect(t *testing.T) { - t.Cleanup(func() { - goleak.VerifyNone(t) - }) - requestCtx, cancel := context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - - waitSrv := &waitServer{} - server := next.NewNetworkServiceServer( - begin.NewServer(), - waitSrv, - ) - - // Make a first request to create an event factory. Begin should make Request only - request := testRequest("id") - _, err := server.Request(requestCtx, request) - require.EqualError(t, err, context.DeadlineExceeded.Error()) - require.Equal(t, int32(0), waitSrv.requestDone.Load()) - require.Eventually(t, func() bool { - return waitSrv.requestDone.Load() == 1 - }, waitTime*2, time.Millisecond*500) - - // Make a second request with RESELECT_REQUESTED. Begin should make Close with extended context first and then Request - requestCtx, cancel = context.WithTimeout(context.Background(), time.Millisecond*200) - defer cancel() - newRequest := request.Clone() - newRequest.Connection.State = networkservice.State_RESELECT_REQUESTED - - _, err = server.Request(requestCtx, newRequest) - require.EqualError(t, err, context.DeadlineExceeded.Error()) - require.Equal(t, int32(0), waitSrv.closeDone.Load()) - require.Eventually(t, func() bool { - return waitSrv.closeDone.Load() == 1 && waitSrv.requestDone.Load() == 2 - }, waitTime*4, time.Millisecond*500) -} diff --git a/pkg/networkservice/common/dial/client.go b/pkg/networkservice/common/dial/client.go index 31be1deac..d53d9217f 100644 --- a/pkg/networkservice/common/dial/client.go +++ b/pkg/networkservice/common/dial/client.go @@ -72,12 +72,8 @@ func (d *dialClient) Request(ctx context.Context, request *networkservice.Networ return next.Client(ctx).Request(ctx, request, opts...) } - di.mu.Lock() - dialClientURL := di.clientURL - di.mu.Unlock() - // If our existing dialer has a different URL close down the chain - if dialClientURL != nil && dialClientURL.String() != clientURL.String() { + if di.clientURL != nil && di.clientURL.String() != clientURL.String() { closeCtx, closeCancel := closeContextFunc() defer closeCancel() err := di.Dial(closeCtx, di.clientURL) diff --git a/pkg/networkservice/common/dial/dialer.go b/pkg/networkservice/common/dial/dialer.go index 2d9e769b7..b0abe5d14 100644 --- a/pkg/networkservice/common/dial/dialer.go +++ b/pkg/networkservice/common/dial/dialer.go @@ -20,7 +20,6 @@ import ( "context" "net/url" "runtime" - "sync" "time" "github.com/pkg/errors" @@ -31,13 +30,13 @@ import ( ) type dialer struct { - ctx context.Context - clientURL *url.URL - cleanupCancel context.CancelFunc + ctx context.Context + cleanupContext context.Context + clientURL *url.URL + cleanupCancel context.CancelFunc *grpc.ClientConn dialOptions []grpc.DialOption dialTimeout time.Duration - mu sync.Mutex } func newDialer(ctx context.Context, dialTimeout time.Duration, dialOptions ...grpc.DialOption) *dialer { @@ -57,10 +56,8 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { di.cleanupCancel() } - di.mu.Lock() // Set the clientURL di.clientURL = clientURL - di.mu.Unlock() // Setup dialTimeout if needed dialCtx := ctx @@ -69,10 +66,7 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { } // Dial - di.mu.Lock() target := grpcutils.URLToTarget(di.clientURL) - di.mu.Unlock() - cc, err := grpc.DialContext(dialCtx, target, di.dialOptions...) if err != nil { if cc != nil { @@ -80,32 +74,26 @@ func (di *dialer) Dial(ctx context.Context, clientURL *url.URL) error { } return errors.Wrapf(err, "failed to dial %s", target) } - di.mu.Lock() di.ClientConn = cc - var cleanupContext context.Context - cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) - di.mu.Unlock() + + di.cleanupContext, di.cleanupCancel = context.WithCancel(di.ctx) go func(cleanupContext context.Context, cc *grpc.ClientConn) { <-cleanupContext.Done() _ = cc.Close() - }(cleanupContext, cc) + }(di.cleanupContext, cc) return nil } func (di *dialer) Close() error { if di != nil && di.cleanupCancel != nil { - di.mu.Lock() di.cleanupCancel() - di.mu.Unlock() runtime.Gosched() } return nil } func (di *dialer) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { - di.mu.Lock() - defer di.mu.Unlock() if di.ClientConn == nil { return errors.New("no dialer.ClientConn found") } @@ -113,9 +101,6 @@ func (di *dialer) Invoke(ctx context.Context, method string, args, reply interfa } func (di *dialer) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - di.mu.Lock() - defer di.mu.Unlock() - if di.ClientConn == nil { return nil, errors.New("no dialer.ClientConn found") }