From 923fd0fdad65c26b7ee4401919f4225de59f67d6 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Tue, 3 Oct 2023 10:56:08 -0700 Subject: [PATCH] Separated Register and Deploy Steps Signed-off-by: Owais Kazi --- .../flowframework/workflow/DeployModel.java | 48 +++++++- ...yModelStep.java => RegisterModelStep.java} | 104 +++++++++--------- .../workflow/WorkflowStepFactory.java | 2 + ...Tests.java => RegisterModelStepTests.java} | 60 +++------- src/test/resources/template/demo.json | 40 +++---- .../resources/template/finaltemplate.json | 3 +- 6 files changed, 140 insertions(+), 117 deletions(-) rename src/main/java/org/opensearch/flowframework/workflow/{RegisterAndDeployModelStep.java => RegisterModelStep.java} (63%) rename src/test/java/org/opensearch/flowframework/workflow/{RegisterAndDeployModelStepTests.java => RegisterModelStepTests.java} (64%) diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java index 8bb6ec232..65ba93e58 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModel.java @@ -10,30 +10,74 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -public class DeployModel { +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public class DeployModel implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModel.class); - public void deployModel(MachineLearningNodeClient machineLearningNodeClient, String modelId) { + private NodeClient nodeClient; + private static final String MODEL_ID = "model_id"; + static final String NAME = "deploy_model"; + + public DeployModel(Client client) { + this.nodeClient = (NodeClient) client; + } + + @Override + public CompletableFuture execute(List data) { + + CompletableFuture deployModelFuture = new CompletableFuture<>(); + + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { if (mlDeployModelResponse.getStatus() == MLTaskState.COMPLETED.name()) { logger.info("Model deployed successfully"); + deployModelFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("deploy-model-status", mlDeployModelResponse.getStatus()))) + ); } } @Override public void onFailure(Exception e) { logger.error("Model deployment failed"); + deployModelFuture.completeExceptionally(e); } }; + + String modelId = null; + + for (WorkflowData workflowData : data) { + if (workflowData != null) { + Map content = workflowData.getContent(); + + for (Map.Entry entry : content.entrySet()) { + if (entry.getKey() == MODEL_ID) { + modelId = (String) content.get(MODEL_ID); + } + + } + } + } machineLearningNodeClient.deploy(modelId, actionListener); + return deployModelFuture; + } + @Override + public String getName() { + return NAME; } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java similarity index 63% rename from src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java rename to src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index ef2e4e59c..7d11e2f73 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -21,27 +21,26 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.threadpool.Scheduler; -import org.opensearch.threadpool.ThreadPool; import java.io.IOException; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Map.Entry; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -public class RegisterAndDeployModelStep implements WorkflowStep { +public class RegisterModelStep implements WorkflowStep { - private static final Logger logger = LogManager.getLogger(RegisterAndDeployModelStep.class); + private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private Client client; - private ThreadPool threadPool; + private NodeClient nodeClient; private volatile Scheduler.Cancellable scheduledFuture; - static final String NAME = "register_model_step"; + static final String NAME = "register_model"; private static final String FUNCTION_NAME = "function_name"; - private static final String MODEL_NAME = "model_name"; + private static final String MODEL_NAME = "name"; private static final String MODEL_VERSION = "model_version"; private static final String MODEL_GROUP_ID = "model_group_id"; private static final String DESCRIPTION = "description"; @@ -49,8 +48,8 @@ public class RegisterAndDeployModelStep implements WorkflowStep { private static final String MODEL_FORMAT = "model_format"; private static final String MODEL_CONFIG = "model_config"; - public RegisterAndDeployModelStep(Client client) { - this.client = client; + public RegisterModelStep(Client client) { + this.nodeClient = (NodeClient) client; } @Override @@ -58,7 +57,7 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient((NodeClient) client); + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(nodeClient); ActionListener actionListener = new ActionListener<>() { @Override @@ -85,9 +84,17 @@ public void onFailure(Exception e) { // scheduledFuture = threadPool.scheduleWithFixedDelay(new GetTask(machineLearningNodeClient, // mlRegisterModelResponse.getTaskId()), TimeValue.timeValueMillis(10L), ThreadPool.Names.GENERIC); - DeployModel deployModel = new DeployModel(); - deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId()); - + /*DeployModel deployModel = new DeployModel(); + deployModel.deployModel(machineLearningNodeClient, mlRegisterModelResponse.getModelId());*/ + logger.info("Model registration successful"); + registerModelFuture.complete( + new WorkflowData( + Map.ofEntries( + Map.entry("modelId", mlRegisterModelResponse.getModelId()), + Map.entry("model-register-status", mlRegisterModelResponse.getStatus()) + ) + ) + ); } @Override @@ -107,53 +114,50 @@ public void onFailure(Exception e) { MLModelConfig modelConfig = null; for (WorkflowData workflowData : data) { - Map parameters = workflowData.getParams(); - Map content = workflowData.getContent(); - logger.info("Previous step sent params: {}, content: {}", parameters, content); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case FUNCTION_NAME: - functionName = (FunctionName) content.get(FUNCTION_NAME); - break; - case MODEL_NAME: - modelName = (String) content.get(MODEL_NAME); - break; - case MODEL_VERSION: - modelVersion = (String) content.get(MODEL_VERSION); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_FORMAT: - modelFormat = (MLModelFormat) content.get(MODEL_FORMAT); - break; - case MODEL_CONFIG: - modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); - break; - case DESCRIPTION: - description = (String) content.get(DESCRIPTION); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; + if (workflowData != null) { + Map content = workflowData.getContent(); + logger.info("Previous step sent content: {}", content); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case FUNCTION_NAME: + functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + break; + case MODEL_NAME: + modelName = (String) content.get(MODEL_NAME); + break; + case MODEL_VERSION: + modelVersion = (String) content.get(MODEL_VERSION); + break; + case MODEL_GROUP_ID: + modelGroupId = (String) content.get(MODEL_GROUP_ID); + break; + case MODEL_FORMAT: + modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); + break; + case MODEL_CONFIG: + modelConfig = (MLModelConfig) content.get(MODEL_CONFIG); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case CONNECTOR_ID: + connectorId = (String) content.get(CONNECTOR_ID); + break; + default: + break; + } } } } - if (Stream.of(functionName, modelName, modelVersion, modelGroupId, description, connectorId).allMatch(x -> x != null)) { + if (Stream.of(functionName, modelName, description, connectorId).allMatch(x -> x != null)) { // TODO: Add model Config and type cast correctly MLRegisterModelInput mlInput = MLRegisterModelInput.builder() .functionName(functionName) .modelName(modelName) - .version(modelVersion) - .modelGroupId(modelGroupId) - .modelFormat(modelFormat) - .modelConfig(modelConfig) .description(description) .connectorId(connectorId) .build(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 26dab0f42..aea2ac7d5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -36,6 +36,8 @@ public WorkflowStepFactory(Client client) { private void populateMap(Client client) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); + stepMap.put(DeployModel.NAME, new DeployModel(client)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java similarity index 64% rename from src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 5d3ba1972..13bf9b30c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAndDeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -8,18 +8,19 @@ */ package org.opensearch.flowframework.workflow; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; + import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; -import org.opensearch.ml.common.model.MLModelFormat; import org.opensearch.ml.common.model.TextEmbeddingModelConfig; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; @@ -30,13 +31,15 @@ import static org.mockito.Mockito.*; -public class RegisterAndDeployModelStepTests extends OpenSearchTestCase { +@ThreadLeakScope(ThreadLeakScope.Scope.NONE) +public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; @Mock(answer = Answers.RETURNS_DEEP_STUBS) private NodeClient nodeClient; - private MachineLearningNodeClient machineLearningNodeClient; + @Mock + MachineLearningNodeClient machineLearningNodeClient; @Override public void setUp() throws Exception { @@ -49,68 +52,41 @@ public void setUp() throws Exception { .embeddingDimension(100) .build(); + MockitoAnnotations.openMocks(this); + inputData = new WorkflowData( - Map.of( - "function_name", - FunctionName.KMEANS, - "model_name", - "bedrock", - "model_version", - "1.0.0", - "model_group_id", - "1.0", - "model_format", - MLModelFormat.TORCH_SCRIPT, - "model_config", - config, - "description", - "description", - "connector_id", - "abcdefgh" + Map.ofEntries( + Map.entry("function_name", "remote"), + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry("connector_id", "abcdefg") ) ); - nodeClient = mock(NodeClient.class); - + nodeClient = new NoOpNodeClient("xyz"); } public void testRegisterModel() throws ExecutionException, InterruptedException { - FunctionName functionName = FunctionName.KMEANS; - - MLModelConfig config = TextEmbeddingModelConfig.builder() - .modelType("testModelType") - .allConfig("{\"field1\":\"value1\",\"field2\":\"value2\"}") - .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) - .embeddingDimension(100) - .build(); - MLRegisterModelInput mlInput = MLRegisterModelInput.builder() - .functionName(functionName) + .functionName(FunctionName.from("REMOTE")) .modelName("testModelName") - .version("testModelVersion") - .modelGroupId("modelGroupId") - .url("url") - .modelFormat(MLModelFormat.ONNX) - .modelConfig(config) .description("description") .connectorId("abcdefgh") .build(); - RegisterAndDeployModelStep registerModelStep = new RegisterAndDeployModelStep(nodeClient); + RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); CompletableFuture future = registerModelStep.execute(List.of(inputData)); - assertFalse(future.isDone()); - /*try (MockedStatic mlClientMockedStatic = Mockito.mockStatic(MLClient.class)) { mlClientMockedStatic .when(() -> MLClient.createMLClient(any(NodeClient.class))) .thenReturn(machineLearningNodeClient); }*/ - when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); + // when(spy(MLClient.createMLClient(nodeClient))).thenReturn(machineLearningNodeClient); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new MLRegisterModelResponse("xyz", MLTaskState.COMPLETED.name(), "abc")); diff --git a/src/test/resources/template/demo.json b/src/test/resources/template/demo.json index e27158bff..103afb92a 100644 --- a/src/test/resources/template/demo.json +++ b/src/test/resources/template/demo.json @@ -9,37 +9,33 @@ "nodes": [ { "id": "fetch_model", - "type": "demo_delay_3" + "type": "demo_delay_3", + "inputs": { + "ingest_key": "ingest_value" + } }, { - "id": "create_ingest_pipeline", - "type": "demo_delay_3" + "id": "register_model", + "type": "register_model", + "inputs": { + "name": "openAI-gpt-3.5-turbo", + "function_name": "remote", + "description": "test model", + "connector_id": "uDna54oB76l1MtYJF84U" + } }, { - "id": "create_search_pipeline", - "type": "demo_delay_5" - }, - { - "id": "create_neural_search_index", - "type": "demo_delay_3" + "id": "deploy_model", + "type": "deploy_model", + "inputs": { + "model_id": "abc" + } } ], "edges": [ { "source": "fetch_model", - "dest": "create_ingest_pipeline" - }, - { - "source": "fetch_model", - "dest": "create_search_pipeline" - }, - { - "source": "create_ingest_pipeline", - "dest": "create_neural_search_index" - }, - { - "source": "create_search_pipeline", - "dest": "create_neural_search_index" + "dest": "deploy_model" } ] } diff --git a/src/test/resources/template/finaltemplate.json b/src/test/resources/template/finaltemplate.json index 88a7425f3..689547d8d 100644 --- a/src/test/resources/template/finaltemplate.json +++ b/src/test/resources/template/finaltemplate.json @@ -16,7 +16,8 @@ }, "user_inputs": { "index_name": "my-knn-index", - "index_settings": {} + "index_settings": {}, + }, "workflows": { "provision": {