Skip to content

Commit

Permalink
Revert "[ML] Pass inference timeout to start deployment (elastic#116725
Browse files Browse the repository at this point in the history
…)"

This reverts commit 59602a9.
  • Loading branch information
davidkyle committed Nov 13, 2024
1 parent 7eb37bd commit 095836b
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,9 @@ void chunkedInfer(
/**
* Start or prepare the model for use.
* @param model The model
* @param timeout Start timeout
* @param listener The listener
*/
void start(Model model, TimeValue timeout, ActionListener<Boolean> listener);
void start(Model model, ActionListener<Boolean> listener);

/**
* Stop the model deployment.
Expand All @@ -154,6 +153,17 @@ default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener)
listener.onResponse(true);
}

/**
* Put the model definition (if applicable)
* The main purpose of this function is to download ELSER
* The default action does nothing except acknowledge the request (true).
* @param modelVariant The configuration of the model variant to be downloaded
* @param listener The listener
*/
default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

/**
* Optionally test the new model configuration in the inference service.
* This function should be called when the model is first created, the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -91,7 +90,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);

@Override
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
public void start(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.StrictDynamicMappingException;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
Expand Down Expand Up @@ -160,21 +159,20 @@ protected void masterOperation(
return;
}

parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener);
}

private void parseAndStoreModel(
InferenceService service,
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
TimeValue timeout,
ActionListener<PutInferenceModelAction.Response> listener
) {
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
(delegate, verifiedModel) -> modelRegistry.storeModel(
verifiedModel,
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
delegate.onFailure(
new ElasticsearchStatusException(
Expand All @@ -201,16 +199,11 @@ private void parseAndStoreModel(
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
}

private void startInferenceEndpoint(
InferenceService service,
TimeValue timeout,
Model model,
ActionListener<PutInferenceModelAction.Response> listener
) {
private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
if (skipValidationAndStart) {
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
} else {
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,11 @@ protected abstract void doChunkedInfer(
ActionListener<List<ChunkedInferenceServiceResults>> listener
);

@Override
public void start(Model model, ActionListener<Boolean> listener) {
init();
doStart(model, listener);
}

@Override
public void start(Model model, @Nullable TimeValue unused, ActionListener<Boolean> listener) {
start(model, listener);
doStart(model, listener);
}

protected void doStart(Model model, ActionListener<Boolean> listener) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public BaseElasticsearchInternalService(
}

@Override
public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalListener) {
public void start(Model model, ActionListener<Boolean> finalListener) {
if (model instanceof ElasticsearchInternalModel esModel) {
if (supportedTaskTypes().contains(model.getTaskType()) == false) {
finalListener.onFailure(
Expand All @@ -107,7 +107,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
}
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest();
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
Expand Down Expand Up @@ -149,7 +149,8 @@ protected static IllegalStateException notElasticsearchModelException(Model mode
);
}

protected void putModel(Model model, ActionListener<Boolean> listener) {
@Override
public void putModel(Model model, ActionListener<Boolean> listener) {
if (model instanceof ElasticsearchInternalModel == false) {
listener.onFailure(notElasticsearchModelException(model));
return;
Expand Down Expand Up @@ -302,9 +303,10 @@ protected void maybeStartDeployment(
}

if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
client.execute(InferModelAction.INSTANCE, request, listener);
}));
this.start(
model,
listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })
);
} else {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
Expand All @@ -32,7 +31,7 @@ public boolean usesExistingDeployment() {
}

@Override
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
throw new IllegalStateException("cannot start model that uses an existing deployment");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -68,12 +67,11 @@ public ElasticsearchInternalModel(
this.internalServiceSettings = internalServiceSettings;
}

public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
var startRequest = new StartTrainedModelDeploymentAction.Request(internalServiceSettings.modelId(), this.getInferenceEntityId());
startRequest.setNumberOfAllocations(internalServiceSettings.getNumAllocations());
startRequest.setThreadsPerAllocation(internalServiceSettings.getNumThreads());
startRequest.setAdaptiveAllocationsSettings(internalServiceSettings.getAdaptiveAllocationsSettings());
startRequest.setTimeout(timeout);
startRequest.setWaitForState(STARTED);

return startRequest;
Expand Down

0 comments on commit 095836b

Please sign in to comment.