diff --git a/gravitee-policy-spikearrest/pom.xml b/gravitee-policy-spikearrest/pom.xml index 38ec791..819a601 100644 --- a/gravitee-policy-spikearrest/pom.xml +++ b/gravitee-policy-spikearrest/pom.xml @@ -69,12 +69,6 @@ io.vertx vertx-rx-java3 - - - * - * - - provided diff --git a/gravitee-policy-spikearrest/src/main/java/io/gravitee/policy/spike/SpikeArrestPolicy.java b/gravitee-policy-spikearrest/src/main/java/io/gravitee/policy/spike/SpikeArrestPolicy.java index 2a5e531..7574051 100644 --- a/gravitee-policy-spikearrest/src/main/java/io/gravitee/policy/spike/SpikeArrestPolicy.java +++ b/gravitee-policy-spikearrest/src/main/java/io/gravitee/policy/spike/SpikeArrestPolicy.java @@ -37,7 +37,6 @@ import io.vertx.rxjava3.core.RxHelper; import io.vertx.rxjava3.core.Vertx; import java.util.concurrent.TimeUnit; -import java.util.function.Supplier; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -107,24 +106,21 @@ public void onRequest(Request request, Response response, ExecutionContext execu .incrementAndGet( key, spikeArrestPolicyConfiguration.isAsync(), - new Supplier() { - @Override - public RateLimit get() { - // Set the time at which the current rate limit window resets in UTC epoch seconds. - long resetTimeMillis = getEndOfPeriod(request.timestamp(), slice.getPeriod(), TimeUnit.MILLISECONDS); - - RateLimit rate = new RateLimit(key); - rate.setCounter(0); - rate.setLimit(slice.getLimit()); - rate.setResetTime(resetTimeMillis); - rate.setSubscription((String) executionContext.getAttribute(ExecutionContext.ATTR_SUBSCRIPTION_ID)); - return rate; - } + () -> { + // Set the time at which the current rate limit window resets in UTC epoch seconds. + long resetTimeMillis = getEndOfPeriod(request.timestamp(), slice.getPeriod(), TimeUnit.MILLISECONDS); + + RateLimit rate = new RateLimit(key); + rate.setCounter(0); + rate.setLimit(slice.getLimit()); + rate.setResetTime(resetTimeMillis); + rate.setSubscription((String) executionContext.getAttribute(ExecutionContext.ATTR_SUBSCRIPTION_ID)); + return rate; } ) .observeOn(RxHelper.scheduler(context)) .subscribe( - new SingleObserver() { + new SingleObserver<>() { @Override public void onSubscribe(Disposable d) {} diff --git a/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/SpikeArrestPolicyTest.java b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/SpikeArrestPolicyTest.java new file mode 100644 index 0000000..a75d368 --- /dev/null +++ b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/SpikeArrestPolicyTest.java @@ -0,0 +1,334 @@ +/** + * Copyright (C) 2015 The Gravitee team (http://gravitee.io) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.gravitee.policy.spike; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.lenient; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.gravitee.gateway.api.ExecutionContext; +import io.gravitee.gateway.api.Request; +import io.gravitee.gateway.api.Response; +import io.gravitee.gateway.api.http.HttpHeaders; +import io.gravitee.policy.api.PolicyChain; +import io.gravitee.policy.api.PolicyResult; +import io.gravitee.policy.spike.configuration.SpikeArrestConfiguration; +import io.gravitee.policy.spike.configuration.SpikeArrestPolicyConfiguration; +import io.gravitee.policy.spike.local.ExecutionContextStub; +import io.gravitee.policy.spike.local.LocalCacheRateLimitProvider; +import io.gravitee.repository.ratelimit.api.RateLimitService; +import io.reactivex.rxjava3.core.Single; +import io.vertx.rxjava3.core.Vertx; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Consumer; +import org.assertj.core.api.SoftAssertions; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Nested; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentCaptor; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +/** + * @author David BRASSELY (david.brassely at graviteesource.com) + * @author GraviteeSource Team + */ +@ExtendWith(MockitoExtension.class) +public class SpikeArrestPolicyTest { + + private final LocalCacheRateLimitProvider rateLimitService = new LocalCacheRateLimitProvider(); + + private Vertx vertx; + + @Captor + ArgumentCaptor policyResultCaptor; + + @Mock + private Request request; + + @Mock + private Response response; + + private ExecutionContext executionContext; + HttpHeaders responseHttpHeaders; + + @BeforeEach + public void init() { + vertx = Vertx.vertx(); + + executionContext = spy(new ExecutionContextStub()); + executionContext.setAttribute(ExecutionContext.ATTR_PLAN, "my-plan"); + executionContext.setAttribute(ExecutionContext.ATTR_SUBSCRIPTION_ID, "my-subscription"); + + lenient().when(executionContext.getComponent(Vertx.class)).thenReturn(vertx); + lenient().when(executionContext.getComponent(RateLimitService.class)).thenReturn(rateLimitService); + + responseHttpHeaders = HttpHeaders.create(); + lenient().when(response.headers()).thenReturn(responseHttpHeaders); + } + + @AfterEach + void tearDown() { + rateLimitService.clean(); + vertx.close().blockingAwait(); + } + + @Test + public void should_fail_when_no_service_installed() { + var policy = new SpikeArrestPolicy( + SpikeArrestPolicyConfiguration + .builder() + .spike(SpikeArrestConfiguration.builder().limit(1).periodTime(1).periodTimeUnit(TimeUnit.SECONDS).build()) + .build() + ); + + vertx.runOnContext(event -> { + // Given + var policyChain = mock(PolicyChain.class); + when(executionContext.getComponent(RateLimitService.class)).thenReturn(null); + + // When + policy.onRequest(request, response, executionContext, policyChain); + + // Then + verify(policyChain).failWith(policyResultCaptor.capture()); + SoftAssertions.assertSoftly(soft -> { + var result = policyResultCaptor.getValue(); + soft.assertThat(result.statusCode()).isEqualTo(500); + soft.assertThat(result.message()).isEqualTo("No rate-limit service has been installed."); + }); + }); + } + + @Test + public void should_add_headers_when_enabled() throws InterruptedException { + var latch = new CountDownLatch(1); + var policy = new SpikeArrestPolicy( + SpikeArrestPolicyConfiguration + .builder() + .addHeaders(true) + .spike(SpikeArrestConfiguration.builder().limit(10).periodTime(10).periodTimeUnit(TimeUnit.SECONDS).build()) + .build() + ); + + vertx.runOnContext(event -> + policy.onRequest( + request, + response, + executionContext, + chain( + (req, res) -> { + assertThat(responseHttpHeaders.get(SpikeArrestPolicy.X_SPIKE_ARREST_LIMIT)).isEqualTo("1"); + assertThat(responseHttpHeaders.get(SpikeArrestPolicy.X_SPIKE_ARREST_SLICE)).isEqualTo("1000ms"); + assertThat(responseHttpHeaders.get(SpikeArrestPolicy.X_SPIKE_ARREST_RESET)).isEqualTo("1000"); + latch.countDown(); + }, + policyResult -> { + fail("Unexpected failure: " + policyResult.message()); + latch.countDown(); + } + ) + ) + ); + + assertThat(latch.await(10000, TimeUnit.MILLISECONDS)).isTrue(); + } + + @Test + public void should_not_add_headers_when_disabled() throws InterruptedException { + var latch = new CountDownLatch(1); + var policy = new SpikeArrestPolicy( + SpikeArrestPolicyConfiguration + .builder() + .addHeaders(false) + .spike(SpikeArrestConfiguration.builder().limit(10).periodTime(10).periodTimeUnit(TimeUnit.SECONDS).build()) + .build() + ); + + vertx.runOnContext(event -> + policy.onRequest( + request, + response, + executionContext, + chain( + (req, res) -> { + assertThat(responseHttpHeaders.toSingleValueMap()) + .doesNotContainKey(SpikeArrestPolicy.X_SPIKE_ARREST_LIMIT) + .doesNotContainKey(SpikeArrestPolicy.X_SPIKE_ARREST_SLICE) + .doesNotContainKey(SpikeArrestPolicy.X_SPIKE_ARREST_RESET); + latch.countDown(); + }, + policyResult -> { + fail("Unexpected failure: " + policyResult.message()); + latch.countDown(); + } + ) + ) + ); + + assertThat(latch.await(10000, TimeUnit.MILLISECONDS)).isTrue(); + } + + @Test + public void should_provide_info_when_limit_exceeded() throws InterruptedException { + var latch = new CountDownLatch(2); + var policy = new SpikeArrestPolicy( + SpikeArrestPolicyConfiguration + .builder() + .spike( + SpikeArrestConfiguration + .builder() + .limit(1) + .dynamicLimit("0") + .periodTime(100) + .periodTimeUnit(TimeUnit.MILLISECONDS) + .build() + ) + .build() + ); + vertx.runOnContext(event -> + // Run 1st request + policy.onRequest( + request, + response, + executionContext, + chain( + (_req, _res) -> { + latch.countDown(); + + // Run 2nd request that should fail + policy.onRequest( + request, + response, + executionContext, + chain( + (req, res) -> { + fail("Should fail"); + latch.countDown(); + }, + policyResult -> { + SoftAssertions.assertSoftly(soft -> { + soft.assertThat(policyResult.statusCode()).isEqualTo(429); + soft.assertThat(policyResult.key()).isEqualTo("SPIKE_ARREST_TOO_MANY_REQUESTS"); + soft + .assertThat(policyResult.message()) + .isEqualTo("Spike limit exceeded ! You reach the limit of 1 requests per 100 ms."); + soft + .assertThat(policyResult.parameters()) + .contains( + Map.entry("slice_limit", 1L), + Map.entry("slice_period_time", 100L), + Map.entry("slice_period_unit", TimeUnit.MILLISECONDS), + Map.entry("limit", 1L), + Map.entry("period_time", 100L), + Map.entry("period_unit", TimeUnit.MILLISECONDS) + ); + }); + latch.countDown(); + } + ) + ); + }, + policyResult -> { + latch.countDown(); + fail("Unexpected failure: " + policyResult.message()); + } + ) + ) + ); + + assertThat(latch.await(10000, TimeUnit.MILLISECONDS)).isTrue(); + } + + @Nested + class WhenErrorsOccursAtRepositoryLevel { + + @BeforeEach + void setUp() { + var mockedRateLimitService = mock(RateLimitService.class); + when(mockedRateLimitService.incrementAndGet(any(), anyBoolean(), any())) + .thenReturn(Single.error(new RuntimeException("Error"))); + lenient().when(executionContext.getComponent(RateLimitService.class)).thenReturn(mockedRateLimitService); + } + + @Test + public void should_add_headers_when_enabled() throws InterruptedException { + var latch = new CountDownLatch(1); + var policy = new SpikeArrestPolicy( + SpikeArrestPolicyConfiguration + .builder() + .addHeaders(true) + .spike(SpikeArrestConfiguration.builder().limit(10).periodTime(10).periodTimeUnit(TimeUnit.SECONDS).build()) + .build() + ); + + vertx.runOnContext(event -> + policy.onRequest( + request, + response, + executionContext, + chain( + (req, res) -> { + assertThat(responseHttpHeaders.toSingleValueMap()) + .contains( + Map.entry(SpikeArrestPolicy.X_SPIKE_ARREST_LIMIT, "1"), + Map.entry(SpikeArrestPolicy.X_SPIKE_ARREST_SLICE, "1000ms"), + Map.entry(SpikeArrestPolicy.X_SPIKE_ARREST_RESET, "-1") + ); + latch.countDown(); + }, + policyResult -> { + fail("Unexpected failure: " + policyResult.message()); + latch.countDown(); + } + ) + ) + ); + + assertThat(latch.await(10000, TimeUnit.MILLISECONDS)).isTrue(); + } + } + + private PolicyChain chain(BiConsumer doNext, Consumer failWith) { + return new PolicyChain() { + @Override + public void doNext(Request request, Response response) { + doNext.accept(request, response); + } + + @Override + public void failWith(PolicyResult policyResult) { + failWith.accept(policyResult); + } + + @Override + public void streamFailWith(PolicyResult policyResult) {} + }; + } +} diff --git a/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/ExecutionContextStub.java b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/ExecutionContextStub.java new file mode 100644 index 0000000..49e1e16 --- /dev/null +++ b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/ExecutionContextStub.java @@ -0,0 +1,80 @@ +/** + * Copyright (C) 2015 The Gravitee team (http://gravitee.io) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.gravitee.policy.spike.local; + +import io.gravitee.el.TemplateEngine; +import io.gravitee.el.spel.SpelTemplateEngineFactory; +import io.gravitee.gateway.api.ExecutionContext; +import io.gravitee.gateway.api.Request; +import io.gravitee.gateway.api.Response; +import io.gravitee.tracing.api.Tracer; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + +public class ExecutionContextStub implements ExecutionContext { + + Map attributes = new HashMap<>(); + + @Override + public Request request() { + return null; + } + + @Override + public Response response() { + return null; + } + + @Override + public T getComponent(Class aClass) { + return null; + } + + @Override + public void setAttribute(String s, Object o) { + attributes.put(s, o); + } + + @Override + public void removeAttribute(String s) {} + + @Override + public Object getAttribute(String s) { + return attributes.get(s); + } + + @Override + public Enumeration getAttributeNames() { + return Collections.enumeration(attributes.keySet()); + } + + @Override + public Map getAttributes() { + return attributes; + } + + @Override + public TemplateEngine getTemplateEngine() { + return new SpelTemplateEngineFactory().templateEngine(); + } + + @Override + public Tracer getTracer() { + return null; + } +} diff --git a/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/LocalCacheRateLimitProvider.java b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/LocalCacheRateLimitProvider.java new file mode 100644 index 0000000..26b37d8 --- /dev/null +++ b/gravitee-policy-spikearrest/src/test/java/io/gravitee/policy/spike/local/LocalCacheRateLimitProvider.java @@ -0,0 +1,44 @@ +/** + * Copyright (C) 2015 The Gravitee team (http://gravitee.io) + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.gravitee.policy.spike.local; + +import io.gravitee.repository.ratelimit.api.RateLimitService; +import io.gravitee.repository.ratelimit.model.RateLimit; +import io.reactivex.rxjava3.core.Single; +import java.io.Serializable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.function.Supplier; + +/** + * @author David BRASSELY (david.brassely at graviteesource.com) + * @author GraviteeSource Team + */ +public class LocalCacheRateLimitProvider implements RateLimitService { + + private ConcurrentMap rateLimits = new ConcurrentHashMap<>(); + + public void clean() { + rateLimits.clear(); + } + + @Override + public Single incrementAndGet(String key, long weight, boolean async, Supplier supplier) { + RateLimit rateLimit = rateLimits.computeIfAbsent(key, serializable -> supplier.get()); + rateLimit.setCounter(rateLimit.getCounter() + weight); + return Single.just(rateLimit); + } +}