From d1ce6060e26571a38c880b7d877b634e01173f0e Mon Sep 17 00:00:00 2001 From: Omer <100387053+omerlavanet@users.noreply.github.com> Date: Sun, 18 Aug 2024 21:11:53 +0300 Subject: [PATCH] add a check for services before selecting an endpoint (#1634) --- .../lavasession/consumer_session_manager.go | 4 +-- protocol/lavasession/consumer_types.go | 25 ++++++++++++++++++- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/protocol/lavasession/consumer_session_manager.go b/protocol/lavasession/consumer_session_manager.go index b54e3c0b87..0074c74508 100644 --- a/protocol/lavasession/consumer_session_manager.go +++ b/protocol/lavasession/consumer_session_manager.go @@ -222,7 +222,7 @@ func (csm *ConsumerSessionManager) probeProviders(ctx context.Context, pairingLi // this code needs to be thread safe func (csm *ConsumerSessionManager) probeProvider(ctx context.Context, consumerSessionsWithProvider *ConsumerSessionsWithProvider, epoch uint64, tryReconnectToDisabledEndpoints bool) (latency time.Duration, providerAddress string, err error) { - connected, endpoints, providerAddress, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, tryReconnectToDisabledEndpoints, true) + connected, endpoints, providerAddress, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, tryReconnectToDisabledEndpoints, true, "", nil) if err != nil || !connected { if AllProviderEndpointsDisabledError.Is(err) { csm.blockProvider(providerAddress, true, epoch, MaxConsecutiveConnectionAttempts, 0, false, csm.GenerateReconnectCallback(consumerSessionsWithProvider), []error{err}) // reporting and blocking provider this epoch @@ -454,7 +454,7 @@ func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForS sessionEpoch := sessionWithProvider.CurrentEpoch // Get a valid Endpoint from the provider chosen - connected, endpoints, _, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, false, false) + connected, endpoints, _, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, false, false, addon, extensionNames) if err != nil { // verify err is AllProviderEndpointsDisabled and report. if AllProviderEndpointsDisabledError.Is(err) { diff --git a/protocol/lavasession/consumer_types.go b/protocol/lavasession/consumer_types.go index 51e2b325fa..ef3e6735bf 100644 --- a/protocol/lavasession/consumer_types.go +++ b/protocol/lavasession/consumer_types.go @@ -153,6 +153,23 @@ type Endpoint struct { Geolocation planstypes.Geolocation } +func (e *Endpoint) CheckSupportForServices(addon string, extensions []string) (supported bool) { + if addon != "" { + if _, ok := e.Addons[addon]; !ok { + return false + } + } + for _, extension := range extensions { + if extension == "" { + continue + } + if _, ok := e.Extensions[extension]; !ok { + return false + } + } + return true +} + type SessionWithProvider struct { SessionsWithProvider *ConsumerSessionsWithProvider CurrentEpoch uint64 @@ -457,7 +474,7 @@ func (cswp *ConsumerSessionsWithProvider) sortEndpointsByLatency(endpointInfos [ // fetching an endpoint from a ConsumerSessionWithProvider and establishing a connection, // can fail without an error if trying to connect once to each endpoint but none of them are active. -func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSessionWithProvider(ctx context.Context, retryDisabledEndpoints bool, getAllEndpoints bool) (connected bool, endpointsList []*EndpointAndChosenConnection, providerAddress string, err error) { +func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSessionWithProvider(ctx context.Context, retryDisabledEndpoints bool, getAllEndpoints bool, addon string, extensionNames []string) (connected bool, endpointsList []*EndpointAndChosenConnection, providerAddress string, err error) { getConnectionFromConsumerSessionsWithProvider := func(ctx context.Context) (connected bool, endpointPtr []*EndpointAndChosenConnection, allDisabled bool) { endpoints := make([]*EndpointAndChosenConnection, 0) cswp.Lock.Lock() @@ -468,6 +485,12 @@ func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSes if !retryDisabledEndpoints && !endpoint.Enabled { continue } + + // check endpoint supports the requested addons + supported := endpoint.CheckSupportForServices(addon, extensionNames) + if !supported { + continue + } // return connectEndpoint := func(cswp *ConsumerSessionsWithProvider, ctx context.Context, endpoint *Endpoint) (endpointConnection_ *EndpointConnection, connected_ bool) { for _, endpointConnection := range endpoint.Connections {