diff --git a/src/main/java/demo/Demo.java b/src/main/java/demo/Demo.java index 910f22b14..e4d2aa8f8 100644 --- a/src/main/java/demo/Demo.java +++ b/src/main/java/demo/Demo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.workflow.ProcessNode; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -59,7 +60,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/demo/TemplateParseDemo.java b/src/main/java/demo/TemplateParseDemo.java index a2d0db443..e284764da 100644 --- a/src/main/java/demo/TemplateParseDemo.java +++ b/src/main/java/demo/TemplateParseDemo.java @@ -20,6 +20,7 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -55,7 +56,8 @@ public static void main(String[] args) throws IOException { } ClusterService clusterService = new ClusterService(null, null, null); Client client = new NodeClient(null, null); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); ThreadPool threadPool = new ThreadPool(Settings.EMPTY); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(factory, threadPool); diff --git a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java index 0bac15c61..b9a35c083 100644 --- a/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java +++ b/src/main/java/org/opensearch/flowframework/FlowFrameworkPlugin.java @@ -34,6 +34,7 @@ import org.opensearch.flowframework.workflow.CreateIndexStep; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.repositories.RepositoriesService; @@ -76,7 +77,8 @@ public Collection createComponents( IndexNameExpressionResolver indexNameExpressionResolver, Supplier repositoriesServiceSupplier ) { - WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client); + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + WorkflowStepFactory workflowStepFactory = new WorkflowStepFactory(clusterService, client, mlClient); WorkflowProcessSorter workflowProcessSorter = new WorkflowProcessSorter(workflowStepFactory, threadPool); // TODO : Refactor, move system index creation/associated methods outside of the CreateIndexStep diff --git a/src/main/java/org/opensearch/flowframework/client/MLClient.java b/src/main/java/org/opensearch/flowframework/client/MLClient.java deleted file mode 100644 index 977e24588..000000000 --- a/src/main/java/org/opensearch/flowframework/client/MLClient.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.flowframework.client; - -import org.opensearch.client.Client; -import org.opensearch.ml.client.MachineLearningNodeClient; - -/** - Class to initiate an instance of MLClient - */ -public class MLClient { - private static MachineLearningNodeClient INSTANCE; - - private MLClient() {} - - /** - * Creates machine learning client. - * - * @param client client of OpenSearch. - * @return machine learning client from ml-commons. - */ - public static MachineLearningNodeClient createMLClient(Client client) { - if (INSTANCE == null) { - INSTANCE = new MachineLearningNodeClient(client); - } - return INSTANCE; - } -} diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 528590d7a..ce0df435a 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -48,8 +48,8 @@ private CommonValue() {} public static final String MODEL_ID = "model_id"; /** Function Name field */ public static final String FUNCTION_NAME = "function_name"; - /** Model Name field */ - public static final String MODEL_NAME = "name"; + /** Name field */ + public static final String NAME_FIELD = "name"; /** Model Version field */ public static final String MODEL_VERSION = "model_version"; /** Model Group Id field */ @@ -62,4 +62,14 @@ private CommonValue() {} public static final String MODEL_FORMAT = "model_format"; /** Model config field */ public static final String MODEL_CONFIG = "model_config"; + /** Version field */ + public static final String VERSION_FIELD = "version"; + /** Connector protocol field */ + public static final String PROTOCOL_FIELD = "protocol"; + /** Connector parameters field */ + public static final String PARAMETERS_FIELD = "parameters"; + /** Connector credentials field */ + public static final String CREDENTIALS_FIELD = "credentials"; + /** Connector actions field */ + public static final String ACTIONS_FIELD = "actions"; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java new file mode 100644 index 000000000..dff6ac22a --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -0,0 +1,164 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.connector.ConnectorAction; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; + +import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIALS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; + +/** + * Step to create a connector for a remote model + */ +public class CreateConnectorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(CreateConnectorStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "create_connector"; + + /** + * Instantiate this class + * @param mlClient client to instantiate MLClient + */ + public CreateConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute(List data) throws IOException { + CompletableFuture createConnectorFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { + logger.info("Created connector successfully"); + // TODO Add the response to Global Context + createConnectorFuture.complete( + new WorkflowData(Map.ofEntries(Map.entry("connector_id", mlCreateConnectorResponse.getConnectorId()))) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to create connector"); + createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String name = null; + String description = null; + String version = null; + String protocol = null; + Map parameters = new HashMap<>(); + Map credentials = new HashMap<>(); + List actions = null; + + for (WorkflowData workflowData : data) { + Map content = workflowData.getContent(); + + for (Entry entry : content.entrySet()) { + switch (entry.getKey()) { + case NAME_FIELD: + name = (String) content.get(NAME_FIELD); + break; + case DESCRIPTION: + description = (String) content.get(DESCRIPTION); + break; + case VERSION_FIELD: + version = (String) content.get(VERSION_FIELD); + break; + case PROTOCOL_FIELD: + protocol = (String) content.get(PROTOCOL_FIELD); + break; + case PARAMETERS_FIELD: + parameters = getParameterMap((Map) content.get(PARAMETERS_FIELD)); + break; + case CREDENTIALS_FIELD: + credentials = (Map) content.get(CREDENTIALS_FIELD); + break; + case ACTIONS_FIELD: + actions = (List) content.get(ACTIONS_FIELD); + break; + } + + } + } + + if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { + MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder() + .name(name) + .description(description) + .version(version) + .protocol(protocol) + .parameters(parameters) + .credential(credentials) + .actions(actions) + .build(); + + mlClient.createConnector(mlInput, actionListener); + } else { + createConnectorFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); + } + + return createConnectorFuture; + } + + @Override + public String getName() { + return NAME; + } + + private static Map getParameterMap(Map params) { + + Map parameters = new HashMap<>(); + for (String key : params.keySet()) { + String value = params.get(key); + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + parameters.put(key, value); + return null; + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } + } + return parameters; + } + +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index e4c9b1a14..ba22f3682 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -10,9 +10,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; +import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; @@ -28,15 +29,15 @@ public class DeployModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(DeployModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "deploy_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public DeployModelStep(Client client) { - this.client = client; + public DeployModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -44,8 +45,6 @@ public CompletableFuture execute(List data) { CompletableFuture deployModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @@ -57,8 +56,8 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { @Override public void onFailure(Exception e) { - logger.error("Model deployment failed"); - deployModelFuture.completeExceptionally(e); + logger.error("Failed to deploy model"); + deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -70,7 +69,13 @@ public void onFailure(Exception e) { break; } } - machineLearningNodeClient.deploy(modelId, actionListener); + + if (modelId != null) { + mlClient.deploy(modelId, actionListener); + } else { + deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST)); + } + return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java index b97c56d57..df14d6c54 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelStep.java @@ -10,9 +10,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.opensearch.client.Client; +import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.flowframework.client.MLClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.model.MLModelConfig; @@ -20,7 +21,6 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -34,8 +34,8 @@ import static org.opensearch.flowframework.common.CommonValue.MODEL_CONFIG; import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_ID; -import static org.opensearch.flowframework.common.CommonValue.MODEL_NAME; import static org.opensearch.flowframework.common.CommonValue.MODEL_VERSION; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; /** * Step to register a remote model @@ -44,16 +44,16 @@ public class RegisterModelStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(RegisterModelStep.class); - private Client client; + private MachineLearningNodeClient mlClient; static final String NAME = "register_model"; /** * Instantiate this class - * @param client client to instantiate MLClient + * @param mlClient client to instantiate MLClient */ - public RegisterModelStep(Client client) { - this.client = client; + public RegisterModelStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; } @Override @@ -61,8 +61,6 @@ public CompletableFuture execute(List data) { CompletableFuture registerModelFuture = new CompletableFuture<>(); - MachineLearningNodeClient machineLearningNodeClient = MLClient.createMLClient(client); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -80,7 +78,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to register model"); - registerModelFuture.completeExceptionally(new IOException("Failed to register model ")); + registerModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -101,8 +99,8 @@ public void onFailure(Exception e) { case FUNCTION_NAME: functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); break; - case MODEL_NAME: - modelName = (String) content.get(MODEL_NAME); + case NAME_FIELD: + modelName = (String) content.get(NAME_FIELD); break; case MODEL_VERSION: modelVersion = (String) content.get(MODEL_VERSION); @@ -139,7 +137,11 @@ public void onFailure(Exception e) { .connectorId(connectorId) .build(); - machineLearningNodeClient.register(mlInput, actionListener); + mlClient.register(mlInput, actionListener); + } else { + registerModelFuture.completeExceptionally( + new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) + ); } return registerModelFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index fdb82ef0b..5aabd679f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -10,6 +10,7 @@ import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.ml.client.MachineLearningNodeClient; import java.util.HashMap; import java.util.List; @@ -30,17 +31,19 @@ public class WorkflowStepFactory { * * @param clusterService The OpenSearch cluster service * @param client The OpenSearch client steps can use + * @param mlClient Machine Learning client to perform ml operations */ - public WorkflowStepFactory(ClusterService clusterService, Client client) { - populateMap(clusterService, client); + public WorkflowStepFactory(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { + populateMap(clusterService, client, mlClient); } - private void populateMap(ClusterService clusterService, Client client) { + private void populateMap(ClusterService clusterService, Client client, MachineLearningNodeClient mlClient) { stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); stepMap.put(CreateIngestPipelineStep.NAME, new CreateIngestPipelineStep(client)); - stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(client)); - stepMap.put(DeployModelStep.NAME, new DeployModelStep(client)); + stepMap.put(RegisterModelStep.NAME, new RegisterModelStep(mlClient)); + stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); + stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient)); // TODO: These are from the demo class as placeholders, remove when demos are deleted stepMap.put("demo_delay_3", new DemoWorkflowStep(3000)); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java new file mode 100644 index 000000000..bf170ab9f --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class CreateConnectorStepTests extends OpenSearchTestCase { + private WorkflowData inputData = WorkflowData.EMPTY; + + @Mock + ActionListener registerModelActionListener; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + Map params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7")); + Map credentials = Map.ofEntries(Map.entry("key1", "value1"), Map.entry("key2", "value2")); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("version", "1"), + Map.entry("protocol", "test"), + Map.entry("params", params), + Map.entry("credentials", credentials), + Map.entry("actions", List.of("actions")) + ) + ); + + } + + public void testCreateConnector() throws IOException, ExecutionException, InterruptedException { + + String connectorId = "connect"; + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector_id")); + + } + + public void testCreateConnectorFailure() throws IOException { + CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = createConnectorStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to create connector", ex.getCause().getMessage()); + } + +} diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index fc7c695f8..4cdfaebae 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,35 +10,33 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.MLTaskType; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class DeployModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -50,23 +48,20 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - nodeClient = new NoOpNodeClient("xyz"); - } - public void testDeployModel() { - + public void testDeployModel() throws ExecutionException, InterruptedException { String taskId = "taskId"; String status = MLTaskState.CREATED.name(); MLTaskType mlTaskType = MLTaskType.DEPLOY_MODEL; - DeployModelStep deployModel = new DeployModelStep(nodeClient); + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLDeployModelResponse output = new MLDeployModelResponse(taskId, mlTaskType, status); actionListener.onResponse(output); return null; @@ -74,10 +69,31 @@ public void testDeployModel() { CompletableFuture future = deployModel.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); + assertEquals(status, future.get().getContent().get("deploy_model_status")); + } + + public void testDeployModelFailure() { + DeployModelStep deployModel = new DeployModelStep(machineLearningNodeClient); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to deploy model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + CompletableFuture future = deployModel.execute(List.of(inputData)); + + verify(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to deploy model", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java index b1a2b2fc0..59fb1b173 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelStepTests.java @@ -10,8 +10,9 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; -import org.opensearch.client.node.NodeClient; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLTaskState; @@ -20,28 +21,24 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.test.client.NoOpNodeClient; import java.util.List; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.Answers; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; @ThreadLeakScope(ThreadLeakScope.Scope.NONE) public class RegisterModelStepTests extends OpenSearchTestCase { private WorkflowData inputData = WorkflowData.EMPTY; - @Mock(answer = Answers.RETURNS_DEEP_STUBS) - private NodeClient nodeClient; - @Mock ActionListener registerModelActionListener; @@ -70,7 +67,6 @@ public void setUp() throws Exception { ) ); - nodeClient = new NoOpNodeClient("xyz"); } public void testRegisterModel() throws ExecutionException, InterruptedException { @@ -85,12 +81,12 @@ public void testRegisterModel() throws ExecutionException, InterruptedException .connectorId("abcdefgh") .build(); - RegisterModelStep registerModelStep = new RegisterModelStep(nodeClient); + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); + ActionListener actionListener = invocation.getArgument(1); MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); actionListener.onResponse(output); return null; @@ -98,11 +94,33 @@ public void testRegisterModel() throws ExecutionException, InterruptedException CompletableFuture future = registerModelStep.execute(List.of(inputData)); - // TODO: Find a way to verify the below - // verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get("model_id")); + assertEquals(status, future.get().getContent().get("register_model_status")); + + } + public void testRegisterModelFailure() { + RegisterModelStep registerModelStep = new RegisterModelStep(machineLearningNodeClient); + + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to register model", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + CompletableFuture future = registerModelStep.execute(List.of(inputData)); + + verify(machineLearningNodeClient).register(any(MLRegisterModelInput.class), actionListenerCaptor.capture()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to register model", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index e8ada0e15..f728dd7b1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.model.TemplateTestJsonUtil; import org.opensearch.flowframework.model.Workflow; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -60,11 +61,12 @@ public static void setup() { AdminClient adminClient = mock(AdminClient.class); ClusterService clusterService = mock(ClusterService.class); Client client = mock(Client.class); + MachineLearningNodeClient mlClient = mock(MachineLearningNodeClient.class); when(client.admin()).thenReturn(adminClient); testThreadPool = new TestThreadPool(WorkflowProcessSorterTests.class.getName()); - WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client); + WorkflowStepFactory factory = new WorkflowStepFactory(clusterService, client, mlClient); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool); }