Skip to content

Commit

Permalink
add a check for services before selecting an endpoint (#1634)
Browse files Browse the repository at this point in the history
  • Loading branch information
omerlavanet authored Aug 18, 2024
1 parent 801095d commit d1ce606
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 3 deletions.
4 changes: 2 additions & 2 deletions protocol/lavasession/consumer_session_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ func (csm *ConsumerSessionManager) probeProviders(ctx context.Context, pairingLi

// this code needs to be thread safe
func (csm *ConsumerSessionManager) probeProvider(ctx context.Context, consumerSessionsWithProvider *ConsumerSessionsWithProvider, epoch uint64, tryReconnectToDisabledEndpoints bool) (latency time.Duration, providerAddress string, err error) {
connected, endpoints, providerAddress, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, tryReconnectToDisabledEndpoints, true)
connected, endpoints, providerAddress, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, tryReconnectToDisabledEndpoints, true, "", nil)
if err != nil || !connected {
if AllProviderEndpointsDisabledError.Is(err) {
csm.blockProvider(providerAddress, true, epoch, MaxConsecutiveConnectionAttempts, 0, false, csm.GenerateReconnectCallback(consumerSessionsWithProvider), []error{err}) // reporting and blocking provider this epoch
Expand Down Expand Up @@ -454,7 +454,7 @@ func (csm *ConsumerSessionManager) GetSessions(ctx context.Context, cuNeededForS
sessionEpoch := sessionWithProvider.CurrentEpoch

// Get a valid Endpoint from the provider chosen
connected, endpoints, _, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, false, false)
connected, endpoints, _, err := consumerSessionsWithProvider.fetchEndpointConnectionFromConsumerSessionWithProvider(ctx, false, false, addon, extensionNames)
if err != nil {
// verify err is AllProviderEndpointsDisabled and report.
if AllProviderEndpointsDisabledError.Is(err) {
Expand Down
25 changes: 24 additions & 1 deletion protocol/lavasession/consumer_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,23 @@ type Endpoint struct {
Geolocation planstypes.Geolocation
}

func (e *Endpoint) CheckSupportForServices(addon string, extensions []string) (supported bool) {
if addon != "" {
if _, ok := e.Addons[addon]; !ok {
return false
}
}
for _, extension := range extensions {
if extension == "" {
continue
}
if _, ok := e.Extensions[extension]; !ok {
return false
}
}
return true
}

type SessionWithProvider struct {
SessionsWithProvider *ConsumerSessionsWithProvider
CurrentEpoch uint64
Expand Down Expand Up @@ -457,7 +474,7 @@ func (cswp *ConsumerSessionsWithProvider) sortEndpointsByLatency(endpointInfos [

// fetching an endpoint from a ConsumerSessionWithProvider and establishing a connection,
// can fail without an error if trying to connect once to each endpoint but none of them are active.
func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSessionWithProvider(ctx context.Context, retryDisabledEndpoints bool, getAllEndpoints bool) (connected bool, endpointsList []*EndpointAndChosenConnection, providerAddress string, err error) {
func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSessionWithProvider(ctx context.Context, retryDisabledEndpoints bool, getAllEndpoints bool, addon string, extensionNames []string) (connected bool, endpointsList []*EndpointAndChosenConnection, providerAddress string, err error) {
getConnectionFromConsumerSessionsWithProvider := func(ctx context.Context) (connected bool, endpointPtr []*EndpointAndChosenConnection, allDisabled bool) {
endpoints := make([]*EndpointAndChosenConnection, 0)
cswp.Lock.Lock()
Expand All @@ -468,6 +485,12 @@ func (cswp *ConsumerSessionsWithProvider) fetchEndpointConnectionFromConsumerSes
if !retryDisabledEndpoints && !endpoint.Enabled {
continue
}

// check endpoint supports the requested addons
supported := endpoint.CheckSupportForServices(addon, extensionNames)
if !supported {
continue
}
// return
connectEndpoint := func(cswp *ConsumerSessionsWithProvider, ctx context.Context, endpoint *Endpoint) (endpointConnection_ *EndpointConnection, connected_ bool) {
for _, endpointConnection := range endpoint.Connections {
Expand Down

0 comments on commit d1ce606

Please sign in to comment.