Skip to content

Commit

Permalink
return model id in registering remote model (#1396)
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <[email protected]>
(cherry picked from commit d21b032)

Signed-off-by: zane-neo <[email protected]>
Co-authored-by: Xun Zhang <[email protected]>
Co-authored-by: zane-neo <[email protected]>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent b42104c commit 79a7829
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,46 @@
@Getter
public class MLRegisterModelResponse extends ActionResponse implements ToXContentObject {
public static final String TASK_ID_FIELD = "task_id";
public static final String MODEL_ID_FIELD = "model_id";
public static final String STATUS_FIELD = "status";

private String taskId;
private String status;
private String modelId;

public MLRegisterModelResponse(StreamInput in) throws IOException {
super(in);
this.taskId = in.readString();
this.status = in.readString();
this.modelId = in.readOptionalString();
}

public MLRegisterModelResponse(String taskId, String status) {
this.taskId = taskId;
this.status= status;
}

public MLRegisterModelResponse(String taskId, String status, String modelId) {
this.taskId = taskId;
this.status= status;
this.modelId = modelId;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(taskId);
out.writeString(status);
out.writeOptionalString(modelId);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
builder.startObject();
builder.field(TASK_ID_FIELD, taskId);
builder.field(STATUS_FIELD, status);
if (modelId != null) {
builder.field(MODEL_ID_FIELD, modelId);
}
builder.endObject();
return builder;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,27 @@ public class MLRegisterModelResponseTest {

private String taskId;
private String status;
private String modelId;

@Before
public void setUp() throws Exception {
taskId = "test_id";
status = "test";
modelId = "model_id";
}

@Test
public void writeTo_Success() throws IOException {
// Setup
BytesStreamOutput bytesStreamOutput = new BytesStreamOutput();
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status);
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId);
// Run the test
response.writeTo(bytesStreamOutput);
MLRegisterModelResponse parsedResponse = new MLRegisterModelResponse(bytesStreamOutput.bytes().streamInput());
// Verify the results
assertEquals(response.getTaskId(), parsedResponse.getTaskId());
assertEquals(response.getStatus(), parsedResponse.getStatus());
assertEquals(response.getModelId(), parsedResponse.getModelId());
}

@Test
Expand All @@ -52,4 +55,18 @@ public void testToXContent() throws IOException {
assertEquals("{\"task_id\":\"test_id\"," +
"\"status\":\"test\"}", jsonStr);
}

@Test
public void testToXContent_withModelId() throws IOException {
// Setup
MLRegisterModelResponse response = new MLRegisterModelResponse(taskId, status, modelId);
// Run the test
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertNotNull(builder);
String jsonStr = builder.toString();
// Verify the results
assertEquals("{\"task_id\":\"test_id\"," +
"\"status\":\"test\"," + "\"model_id\":\"model_id\"}", jsonStr);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ private void registerModel(MLRegisterModelInput registerModelInput, ActionListen
throw new IllegalArgumentException("URL can't match trusted url regex");
}
}
System.out.println("registering the model");
boolean isAsync = registerModelInput.getFunctionName() != FunctionName.REMOTE;
MLTask mlTask = MLTask
.builder()
Expand Down
232 changes: 191 additions & 41 deletions plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.exception.MLLimitExceededException;
Expand All @@ -103,6 +104,7 @@
import org.opensearch.ml.common.transport.deploy.MLDeployModelRequest;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput;
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.MLExecutable;
Expand Down Expand Up @@ -222,30 +224,22 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput,
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
Map<String, Object> source = modelGroup.getSourceAsMap();
int latestVersion = (int) source.get(MLModelGroup.LATEST_VERSION_FIELD);
int newVersion = latestVersion + 1;
source.put(MLModelGroup.LATEST_VERSION_FIELD, newVersion);
source.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateRequest updateModelGroupRequest = new UpdateRequest();
long seqNo = modelGroup.getSeqNo();
long primaryTerm = modelGroup.getPrimaryTerm();
updateModelGroupRequest
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.setIfSeqNo(seqNo)
.setIfPrimaryTerm(primaryTerm)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.doc(source);
client
.update(
updateModelGroupRequest,
ActionListener
.wrap(r -> { uploadMLModelMeta(mlRegisterModelMetaInput, newVersion + "", listener); }, e -> {
log.error("Failed to update model group", e);
listener.onFailure(e);
})
);
Map<String, Object> modelGroupSource = modelGroup.getSourceAsMap();
int updatedVersion = incrementLatestVersion(modelGroupSource);
UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(
modelGroupSource,
modelGroupId,
modelGroup.getSeqNo(),
modelGroup.getPrimaryTerm(),
updatedVersion
);

client.update(updateModelGroupRequest, ActionListener.wrap(r -> {
uploadMLModelMeta(mlRegisterModelMetaInput, updatedVersion + "", listener);
}, e -> {
log.error("Failed to update model group", e);
listener.onFailure(e);
}));
} else {
log.error("Model group not found");
listener.onFailure(new MLResourceNotFoundException("Fail to find model group"));
Expand Down Expand Up @@ -312,6 +306,80 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput
}
}

/**
*
* @param mlRegisterModelInput register model input for remote models
* @param mlTask ML task
* @param listener action listener
*/
public void registerMLRemoteModel(
MLRegisterModelInput mlRegisterModelInput,
MLTask mlTask,
ActionListener<MLRegisterModelResponse> listener
) {
checkAndAddRunningTask(mlTask, maxRegisterTasksPerNode);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment();
mlStats.createCounterStatIfAbsent(mlTask.getFunctionName(), REGISTER, ML_ACTION_REQUEST_COUNT).increment();
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment();

String modelGroupId = mlRegisterModelInput.getModelGroupId();
GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId);
if (Strings.isBlank(modelGroupId)) {
indexRemoteModel(mlRegisterModelInput, mlTask, "1", listener);
}

client.get(getModelGroupRequest, ActionListener.wrap(getModelGroupResponse -> {
if (getModelGroupResponse.isExists()) {
Map<String, Object> modelGroupSourceMap = getModelGroupResponse.getSourceAsMap();
int updatedVersion = incrementLatestVersion(modelGroupSourceMap);
UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(
modelGroupSourceMap,
modelGroupId,
getModelGroupResponse.getSeqNo(),
getModelGroupResponse.getPrimaryTerm(),
updatedVersion
);
client.update(updateModelGroupRequest, ActionListener.wrap(r -> {
indexRemoteModel(mlRegisterModelInput, mlTask, updatedVersion + "", listener);
}, e -> {
log.error("Failed to update model group " + modelGroupId, e);
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e);
listener.onFailure(e);
}));
} else {
log.error("Model group response is empty");
handleException(
mlRegisterModelInput.getFunctionName(),
mlTask.getTaskId(),
new MLValidationException("Model group not found")
);
listener.onFailure(new MLResourceNotFoundException("Model Group Response is empty for " + modelGroupId));
}
}, error -> {
if (error instanceof IndexNotFoundException) {
log.error("Model group Index is missing");
handleException(
mlRegisterModelInput.getFunctionName(),
mlTask.getTaskId(),
new MLResourceNotFoundException("Failed to get model group due to index missing")
);
listener.onFailure(error);
} else {
log.error("Failed to get model group", error);
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), error);
listener.onFailure(error);
}
}));
} catch (Exception e) {
log.error("Failed to register remote model", e);
handleException(mlRegisterModelInput.getFunctionName(), mlTask.getTaskId(), e);
listener.onFailure(e);
} finally {
mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).decrement();
}
}

/**
* Register model. Basically download model file, split into chunks and save into model index.
*
Expand All @@ -334,25 +402,19 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> {
if (modelGroup.isExists()) {
Map<String, Object> source = modelGroup.getSourceAsMap();
int latestVersion = (int) source.get(MLModelGroup.LATEST_VERSION_FIELD);
int newVersion = latestVersion + 1;
source.put(MLModelGroup.LATEST_VERSION_FIELD, newVersion);
source.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateRequest updateModelGroupRequest = new UpdateRequest();
long seqNo = modelGroup.getSeqNo();
long primaryTerm = modelGroup.getPrimaryTerm();
updateModelGroupRequest
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.setIfSeqNo(seqNo)
.setIfPrimaryTerm(primaryTerm)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.doc(source);
Map<String, Object> modelGroupSourceMap = modelGroup.getSourceAsMap();
int updatedVersion = incrementLatestVersion(modelGroupSourceMap);
UpdateRequest updateModelGroupRequest = createUpdateModelGroupRequest(
modelGroupSourceMap,
modelGroupId,
modelGroup.getSeqNo(),
modelGroup.getPrimaryTerm(),
updatedVersion
);
client
.update(
updateModelGroupRequest,
ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, newVersion + ""); }, e -> {
ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, updatedVersion + ""); }, e -> {
log.error("Failed to update model group", e);
handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e);
})
Expand Down Expand Up @@ -388,6 +450,95 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa
}
}

private UpdateRequest createUpdateModelGroupRequest(
Map<String, Object> modelGroupSourceMap,
String modelGroupId,
long seqNo,
long primaryTerm,
int updatedVersion
) {
modelGroupSourceMap.put(MLModelGroup.LATEST_VERSION_FIELD, updatedVersion);
modelGroupSourceMap.put(MLModelGroup.LAST_UPDATED_TIME_FIELD, Instant.now().toEpochMilli());
UpdateRequest updateModelGroupRequest = new UpdateRequest();

updateModelGroupRequest
.index(ML_MODEL_GROUP_INDEX)
.id(modelGroupId)
.setIfSeqNo(seqNo)
.setIfPrimaryTerm(primaryTerm)
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.doc(modelGroupSourceMap);

return updateModelGroupRequest;
}

private int incrementLatestVersion(Map<String, Object> modelGroupSourceMap) {
return (int) modelGroupSourceMap.get(MLModelGroup.LATEST_VERSION_FIELD) + 1;
}

private void indexRemoteModel(
MLRegisterModelInput registerModelInput,
MLTask mlTask,
String modelVersion,
ActionListener<MLRegisterModelResponse> listener
) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
String modelName = registerModelInput.getModelName();
String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion;
Instant now = Instant.now();
if (registerModelInput.getConnector() != null) {
registerModelInput.getConnector().encrypt(mlEngine::encrypt);
}

mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(boolResponse -> {
MLModel mlModelMeta = MLModel
.builder()
.name(modelName)
.algorithm(functionName)
.modelGroupId(registerModelInput.getModelGroupId())
.version(version)
.description(registerModelInput.getDescription())
.modelFormat(registerModelInput.getModelFormat())
.modelState(MLModelState.REGISTERED)
.connector(registerModelInput.getConnector())
.connectorId(registerModelInput.getConnectorId())
.modelConfig(registerModelInput.getModelConfig())
.createdTime(now)
.lastUpdateTime(now)
.build();

IndexRequest indexModelMetaRequest = new IndexRequest(ML_MODEL_INDEX);
indexModelMetaRequest.source(mlModelMeta.toXContent(XContentBuilder.builder(JSON.xContent()), EMPTY_PARAMS));
indexModelMetaRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// index remote model doc
ActionListener<IndexResponse> indexListener = ActionListener.wrap(modelMetaRes -> {
String modelId = modelMetaRes.getId();
mlTask.setModelId(modelId);
log.info("create new model meta doc {} for upload task {}", modelId, taskId);
mlTaskManager.updateMLTask(taskId, ImmutableMap.of(MODEL_ID_FIELD, modelId, STATE_FIELD, COMPLETED), 5000, true);
// if (registerModelInput.isDeployModel()) {
// deployModelAfterRegistering(registerModelInput, modelId);
// }
listener.onResponse(new MLRegisterModelResponse(taskId, MLTaskState.CREATED.name(), modelId));
}, e -> {
log.error("Failed to index model meta doc", e);
handleException(functionName, taskId, e);
listener.onFailure(e);
});

client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener));
}, error -> {
// failed to initialize the model index
log.error("Failed to init model index", error);
handleException(functionName, taskId, error);
listener.onFailure(error);
}));
}
}

private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask mlTask, String modelVersion) {
String taskId = mlTask.getTaskId();
FunctionName functionName = mlTask.getFunctionName();
Expand Down Expand Up @@ -431,7 +582,6 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml
log.error("Failed to index model meta doc", e);
handleException(functionName, taskId, e);
});

client.index(indexModelMetaRequest, threadedActionListener(REGISTER_THREAD_POOL, indexListener));
}, e -> {
log.error("Failed to init model index", e);
Expand Down
Loading

0 comments on commit 79a7829

Please sign in to comment.