Skip to content

Commit

Permalink
[ML] Enable built-in Inference Endpoints and default for Semantic Text (
Browse files Browse the repository at this point in the history
elastic#116931)

Adds built-in inference endpoints for the ELSER (.elser-2-elasticsearch)
and multilingual-e5-small models (.multilingual-e5-small-elasticsearch).
The semantic text inference Id field now defaults to elser-2-elasticsearch
  • Loading branch information
davidkyle authored Nov 18, 2024
1 parent 366fa74 commit 9790cc4
Show file tree
Hide file tree
Showing 13 changed files with 39 additions and 98 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/116931.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 116931
summary: Enable built-in Inference Endpoints and default for Semantic Text
area: "Machine Learning"
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ public enum FeatureFlag {
TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null),
FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null),
SUB_OBJECTS_AUTO_ENABLED("es.sub_objects_auto_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
INFERENCE_DEFAULT_ELSER("es.inference_default_elser_feature_flag_enabled=true", Version.fromString("8.16.0"), null),
ML_SCALE_FROM_ZERO("es.ml_scale_from_zero_feature_flag_enabled=true", Version.fromString("8.16.0"), null);

public final String systemProperty;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public void tearDown() throws Exception {

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultElser() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_ELSER_ID);
assertDefaultElserConfig(model);

Expand Down Expand Up @@ -78,7 +77,6 @@ private static void assertDefaultElserConfig(Map<String, Object> modelConfig) {

@SuppressWarnings("unchecked")
public void testInferDeploysDefaultE5() throws IOException {
assumeTrue("Default config requires a feature flag", DefaultElserFeatureFlag.isEnabled());
var model = getModel(ElasticsearchInternalService.DEFAULT_E5_ID);
assertDefaultE5Config(model);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,18 @@ public void testCRUD() throws IOException {
}

var getAllModels = getAllModels();
int numModels = DefaultElserFeatureFlag.isEnabled() ? 11 : 9;
int numModels = 11;
assertThat(getAllModels, hasSize(numModels));

var getSparseModels = getModels("_all", TaskType.SPARSE_EMBEDDING);
int numSparseModels = DefaultElserFeatureFlag.isEnabled() ? 6 : 5;
int numSparseModels = 6;
assertThat(getSparseModels, hasSize(numSparseModels));
for (var sparseModel : getSparseModels) {
assertEquals("sparse_embedding", sparseModel.get("task_type"));
}

var getDenseModels = getModels("_all", TaskType.TEXT_EMBEDDING);
int numDenseModels = DefaultElserFeatureFlag.isEnabled() ? 5 : 4;
int numDenseModels = 5;
assertThat(getDenseModels, hasSize(numDenseModels));
for (var denseModel : getDenseModels) {
assertEquals("text_embedding", denseModel.get("task_type"));
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.elasticsearch.xpack.inference.rank.random.RandomRankRetrieverBuilder;
import org.elasticsearch.xpack.inference.rank.textsimilarity.TextSimilarityRankRetrieverBuilder;

import java.util.HashSet;
import java.util.Set;

/**
Expand All @@ -24,16 +23,14 @@ public class InferenceFeatures implements FeatureSpecification {

@Override
public Set<NodeFeature> getFeatures() {
var features = new HashSet<NodeFeature>();
features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED);
features.add(RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED);
features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID);
features.add(SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS);
features.add(TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED);
if (DefaultElserFeatureFlag.isEnabled()) {
features.add(SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2);
}
return Set.copyOf(features);
return Set.of(
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_RETRIEVER_SUPPORTED,
RandomRankRetrieverBuilder.RANDOM_RERANKER_RETRIEVER_SUPPORTED,
SemanticTextFieldMapper.SEMANTIC_TEXT_SEARCH_INFERENCE_ID,
SemanticQueryBuilder.SEMANTIC_TEXT_INNER_HITS,
SemanticTextFieldMapper.SEMANTIC_TEXT_DEFAULT_ELSER_2,
TextSimilarityRankRetrieverBuilder.TEXT_SIMILARITY_RERANKER_COMPOSITION_SUPPORTED
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,8 @@ public Collection<?> createComponents(PluginServices services) {
// reference correctly
var registry = new InferenceServiceRegistry(inferenceServices, factoryContext);
registry.init(services.client());
if (DefaultElserFeatureFlag.isEnabled()) {
for (var service : registry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
for (var service : registry.getServices().values()) {
service.defaultConfigIds().forEach(modelRegistry::addDefaultIds);
}
inferenceServiceRegistry.set(registry);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
import org.elasticsearch.xcontent.XContentParserConfiguration;
import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults;
import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -111,16 +110,12 @@ public static class Builder extends FieldMapper.Builder {
INFERENCE_ID_FIELD,
false,
mapper -> ((SemanticTextFieldType) mapper.fieldType()).inferenceId,
DefaultElserFeatureFlag.isEnabled() ? DEFAULT_ELSER_2_INFERENCE_ID : null
DEFAULT_ELSER_2_INFERENCE_ID
).addValidator(v -> {
if (Strings.isEmpty(v)) {
// If the default ELSER feature flag is enabled, the only way we get here is if the user explicitly sets the param to an
// empty value. However, if the feature flag is disabled, we can get here if the user didn't set the param.
// Adjust the error message appropriately.
String message = DefaultElserFeatureFlag.isEnabled()
? "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
: "[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must be specified";
throw new IllegalArgumentException(message);
throw new IllegalArgumentException(
"[" + INFERENCE_ID_FIELD + "] on mapper [" + leafName() + "] of type [" + CONTENT_TYPE + "] must not be empty"
);
}
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import org.elasticsearch.rest.ServerlessScope;
import org.elasticsearch.rest.action.RestToXContentListener;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;

import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

Expand Down Expand Up @@ -69,11 +66,6 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient

@Override
public Set<String> supportedCapabilities() {
Set<String> capabilities = new HashSet<>();
if (DefaultElserFeatureFlag.isEnabled()) {
capabilities.add(DEFAULT_ELSER_2_CAPABILITY);
}

return Collections.unmodifiableSet(capabilities);
return Set.of(DEFAULT_ELSER_2_CAPABILITY);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.elasticsearch.xpack.core.ml.inference.TrainedModelPrefixStrings;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.InferenceConfigUpdate;
import org.elasticsearch.xpack.core.ml.utils.MlPlatformArchitecturesUtil;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferencePlugin;

import java.io.IOException;
Expand Down Expand Up @@ -296,11 +295,6 @@ protected void maybeStartDeployment(
InferModelAction.Request request,
ActionListener<InferModelAction.Response> listener
) {
if (DefaultElserFeatureFlag.isEnabled() == false) {
listener.onFailure(e);
return;
}

if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
client.execute(InferModelAction.INSTANCE, request, listener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.DefaultElserFeatureFlag;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.model.TestModel;
import org.junit.AssumptionViolatedException;
Expand Down Expand Up @@ -103,9 +102,6 @@ protected Collection<? extends Plugin> getPlugins() {
@Override
protected void minimalMapping(XContentBuilder b) throws IOException {
b.field("type", "semantic_text");
if (DefaultElserFeatureFlag.isEnabled() == false) {
b.field("inference_id", "test_model");
}
}

@Override
Expand Down Expand Up @@ -175,9 +171,7 @@ public void testDefaults() throws Exception {
DocumentMapper mapper = mapperService.documentMapper();
assertEquals(Strings.toString(fieldMapping), mapper.mappingSource().toString());
assertSemanticTextField(mapperService, fieldName, false);
if (DefaultElserFeatureFlag.isEnabled()) {
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);
}
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, DEFAULT_ELSER_2_INFERENCE_ID);

ParsedDocument doc1 = mapper.parse(source(this::writeField));
List<IndexableField> fields = doc1.rootDoc().getFields("field");
Expand Down Expand Up @@ -211,15 +205,13 @@ public void testSetInferenceEndpoints() throws IOException {
assertSerialization.accept(fieldMapping, mapperService);
}
{
if (DefaultElserFeatureFlag.isEnabled()) {
final XContentBuilder fieldMapping = fieldMapping(
b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
);
final MapperService mapperService = createMapperService(fieldMapping);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
assertSerialization.accept(fieldMapping, mapperService);
}
final XContentBuilder fieldMapping = fieldMapping(
b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, searchInferenceId)
);
final MapperService mapperService = createMapperService(fieldMapping);
assertSemanticTextField(mapperService, fieldName, false);
assertInferenceEndpoints(mapperService, fieldName, DEFAULT_ELSER_2_INFERENCE_ID, searchInferenceId);
assertSerialization.accept(fieldMapping, mapperService);
}
{
final XContentBuilder fieldMapping = fieldMapping(
Expand All @@ -246,26 +238,18 @@ public void testInvalidInferenceEndpoints() {
);
}
{
final String expectedMessage = DefaultElserFeatureFlag.isEnabled()
? "[inference_id] on mapper [field] of type [semantic_text] must not be empty"
: "[inference_id] on mapper [field] of type [semantic_text] must be specified";
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(INFERENCE_ID_FIELD, "")))
);
assertThat(e.getMessage(), containsString(expectedMessage));
assertThat(e.getMessage(), containsString("[inference_id] on mapper [field] of type [semantic_text] must not be empty"));
}
{
if (DefaultElserFeatureFlag.isEnabled()) {
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
);
assertThat(
e.getMessage(),
containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty")
);
}
Exception e = expectThrows(
MapperParsingException.class,
() -> createMapperService(fieldMapping(b -> b.field("type", "semantic_text").field(SEARCH_INFERENCE_ID_FIELD, "")))
);
assertThat(e.getMessage(), containsString("[search_inference_id] on mapper [field] of type [semantic_text] must not be empty"));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ setup:
---
"Calculates embeddings using the default ELSER 2 endpoint":
- requires:
reason: "default ELSER 2 inference ID is behind a feature flag"
reason: "default ELSER 2 inference ID is enabled via a capability"
test_runner_features: [capabilities]
capabilities:
- method: GET
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -843,7 +843,7 @@ setup:
---
"Query a field that uses the default ELSER 2 endpoint":
- requires:
reason: "default ELSER 2 inference ID is behind a feature flag"
reason: "default ELSER 2 inference ID is enabled via a capability"
test_runner_features: [capabilities]
capabilities:
- method: GET
Expand Down

0 comments on commit 9790cc4

Please sign in to comment.