Skip to content

Commit

Permalink
[ML] Fixing put inference service request hanging (elastic#100725) (e…
Browse files Browse the repository at this point in the history
…lastic#100763)

* Handling in cluster non cloud listener on response

* Cleaning up code

* spelling

(cherry picked from commit 0848c2b)

Co-authored-by: Elastic Machine <[email protected]>
  • Loading branch information
jonathan-buttner and elasticmachine authored Oct 12, 2023
1 parent 6ae07ca commit d20b45b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,17 @@ protected Function<Client, Client> getClientWrapper() {

public void testMockService() {
String modelId = "test-mock";
ModelConfigurations putModel = putMockService(modelId, TaskType.SPARSE_EMBEDDING);
ModelConfigurations putModel = putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
assertModelsAreEqual(putModel, readModel);

// The response is randomly generated, the input can be anything
inferOnMockService(modelId, TaskType.SPARSE_EMBEDDING, randomAlphaOfLength(10));
}

public void testMockInClusterService() {
String modelId = "test-mock-in-cluster";
ModelConfigurations putModel = putMockService(modelId, "test_service_in_cluster_service", TaskType.SPARSE_EMBEDDING);
ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);
assertModelsAreEqual(putModel, readModel);

Expand All @@ -85,7 +95,7 @@ public void testMockService() {

public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOException {
String modelId = "test-mock";
putMockService(modelId, TaskType.SPARSE_EMBEDDING);
putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);
ModelConfigurations readModel = getModel(modelId, TaskType.SPARSE_EMBEDDING);

assertThat(readModel.getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
Expand All @@ -103,7 +113,7 @@ public void testMockService_DoesNotReturnSecretsInGetResponse() throws IOExcepti

public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated() {
String modelId = "test-unparsed";
putMockService(modelId, TaskType.SPARSE_EMBEDDING);
putMockService(modelId, "test_service", TaskType.SPARSE_EMBEDDING);

var listener = new PlainActionFuture<ModelRegistry.ModelConfigMap>();
modelRegistry.getUnparsedModelMap(modelId, listener);
Expand All @@ -114,10 +124,10 @@ public void testGetUnparsedModelMap_ForTestServiceModel_ReturnsSecretsPopulated(
assertThat(secrets.apiKey(), is("abc64"));
}

private ModelConfigurations putMockService(String modelId, TaskType taskType) {
String body = """
private ModelConfigurations putMockService(String modelId, String serviceName, TaskType taskType) {
String body = Strings.format("""
{
"service": "test_service",
"service": "%s",
"service_settings": {
"model": "my_model",
"api_key": "abc64"
Expand All @@ -126,7 +136,7 @@ private ModelConfigurations putMockService(String modelId, TaskType taskType) {
"temperature": 3
}
}
""";
""", serviceName);
var request = new PutInferenceModelAction.Request(
taskType.toString(),
modelId,
Expand All @@ -135,7 +145,7 @@ private ModelConfigurations putMockService(String modelId, TaskType taskType) {
);

var response = client().execute(PutInferenceModelAction.INSTANCE, request).actionGet();
assertEquals("test_service", response.getModel().getService());
assertEquals(serviceName, response.getModel().getService());

assertThat(response.getModel().getServiceSettings(), instanceOf(TestInferenceServicePlugin.TestServiceSettings.class));
var serviceSettings = (TestInferenceServicePlugin.TestServiceSettings) response.getModel().getServiceSettings();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class TestInferenceServicePlugin extends Plugin implements InferenceServi

@Override
public List<Factory> getInferenceServiceFactories() {
return List.of(TestInferenceService::new);
return List.of(TestInferenceService::new, TestInferenceServiceClusterService::new);
}

@Override
Expand All @@ -54,10 +54,39 @@ public List<NamedWriteableRegistry.Entry> getInferenceServiceNamedWriteables() {
);
}

public static class TestInferenceService implements InferenceService {

public static class TestInferenceService extends TestInferenceServiceBase {
private static final String NAME = "test_service";

public TestInferenceService(InferenceServiceFactoryContext context) {
super(context);
}

@Override
public String name() {
return NAME;
}
}

public static class TestInferenceServiceClusterService extends TestInferenceServiceBase {
private static final String NAME = "test_service_in_cluster_service";

public TestInferenceServiceClusterService(InferenceServiceFactoryContext context) {
super(context);
}

@Override
public boolean isInClusterService() {
return true;
}

@Override
public String name() {
return NAME;
}
}

public abstract static class TestInferenceServiceBase implements InferenceService {

private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settings) {
Map<String, Object> taskSettingsMap;
// task settings are optional
Expand All @@ -70,13 +99,8 @@ private static Map<String, Object> getTaskSettingsMap(Map<String, Object> settin
return taskSettingsMap;
}

public TestInferenceService(InferenceServicePlugin.InferenceServiceFactoryContext context) {

}
public TestInferenceServiceBase(InferenceServicePlugin.InferenceServiceFactoryContext context) {

@Override
public String name() {
return NAME;
}

@Override
Expand All @@ -93,11 +117,11 @@ public TestServiceModel parseRequestConfig(
var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);

throwIfNotEmptyMap(config, NAME);
throwIfNotEmptyMap(serviceSettingsMap, NAME);
throwIfNotEmptyMap(taskSettingsMap, NAME);
throwIfNotEmptyMap(config, name());
throwIfNotEmptyMap(serviceSettingsMap, name());
throwIfNotEmptyMap(taskSettingsMap, name());

return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}

@Override
Expand All @@ -116,7 +140,7 @@ public TestServiceModel parsePersistedConfig(
var taskSettingsMap = getTaskSettingsMap(config);
var taskSettings = TestTaskSettings.fromMap(taskSettingsMap);

return new TestServiceModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings);
return new TestServiceModel(modelId, taskType, name(), serviceSettings, taskSettings, secretSettings);
}

@Override
Expand All @@ -125,7 +149,7 @@ public void infer(Model model, String input, Map<String, Object> taskSettings, A
case SPARSE_EMBEDDING -> listener.onResponse(TextExpansionResultsTests.createRandomResults(1, 10));
default -> listener.onFailure(
new ElasticsearchStatusException(
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), NAME),
TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()),
RestStatus.BAD_REQUEST
)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,17 @@ protected void masterOperation(
// information when creating the model
MlPlatformArchitecturesUtil.getMlNodesArchitecturesSet(ActionListener.wrap(architectures -> {
if (architectures.isEmpty() && clusterIsInElasticCloud(clusterService.getClusterSettings())) {
// In Elastic cloud ml nodes run on Linux x86
architectures = Set.of("linux-x86_64");
parseAndStoreModel(
service.get(),
request.getModelId(),
request.getTaskType(),
requestAsMap,
// In Elastic cloud ml nodes run on Linux x86
Set.of("linux-x86_64"),
listener
);
} else {
// The architecture field could be an empty set, the individual services will need to handle that
parseAndStoreModel(service.get(), request.getModelId(), request.getTaskType(), requestAsMap, architectures, listener);
}
}, listener::onFailure), client, threadPool.executor(InferencePlugin.UTILITY_THREAD_POOL_NAME));
Expand All @@ -118,10 +127,10 @@ private void parseAndStoreModel(
String modelId,
TaskType taskType,
Map<String, Object> config,
Set<String> platfromArchitectures,
Set<String> platformArchitectures,
ActionListener<PutInferenceModelAction.Response> listener
) {
var model = service.parseRequestConfig(modelId, taskType, config, platfromArchitectures);
var model = service.parseRequestConfig(modelId, taskType, config, platformArchitectures);
// model is valid good to persist then start
this.modelRegistry.storeModel(model, ActionListener.wrap(r -> { startModel(service, model, listener); }, listener::onFailure));
}
Expand Down

0 comments on commit d20b45b

Please sign in to comment.