From d65fd5eb2994759e79a9d0eacb7935c81bc3f63c Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 18 Oct 2023 22:00:51 -0700 Subject: [PATCH 1/5] Added initial implementation of create connector Signed-off-by: Owais Kazi --- .../flowframework/common/CommonValue.java | 14 +- .../workflow/CreateConnectorStep.java | 149 ++++++++++++++++++ .../workflow/RegisterModelStep.java | 8 +- .../workflow/CreateConnectorStepTests.java | 87 ++++++++++ 4 files changed, 252 insertions(+), 6 deletions(-) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 528590d7a..ce0df435a 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -48,8 +48,8 @@ private CommonValue() {} public static final String MODEL_ID = "model_id"; /** Function Name field */ public static final String FUNCTION_NAME = "function_name"; - /** Model Name field */ - public static final String MODEL_NAME = "name"; + /** Name field */ + public static final String NAME_FIELD = "name"; /** Model Version field */ public static final String MODEL_VERSION = "model_version"; /** Model Group Id field */ @@ -62,4 +62,14 @@ private CommonValue() {} public static final String MODEL_FORMAT = "model_format"; /** Model config field */ public static final String MODEL_CONFIG = "model_config"; + /** Version field */ + public static final String VERSION_FIELD = "version"; + /** Connector protocol field */ + public static final String PROTOCOL_FIELD = "protocol"; + /** Connector parameters field */ + public static final String PARAMETERS_FIELD = "parameters"; + /** Connector credentials field */ + public static final String CREDENTIALS_FIELD = "credentials"; + /** Connector actions field */ + public static final String ACTIONS_FIELD = "actions"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java new file mode 100644 index 000000000..71096d4e0 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -0,0 +1,149 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.client.MLClient; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.*; + +public class CreateConnectorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); + + private Client client; + + static final String NAME = "create_connector"; + + /** + * Instantiate this class + * @param client client to instantiate MLClient + */ + public CreateConnectorStep(Client client) { + this.client = client; + } + + @Override + public CompletableFuture execute(List data) throws IOException { + CompletableFuture createConnectorFuture = new CompletableFuture<>(); + + MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { + logger.info("Created connector successfully"); + createConnectorFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("connector-id", mlCreateConnectorResponse.getConnectorId()))) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to create connector"); + createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + } + }; + + String name = null; + String description = null; + String version = null; + String protocol = null; + Map parameters = new HashMap<>(); + Map credentials = new HashMap<>(); + List actions = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case VERSION_FIELD: + version = (String) content.get(VERSION_FIELD); + break; + case PROTOCOL_FIELD: + protocol = (String) content.get(PROTOCOL_FIELD); + case PARAMETERS_FIELD: + parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + case CREDENTIALS_FIELD: + credentials = (Map) content.get(CREDENTIALS_FIELD); + case ACTIONS_FIELD: + actions = (List) content.get(ACTIONS_FIELD); + } + + } + } + + if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { + MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder() + .name(name) + .description(description) + .version(version) + .protocol(protocol) + .parameters(parameters) + .credential(credentials) + .actions(actions) + .build(); + + machineLearningNodeClient.createConnector(mlInput, actionListener); + } + + return createConnectorFuture; + } + + @Override + public String getName() { + return NAME; + } + + private static Map getParameterMap(Map params) { + + Map parameters = new HashMap<>(); + for (String key : params.keySet()) { + String value = params.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + parameters.put(key, value); + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index b97c56d57..cf880626b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -34,8 +34,8 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; /** * Step to register a remote model @@ -80,7 +80,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + registerModelFuture.completeExceptionally(new IOException("Failed to register model")); } }; @@ -101,8 +101,8 @@ public void onFailure(Exception e) { case FUNCTION_NAME: functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); break; - case MODEL_NAME: - modelName = (String) content.get(MODEL_NAME); + case NAME_FIELD: + modelName = (String) content.get(NAME_FIELD); break; case MODEL_VERSION: modelVersion = (String) content.get(MODEL_VERSION); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java new file mode 100644 index 000000000..98c31470e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.client.Client; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.mockito.Answers; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class CreateConnectorStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock(answer = Answers.RETURNS_DEEP_STUBS) + private Client client; + + @Mock + ActionListener registerModelActionListener; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); + Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("version", "1"), + Map.entry("protocol", "test"), + Map.entry("params", params), + Map.entry("credentials", credentials), + Map.entry("actions", List.of("actions")) + ) + ); + + } + + public void testCreateConnector() throws IOException { + + String connectorId = "connect"; + CreateConnectorStep createConnectorStep = new CreateConnectorStep(client); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + } + +} From 8cebc8c98cf8acb6c8395dfb1148d26f688754dd Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 19 Oct 2023 13:02:37 -0700 Subject: [PATCH 2/5] Added test for create connector Signed-off-by: Owais Kazi --- src/main/java/demo/Demo.java | 4 +++- src/main/java/demo/TemplateParseDemo.java | 4 +++- .../flowframework/FlowFrameworkPlugin.java | 4 +++- .../workflow/CreateConnectorStep.java | 18 +++++++++--------- .../workflow/DeployModelStep.java | 14 +++++--------- .../workflow/RegisterModelStep.java | 14 +++++--------- .../workflow/WorkflowStepFactory.java | 12 +++++++----- .../workflow/CreateConnectorStepTests.java | 11 ++++------- .../workflow/DeployModelStepTests.java | 18 +++++------------- .../workflow/RegisterModelStepTests.java | 17 +++++------------ .../workflow/WorkflowProcessSorterTests.java | 4 +++- 11 files changed, 52 insertions(+), 68 deletions(-) diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 910f22b14..e4d2aa8f8 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index a2d0db443..e284764da 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 0bac15c61..b9a35c083 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -34,6 +34,7 @@ import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; @@ -76,7 +77,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 71096d4e0..afae055bc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -10,10 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -37,24 +35,22 @@ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "create_connector"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public CreateConnectorStep(Client client) { - this.client = client; + public CreateConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override public CompletableFuture execute(List data) throws IOException { CompletableFuture createConnectorFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override @@ -96,12 +92,16 @@ public void onFailure(Exception e) { break; case PROTOCOL_FIELD: protocol = (String) content.get(PROTOCOL_FIELD); + break; case PARAMETERS_FIELD: parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + break; case CREDENTIALS_FIELD: credentials = (Map) content.get(CREDENTIALS_FIELD); + break; case ACTIONS_FIELD: actions = (List) content.get(ACTIONS_FIELD); + break; } } @@ -118,7 +118,7 @@ public void onFailure(Exception e) { .actions(actions) .build(); - machineLearningNodeClient.createConnector(mlInput, actionListener); + mlClient.createConnector(mlInput, actionListener); } return createConnectorFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index e4c9b1a14..07558fe0c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,9 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -28,15 +26,15 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "deploy_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public DeployModelStep(Client client) { - this.client = client; + public DeployModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -44,8 +42,6 @@ public CompletableFuture execute(List data) { CompletableFuture deployModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @@ -70,7 +66,7 @@ public void onFailure(Exception e) { break; } } - machineLearningNodeClient.deploy(modelId, actionListener); + mlClient.deploy(modelId, actionListener); return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index cf880626b..bdbda66e4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,9 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; @@ -44,16 +42,16 @@ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "register_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public RegisterModelStep(Client client) { - this.client = client; + public RegisterModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -61,8 +59,6 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -139,7 +135,7 @@ public void onFailure(Exception e) { .connectorId(connectorId) .build(); - machineLearningNodeClient.register(mlInput, actionListener); + mlClient.register(mlInput, actionListener); } return registerModelFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index fdb82ef0b..be52a5fcd 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,6 +10,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; import java.util.List; @@ -32,15 +33,16 @@ public class WorkflowStepFactory { * @param client The OpenSearch client steps can use */ - public WorkflowStepFactory(ClusterService clusterService, Client client) { - populateMap(clusterService, client); + public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + populateMap(clusterService, client, mlClient); } - private void populateMap(ClusterService clusterService, Client client) { + private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 98c31470e..65329661f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,7 +8,6 @@ */ package org.opensearch.flowframework.workflow; -import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; @@ -20,7 +19,6 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -32,9 +30,6 @@ public class CreateConnectorStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private Client client; - @Mock ActionListener registerModelActionListener; @@ -67,12 +62,12 @@ public void setUp() throws Exception { public void testCreateConnector() throws IOException { String connectorId = "connect"; - CreateConnectorStep createConnectorStep = new CreateConnectorStep(client); + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); actionListener.onResponse(output); return null; @@ -82,6 +77,8 @@ public void testCreateConnector() throws IOException { verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + assertTrue(future.isDone()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fc7c695f8..e32a7c75f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,35 +10,30 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -50,8 +45,6 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - nodeClient = new NoOpNodeClient("xyz"); - } public void testDeployModel() { @@ -60,13 +53,13 @@ public void testDeployModel() { String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModelStep deployModel = new DeployModelStep(nodeClient); + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; @@ -74,10 +67,9 @@ public void testDeployModel() { CompletableFuture future = deployModel.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index b1a2b2fc0..1e8026a15 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -10,7 +10,6 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -20,28 +19,24 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock ActionListener registerModelActionListener; @@ -70,7 +65,6 @@ public void setUp() throws Exception { ) ); - nodeClient = new NoOpNodeClient("xyz"); } public void testRegisterModel() throws ExecutionException, InterruptedException { @@ -85,12 +79,12 @@ public void testRegisterModel() throws ExecutionException, InterruptedException .connectorId("abcdefgh") .build(); - RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); actionListener.onResponse(output); return null; @@ -98,10 +92,9 @@ public void testRegisterModel() throws ExecutionException, InterruptedException CompletableFuture future = registerModelStep.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e8ada0e15..f728dd7b1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -60,11 +61,12 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); } From 3f7dfb6bedcfae29f033ae4527ef338aea8ed680 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Thu, 19 Oct 2023 14:51:56 -0700 Subject: [PATCH 3/5] Added more tests and updated MLClient initialization Signed-off-by: Owais Kazi --- .../flowframework/client/MLClient.java | 34 ------------------- .../workflow/CreateConnectorStep.java | 13 +++++-- .../workflow/DeployModelStep.java | 6 ++-- .../workflow/RegisterModelStep.java | 5 +-- .../workflow/WorkflowStepFactory.java | 1 + .../workflow/CreateConnectorStepTests.java | 28 ++++++++++++++- .../workflow/DeployModelStepTests.java | 28 +++++++++++++-- .../workflow/RegisterModelStepTests.java | 25 ++++++++++++++ 8 files changed, 97 insertions(+), 43 deletions(-) delete mode 100644 src/main/java/org/opensearch/flowframework/client/MLClient.java diff --git a/src/main/java/org/opensearch/flowframework/client/MLClient.java b/src/main/java/org/opensearch/flowframework/client/MLClient.java deleted file mode 100644 index 977e24588..000000000 --- a/src/main/java/org/opensearch/flowframework/client/MLClient.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.client; - -import org.opensearch.client.Client; -import org.opensearch.ml.client.MachineLearningNodeClient; - -/** - Class to initiate an instance of MLClient - */ -public class MLClient { - private static MachineLearningNodeClient INSTANCE; - - private MLClient() {} - - /** - * Creates machine learning client. - * - * @param client client of OpenSearch. - * @return machine learning client from ml-commons. - */ - public static MachineLearningNodeClient createMLClient(Client client) { - if (INSTANCE == null) { - INSTANCE = new MachineLearningNodeClient(client); - } - return INSTANCE; - } -} diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index afae055bc..5e040baaf 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -29,8 +29,17 @@ import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; -import static org.opensearch.flowframework.common.CommonValue.*; - +import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; + +/** + * Step to create a connector for a remote model + */ public class CreateConnectorStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 07558fe0c..9dd183e37 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -53,8 +55,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @Override public void onFailure(Exception e) { - logger.error("Model deployment failed"); - deployModelFuture.completeExceptionally(e); + logger.error("Failed to deploy model"); + deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index bdbda66e4..9a430c6f5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; @@ -18,7 +20,6 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -76,7 +77,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new IOException("Failed to register model")); + registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index be52a5fcd..5aabd679f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -31,6 +31,7 @@ public class WorkflowStepFactory { * * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use + * @param mlClient Machine Learning client to perform ml operations */ public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 65329661f..9fd125a1a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -9,6 +9,8 @@ package org.opensearch.flowframework.workflow; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -18,11 +20,13 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import static org.junit.Assert.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; @@ -59,7 +63,7 @@ public void setUp() throws Exception { } - public void testCreateConnector() throws IOException { + public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); @@ -78,7 +82,29 @@ public void testCreateConnector() throws IOException { verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector-id")); } + public void testCreateConnectorFailure() throws IOException { + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create connector", ex.getCause().getMessage()); + } + } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index e32a7c75f..4cdfaebae 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -11,6 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; @@ -20,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; import org.mockito.Mock; @@ -47,8 +50,7 @@ public void setUp() throws Exception { } - public void testDeployModel() { - + public void testDeployModel() throws ExecutionException, InterruptedException { String taskId = "taskId"; String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; @@ -70,6 +72,28 @@ public void testDeployModel() { verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(status, future.get().getContent().get("deploy_model_status")); + } + + public void testDeployModelFailure() { + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + CompletableFuture future = deployModel.execute(List.of(inputData)); + + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to deploy model", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index 1e8026a15..59fb1b173 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -11,6 +11,8 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; @@ -95,7 +97,30 @@ public void testRegisterModel() throws ExecutionException, InterruptedException verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get("model_id")); + assertEquals(status, future.get().getContent().get("register_model_status")); } + public void testRegisterModelFailure() { + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register model", ex.getCause().getMessage()); + } + } From ae5c15538f9d487b14df93d038bd3ef51284cc62 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Fri, 20 Oct 2023 15:46:55 -0700 Subject: [PATCH 4/5] Addressed PR comments Signed-off-by: Owais Kazi --- .../flowframework/workflow/CreateConnectorStep.java | 7 ++++--- .../opensearch/flowframework/workflow/DeployModelStep.java | 4 ++-- .../flowframework/workflow/RegisterModelStep.java | 4 ++-- .../flowframework/workflow/CreateConnectorStepTests.java | 2 +- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 5e040baaf..6c7f22d1a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -10,8 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -65,15 +65,16 @@ public CompletableFuture execute(List data) throws I @Override public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { logger.info("Created connector successfully"); + // TODO Add the response to Global Context createConnectorFuture.complete( - new WorkflowData(Map.ofEntries(Map.entry("connector-id", mlCreateConnectorResponse.getConnectorId()))) + new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) ); } @Override public void onFailure(Exception e) { logger.error("Failed to create connector"); - createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 9dd183e37..d43e7a28b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,8 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -56,7 +56,7 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to deploy model"); - deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 9a430c6f5..3f60e7bb7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,8 +10,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -77,7 +77,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 9fd125a1a..bf170ab9f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -82,7 +82,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); assertTrue(future.isDone()); - assertEquals(connectorId, future.get().getContent().get("connector-id")); + assertEquals(connectorId, future.get().getContent().get("connector_id")); } From 6163fc2c1cbcf9b3a93808c4b73018ef85a509e8 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Wed, 25 Oct 2023 10:06:25 -0700 Subject: [PATCH 5/5] CompletedFuture exceptionally if fields are not present Signed-off-by: Owais Kazi --- .../flowframework/workflow/CreateConnectorStep.java | 5 +++++ .../flowframework/workflow/DeployModelStep.java | 9 ++++++++- .../flowframework/workflow/RegisterModelStep.java | 5 +++++ 3 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 6c7f22d1a..dff6ac22a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; @@ -129,6 +130,10 @@ public void onFailure(Exception e) { .build(); mlClient.createConnector(mlInput, actionListener); + } else { + createConnectorFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); } return createConnectorFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index d43e7a28b..ba22f3682 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -68,7 +69,13 @@ public void onFailure(Exception e) { break; } } - mlClient.deploy(modelId, actionListener); + + if (modelId != null) { + mlClient.deploy(modelId, actionListener); + } else { + deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST)); + } + return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index 3f60e7bb7..df14d6c54 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; @@ -137,6 +138,10 @@ public void onFailure(Exception e) { .build(); mlClient.register(mlInput, actionListener); + } else { + registerModelFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); } return registerModelFuture;