Skip to content

Commit

Permalink
Make Elser model version configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Oct 5, 2023
1 parent 24ed4ca commit ffc28a7
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NODE_INFO_REQUEST_SIMPLIFIED = def(8_510_00_0);
public static final TransportVersion NESTED_KNN_VECTOR_QUERY_V = def(8_511_00_0);
public static final TransportVersion ML_PACKAGE_LOADER_PLATFORM_ADDED = def(8_512_00_0);
public static final TransportVersion ELSER_SERVICE_MODEL_VERSION_ADDED = def(8_513_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus.State.STARTED;
import static org.elasticsearch.xpack.inference.services.MapParsingUtils.removeFromMapOrThrowIfNull;
Expand All @@ -35,7 +36,10 @@ public class ElserMlNodeService implements InferenceService {

public static final String NAME = "elser_mlnode";

private static final String ELSER_V1_MODEL = ".elser_model_1";
static final String ELSER_V1_MODEL = ".elser_model_1";
// Default non platform specific v2 model
static final String ELSER_V2_MODEL = ".elser_model_2";
static final String ELSER_V2_MODEL_LINUX_X86 = ".elser_model_2_linux-x86_64";

public static ElserMlNodeModel parseConfig(
boolean throwOnUnknownFields,
Expand Down Expand Up @@ -106,7 +110,10 @@ public void start(Model model, ActionListener<Boolean> listener) {
var elserModel = (ElserMlNodeModel) model;
var serviceSettings = elserModel.getServiceSettings();

var startRequest = new StartTrainedModelDeploymentAction.Request(ELSER_V1_MODEL, model.getConfigurations().getModelId());
var startRequest = new StartTrainedModelDeploymentAction.Request(
serviceSettings.getModelVariant(),
model.getConfigurations().getModelId()
);
startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations());
startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads());
startRequest.setWaitForState(STARTED);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

public class ElserMlNodeServiceSettings implements ServiceSettings {

public static final String NAME = "elser_mlnode_service_settings";
public static final String NUM_ALLOCATIONS = "num_allocations";
public static final String NUM_THREADS = "num_threads";
public static final String MODEL_VERSION = "model_version";

public static Set<String> VALID_ELSER_MODELS = Set.of(
ElserMlNodeService.ELSER_V1_MODEL,
ElserMlNodeService.ELSER_V2_MODEL,
ElserMlNodeService.ELSER_V2_MODEL_LINUX_X86
);

private final int numAllocations;
private final int numThreads;
private final String modelVariant;

/**
* Parse the Elser service setting from map and validate the setting values.
Expand Down Expand Up @@ -61,21 +70,34 @@ public static ElserMlNodeServiceSettings fromMap(Map<String, Object> map) {
validationException.addValidationError(mustBeAPositiveNumberError(NUM_THREADS, numThreads));
}

String version = MapParsingUtils.removeAsType(map, MODEL_VERSION, String.class);
if (version != null && VALID_ELSER_MODELS.contains(version) == false) {
validationException.addValidationError("unknown ELSER model version [" + version + "]");
} else {
version = ElserMlNodeService.ELSER_V2_MODEL;
}

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}

return new ElserMlNodeServiceSettings(numAllocations, numThreads);
return new ElserMlNodeServiceSettings(numAllocations, numThreads, version);
}

public ElserMlNodeServiceSettings(int numAllocations, int numThreads) {
public ElserMlNodeServiceSettings(int numAllocations, int numThreads, String variant) {
this.numAllocations = numAllocations;
this.numThreads = numThreads;
this.modelVariant = variant;
}

public ElserMlNodeServiceSettings(StreamInput in) throws IOException {
numAllocations = in.readVInt();
numThreads = in.readVInt();
if (in.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) {
modelVariant = in.readString();
} else {
modelVariant = ElserMlNodeService.ELSER_V1_MODEL;
}
}

public int getNumAllocations() {
Expand All @@ -86,11 +108,16 @@ public int getNumThreads() {
return numThreads;
}

public String getModelVariant() {
return modelVariant;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(NUM_ALLOCATIONS, numAllocations);
builder.field(NUM_THREADS, numThreads);
builder.field(MODEL_VERSION, modelVariant);
builder.endObject();
return builder;
}
Expand All @@ -109,19 +136,22 @@ public TransportVersion getMinimalSupportedVersion() {
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(numAllocations);
out.writeVInt(numThreads);
if (out.getTransportVersion().onOrAfter(TransportVersions.ELSER_SERVICE_MODEL_VERSION_ADDED)) {
out.writeString(modelVariant);
}
}

@Override
public int hashCode() {
return Objects.hash(numAllocations, numThreads);
return Objects.hash(numAllocations, numThreads, modelVariant);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ElserMlNodeServiceSettings that = (ElserMlNodeServiceSettings) o;
return numAllocations == that.numAllocations && numThreads == that.numThreads;
return numAllocations == that.numAllocations && numThreads == that.numThreads && Objects.equals(modelVariant, that.modelVariant);
}

private static String mustBeAPositiveNumberError(String settingName, int value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,39 @@
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;

import static org.hamcrest.Matchers.containsString;

public class ElserMlNodeServiceSettingsTests extends AbstractWireSerializingTestCase<ElserMlNodeServiceSettings> {

public static ElserMlNodeServiceSettings createRandom() {
return new ElserMlNodeServiceSettings(randomIntBetween(1, 4), randomIntBetween(1, 2));
return new ElserMlNodeServiceSettings(randomIntBetween(1, 4), randomIntBetween(1, 2),
randomFrom(ElserMlNodeServiceSettings.VALID_ELSER_MODELS));
}

public void testFromMap() {
public void testFromMap_DefaultModelVersion() {
var serviceSettings = ElserMlNodeServiceSettings.fromMap(
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4))
);
assertEquals(new ElserMlNodeServiceSettings(1, 4), serviceSettings);
assertEquals(new ElserMlNodeServiceSettings(1, 4, ".elser_model_2"), serviceSettings);
}

public void testFromMap() {
var serviceSettings = ElserMlNodeServiceSettings.fromMap(
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4,
"model_version", ".elser_model_1"))
);
assertEquals(new ElserMlNodeServiceSettings(1, 4, ".elser_model_1"), serviceSettings);
}

public void testFromMapInvalidVersion() {
var e = expectThrows(ValidationException.class, () ->ElserMlNodeServiceSettings.fromMap(
new HashMap<>(Map.of(ElserMlNodeServiceSettings.NUM_ALLOCATIONS, 1, ElserMlNodeServiceSettings.NUM_THREADS, 4,
"model_version", ".elser_model_27"))
));
assertThat(e.getMessage(), containsString("faeafa"));
}

public void testFromMapMissingOptions() {
Expand Down Expand Up @@ -67,9 +85,14 @@ protected ElserMlNodeServiceSettings createTestInstance() {

@Override
protected ElserMlNodeServiceSettings mutateInstance(ElserMlNodeServiceSettings instance) {
return switch (randomIntBetween(0, 1)) {
case 0 -> new ElserMlNodeServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads());
case 1 -> new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1);
return switch (randomIntBetween(0, 2)) {
case 0 -> new ElserMlNodeServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads(), instance.getModelVariant());
case 1 -> new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1, instance.getModelVariant());
case 2 -> {
var versions = new HashSet<>(ElserMlNodeServiceSettings.VALID_ELSER_MODELS);
versions.remove(instance.getModelVariant());
yield new ElserMlNodeServiceSettings(instance.getNumAllocations(), instance.getNumThreads(), versions.iterator().next());
}
default -> throw new IllegalStateException();
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public void testParseConfigStrict() {
"foo",
TaskType.SPARSE_EMBEDDING,
ElserMlNodeService.NAME,
new ElserMlNodeServiceSettings(1, 4),
new ElserMlNodeServiceSettings(1, 4, ElserMlNodeService.ELSER_V2_MODEL),
ElserMlNodeTaskSettings.DEFAULT
),
parsedModel
Expand All @@ -77,7 +77,7 @@ public void testParseConfigStrictWithNoTaskSettings() {
"foo",
TaskType.SPARSE_EMBEDDING,
ElserMlNodeService.NAME,
new ElserMlNodeServiceSettings(1, 4),
new ElserMlNodeServiceSettings(1, 4, ElserMlNodeService.ELSER_V2_MODEL),
ElserMlNodeTaskSettings.DEFAULT
),
parsedModel
Expand Down

0 comments on commit ffc28a7

Please sign in to comment.