Skip to content

Commit

Permalink
Pass timeout to start deployment
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 13, 2024
1 parent 103a8b0 commit ea3abee
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,10 @@ void chunkedInfer(
/**
* Start or prepare the model for use.
* @param model The model
* @param timeout Start timeout
* @param listener The listener
*/
void start(Model model, ActionListener<Boolean> listener);
void start(Model model, TimeValue timeout, ActionListener<Boolean> listener);

/**
* Stop the model deployment.
Expand All @@ -153,17 +154,6 @@ default void stop(UnparsedModel unparsedModel, ActionListener<Boolean> listener)
listener.onResponse(true);
}

/**
* Put the model definition (if applicable)
* The main purpose of this function is to download ELSER
* The default action does nothing except acknowledge the request (true).
* @param modelVariant The configuration of the model variant to be downloaded
* @param listener The listener
*/
default void putModel(Model modelVariant, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

/**
* Optionally test the new model configuration in the inference service.
* This function should be called when the model is first created, the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -90,7 +91,7 @@ public Model parsePersistedConfig(String modelId, TaskType taskType, Map<String,
protected abstract ServiceSettings getServiceSettingsFromMap(Map<String, Object> serviceSettingsMap);

@Override
public void start(Model model, ActionListener<Boolean> listener) {
public void start(Model model, TimeValue timeout, ActionListener<Boolean> listener) {
listener.onResponse(true);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.mapper.StrictDynamicMappingException;
import org.elasticsearch.inference.InferenceService;
import org.elasticsearch.inference.InferenceServiceRegistry;
Expand Down Expand Up @@ -159,20 +160,21 @@ protected void masterOperation(
return;
}

parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
}

private void parseAndStoreModel(
InferenceService service,
String inferenceEntityId,
TaskType taskType,
Map<String, Object> config,
TimeValue timeout,
ActionListener<PutInferenceModelAction.Response> listener
) {
ActionListener<Model> storeModelListener = listener.delegateFailureAndWrap(
(delegate, verifiedModel) -> modelRegistry.storeModel(
verifiedModel,
ActionListener.wrap(r -> startInferenceEndpoint(service, verifiedModel, delegate), e -> {
ActionListener.wrap(r -> startInferenceEndpoint(service, timeout, verifiedModel, delegate), e -> {
if (e.getCause() instanceof StrictDynamicMappingException && e.getCause().getMessage().contains("chunking_settings")) {
delegate.onFailure(
new ElasticsearchStatusException(
Expand All @@ -199,11 +201,16 @@ private void parseAndStoreModel(
service.parseRequestConfig(inferenceEntityId, taskType, config, parsedModelListener);
}

private void startInferenceEndpoint(InferenceService service, Model model, ActionListener<PutInferenceModelAction.Response> listener) {
private void startInferenceEndpoint(
InferenceService service,
TimeValue timeout,
Model model,
ActionListener<PutInferenceModelAction.Response> listener
) {
if (skipValidationAndStart) {
listener.onResponse(new PutInferenceModelAction.Response(model.getConfigurations()));
} else {
service.start(model, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
service.start(model, timeout, listener.map(started -> new PutInferenceModelAction.Response(model.getConfigurations())));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,16 @@ protected abstract void doChunkedInfer(
ActionListener<List<ChunkedInferenceServiceResults>> listener
);

@Override
public void start(Model model, ActionListener<Boolean> listener) {
init();

doStart(model, listener);
}

@Override
public void start(Model model, @Nullable TimeValue unused, ActionListener<Boolean> listener) {
start(model, listener);
}

protected void doStart(Model model, ActionListener<Boolean> listener) {
listener.onResponse(true);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public BaseElasticsearchInternalService(
}

@Override
public void start(Model model, ActionListener<Boolean> finalListener) {
public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalListener) {
if (model instanceof ElasticsearchInternalModel esModel) {
if (supportedTaskTypes().contains(model.getTaskType()) == false) {
finalListener.onFailure(
Expand All @@ -107,7 +107,7 @@ public void start(Model model, ActionListener<Boolean> finalListener) {
}
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest();
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
Expand Down Expand Up @@ -149,8 +149,7 @@ protected static IllegalStateException notElasticsearchModelException(Model mode
);
}

@Override
public void putModel(Model model, ActionListener<Boolean> listener) {
protected void putModel(Model model, ActionListener<Boolean> listener) {
if (model instanceof ElasticsearchInternalModel == false) {
listener.onFailure(notElasticsearchModelException(model));
return;
Expand Down Expand Up @@ -303,10 +302,9 @@ protected void maybeStartDeployment(
}

if (isDefaultId(model.getInferenceEntityId()) && ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) {
this.start(
model,
listener.delegateFailureAndWrap((l, started) -> { client.execute(InferModelAction.INSTANCE, request, listener); })
);
this.start(model, request.getInferenceTimeout(), listener.delegateFailureAndWrap((l, started) -> {
client.execute(InferModelAction.INSTANCE, request, listener);
}));
} else {
listener.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.inference.services.elasticsearch;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
Expand All @@ -31,7 +32,7 @@ public boolean usesExistingDeployment() {
}

@Override
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
throw new IllegalStateException("cannot start model that uses an existing deployment");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.Strings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ChunkingSettings;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
Expand Down Expand Up @@ -67,11 +68,12 @@ public ElasticsearchInternalModel(
this.internalServiceSettings = internalServiceSettings;
}

public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest() {
public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentActionRequest(TimeValue timeout) {
var startRequest = new StartTrainedModelDeploymentAction.Request(internalServiceSettings.modelId(), this.getInferenceEntityId());
startRequest.setNumberOfAllocations(internalServiceSettings.getNumAllocations());
startRequest.setThreadsPerAllocation(internalServiceSettings.getNumThreads());
startRequest.setAdaptiveAllocationsSettings(internalServiceSettings.getAdaptiveAllocationsSettings());
startRequest.setTimeout(timeout);
startRequest.setWaitForState(STARTED);

return startRequest;
Expand Down

0 comments on commit ea3abee

Please sign in to comment.