Skip to content

Commit

Permalink
Allow strings for boolean workflow step parameters
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Apr 19, 2024
1 parent 54f8b3f commit d8baebf
Show file tree
Hide file tree
Showing 11 changed files with 577 additions and 10 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/)
### Features
### Enhancements
- Add guardrails to default use case params ([#658](https://github.com/opensearch-project/flow-framework/pull/658))
- Allow strings for boolean workflow step parameters ([#671](https://github.com/opensearch-project/flow-framework/pull/671))

### Bug Fixes
- Reset workflow state to initial state after successful deprovision ([#635](https://github.com/opensearch-project/flow-framework/pull/635))
- Silently ignore content on APIs that don't require it ([#639](https://github.com/opensearch-project/flow-framework/pull/639))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
Expand Down Expand Up @@ -113,7 +115,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String modelGroupId = (String) inputs.get(MODEL_GROUP_ID);
String allConfig = (String) inputs.get(ALL_CONFIG);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

// Build register model input
MLRegisterModelInputBuilder mlInputBuilder = MLRegisterModelInput.builder()
Expand Down Expand Up @@ -217,6 +219,8 @@ public PlainActionFuture<WorkflowData> execute(
logger.error(errorMessage, exception);
registerLocalModelFuture.onFailure(new WorkflowStepException(errorMessage, ExceptionsHelper.status(exception)));
}));
} catch (IllegalArgumentException iae) {
registerLocalModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
registerLocalModelFuture.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand Down Expand Up @@ -139,7 +141,9 @@ public void onFailure(Exception e) {
String description = (String) inputs.get(DESCRIPTION_FIELD);
List<String> backendRoles = getBackendRoles(inputs);
AccessMode modelAccessMode = (AccessMode) inputs.get(MODEL_ACCESS_MODE);
Boolean isAddAllBackendRoles = (Boolean) inputs.get(ADD_ALL_BACKEND_ROLES);
Boolean isAddAllBackendRoles = inputs.containsKey(ADD_ALL_BACKEND_ROLES)
? Booleans.parseBoolean(inputs.get(ADD_ALL_BACKEND_ROLES).toString())
: null;

MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder();
builder.name(modelGroupName);
Expand All @@ -158,6 +162,8 @@ public void onFailure(Exception e) {
MLRegisterModelGroupInput mlInput = builder.build();

mlClient.registerModelGroup(mlInput, actionListener);
} catch (IllegalArgumentException iae) {
registerModelGroupFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
registerModelGroupFuture.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.common.Booleans;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
Expand Down Expand Up @@ -90,7 +92,7 @@ public PlainActionFuture<WorkflowData> execute(
String description = (String) inputs.get(DESCRIPTION_FIELD);
String connectorId = (String) inputs.get(CONNECTOR_ID);
Guardrails guardRails = (Guardrails) inputs.get(GUARDRAILS_FIELD);
final Boolean deploy = (Boolean) inputs.get(DEPLOY_FIELD);
final Boolean deploy = inputs.containsKey(DEPLOY_FIELD) ? Booleans.parseBoolean(inputs.get(DEPLOY_FIELD).toString()) : null;

MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder()
.functionName(FunctionName.REMOTE)
Expand All @@ -106,7 +108,6 @@ public PlainActionFuture<WorkflowData> execute(
if (deploy != null) {
builder.deployModel(deploy);
}

if (guardRails != null) {
builder.guardrails(guardRails);
}
Expand Down Expand Up @@ -190,6 +191,8 @@ public void onFailure(Exception e) {
}
});

} catch (IllegalArgumentException iae) {
registerRemoteModelFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
registerRemoteModelFuture.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Booleans;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.common.agent.MLToolSpec;

Expand Down Expand Up @@ -61,7 +64,9 @@ public PlainActionFuture<WorkflowData> execute(
String type = (String) inputs.get(TYPE);
String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
Boolean includeOutputInAgentResponse = (Boolean) inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE);
Boolean includeOutputInAgentResponse = inputs.containsKey(INCLUDE_OUTPUT_IN_AGENT_RESPONSE)
? Booleans.parseBoolean(inputs.get(INCLUDE_OUTPUT_IN_AGENT_RESPONSE).toString())
: null;
Map<String, String> parameters = getToolsParametersMap(inputs.get(PARAMETERS_FIELD), previousNodeInputs, outputs);

MLToolSpec.MLToolSpecBuilder builder = MLToolSpec.builder();
Expand Down Expand Up @@ -92,6 +97,8 @@ public PlainActionFuture<WorkflowData> execute(

logger.info("Tool registered successfully {}", type);

} catch (IllegalArgumentException iae) {
toolFuture.onFailure(new WorkflowStepException(iae.getMessage(), RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
toolFuture.onFailure(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.common.FlowFrameworkSettings;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.MLTask;
Expand All @@ -31,6 +33,7 @@
import org.opensearch.threadpool.ThreadPool;
import org.junit.AfterClass;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.concurrent.ExecutionException;
Expand All @@ -40,6 +43,7 @@
import org.mockito.MockitoAnnotations;

import static org.opensearch.action.DocWriteResponse.Result.UPDATED;
import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD;
import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS;
Expand Down Expand Up @@ -283,4 +287,118 @@ public void testMissingInputs() {
}
assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]"));
}

public void testBoolParse() throws IOException, ExecutionException, InterruptedException {
String taskId = "abcd";
String modelId = "model-id";
String status = MLTaskState.COMPLETED.name();

// Stub register for success case
doAnswer(invocation -> {
ActionListener<MLRegisterModelResponse> actionListener = invocation.getArgument(1);
MLRegisterModelResponse output = new MLRegisterModelResponse(taskId, status, null);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any());

// Stub getTask for success case
doAnswer(invocation -> {
ActionListener<MLTask> actionListener = invocation.getArgument(1);
MLTask output = new MLTask(
taskId,
modelId,
null,
null,
MLTaskState.COMPLETED,
null,
null,
null,
null,
null,
null,
null,
null,
false
);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).getTask(any(), any());

doAnswer(invocation -> {
ActionListener<UpdateResponse> updateResponseListener = invocation.getArgument(4);
updateResponseListener.onResponse(new UpdateResponse(new ShardId(WORKFLOW_STATE_INDEX, "", 1), "id", -2, 0, 0, UPDATED));
return null;
}).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any());

WorkflowData boolStringWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("version", "1.0.0"),
Map.entry("description", "description"),
Map.entry("function_name", "SPARSE_TOKENIZE"),
Map.entry("model_format", "TORCH_SCRIPT"),
Map.entry(MODEL_GROUP_ID, "abcdefg"),
Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"),
Map.entry("model_type", "bert"),
Map.entry("embedding_dimension", "384"),
Map.entry("framework_type", "sentence_transformers"),
Map.entry("url", "something.com"),
Map.entry(DEPLOY_FIELD, "false")
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = registerLocalModelStep.execute(
boolStringWorkflowData.getNodeId(),
boolStringWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

future.actionGet();

verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any());
verify(machineLearningNodeClient, times(1)).getTask(any(), any());

assertEquals(modelId, future.get().getContent().get(MODEL_ID));
assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS));
}

public void testBoolParseFail() throws IOException, ExecutionException, InterruptedException {
WorkflowData boolStringWorkflowData = new WorkflowData(
Map.ofEntries(
Map.entry("name", "xyz"),
Map.entry("version", "1.0.0"),
Map.entry("description", "description"),
Map.entry("function_name", "SPARSE_TOKENIZE"),
Map.entry("model_format", "TORCH_SCRIPT"),
Map.entry(MODEL_GROUP_ID, "abcdefg"),
Map.entry("model_content_hash_value", "aiwoeifjoaijeofiwe"),
Map.entry("model_type", "bert"),
Map.entry("embedding_dimension", "384"),
Map.entry("framework_type", "sentence_transformers"),
Map.entry("url", "something.com"),
Map.entry(DEPLOY_FIELD, "no")
),
"test-id",
"test-node-id"
);

PlainActionFuture<WorkflowData> future = registerLocalModelStep.execute(
boolStringWorkflowData.getNodeId(),
boolStringWorkflowData,
Collections.emptyMap(),
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isDone());
ExecutionException e = assertThrows(ExecutionException.class, () -> future.get());
assertEquals(WorkflowStepException.class, e.getCause().getClass());
WorkflowStepException w = (WorkflowStepException) e.getCause();
assertEquals("Failed to parse value [no] as only [true] or [false] are allowed.", w.getMessage());
assertEquals(RestStatus.BAD_REQUEST, w.getRestStatus());
}
}
Loading

0 comments on commit d8baebf

Please sign in to comment.