diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java index 1fef26989d845..2251f140852c1 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/DefaultEndPointsIT.java @@ -8,6 +8,9 @@ package org.elasticsearch.xpack.inference; import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseListener; +import org.elasticsearch.common.Strings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; @@ -16,9 +19,12 @@ import org.junit.Before; import java.io.IOException; +import java.util.ArrayList; import java.util.List; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -110,4 +116,39 @@ private static void assertDefaultE5Config(Map modelConfig) { Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32)) ); } + + public void testMultipleInferencesTiggeringDownloadAndDeploy() throws InterruptedException { + assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled()); + + int numParallelRequests = 4; + var latch = new CountDownLatch(numParallelRequests); + var errors = new ArrayList(); + + var listener = new ResponseListener() { + @Override + public void onSuccess(Response response) { + latch.countDown(); + } + + @Override + public void onFailure(Exception exception) { + errors.add(exception); + latch.countDown(); + } + }; + + var inputs = List.of("Hello World", "Goodnight moon"); + var queryParams = Map.of("timeout", "120s"); + for (int i = 0; i < numParallelRequests; i++) { + var request = createInferenceRequest( + Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID), + inputs, + queryParams + ); + client().performRequestAsync(request, listener); + } + + latch.await(); + assertThat(errors.toString(), errors, empty()); + } } diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 6790b9bb14c5a..4e32ef99d06dd 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -373,12 +373,17 @@ protected Map infer(String modelId, TaskType taskType, List inferInternal(String endpoint, List input, Map queryParameters) throws IOException { + protected Request createInferenceRequest(String endpoint, List input, Map queryParameters) { var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input)); if (queryParameters.isEmpty() == false) { request.addParameters(queryParameters); } + return request; + } + + private Map inferInternal(String endpoint, List input, Map queryParameters) throws IOException { + var request = createInferenceRequest(endpoint, input, queryParameters); var response = client().performRequest(request); assertOkOrCreated(response); return entityAsMap(response);