Skip to content

Commit

Permalink
fix: force default value for endpoint weight
Browse files Browse the repository at this point in the history
  - in order to avoid having endpoint weight to 0, validation is done on service side and a protection is applied on gateway side
  • Loading branch information
guillaumelamirand committed Dec 13, 2024
1 parent 12e1a16 commit 1a1c798
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ public Endpoint deserialize(JsonParser parser, DeserializationContext ctxt) thro
final JsonNode weightNode = node.get("weight");
if (weightNode != null) {
int weight = weightNode.asInt(Endpoint.DEFAULT_WEIGHT);
if (weight <= 0) {
weight = Endpoint.DEFAULT_WEIGHT;
}
endpoint.setWeight(weight);
} else {
endpoint.setWeight(Endpoint.DEFAULT_WEIGHT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
public class Endpoint implements Serializable {

private static final long serialVersionUID = 7139083731513897591L;
private static final int DEFAULT_WEIGHT = 1;
public static final int DEFAULT_WEIGHT = 1;

@NotBlank
private String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ private void loadRuntimeRatios() {
int position = 0;

for (Endpoint endpoint : endpoints) {
runtimeRatios.add(new WeightRatio(position++, endpoint.weight()));
runtimeRatios.add(new WeightRatio(position++, computeWeight(endpoint)));
}
}

private int computeWeight(final Endpoint endpoint) {
// has been implemented to protect the load balancer behavior as the initial weight cannot be 0 or lower
return endpoint.weight() > 0 ? endpoint.weight() : 1;
}

boolean isRuntimeRatiosZeroed() {
boolean cleared = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,14 @@ public void refresh() {

int position = 0;
for (ManagedEndpoint managedEndpoint : endpoints) {
computedDistribution.add(new WeightDistributions.WeightDistribution(position++, managedEndpoint.getDefinition().getWeight()));
computedDistribution.add(new WeightDistributions.WeightDistribution(position++, computeWeight(managedEndpoint)));
}
weightDistributions.set(new WeightDistributions(computedDistribution));
}

private int computeWeight(final ManagedEndpoint managedEndpoint) {
// has been implemented to protect the load balancer behavior as the initial weight cannot be 0 or lower
int weight = managedEndpoint.getDefinition().getWeight();
return weight > 0 ? weight : 1;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
/*
* Copyright © 2015 The Gravitee team (http://gravitee.io)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.gravitee.gateway.core.loadbalancer;

import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import io.gravitee.gateway.api.endpoint.Endpoint;
import io.gravitee.gateway.reactive.api.connector.endpoint.EndpointConnector;
import io.gravitee.gateway.reactive.core.v4.endpoint.DefaultManagedEndpoint;
import io.gravitee.gateway.reactive.core.v4.endpoint.DefaultManagedEndpointGroup;
import io.gravitee.gateway.reactive.core.v4.endpoint.ManagedEndpoint;
import java.util.ArrayList;
import java.util.List;
import org.junit.jupiter.api.DisplayNameGeneration;
import org.junit.jupiter.api.DisplayNameGenerator;
import org.junit.jupiter.api.Test;

/**
* @author Guillaume LAMIRAND (guillaume.lamirand at graviteesource.com)
* @author GraviteeSource Team
*/
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class WeightedRoundRobinLoadBalancerTest {

@Test
void should_return_null_with_empty_endpoints() {
WeightedRoundRobinLoadBalancer cut = new WeightedRoundRobinLoadBalancer(List.of());
Endpoint next = cut.next();
assertThat(next).isNull();
}

@Test
void should_return_endpoint_even_with_invalid_configured_weight() {
List<Endpoint> endpoints = new ArrayList<>();
Endpoint endpoint1 = endpoint(0);
endpoints.add(endpoint1);
WeightedRoundRobinLoadBalancer cut = new WeightedRoundRobinLoadBalancer(endpoints);
// 1
Endpoint next = cut.next();
assertThat(next).isEqualTo(endpoint1); // 1 > 0
}

@Test
void should_return_endpoints_in_order() {
List<Endpoint> endpoints = new ArrayList<>();
Endpoint endpoint1 = endpoint(1);
endpoints.add(endpoint1);
Endpoint endpoint2 = endpoint(5);
endpoints.add(endpoint2);
Endpoint endpoint3 = endpoint(3);
endpoints.add(endpoint3);

WeightedRoundRobinLoadBalancer cut = new WeightedRoundRobinLoadBalancer(endpoints);
// 1
Endpoint next = cut.next();
assertThat(next).isEqualTo(endpoint1); // 1 > 0
next = cut.next();
assertThat(next).isEqualTo(endpoint2); // 5 > 4
next = cut.next();
assertThat(next).isEqualTo(endpoint3); // 3 > 2

// 2
next = cut.next();
assertThat(next).isEqualTo(endpoint2); // 4 > 3
next = cut.next();
assertThat(next).isEqualTo(endpoint3); // 2 > 1

// 3
next = cut.next();
assertThat(next).isEqualTo(endpoint2); // 3 > 2
next = cut.next();
assertThat(next).isEqualTo(endpoint3); // 1 > 0

// 4
next = cut.next();
assertThat(next).isEqualTo(endpoint2); // 2 > 1
next = cut.next();
assertThat(next).isEqualTo(endpoint2); // 1 > 0

// 5
next = cut.next();
assertThat(next).isEqualTo(endpoint1);
next = cut.next();
assertThat(next).isEqualTo(endpoint2);
next = cut.next();
assertThat(next).isEqualTo(endpoint3);
}

private static Endpoint endpoint(int weight) {
Endpoint endpoint = mock(Endpoint.class, "endpoint with %s".formatted(weight));
when(endpoint.weight()).thenReturn(weight);
when(endpoint.primary()).thenReturn(true);
return endpoint;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,23 @@ void should_return_null_with_empty_endpoints() {
assertThat(next).isNull();
}

@Test
void should_return_endpoint_even_with_invalid_configured_weight() {
List<ManagedEndpoint> endpoints = new ArrayList<>();
Endpoint endpoint1 = new Endpoint();
endpoint1.setWeight(0);
ManagedEndpoint managedEndpoint1 = new DefaultManagedEndpoint(
endpoint1,
new DefaultManagedEndpointGroup(new EndpointGroup()),
mock(EndpointConnector.class)
);
endpoints.add(managedEndpoint1);
WeightedRoundRobinLoadBalancer cut = new WeightedRoundRobinLoadBalancer(endpoints);
// 1
ManagedEndpoint next = cut.next();
assertThat(next).isEqualTo(managedEndpoint1); // 1 > 0
}

@Test
void should_return_endpoints_in_order() {
List<ManagedEndpoint> endpoints = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public List<EndpointGroup> validateAndSanitize(ApiType apiType, List<EndpointGro
.getEndpoints()
.forEach(endpoint -> {
validateUniqueEndpointName(endpoint.getName(), names);
validateEndpointWeight(endpoint);
validateEndpointType(endpoint.getType());
validateServices(endpointGroup.getServices(), endpoint.getServices());
validateEndpointMatchType(endpointGroup, endpoint);
Expand All @@ -103,6 +104,12 @@ public List<EndpointGroup> validateAndSanitize(ApiType apiType, List<EndpointGro
return endpointGroups;
}

private void validateEndpointWeight(final Endpoint endpoint) {
if (endpoint.getWeight() <= 0) {
endpoint.setWeight(Endpoint.DEFAULT_WEIGHT);
}
}

private void validateEndpointConfiguration(ConnectorPluginEntity endpointConnector, Endpoint endpoint) {
endpoint.setConfiguration(endpointService.validateConnectorConfiguration(endpointConnector, endpoint.getConfiguration()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void shouldReturnValidatedEndpointGroupsWithEndpoints() {
Endpoint validatedEndpoint = endpoints.get(0);
assertThat(validatedEndpoint.getName()).isEqualTo("endpoint");
assertThat(validatedEndpoint.getType()).isEqualTo("http");
assertThat(validatedEndpoint.getWeight()).isEqualTo(1);
assertThat(validatedEndpointGroup.getServices()).isNotNull();
assertThat(validatedEndpointGroup.getSharedConfiguration()).isNull();
assertThat(validatedEndpointGroup.getLoadBalancer()).isNotNull();
Expand Down Expand Up @@ -185,6 +186,7 @@ public void shouldReturnValidatedEndpointGroupsWithGroupHealthChecks() {
Endpoint validatedEndpoint = endpoints.get(0);
assertThat(validatedEndpoint.getName()).isEqualTo("endpoint");
assertThat(validatedEndpoint.getType()).isEqualTo("http");
assertThat(validatedEndpoint.getWeight()).isEqualTo(1);
assertThat(validatedEndpointGroup.getServices())
.isNotNull()
.matches(svc -> svc.getHealthCheck().getConfiguration().equals(FIXED_HC_CONFIG));
Expand Down Expand Up @@ -457,6 +459,21 @@ public void shouldThrowValidationExceptionWithMismatch() {
.isThrownBy(() -> endpointGroupsValidationService.validateAndSanitize(ApiType.PROXY, List.of(endpointGroup)));
}

@Test
public void shouldReturnDefaultWeightWithWrongEndpointWeight() {
EndpointGroup endpointGroup = new EndpointGroup();
endpointGroup.setName("name");
endpointGroup.setType("http");
Endpoint endpoint = new Endpoint();
endpoint.setName("endpoint");
endpoint.setType("http");
endpoint.setWeight(0);
endpointGroup.setEndpoints(List.of(endpoint));
List<EndpointGroup> endpointGroups = endpointGroupsValidationService.validateAndSanitize(ApiType.PROXY, List.of(endpointGroup));

assertThat(endpointGroups.get(0).getEndpoints().get(0).getWeight()).isEqualTo(1);
}

@Test(expected = EndpointMissingException.class)
public void shouldThrowExceptionWithNullParameter() {
assertThat(endpointGroupsValidationService.validateAndSanitize(ApiType.PROXY, null)).isNull();
Expand Down

0 comments on commit 1a1c798

Please sign in to comment.