From d8baebfbb90de80dbdd9259396518271d1dc9fa2 Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Thu, 18 Apr 2024 17:09:03 -0700 Subject: [PATCH 1/2] Allow strings for boolean workflow step parameters Signed-off-by: Daniel Widdis --- CHANGELOG.md | 2 + .../AbstractRegisterLocalModelStep.java | 6 +- .../workflow/RegisterModelGroupStep.java | 8 +- .../workflow/RegisterRemoteModelStep.java | 7 +- .../flowframework/workflow/ToolStep.java | 9 +- .../RegisterLocalCustomModelStepTests.java | 118 ++++++++++++++++++ ...RegisterLocalPretrainedModelStepTests.java | 106 ++++++++++++++++ ...sterLocalSparseEncodingModelStepTests.java | 110 ++++++++++++++++ ....java => RegisterModelGroupStepTests.java} | 83 +++++++++++- .../RegisterRemoteModelStepTests.java | 76 +++++++++++ .../flowframework/workflow/ToolStepTests.java | 62 ++++++++- 11 files changed, 577 insertions(+), 10 deletions(-) rename src/test/java/org/opensearch/flowframework/workflow/{ModelGroupStepTests.java => RegisterModelGroupStepTests.java} (63%) diff --git a/CHANGELOG.md b/CHANGELOG.md index a66add559..d15a7d8b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Features ### Enhancements - Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658)) +- Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671)) + ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) - Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 442e41355..51bab0a8f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -12,7 +12,9 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -113,7 +115,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String allConfig = (String) inputs.get(ALL_CONFIG); - final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); + final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; // Build register model input MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder() @@ -217,6 +219,8 @@ public PlainActionFuture execute( logger.error(errorMessage, exception); registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception))); })); + } catch (IllegalArgumentException iae) { + registerLocalModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerLocalModelFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index b8a79325a..10824c2d5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -12,8 +12,10 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -139,7 +141,9 @@ public void onFailure(Exception e) { String description = (String) inputs.get(DESCRIPTION_FIELD); List backendRoles = getBackendRoles(inputs); AccessMode modelAccessMode = (AccessMode) inputs.get(MODEL_ACCESS_MODE); - Boolean isAddAllBackendRoles = (Boolean) inputs.get(ADD_ALL_BACKEND_ROLES); + Boolean isAddAllBackendRoles = inputs.containsKey(ADD_ALL_BACKEND_ROLES) + ? Booleans.parseBoolean(inputs.get(ADD_ALL_BACKEND_ROLES).toString()) + : null; MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder(); builder.name(modelGroupName); @@ -158,6 +162,8 @@ public void onFailure(Exception e) { MLRegisterModelGroupInput mlInput = builder.build(); mlClient.registerModelGroup(mlInput, actionListener); + } catch (IllegalArgumentException iae) { + registerModelGroupFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerModelGroupFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index cc3800284..cfdc21cd9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -90,7 +92,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD); - final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); + final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(FunctionName.REMOTE) @@ -106,7 +108,6 @@ public PlainActionFuture execute( if (deploy != null) { builder.deployModel(deploy); } - if (guardRails != null) { builder.guardrails(guardRails); } @@ -190,6 +191,8 @@ public void onFailure(Exception e) { } }); + } catch (IllegalArgumentException iae) { + registerRemoteModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerRemoteModelFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 6809e7832..87ecf762e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -11,7 +11,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; @@ -61,7 +64,9 @@ public PlainActionFuture execute( String type = (String) inputs.get(TYPE); String name = (String) inputs.get(NAME_FIELD); String description = (String) inputs.get(DESCRIPTION_FIELD); - Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + Boolean includeOutputInAgentResponse = inputs.containsKey(INCLUDE_OUTPUT_IN_AGENT_RESPONSE) + ? Booleans.parseBoolean(inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE).toString()) + : null; Map parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs); MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); @@ -92,6 +97,8 @@ public PlainActionFuture execute( logger.info("Tool registered successfully {}", type); + } catch (IllegalArgumentException iae) { + toolFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { toolFuture.onFailure(e); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index a6d73dd33..a5f9b72b5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -283,4 +287,118 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 4dff97a54..004928777 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -268,4 +272,106 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("description", "aiwoeifjoaijeofiwe"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalPretrainedModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("description", "aiwoeifjoaijeofiwe"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalPretrainedModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 607b0ebed..dc0b0146f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -271,4 +275,110 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java similarity index 63% rename from src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 0df8432c4..65c26959a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; @@ -41,9 +42,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -public class ModelGroupStepTests extends OpenSearchTestCase { +public class RegisterModelGroupStepTests extends OpenSearchTestCase { private WorkflowData inputData; private WorkflowData inputDataWithNoName; + private WorkflowData boolStringInputData; + private WorkflowData badBoolInputData; + + private String modelGroupId = MODEL_GROUP_ID; + private String status = MLTaskState.CREATED.name(); @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -67,12 +73,31 @@ public void setUp() throws Exception { "test-node-id" ); inputDataWithNoName = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + boolStringInputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("backend_roles", List.of("role-1")), + Map.entry("access_mode", AccessMode.PUBLIC), + Map.entry("add_all_backend_roles", "false") + ), + "test-id", + "test-node-id" + ); + badBoolInputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("backend_roles", List.of("role-1")), + Map.entry("access_mode", AccessMode.PUBLIC), + Map.entry("add_all_backend_roles", "no") + ), + "test-id", + "test-node-id" + ); } public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { - String modelGroupId = MODEL_GROUP_ID; - String status = MLTaskState.CREATED.name(); - RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") @@ -153,4 +178,54 @@ public void testRegisterModelGroupWithNoName() throws IOException { assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + PlainActionFuture future = modelGroupStep.execute( + boolStringInputData.getNodeId(), + boolStringInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); + assertEquals(status, future.get().getContent().get("model_group_status")); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); + + PlainActionFuture future = modelGroupStep.execute( + badBoolInputData.getNodeId(), + badBoolInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index e2627685e..1151262db 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -14,7 +14,9 @@ import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; @@ -22,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -193,4 +196,77 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + // updates both register and deploy resources + verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, "yes") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 27bb44c83..3bfd5d8d4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -9,6 +9,8 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.test.OpenSearchTestCase; @@ -18,7 +20,9 @@ import java.util.concurrent.ExecutionException; public class ToolStepTests extends OpenSearchTestCase { - private WorkflowData inputData = WorkflowData.EMPTY; + private WorkflowData inputData; + private WorkflowData boolStringInputData; + private WorkflowData badBoolInputData; @Override public void setUp() throws Exception { @@ -35,6 +39,28 @@ public void setUp() throws Exception { "test-id", "test-node-id" ); + boolStringInputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", "false") + ), + "test-id", + "test-node-id" + ); + badBoolInputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", "yes") + ), + "test-id", + "test-node-id" + ); } public void testTool() throws IOException, ExecutionException, InterruptedException { @@ -51,4 +77,38 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + boolStringInputData.getNodeId(), + boolStringInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + badBoolInputData.getNodeId(), + badBoolInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } From cd42c1408b8bca4ad57da39e2aea5584b5783d8d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Fri, 19 Apr 2024 10:26:43 -0700 Subject: [PATCH 2/2] Reduce redundancy in similar tests Signed-off-by: Daniel Widdis --- .../RegisterLocalCustomModelStepTests.java | 113 ++++++------------ ...RegisterLocalPretrainedModelStepTests.java | 101 +++++----------- ...sterLocalSparseEncodingModelStepTests.java | 105 +++++----------- .../workflow/RegisterModelGroupStepTests.java | 43 ++----- .../RegisterRemoteModelStepTests.java | 73 ++++------- .../flowframework/workflow/ToolStepTests.java | 9 +- 6 files changed, 134 insertions(+), 310 deletions(-) diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index a5f9b72b5..859b7bf0d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -188,6 +188,41 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(2)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testRegisterLocalCustomModelFailure() { @@ -288,84 +323,6 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - String taskId = "abcd"; - String modelId = "model-id"; - String status = MLTaskState.COMPLETED.name(); - - // Stub register for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - - // Stub getTask for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).getTask(any(), any()); - - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - WorkflowData boolStringWorkflowData = new WorkflowData( - Map.ofEntries( - Map.entry("name", "xyz"), - Map.entry("version", "1.0.0"), - Map.entry("description", "description"), - Map.entry("function_name", "SPARSE_TOKENIZE"), - Map.entry("model_format", "TORCH_SCRIPT"), - Map.entry(MODEL_GROUP_ID, "abcdefg"), - Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), - Map.entry("model_type", "bert"), - Map.entry("embedding_dimension", "384"), - Map.entry("framework_type", "sentence_transformers"), - Map.entry("url", "something.com"), - Map.entry(DEPLOY_FIELD, "false") - ), - "test-id", - "test-node-id" - ); - - PlainActionFuture future = registerLocalModelStep.execute( - boolStringWorkflowData.getNodeId(), - boolStringWorkflowData, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap() - ); - - future.actionGet(); - - verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); - - assertEquals(modelId, future.get().getContent().get(MODEL_ID)); - assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { WorkflowData boolStringWorkflowData = new WorkflowData( Map.ofEntries( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 004928777..431827a1c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -182,6 +182,35 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("description", "aiwoeifjoaijeofiwe"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + future = registerLocalPretrainedModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(2)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testRegisterLocalPretrainedModelFailure() { @@ -273,78 +302,6 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - String taskId = "abcd"; - String modelId = "model-id"; - String status = MLTaskState.COMPLETED.name(); - - // Stub register for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - - // Stub getTask for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).getTask(any(), any()); - - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - WorkflowData boolStringWorkflowData = new WorkflowData( - Map.ofEntries( - Map.entry("name", "xyz"), - Map.entry("version", "1.0.0"), - Map.entry("model_format", "TORCH_SCRIPT"), - Map.entry(MODEL_GROUP_ID, "abcdefg"), - Map.entry("description", "aiwoeifjoaijeofiwe"), - Map.entry(DEPLOY_FIELD, "false") - ), - "test-id", - "test-node-id" - ); - - PlainActionFuture future = registerLocalPretrainedModelStep.execute( - boolStringWorkflowData.getNodeId(), - boolStringWorkflowData, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap() - ); - - future.actionGet(); - - verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); - - assertEquals(modelId, future.get().getContent().get(MODEL_ID)); - assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { WorkflowData boolStringWorkflowData = new WorkflowData( Map.ofEntries( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index dc0b0146f..e98b7d5d5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -185,6 +185,37 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + future = registerLocalSparseEncodingModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(2)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testRegisterLocalSparseEncodingModelFailure() { @@ -276,80 +307,6 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - String taskId = "abcd"; - String modelId = "model-id"; - String status = MLTaskState.COMPLETED.name(); - - // Stub register for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - - // Stub getTask for success case - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLTask output = new MLTask( - taskId, - modelId, - null, - null, - MLTaskState.COMPLETED, - null, - null, - null, - null, - null, - null, - null, - null, - false - ); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).getTask(any(), any()); - - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - WorkflowData boolStringWorkflowData = new WorkflowData( - Map.ofEntries( - Map.entry("name", "xyz"), - Map.entry("version", "1.0.0"), - Map.entry("description", "description"), - Map.entry("function_name", "SPARSE_TOKENIZE"), - Map.entry("model_format", "TORCH_SCRIPT"), - Map.entry(MODEL_GROUP_ID, "abcdefg"), - Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), - Map.entry("url", "something.com"), - Map.entry(DEPLOY_FIELD, "false") - ), - "test-id", - "test-node-id" - ); - PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( - boolStringWorkflowData.getNodeId(), - boolStringWorkflowData, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap() - ); - - future.actionGet(); - - verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - verify(machineLearningNodeClient, times(1)).getTask(any(), any()); - - assertEquals(modelId, future.get().getContent().get(MODEL_ID)); - assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { WorkflowData boolStringWorkflowData = new WorkflowData( Map.ofEntries( diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 65c26959a..921c7ac15 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -130,6 +130,17 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); assertEquals(status, future.get().getContent().get("model_group_status")); + future = modelGroupStep.execute( + boolStringInputData.getNodeId(), + boolStringInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); + assertEquals(status, future.get().getContent().get("model_group_status")); } public void testRegisterModelGroupFailure() throws IOException { @@ -178,38 +189,6 @@ public void testRegisterModelGroupWithNoName() throws IOException { assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); - actionListener.onResponse(output); - return null; - }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - PlainActionFuture future = modelGroupStep.execute( - boolStringInputData.getNodeId(), - boolStringInputData, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap() - ); - - assertTrue(future.isDone()); - assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); - assertEquals(status, future.get().getContent().get("model_group_status")); - } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 1151262db..11eb6af05 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -107,7 +107,6 @@ public void testRegisterRemoteModelSuccess() throws Exception { assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } public void testRegisterAndDeployRemoteModelSuccess() throws Exception { @@ -155,6 +154,32 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { assertTrue(future.isDone()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + + deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(mlNodeClient, times(2)).register(any(MLRegisterModelInput.class), any()); + // updates both register and deploy resources + verify(flowFrameworkIndicesHandler, times(4)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } public void testRegisterRemoteModelFailure() { @@ -196,52 +221,6 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - String taskId = "abcd"; - String modelId = "efgh"; - String status = MLTaskState.CREATED.name(); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(1); - MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); - actionListener.onResponse(output); - return null; - }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - - doAnswer(invocation -> { - ActionListener updateResponseListener = invocation.getArgument(4); - updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); - return null; - }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - WorkflowData deployWorkflowData = new WorkflowData( - Map.ofEntries( - Map.entry("name", "xyz"), - Map.entry("description", "description"), - Map.entry(CONNECTOR_ID, "abcdefg"), - Map.entry(DEPLOY_FIELD, "true") - ), - "test-id", - "test-node-id" - ); - - PlainActionFuture future = this.registerRemoteModelStep.execute( - deployWorkflowData.getNodeId(), - deployWorkflowData, - Collections.emptyMap(), - Collections.emptyMap(), - Collections.emptyMap() - ); - - verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); - // updates both register and deploy resources - verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - - assertTrue(future.isDone()); - assertEquals(modelId, future.get().getContent().get(MODEL_ID)); - assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); - } - public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { WorkflowData deployWorkflowData = new WorkflowData( Map.ofEntries( diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 3bfd5d8d4..45cb3816c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -73,22 +73,17 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); - } - - public void testBoolParse() throws IOException, ExecutionException, InterruptedException { - ToolStep toolStep = new ToolStep(); - PlainActionFuture future = toolStep.execute( + toolStep = new ToolStep(); + future = toolStep.execute( boolStringInputData.getNodeId(), boolStringInputData, Collections.emptyMap(), Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); }