From 3d3c69d2af0922c82fdc0e4359cbf4118fd87e96 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 10 Oct 2023 18:00:43 +0100 Subject: [PATCH] Compatible with patch version --- .../org/elasticsearch/TransportVersions.java | 1 + .../elser/ElserMlNodeServiceSettings.java | 17 +++++-- .../ElserMlNodeServiceSettingsTests.java | 48 +++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 6267fb3b86ae4..6b936284b18f7 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -134,6 +134,7 @@ static TransportVersion def(int id) { public static final TransportVersion NODE_INFO_REQUEST_SIMPLIFIED = def(8_510_00_0); public static final TransportVersion NESTED_KNN_VECTOR_QUERY_V = def(8_511_00_0); public static final TransportVersion ML_PACKAGE_LOADER_PLATFORM_ADDED = def(8_512_00_0); + public static final TransportVersion ELSER_SERVICE_MODEL_VERSION_ADDED_PATCH = def(8_512_00_1); public static final TransportVersion PLUGIN_DESCRIPTOR_OPTIONAL_CLASSNAME = def(8_513_00_0); public static final TransportVersion UNIVERSAL_PROFILING_LICENSE_ADDED = def(8_514_00_0); public static final TransportVersion ELSER_SERVICE_MODEL_VERSION_ADDED = def(8_515_00_0); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java index 7dffbc693ca51..d1f27302f85f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettings.java @@ -87,10 +87,21 @@ public ElserMlNodeServiceSettings(int numAllocations, int numThreads, String var public ElserMlNodeServiceSettings(StreamInput in) throws IOException { numAllocations = in.readVInt(); numThreads = in.readVInt(); - if (in.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) { + if (transportVersionIsCompatibleWithElserModelVersion(in.getTransportVersion())) { modelVariant = in.readString(); } else { - modelVariant = ElserMlNodeService.ELSER_V1_MODEL; + modelVariant = ElserMlNodeService.ELSER_V2_MODEL; + } + } + + static boolean transportVersionIsCompatibleWithElserModelVersion(TransportVersion transportVersion) { + var nextNonPatchVersion = TransportVersions.PLUGIN_DESCRIPTOR_OPTIONAL_CLASSNAME; + + if (transportVersion.onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) { + return true; + } else { + return transportVersion.onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED_PATCH) + && transportVersion.before(nextNonPatchVersion); } } @@ -130,7 +141,7 @@ public TransportVersion getMinimalSupportedVersion() { public void writeTo(StreamOutput out) throws IOException { out.writeVInt(numAllocations); out.writeVInt(numThreads); - if (out.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) { + if (transportVersionIsCompatibleWithElserModelVersion(out.getTransportVersion())) { out.writeString(modelVariant); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java index 35d5c0b8e9603..8b6f3f1a56ba6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserMlNodeServiceSettingsTests.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.services.elser; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; +import java.io.IOException; import java.util.HashMap; import java.util.HashSet; import java.util.Map; @@ -85,6 +87,52 @@ public void testFromMapMissingOptions() { assertThat(e.getMessage(), containsString("[service_settings] does not contain the required setting [num_allocations]")); } + public void testTransportVersionIsCompatibleWithElserModelVersion() { + assertTrue( + ElserMlNodeServiceSettings.transportVersionIsCompatibleWithElserModelVersion( + TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED + ) + ); + assertTrue( + ElserMlNodeServiceSettings.transportVersionIsCompatibleWithElserModelVersion( + TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED_PATCH + ) + ); + + assertFalse( + ElserMlNodeServiceSettings.transportVersionIsCompatibleWithElserModelVersion(TransportVersions.ML_PACKAGE_LOADER_PLATFORM_ADDED) + ); + assertFalse( + ElserMlNodeServiceSettings.transportVersionIsCompatibleWithElserModelVersion( + TransportVersions.PLUGIN_DESCRIPTOR_OPTIONAL_CLASSNAME + ) + ); + assertFalse( + ElserMlNodeServiceSettings.transportVersionIsCompatibleWithElserModelVersion( + TransportVersions.UNIVERSAL_PROFILING_LICENSE_ADDED + ) + ); + } + + public void testBwcWrite() throws IOException { + { + var settings = new ElserMlNodeServiceSettings(1, 1, ".elser_model_1"); + var copy = copyInstance(settings, TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED); + assertEquals(settings, copy); + } + { + var settings = new ElserMlNodeServiceSettings(1, 1, ".elser_model_1"); + var copy = copyInstance(settings, TransportVersions.PLUGIN_DESCRIPTOR_OPTIONAL_CLASSNAME); + assertNotEquals(settings, copy); + assertEquals(".elser_model_2", copy.getModelVariant()); + } + { + var settings = new ElserMlNodeServiceSettings(1, 1, ".elser_model_1"); + var copy = copyInstance(settings, TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED_PATCH); + assertEquals(settings, copy); + } + } + public void testFromMapInvalidSettings() { var settingsMap = new HashMap( Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 0, ElserMlNodeServiceSettings.NUM_THREADS, -1)