diff --git a/CHANGELOG.md b/CHANGELOG.md index a66add559..d15a7d8b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) ### Features ### Enhancements - Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658)) +- Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671)) + ### Bug Fixes - Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635)) - Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639)) diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 442e41355..51bab0a8f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -12,7 +12,9 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; @@ -113,7 +115,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); String allConfig = (String) inputs.get(ALL_CONFIG); - final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); + final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; // Build register model input MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder() @@ -217,6 +219,8 @@ public PlainActionFuture execute( logger.error(errorMessage, exception); registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception))); })); + } catch (IllegalArgumentException iae) { + registerLocalModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerLocalModelFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index b8a79325a..10824c2d5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -12,8 +12,10 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -139,7 +141,9 @@ public void onFailure(Exception e) { String description = (String) inputs.get(DESCRIPTION_FIELD); List backendRoles = getBackendRoles(inputs); AccessMode modelAccessMode = (AccessMode) inputs.get(MODEL_ACCESS_MODE); - Boolean isAddAllBackendRoles = (Boolean) inputs.get(ADD_ALL_BACKEND_ROLES); + Boolean isAddAllBackendRoles = inputs.containsKey(ADD_ALL_BACKEND_ROLES) + ? Booleans.parseBoolean(inputs.get(ADD_ALL_BACKEND_ROLES).toString()) + : null; MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder(); builder.name(modelGroupName); @@ -158,6 +162,8 @@ public void onFailure(Exception e) { MLRegisterModelGroupInput mlInput = builder.build(); mlClient.registerModelGroup(mlInput, actionListener); + } catch (IllegalArgumentException iae) { + registerModelGroupFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerModelGroupFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index cc3800284..cfdc21cd9 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -13,7 +13,9 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; +import org.opensearch.common.Booleans; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; @@ -90,7 +92,7 @@ public PlainActionFuture execute( String description = (String) inputs.get(DESCRIPTION_FIELD); String connectorId = (String) inputs.get(CONNECTOR_ID); Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD); - final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD); + final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null; MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(FunctionName.REMOTE) @@ -106,7 +108,6 @@ public PlainActionFuture execute( if (deploy != null) { builder.deployModel(deploy); } - if (guardRails != null) { builder.guardrails(guardRails); } @@ -190,6 +191,8 @@ public void onFailure(Exception e) { } }); + } catch (IllegalArgumentException iae) { + registerRemoteModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { registerRemoteModelFuture.onFailure(e); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 6809e7832..87ecf762e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -11,7 +11,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.Booleans; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; @@ -61,7 +64,9 @@ public PlainActionFuture execute( String type = (String) inputs.get(TYPE); String name = (String) inputs.get(NAME_FIELD); String description = (String) inputs.get(DESCRIPTION_FIELD); - Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE); + Boolean includeOutputInAgentResponse = inputs.containsKey(INCLUDE_OUTPUT_IN_AGENT_RESPONSE) + ? Booleans.parseBoolean(inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE).toString()) + : null; Map parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs); MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder(); @@ -92,6 +97,8 @@ public PlainActionFuture execute( logger.info("Tool registered successfully {}", type); + } catch (IllegalArgumentException iae) { + toolFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST)); } catch (FlowFrameworkException e) { toolFuture.onFailure(e); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index a6d73dd33..a5f9b72b5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -283,4 +287,118 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("model_type", "bert"), + Map.entry("embedding_dimension", "384"), + Map.entry("framework_type", "sentence_transformers"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 4dff97a54..004928777 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -268,4 +272,106 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("description", "aiwoeifjoaijeofiwe"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalPretrainedModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("description", "aiwoeifjoaijeofiwe"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = registerLocalPretrainedModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 607b0ebed..dc0b0146f 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -17,8 +17,10 @@ import org.opensearch.common.util.concurrent.OpenSearchExecutors; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTask; @@ -31,6 +33,7 @@ import org.opensearch.threadpool.ThreadPool; import org.junit.AfterClass; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -40,6 +43,7 @@ import org.mockito.MockitoAnnotations; import static org.opensearch.action.DocWriteResponse.Result.UPDATED; +import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; @@ -271,4 +275,110 @@ public void testMissingInputs() { } assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "model-id"; + String status = MLTaskState.COMPLETED.name(); + + // Stub register for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); + + // Stub getTask for success case + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLTask output = new MLTask( + taskId, + modelId, + null, + null, + MLTaskState.COMPLETED, + null, + null, + null, + null, + null, + null, + null, + null, + false + ); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).getTask(any(), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "false") + ), + "test-id", + "test-node-id" + ); + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + future.actionGet(); + + verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + verify(machineLearningNodeClient, times(1)).getTask(any(), any()); + + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData boolStringWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("version", "1.0.0"), + Map.entry("description", "description"), + Map.entry("function_name", "SPARSE_TOKENIZE"), + Map.entry("model_format", "TORCH_SCRIPT"), + Map.entry(MODEL_GROUP_ID, "abcdefg"), + Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"), + Map.entry("url", "something.com"), + Map.entry(DEPLOY_FIELD, "no") + ), + "test-id", + "test-node-id" + ); + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( + boolStringWorkflowData.getNodeId(), + boolStringWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java similarity index 63% rename from src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java rename to src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java index 0df8432c4..65c26959a 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterModelGroupStepTests.java @@ -14,6 +14,7 @@ import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; @@ -41,9 +42,14 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; -public class ModelGroupStepTests extends OpenSearchTestCase { +public class RegisterModelGroupStepTests extends OpenSearchTestCase { private WorkflowData inputData; private WorkflowData inputDataWithNoName; + private WorkflowData boolStringInputData; + private WorkflowData badBoolInputData; + + private String modelGroupId = MODEL_GROUP_ID; + private String status = MLTaskState.CREATED.name(); @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -67,12 +73,31 @@ public void setUp() throws Exception { "test-node-id" ); inputDataWithNoName = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + boolStringInputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("backend_roles", List.of("role-1")), + Map.entry("access_mode", AccessMode.PUBLIC), + Map.entry("add_all_backend_roles", "false") + ), + "test-id", + "test-node-id" + ); + badBoolInputData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "test"), + Map.entry("description", "description"), + Map.entry("backend_roles", List.of("role-1")), + Map.entry("access_mode", AccessMode.PUBLIC), + Map.entry("add_all_backend_roles", "no") + ), + "test-id", + "test-node-id" + ); } public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { - String modelGroupId = MODEL_GROUP_ID; - String status = MLTaskState.CREATED.name(); - RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); @SuppressWarnings("unchecked") @@ -153,4 +178,54 @@ public void testRegisterModelGroupWithNoName() throws IOException { assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); + + @SuppressWarnings("unchecked") + ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelGroupResponse output = new MLRegisterModelGroupResponse(modelGroupId, status); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + PlainActionFuture future = modelGroupStep.execute( + boolStringInputData.getNodeId(), + boolStringInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(modelGroupId, future.get().getContent().get(MODEL_GROUP_ID)); + assertEquals(status, future.get().getContent().get("model_group_status")); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); + + PlainActionFuture future = modelGroupStep.execute( + badBoolInputData.getNodeId(), + badBoolInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index e2627685e..1151262db 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -14,7 +14,9 @@ import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; @@ -22,6 +24,7 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; import org.opensearch.test.OpenSearchTestCase; +import java.io.IOException; import java.util.Collections; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -193,4 +196,77 @@ public void testMissingInputs() { assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + String taskId = "abcd"; + String modelId = "efgh"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, modelId); + actionListener.onResponse(output); + return null; + }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); + + doAnswer(invocation -> { + ActionListener updateResponseListener = invocation.getArgument(4); + updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED)); + return null; + }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, "true") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + verify(mlNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); + // updates both register and deploy resources + verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); + + assertTrue(future.isDone()); + assertEquals(modelId, future.get().getContent().get(MODEL_ID)); + assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + WorkflowData deployWorkflowData = new WorkflowData( + Map.ofEntries( + Map.entry("name", "xyz"), + Map.entry("description", "description"), + Map.entry(CONNECTOR_ID, "abcdefg"), + Map.entry(DEPLOY_FIELD, "yes") + ), + "test-id", + "test-node-id" + ); + + PlainActionFuture future = this.registerRemoteModelStep.execute( + deployWorkflowData.getNodeId(), + deployWorkflowData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index 27bb44c83..3bfd5d8d4 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -9,6 +9,8 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.test.OpenSearchTestCase; @@ -18,7 +20,9 @@ import java.util.concurrent.ExecutionException; public class ToolStepTests extends OpenSearchTestCase { - private WorkflowData inputData = WorkflowData.EMPTY; + private WorkflowData inputData; + private WorkflowData boolStringInputData; + private WorkflowData badBoolInputData; @Override public void setUp() throws Exception { @@ -35,6 +39,28 @@ public void setUp() throws Exception { "test-id", "test-node-id" ); + boolStringInputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", "false") + ), + "test-id", + "test-node-id" + ); + badBoolInputData = new WorkflowData( + Map.ofEntries( + Map.entry("type", "type"), + Map.entry("name", "name"), + Map.entry("description", "description"), + Map.entry("parameters", Collections.emptyMap()), + Map.entry("include_output_in_agent_response", "yes") + ), + "test-id", + "test-node-id" + ); } public void testTool() throws IOException, ExecutionException, InterruptedException { @@ -51,4 +77,38 @@ public void testTool() throws IOException, ExecutionException, InterruptedExcept assertTrue(future.isDone()); assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); } + + public void testBoolParse() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + boolStringInputData.getNodeId(), + boolStringInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + assertEquals(MLToolSpec.class, future.get().getContent().get("tools").getClass()); + } + + public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException { + ToolStep toolStep = new ToolStep(); + + PlainActionFuture future = toolStep.execute( + badBoolInputData.getNodeId(), + badBoolInputData, + Collections.emptyMap(), + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isDone()); + ExecutionException e = assertThrows(ExecutionException.class, () -> future.get()); + assertEquals(WorkflowStepException.class, e.getCause().getClass()); + WorkflowStepException w = (WorkflowStepException) e.getCause(); + assertEquals("Failed to parse value [yes] as only [true] or [false] are allowed.", w.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus()); + } }