Skip to content

Commit

Permalink
Test concurrent inference on default endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 15, 2024
1 parent 7d98561 commit 4605650
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
package org.elasticsearch.xpack.inference;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService;
Expand All @@ -16,9 +19,12 @@
import org.junit.Before;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;

import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.oneOf;
Expand Down Expand Up @@ -110,4 +116,39 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
);
}

public void testMultipleInferencesTiggeringDownloadAndDeploy() throws InterruptedException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());

int numParallelRequests = 4;
var latch = new CountDownLatch(numParallelRequests);
var errors = new ArrayList<Exception>();

var listener = new ResponseListener() {
@Override
public void onSuccess(Response response) {
latch.countDown();
}

@Override
public void onFailure(Exception exception) {
errors.add(exception);
latch.countDown();
}
};

var inputs = List.of("Hello World", "Goodnight moon");
var queryParams = Map.of("timeout", "120s");
for (int i = 0; i < numParallelRequests; i++) {
var request = createInferenceRequest(
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
inputs,
queryParams
);
client().performRequestAsync(request, listener);
}

latch.await();
assertThat(errors.toString(), errors, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -373,12 +373,17 @@ protected Map<String, Object> infer(String modelId, TaskType taskType, List<Stri
return inferInternal(endpoint, input, queryParameters);
}

private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));
if (queryParameters.isEmpty() == false) {
request.addParameters(queryParameters);
}
return request;
}

private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
var request = createInferenceRequest(endpoint, input, queryParameters);
var response = client().performRequest(request);
assertOkOrCreated(response);
return entityAsMap(response);
Expand Down

0 comments on commit 4605650

Please sign in to comment.