Skip to content

Commit

Permalink
Separated Register and Deploy Steps
Browse files Browse the repository at this point in the history
Signed-off-by: Owais Kazi <[email protected]>
  • Loading branch information
owaiskazi19 committed Oct 3, 2023
1 parent a126c47 commit 4404633
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 117 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,74 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

public class DeployModel {
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;

public class DeployModel implements WorkflowStep {
private static final Logger logger = LogManager.getLogger(DeployModel.class);

public void deployModel(MachineLearningNodeClient machineLearningNodeClient, String modelId) {
private NodeClient nodeClient;
private static final String MODEL_ID = "model_id";
static final String NAME = "deploy_model";

public DeployModel(Client client) {
this.nodeClient = (NodeClient) client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> deployModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient);

ActionListener<MLDeployModelResponse> actionListener = new ActionListener<>() {
@Override
public void onResponse(MLDeployModelResponse mlDeployModelResponse) {
if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) {
logger.info("Model deployed successfully");
deployModelFuture.complete(
new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus())))
);
}
}

@Override
public void onFailure(Exception e) {
logger.error("Model deployment failed");
deployModelFuture.completeExceptionally(e);
}
};

String modelId = null;

for (WorkflowData workflowData : data) {
if (workflowData != null) {
Map<String, Object> content = workflowData.getContent();

for (Map.Entry<String, Object> entry : content.entrySet()) {
if (entry.getKey() == MODEL_ID) {
modelId = (String) content.get(MODEL_ID);
}

}
}
}
machineLearningNodeClient.deploy(modelId, actionListener);
return deployModelFuture;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,44 +21,43 @@
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

public class RegisterAndDeployModelStep implements WorkflowStep {
public class RegisterModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(RegisterAndDeployModelStep.class);
private static final Logger logger = LogManager.getLogger(RegisterModelStep.class);

private Client client;
private ThreadPool threadPool;
private NodeClient nodeClient;
private volatile Scheduler.Cancellable scheduledFuture;

static final String NAME = "register_model_step";
static final String NAME = "register_model";

private static final String FUNCTION_NAME = "function_name";
private static final String MODEL_NAME = "model_name";
private static final String MODEL_NAME = "name";
private static final String MODEL_VERSION = "model_version";
private static final String MODEL_GROUP_ID = "model_group_id";
private static final String DESCRIPTION = "description";
private static final String CONNECTOR_ID = "connector_id";
private static final String MODEL_FORMAT = "model_format";
private static final String MODEL_CONFIG = "model_config";

public RegisterAndDeployModelStep(Client client) {
this.client = client;
public RegisterModelStep(Client client) {
this.nodeClient = (NodeClient) client;
}

@Override
public CompletableFuture<WorkflowData> execute(List<WorkflowData> data) {

CompletableFuture<WorkflowData> registerModelFuture = new CompletableFuture<>();

MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client);
MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient);

ActionListener<MLRegisterModelResponse> actionListener = new ActionListener<>() {
@Override
Expand All @@ -85,9 +84,17 @@ public void onFailure(Exception e) {
// scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient,
// mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC);

DeployModel deployModel = new DeployModel();
deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());

/*DeployModel deployModel = new DeployModel();
deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());*/
logger.info("Model registration successful");
registerModelFuture.complete(
new WorkflowData(
Map.ofEntries(
Map.entry("modelId", mlRegisterModelResponse.getModelId()),
Map.entry("model-register-status", mlRegisterModelResponse.getStatus())
)
)
);
}

@Override
Expand All @@ -107,53 +114,50 @@ public void onFailure(Exception e) {
MLModelConfig modelConfig = null;

for (WorkflowData workflowData : data) {
Map<String, String> parameters = workflowData.getParams();
Map<String, Object> content = workflowData.getContent();
logger.info("Previous step sent params: {}, content: {}", parameters, content);

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case FUNCTION_NAME:
functionName = (FunctionName) content.get(FUNCTION_NAME);
break;
case MODEL_NAME:
modelName = (String) content.get(MODEL_NAME);
break;
case MODEL_VERSION:
modelVersion = (String) content.get(MODEL_VERSION);
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_FORMAT:
modelFormat = (MLModelFormat) content.get(MODEL_FORMAT);
break;
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
break;
default:
break;
if (workflowData != null) {
Map<String, Object> content = workflowData.getContent();
logger.info("Previous step sent content: {}", content);

for (Entry<String, Object> entry : content.entrySet()) {
switch (entry.getKey()) {
case FUNCTION_NAME:
functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT));
break;
case MODEL_NAME:
modelName = (String) content.get(MODEL_NAME);
break;
case MODEL_VERSION:
modelVersion = (String) content.get(MODEL_VERSION);
break;
case MODEL_GROUP_ID:
modelGroupId = (String) content.get(MODEL_GROUP_ID);
break;
case MODEL_FORMAT:
modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT));
break;
case MODEL_CONFIG:
modelConfig = (MLModelConfig) content.get(MODEL_CONFIG);
break;
case DESCRIPTION:
description = (String) content.get(DESCRIPTION);
break;
case CONNECTOR_ID:
connectorId = (String) content.get(CONNECTOR_ID);
break;
default:
break;

}
}
}
}

if (Stream.of(functionName, modelName, modelVersion, modelGroupId, description, connectorId).allMatch(x -> x != null)) {
if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) {

// TODO: Add model Config and type cast correctly
MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.modelName(modelName)
.version(modelVersion)
.modelGroupId(modelGroupId)
.modelFormat(modelFormat)
.modelConfig(modelConfig)
.description(description)
.connectorId(connectorId)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ private WorkflowStepFactory(Client client) {
private void populateMap(Client client) {
stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(client));
stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client));
stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client));
stepMap.put(DeployModel.NAME, new DeployModel(client));

// TODO: These are from the demo class as placeholders, remove when demos are deleted
stepMap.put("demo_delay_3", new DemoWorkflowStep(3000));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,19 @@
*/
package org.opensearch.flowframework.workflow;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;

import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.client.MLClient;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.model.MLModelConfig;
import org.opensearch.ml.common.model.MLModelFormat;
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.client.NoOpNodeClient;

import java.util.List;
import java.util.Map;
Expand All @@ -30,13 +31,15 @@

import static org.mockito.Mockito.*;

public class RegisterAndDeployModelStepTests extends OpenSearchTestCase {
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class RegisterModelStepTests extends OpenSearchTestCase {
private WorkflowData inputData = WorkflowData.EMPTY;

@Mock(answer = Answers.RETURNS_DEEP_STUBS)
private NodeClient nodeClient;

private MachineLearningNodeClient machineLearningNodeClient;
@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
Expand All @@ -49,68 +52,41 @@ public void setUp() throws Exception {
.embeddingDimension(100)
.build();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(
Map.of(
"function_name",
FunctionName.KMEANS,
"model_name",
"bedrock",
"model_version",
"1.0.0",
"model_group_id",
"1.0",
"model_format",
MLModelFormat.TORCH_SCRIPT,
"model_config",
config,
"description",
"description",
"connector_id",
"abcdefgh"
Map.ofEntries(
Map.entry("function_name", "remote"),
Map.entry("name", "xyz"),
Map.entry("description", "description"),
Map.entry("connector_id", "abcdefg")
)
);

nodeClient = mock(NodeClient.class);

nodeClient = new NoOpNodeClient("xyz");
}

public void testRegisterModel() throws ExecutionException, InterruptedException {

FunctionName functionName = FunctionName.KMEANS;

MLModelConfig config = TextEmbeddingModelConfig.builder()
.modelType("testModelType")
.allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}")
.frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS)
.embeddingDimension(100)
.build();

MLRegisterModelInput mlInput = MLRegisterModelInput.builder()
.functionName(functionName)
.functionName(FunctionName.from("REMOTE"))
.modelName("testModelName")
.version("testModelVersion")
.modelGroupId("modelGroupId")
.url("url")
.modelFormat(MLModelFormat.ONNX)
.modelConfig(config)
.description("description")
.connectorId("abcdefgh")
.build();

RegisterAndDeployModelStep registerModelStep = new RegisterAndDeployModelStep(nodeClient);
RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient);

ArgumentCaptor<ActionListener<MLRegisterModelResponse>> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class);
CompletableFuture<WorkflowData> future = registerModelStep.execute(List.of(inputData));

assertFalse(future.isDone());

/*try (MockedStatic<MLClient> mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) {
mlClientMockedStatic
.when(() -> MLClient.createMLClient(any(NodeClient.class)))
.thenReturn(machineLearningNodeClient);
}*/
when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient);
// when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient);
verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture());
actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc"));

Expand Down
Loading

0 comments on commit 4404633

Please sign in to comment.