Skip to content

Commit

Permalink
include deployment status in deploy API response (#1395)
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
(cherry picked from commit 9992067)

Co-authored-by: Xun Zhang <[email protected]>
  • Loading branch information
zane-neo and Zhangxunmt authored Sep 27, 2023
1 parent 44946da commit b42104c
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,47 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

@Getter
public class MLDeployModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String TASK_TYPE_FIELD = "task_type";
public static final String STATUS_FIELD = "status";

private String taskId;
private MLTaskType taskType;
private String status;

public MLDeployModelResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.taskType = in.readEnum(MLTaskType.class);
this.status = in.readString();
}

public MLDeployModelResponse(String taskId, String status) {
public MLDeployModelResponse(String taskId, MLTaskType mlTaskType, String status) {
this.taskId = taskId;
this.taskType = mlTaskType;
this.status= status;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeEnum(taskType);
out.writeString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
if (taskType != null) {
builder.field(TASK_TYPE_FIELD, taskType);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.ml.common.MLTaskType;

import java.io.IOException;

Expand All @@ -19,37 +20,40 @@ public class MLDeployModelResponseTest {

private String taskId;
private String status;
private MLTaskType taskType;

@Before
public void setUp() throws Exception {
taskId = "test_id";
status = "test";
taskType = MLTaskType.DEPLOY_MODEL;
}

@Test
public void writeTo_Success() throws IOException {
// Setup
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
response.writeTo(bytesStreamOutput);
MLDeployModelResponse parsedResponse = new MLDeployModelResponse(bytesStreamOutput.bytes().streamInput());
// Verify the results
assertEquals(response.getTaskId(), parsedResponse.getTaskId());
assertEquals(response.getTaskType(), parsedResponse.getTaskType());
assertEquals(response.getStatus(), parsedResponse.getStatus());
}

@Test
public void testToXContent() throws IOException {
// Setup
MLDeployModelResponse response = new MLDeployModelResponse(taskId, status);
MLDeployModelResponse response = new MLDeployModelResponse(taskId, taskType, status);
// Run the test
XContentBuilder builder = MediaTypeRegistry.contentBuilder(XContentType.JSON);
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
// Verify the results
assertEquals("{\"task_id\":\"test_id\"," +
assertEquals("{\"task_id\":\"test_id\"," + "\"task_type\":\"DEPLOY_MODEL\"," +
"\"status\":\"test\"}", jsonStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,14 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
if (algorithm == FunctionName.REMOTE) {
mlTaskManager.add(mlTask, nodeIds);
deployRemoteModel(mlModel, mlTask, localNodeId, eligibleNodes, deployToAllNodes, listener);
return;
}
try {
mlTaskManager.add(mlTask, nodeIds);
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
.execute(
Expand Down Expand Up @@ -260,6 +265,82 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl

}

@VisibleForTesting
void deployRemoteModel(
MLModel mlModel,
MLTask mlTask,
String localNodeId,
List<DiscoveryNode> eligibleNodes,
boolean deployToAllNodes,
ActionListener<MLDeployModelResponse> listener
) {
MLDeployModelInput deployModelInput = new MLDeployModelInput(
mlModel.getModelId(),
mlTask.getTaskId(),
mlModel.getModelContentHash(),
eligibleNodes.size(),
localNodeId,
deployToAllNodes,
mlTask
);

MLDeployModelNodesRequest deployModelRequest = new MLDeployModelNodesRequest(
eligibleNodes.toArray(new DiscoveryNode[0]),
deployModelInput
);

ActionListener<MLDeployModelNodesResponse> actionListener = deployModelNodesResponseListener(
mlTask.getTaskId(),
mlModel.getModelId(),
listener
);
List<String> workerNodes = eligibleNodes.stream().map(n -> n.getId()).collect(Collectors.toList());
mlModelManager
.updateModel(
mlModel.getModelId(),
ImmutableMap
.of(
MLModel.MODEL_STATE_FIELD,
MLModelState.DEPLOYING,
MLModel.PLANNING_WORKER_NODE_COUNT_FIELD,
eligibleNodes.size(),
MLModel.PLANNING_WORKER_NODES_FIELD,
workerNodes,
MLModel.DEPLOY_TO_ALL_NODES_FIELD,
deployToAllNodes
),
ActionListener
.wrap(
r -> client.execute(MLDeployModelOnNodeAction.INSTANCE, deployModelRequest, actionListener),
actionListener::onFailure
)
);
}

private ActionListener<MLDeployModelNodesResponse> deployModelNodesResponseListener(
String taskId,
String modelId,
ActionListener<MLDeployModelResponse> listener
) {
return ActionListener.wrap(r -> {
if (mlTaskManager.contains(taskId)) {
mlTaskManager.updateMLTask(taskId, ImmutableMap.of(STATE_FIELD, MLTaskState.RUNNING), TASK_SEMAPHORE_TIMEOUT, false);
}
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, MLTaskState.COMPLETED.name()));
}, e -> {
log.error("Failed to deploy model " + modelId, e);
mlTaskManager
.updateMLTask(
taskId,
ImmutableMap.of(MLTask.ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(e), STATE_FIELD, FAILED),
TASK_SEMAPHORE_TIMEOUT,
true
);
mlModelManager.updateModel(modelId, ImmutableMap.of(MLModel.MODEL_STATE_FIELD, MLModelState.DEPLOY_FAILED));
listener.onFailure(e);
});
}

@VisibleForTesting
void updateModelDeployStatusAndTriggerOnNodesAction(
String modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
mlTask.setTaskId(taskId);
mlModelManager.registerMLModel(registerModelInput, mlTask);
listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name()));
mlModelManager.registerMLRemoteModel(registerModelInput, mlTask, listener);
}, e -> {
logException("Failed to register model", e, log);
listener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ public void test_toXContent() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String xContentString = TestHelper.xContentBuilderToString(builder);
System.out.println(xContentString);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ public void testDeployRemoteModel() throws IOException, InterruptedException {
String modelId = (String) responseMap.get("model_id");
response = deployRemoteModel(modelId);
responseMap = parseResponseToMap(response);
assertEquals("CREATED", (String) responseMap.get("status"));
assertEquals("COMPLETED", (String) responseMap.get("status"));
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ public void testPrepareRequest() throws Exception {
SearchRequest searchRequest = argumentCaptor.getValue();
String[] indices = searchRequest.indices();
assertArrayEquals(new String[] { ML_CONNECTOR_INDEX }, indices);
System.out.println(searchRequest);
assertEquals(
"{\"query\":{\"match_all\":{\"boost\":1.0}},\"version\":true,\"seq_no_primary_term\":true,\"_source\":{\"includes\":[],\"excludes\":[\"content\",\"model_content\",\"ui_metadata\"]}}",
searchRequest.source().toString()
Expand Down

0 comments on commit b42104c

Please sign in to comment.