Skip to content

Commit

Permalink
Adding default endpoint for Elastic Rerank (elastic#117939)
Browse files Browse the repository at this point in the history
* Adding default endpoint for Elastic Rerank

* CustomElandRerankTaskSettings -> RerankTaskSettings

* Update docs/changelog/117939.yaml
  • Loading branch information
ymao1 committed Dec 6, 2024
1 parent d08c26b commit a76a0e4
Show file tree
Hide file tree
Showing 11 changed files with 185 additions and 96 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117939.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117939
summary: Adding default endpoint for Elastic Rerank
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ public void testGet() throws IOException {

var e5Model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
assertDefaultE5Config(e5Model);

var rerankModel = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
assertDefaultRerankConfig(rerankModel);
}

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -117,6 +120,42 @@ private static void assertDefaultE5Config(Map<String, Object> modelConfig) {
assertDefaultChunkingSettings(modelConfig);
}

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultRerank() throws IOException {
var model = getModel(ElasticsearchInternalService.DEFAULT_RERANK_ID);
assertDefaultRerankConfig(model);

var inputs = List.of("Hello World", "Goodnight moon");
var query = "but why";
var queryParams = Map.of("timeout", "120s");
var results = infer(ElasticsearchInternalService.DEFAULT_RERANK_ID, TaskType.RERANK, inputs, query, queryParams);
var embeddings = (List<Map<String, Object>>) results.get("rerank");
assertThat(results.toString(), embeddings, hasSize(2));
}

@SuppressWarnings("unchecked")
private static void assertDefaultRerankConfig(Map<String, Object> modelConfig) {
assertEquals(modelConfig.toString(), ElasticsearchInternalService.DEFAULT_RERANK_ID, modelConfig.get("inference_id"));
assertEquals(modelConfig.toString(), ElasticsearchInternalService.NAME, modelConfig.get("service"));
assertEquals(modelConfig.toString(), TaskType.RERANK.toString(), modelConfig.get("task_type"));

var serviceSettings = (Map<String, Object>) modelConfig.get("service_settings");
assertThat(modelConfig.toString(), serviceSettings.get("model_id"), is(".rerank-v1"));
assertEquals(modelConfig.toString(), 1, serviceSettings.get("num_threads"));

var adaptiveAllocations = (Map<String, Object>) serviceSettings.get("adaptive_allocations");
assertThat(
modelConfig.toString(),
adaptiveAllocations,
Matchers.is(Map.of("enabled", true, "min_number_of_allocations", 0, "max_number_of_allocations", 32))
);

var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
assertNull(chunkingSettings);
var taskSettings = (Map<String, Object>) modelConfig.get("task_settings");
assertThat(modelConfig.toString(), taskSettings, Matchers.is(Map.of("return_documents", true)));
}

@SuppressWarnings("unchecked")
private static void assertDefaultChunkingSettings(Map<String, Object> modelConfig) {
var chunkingSettings = (Map<String, Object>) modelConfig.get("chunking_settings");
Expand Down Expand Up @@ -151,6 +190,7 @@ public void onFailure(Exception exception) {
var request = createInferenceRequest(
Strings.format("_inference/%s", ElasticsearchInternalService.DEFAULT_ELSER_ID),
inputs,
null,
queryParams
);
client().performRequestAsync(request, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ private List<Object> getInternalAsList(String endpoint) throws IOException {

protected Map<String, Object> infer(String modelId, List<String> input) throws IOException {
var endpoint = Strings.format("_inference/%s", modelId);
return inferInternal(endpoint, input, Map.of());
return inferInternal(endpoint, input, null, Map.of());
}

protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskType taskType, List<String> input) throws Exception {
Expand All @@ -344,7 +344,7 @@ protected Deque<ServerSentEvent> streamInferOnMockService(String modelId, TaskTy
private Deque<ServerSentEvent> callAsync(String endpoint, List<String> input) throws Exception {
var responseConsumer = new AsyncInferenceResponseConsumer();
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));
request.setJsonEntity(jsonBody(input, null));
request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build());
var latch = new CountDownLatch(1);
client().performRequestAsync(request, new ResponseListener() {
Expand All @@ -364,33 +364,60 @@ public void onFailure(Exception exception) {

protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input) throws IOException {
var endpoint = Strings.format("_inference/%s/%s", taskType, modelId);
return inferInternal(endpoint, input, Map.of());
return inferInternal(endpoint, input, null, Map.of());
}

protected Map<String, Object> infer(String modelId, TaskType taskType, List<String> input, Map<String, String> queryParameters)
throws IOException {
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
return inferInternal(endpoint, input, queryParameters);
return inferInternal(endpoint, input, null, queryParameters);
}

protected Request createInferenceRequest(String endpoint, List<String> input, Map<String, String> queryParameters) {
protected Map<String, Object> infer(
String modelId,
TaskType taskType,
List<String> input,
String query,
Map<String, String> queryParameters
) throws IOException {
var endpoint = Strings.format("_inference/%s/%s?error_trace", taskType, modelId);
return inferInternal(endpoint, input, query, queryParameters);
}

protected Request createInferenceRequest(
String endpoint,
List<String> input,
@Nullable String query,
Map<String, String> queryParameters
) {
var request = new Request("POST", endpoint);
request.setJsonEntity(jsonBody(input));
request.setJsonEntity(jsonBody(input, query));
if (queryParameters.isEmpty() == false) {
request.addParameters(queryParameters);
}
return request;
}

private Map<String, Object> inferInternal(String endpoint, List<String> input, Map<String, String> queryParameters) throws IOException {
var request = createInferenceRequest(endpoint, input, queryParameters);
private Map<String, Object> inferInternal(
String endpoint,
List<String> input,
@Nullable String query,
Map<String, String> queryParameters
) throws IOException {
var request = createInferenceRequest(endpoint, input, query, queryParameters);
var response = client().performRequest(request);
assertOkOrCreated(response);
return entityAsMap(response);
}

private String jsonBody(List<String> input) {
var bodyBuilder = new StringBuilder("{\"input\": [");
private String jsonBody(List<String> input, @Nullable String query) {
final StringBuilder bodyBuilder = new StringBuilder("{");

if (query != null) {
bodyBuilder.append("\"query\":\"").append(query).append("\",");
}

bodyBuilder.append("\"input\": [");
for (var in : input) {
bodyBuilder.append('"').append(in).append('"').append(',');
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void testCRUD() throws IOException {
}

var getAllModels = getAllModels();
int numModels = 11;
int numModels = 12;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
Expand Down Expand Up @@ -328,7 +328,7 @@ public void testSupportedStream() throws Exception {
}

public void testGetZeroModels() throws IOException {
var models = getModels("_all", TaskType.RERANK);
var models = getModels("_all", TaskType.COMPLETION);
assertThat(models, empty());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@
import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceSparseEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandInternalTextEmbeddingServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticRerankerServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.ElserMlNodeTaskSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.MultilingualE5SmallInternalServiceSettings;
import org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiSecretSettings;
Expand Down Expand Up @@ -510,9 +510,7 @@ private static void addCustomElandWriteables(final List<NamedWriteableRegistry.E
CustomElandInternalTextEmbeddingServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, CustomElandRerankTaskSettings.NAME, CustomElandRerankTaskSettings::new)
);
namedWriteables.add(new NamedWriteableRegistry.Entry(TaskSettings.class, RerankTaskSettings.NAME, RerankTaskSettings::new));
}

private static void addAnthropicNamedWritables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import java.util.HashMap;
import java.util.Map;

import static org.elasticsearch.xpack.inference.services.elasticsearch.CustomElandRerankTaskSettings.RETURN_DOCUMENTS;
import static org.elasticsearch.xpack.inference.services.elasticsearch.RerankTaskSettings.RETURN_DOCUMENTS;

public class CustomElandRerankModel extends CustomElandModel {

Expand All @@ -26,7 +26,7 @@ public CustomElandRerankModel(
TaskType taskType,
String service,
CustomElandInternalServiceSettings serviceSettings,
CustomElandRerankTaskSettings taskSettings
RerankTaskSettings taskSettings
) {
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction;
Expand All @@ -22,9 +21,9 @@ public ElasticRerankerModel(
TaskType taskType,
String service,
ElasticRerankerServiceSettings serviceSettings,
ChunkingSettings chunkingSettings
RerankTaskSettings taskSettings
) {
super(inferenceEntityId, taskType, service, serviceSettings, chunkingSettings);
super(inferenceEntityId, taskType, service, serviceSettings, taskSettings);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
public static final int EMBEDDING_MAX_BATCH_SIZE = 10;
public static final String DEFAULT_ELSER_ID = ".elser-2-elasticsearch";
public static final String DEFAULT_E5_ID = ".multilingual-e5-small-elasticsearch";
public static final String DEFAULT_RERANK_ID = ".rerank-v1-elasticsearch";

private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.RERANK,
Expand Down Expand Up @@ -225,7 +226,7 @@ public void parseRequestConfig(
)
);
} else if (RERANKER_ID.equals(modelId)) {
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, chunkingSettings, modelListener);
rerankerCase(inferenceEntityId, taskType, config, serviceSettingsMap, taskSettingsMap, modelListener);
} else {
customElandCase(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, chunkingSettings, modelListener);
}
Expand Down Expand Up @@ -308,7 +309,7 @@ private static CustomElandModel createCustomElandModel(
taskType,
NAME,
elandServiceSettings(serviceSettings, context),
CustomElandRerankTaskSettings.fromMap(taskSettings)
RerankTaskSettings.fromMap(taskSettings)
);
default -> throw new ElasticsearchStatusException(TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
};
Expand All @@ -331,7 +332,7 @@ private void rerankerCase(
TaskType taskType,
Map<String, Object> config,
Map<String, Object> serviceSettingsMap,
ChunkingSettings chunkingSettings,
Map<String, Object> taskSettingsMap,
ActionListener<Model> modelListener
) {

Expand All @@ -346,7 +347,7 @@ private void rerankerCase(
taskType,
NAME,
new ElasticRerankerServiceSettings(esServiceSettingsBuilder.build()),
chunkingSettings
RerankTaskSettings.fromMap(taskSettingsMap)
)
);
}
Expand Down Expand Up @@ -512,6 +513,14 @@ public Model parsePersistedConfig(String inferenceEntityId, TaskType taskType, M
ElserMlNodeTaskSettings.DEFAULT,
chunkingSettings
);
} else if (modelId.equals(RERANKER_ID)) {
return new ElasticRerankerModel(
inferenceEntityId,
taskType,
NAME,
new ElasticRerankerServiceSettings(ElasticsearchInternalServiceSettings.fromPersistedMap(serviceSettingsMap)),
RerankTaskSettings.fromMap(taskSettingsMap)
);
} else {
return createCustomElandModel(
inferenceEntityId,
Expand Down Expand Up @@ -653,21 +662,23 @@ public void inferRerank(
) {
var request = buildInferenceRequest(model.mlNodeDeploymentId(), new TextSimilarityConfigUpdate(query), inputs, inputType, timeout);

var modelSettings = (CustomElandRerankTaskSettings) model.getTaskSettings();
var requestSettings = CustomElandRerankTaskSettings.fromMap(requestTaskSettings);
Boolean returnDocs = CustomElandRerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
var returnDocs = Boolean.TRUE;
if (model.getTaskSettings() instanceof RerankTaskSettings modelSettings) {
var requestSettings = RerankTaskSettings.fromMap(requestTaskSettings);
returnDocs = RerankTaskSettings.of(modelSettings, requestSettings).returnDocuments();
}

Function<Integer, String> inputSupplier = returnDocs == Boolean.TRUE ? inputs::get : i -> null;

client.execute(
InferModelAction.INSTANCE,
request,
listener.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(
textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier)
)
)
ActionListener<InferModelAction.Response> mlResultsListener = listener.delegateFailureAndWrap(
(l, inferenceResult) -> l.onResponse(textSimilarityResultsToRankedDocs(inferenceResult.getInferenceResults(), inputSupplier))
);

var maybeDeployListener = mlResultsListener.delegateResponse(
(l, exception) -> maybeStartDeployment(model, exception, request, mlResultsListener)
);

client.execute(InferModelAction.INSTANCE, request, maybeDeployListener);
}

public void chunkedInfer(
Expand Down Expand Up @@ -811,7 +822,8 @@ private RankedDocsResults textSimilarityResultsToRankedDocs(
public List<DefaultConfigId> defaultConfigIds() {
return List.of(
new DefaultConfigId(DEFAULT_ELSER_ID, TaskType.SPARSE_EMBEDDING, this),
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this)
new DefaultConfigId(DEFAULT_E5_ID, TaskType.TEXT_EMBEDDING, this),
new DefaultConfigId(DEFAULT_RERANK_ID, TaskType.RERANK, this)
);
}

Expand Down Expand Up @@ -903,12 +915,19 @@ private List<Model> defaultConfigs(boolean useLinuxOptimizedModel) {
),
ChunkingSettingsBuilder.DEFAULT_SETTINGS
);
return List.of(defaultElser, defaultE5);
var defaultRerank = new ElasticRerankerModel(
DEFAULT_RERANK_ID,
TaskType.RERANK,
NAME,
new ElasticRerankerServiceSettings(null, 1, RERANKER_ID, new AdaptiveAllocationsSettings(Boolean.TRUE, 0, 32)),
RerankTaskSettings.DEFAULT_SETTINGS
);
return List.of(defaultElser, defaultE5, defaultRerank);
}

@Override
boolean isDefaultId(String inferenceId) {
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId);
return DEFAULT_ELSER_ID.equals(inferenceId) || DEFAULT_E5_ID.equals(inferenceId) || DEFAULT_RERANK_ID.equals(inferenceId);
}

static EmbeddingRequestChunker.EmbeddingType embeddingTypeFromTaskTypeAndSettings(
Expand Down
Loading

0 comments on commit a76a0e4

Please sign in to comment.