diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index c49ff1b1f0d29..9f6c1229dc239 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -134,6 +134,7 @@ static TransportVersion def(int id) { public static final TransportVersion NODE_INFO_REQUEST_SIMPLIFIED = def(8_510_00_0); public static final TransportVersion NESTED_KNN_VECTOR_QUERY_V = def(8_511_00_0); public static final TransportVersion ML_PACKAGE_LOADER_PLATFORM_ADDED = def(8_512_00_0); + public static final TransportVersion ELSER_SERVICE_MODEL_VERSION_ADDED = def(8_513_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java index f8e8584a6a382..331db82df9fe6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeService.java @@ -26,6 +26,7 @@ import java.util.List; import java.util.Map; +import java.util.Set; import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED; import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull; @@ -35,7 +36,10 @@ public class ElserMlNodeService implements InferenceService { public static final String NAME = "elser_mlnode"; - private static final String ELSER_V1_MODEL = ".elser_model_1"; + static final String ELSER_V1_MODEL = ".elser_model_1"; + // Default non platform specific v2 model + static final String ELSER_V2_MODEL = ".elser_model_2"; + static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64"; public static ElserMlNodeModel parseConfig( boolean throwOnUnknownFields, @@ -106,7 +110,10 @@ public void start(Model model, ActionListener listener) { var elserModel = (ElserMlNodeModel) model; var serviceSettings = elserModel.getServiceSettings(); - var startRequest = new StartTrainedModelDeploymentAction.Request(ELSER_V1_MODEL, model.getConfigurations().getModelId()); + var startRequest = new StartTrainedModelDeploymentAction.Request( + serviceSettings.getModelVariant(), + model.getConfigurations().getModelId() + ); startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations()); startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads()); startRequest.setWaitForState(STARTED); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java index 42cb491c76204..2a99001df28b7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java @@ -20,15 +20,24 @@ import java.io.IOException; import java.util.Map; import java.util.Objects; +import java.util.Set; public class ElserMlNodeServiceSettings implements ServiceSettings { public static final String NAME = "elser_mlnode_service_settings"; public static final String NUM_ALLOCATIONS = "num_allocations"; public static final String NUM_THREADS = "num_threads"; + public static final String MODEL_VERSION = "model_version"; + + public static Set VALID_ELSER_MODELS = Set.of( + ElserMlNodeService.ELSER_V1_MODEL, + ElserMlNodeService.ELSER_V2_MODEL, + ElserMlNodeService.ELSER_V2_MODEL_LINUX_X86 + ); private final int numAllocations; private final int numThreads; + private final String modelVariant; /** * Parse the Elser service setting from map and validate the setting values. @@ -61,21 +70,34 @@ public static ElserMlNodeServiceSettings fromMap(Map map) { validationException.addValidationError(mustBeAPositiveNumberError(NUM_THREADS, numThreads)); } + String version = MapParsingUtils.removeAsType(map, MODEL_VERSION, String.class); + if (version != null && VALID_ELSER_MODELS.contains(version) == false) { + validationException.addValidationError("unknown ELSER model version [" + version + "]"); + } else { + version = ElserMlNodeService.ELSER_V2_MODEL; + } + if (validationException.validationErrors().isEmpty() == false) { throw validationException; } - return new ElserMlNodeServiceSettings(numAllocations, numThreads); + return new ElserMlNodeServiceSettings(numAllocations, numThreads, version); } - public ElserMlNodeServiceSettings(int numAllocations, int numThreads) { + public ElserMlNodeServiceSettings(int numAllocations, int numThreads, String variant) { this.numAllocations = numAllocations; this.numThreads = numThreads; + this.modelVariant = variant; } public ElserMlNodeServiceSettings(StreamInput in) throws IOException { numAllocations = in.readVInt(); numThreads = in.readVInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) { + modelVariant = in.readString(); + } else { + modelVariant = ElserMlNodeService.ELSER_V1_MODEL; + } } public int getNumAllocations() { @@ -86,11 +108,16 @@ public int getNumThreads() { return numThreads; } + public String getModelVariant() { + return modelVariant; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(NUM_ALLOCATIONS, numAllocations); builder.field(NUM_THREADS, numThreads); + builder.field(MODEL_VERSION, modelVariant); builder.endObject(); return builder; } @@ -109,11 +136,14 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeVInt(numAllocations); out.writeVInt(numThreads); + if (out.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) { + out.writeString(modelVariant); + } } @Override public int hashCode() { - return Objects.hash(numAllocations, numThreads); + return Objects.hash(numAllocations, numThreads, modelVariant); } @Override @@ -121,7 +151,7 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; ElserMlNodeServiceSettings that = (ElserMlNodeServiceSettings) o; - return numAllocations == that.numAllocations && numThreads == that.numThreads; + return numAllocations == that.numAllocations && numThreads == that.numThreads && Objects.equals(modelVariant, that.modelVariant); } private static String mustBeAPositiveNumberError(String settingName, int value) { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java index 5ffc2347b63e6..102fe627968b5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.test.AbstractWireSerializingTestCase; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import static org.hamcrest.Matchers.containsString; @@ -19,14 +20,31 @@ public class ElserMlNodeServiceSettingsTests extends AbstractWireSerializingTestCase { public static ElserMlNodeServiceSettings createRandom() { - return new ElserMlNodeServiceSettings(randomIntBetween(1, 4), randomIntBetween(1, 2)); + return new ElserMlNodeServiceSettings(randomIntBetween(1, 4), randomIntBetween(1, 2), + randomFrom(ElserMlNodeServiceSettings.VALID_ELSER_MODELS)); } - public void testFromMap() { + public void testFromMap_DefaultModelVersion() { var serviceSettings = ElserMlNodeServiceSettings.fromMap( new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4)) ); - assertEquals(new ElserMlNodeServiceSettings(1, 4), serviceSettings); + assertEquals(new ElserMlNodeServiceSettings(1, 4, ".elser_model_2"), serviceSettings); + } + + public void testFromMap() { + var serviceSettings = ElserMlNodeServiceSettings.fromMap( + new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4, + "model_version", ".elser_model_1")) + ); + assertEquals(new ElserMlNodeServiceSettings(1, 4, ".elser_model_1"), serviceSettings); + } + + public void testFromMapInvalidVersion() { + var e = expectThrows(ValidationException.class, () ->ElserMlNodeServiceSettings.fromMap( + new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4, + "model_version", ".elser_model_27")) + )); + assertThat(e.getMessage(), containsString("faeafa")); } public void testFromMapMissingOptions() { @@ -67,9 +85,14 @@ protected ElserMlNodeServiceSettings createTestInstance() { @Override protected ElserMlNodeServiceSettings mutateInstance(ElserMlNodeServiceSettings instance) { - return switch (randomIntBetween(0, 1)) { - case 0 -> new ElserMlNodeServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads()); - case 1 -> new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1); + return switch (randomIntBetween(0, 2)) { + case 0 -> new ElserMlNodeServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads(), instance.getModelVariant()); + case 1 -> new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1, instance.getModelVariant()); + case 2 -> { + var versions = new HashSet<>(ElserMlNodeServiceSettings.VALID_ELSER_MODELS); + versions.remove(instance.getModelVariant()); + yield new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads(), versions.iterator().next()); + } default -> throw new IllegalStateException(); }; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java index 1ab580eec358b..520e84479bb31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceTests.java @@ -54,7 +54,7 @@ public void testParseConfigStrict() { "foo", TaskType.SPARSE_EMBEDDING, ElserMlNodeService.NAME, - new ElserMlNodeServiceSettings(1, 4), + new ElserMlNodeServiceSettings(1, 4, ElserMlNodeService.ELSER_V2_MODEL), ElserMlNodeTaskSettings.DEFAULT ), parsedModel @@ -77,7 +77,7 @@ public void testParseConfigStrictWithNoTaskSettings() { "foo", TaskType.SPARSE_EMBEDDING, ElserMlNodeService.NAME, - new ElserMlNodeServiceSettings(1, 4), + new ElserMlNodeServiceSettings(1, 4, ElserMlNodeService.ELSER_V2_MODEL), ElserMlNodeTaskSettings.DEFAULT ), parsedModel