diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index e22017eaf..9e3b8d067 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -18,16 +18,24 @@ import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.ml.common.agent.LLMSpec; import java.io.IOException; import java.time.Instant; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -39,6 +47,9 @@ public class ParseUtils { private static final Logger logger = LogManager.getLogger(ParseUtils.class); + // Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar + private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}"); + private ParseUtils() {} /** @@ -161,4 +172,95 @@ public static Map getStringToStringMap(Object map, String fieldN throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map."); } + /** + * Creates a map containing the specified input keys, with values derived from template data or previous node + * output. + * + * @param requiredInputKeys A set of keys that must be present, or will cause an exception to be thrown + * @param optionalInputKeys A set of keys that may be present, or will be absent in the returned map + * @param currentNodeInputs Input params and content for this node, from workflow parsing + * @param outputs WorkflowData content of previous steps + * @param previousNodeInputs Input params for this node that come from previous steps + * @return A map containing the requiredInputKeys with their corresponding values, + * and optionalInputKeys with their corresponding values if present. + * Throws a {@link FlowFrameworkException} if a required key is not present. + */ + public static Map getInputsFromPreviousSteps( + Set requiredInputKeys, + Set optionalInputKeys, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) { + // Mutable set to ensure all required keys are used + Set requiredKeys = new HashSet<>(requiredInputKeys); + // Merge input sets to add all requested keys + Set keys = new HashSet<>(requiredInputKeys); + keys.addAll(optionalInputKeys); + // Initialize return map + Map inputs = new HashMap<>(); + for (String key : keys) { + Object value = null; + // Priority 1: specifically named prior step inputs + // ... parse the previousNodeInputs map and fill in the specified keys + Optional previousNodeForKey = previousNodeInputs.entrySet() + .stream() + .filter(e -> key.equals(e.getValue())) + .map(Map.Entry::getKey) + .findAny(); + if (previousNodeForKey.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNodeForKey.get()); + if (previousNodeOutput != null) { + value = previousNodeOutput.getContent().get(key); + } + } + // Priority 2: inputs specified in template + // ... fetch from currentNodeInputs (params take precedence) + if (value == null) { + value = currentNodeInputs.getParams().get(key); + } + if (value == null) { + value = currentNodeInputs.getContent().get(key); + } + // Priority 3: other inputs + if (value == null) { + Optional matchedValue = outputs.values() + .stream() + .map(WorkflowData::getContent) + .filter(m -> m.containsKey(key)) + .map(m -> m.get(key)) + .findAny(); + if (matchedValue.isPresent()) { + value = matchedValue.get(); + } + } + // Check for substitution + if (value != null) { + Matcher m = SUBSTITUTION_PATTERN.matcher(value.toString()); + if (m.matches()) { + WorkflowData data = outputs.get(m.group(1)); + if (data != null && data.getContent().containsKey(m.group(2))) { + value = data.getContent().get(m.group(2)); + } + } + inputs.put(key, value); + requiredKeys.remove(key); + } + } + // After iterating is complete, throw exception if requiredKeys is not empty + if (!requiredKeys.isEmpty()) { + throw new FlowFrameworkException( + "Missing required inputs " + + requiredKeys + + " in workflow [" + + currentNodeInputs.getWorkflowId() + + "] node [" + + currentNodeInputs.getNodeId() + + "]", + RestStatus.BAD_REQUEST + ); + } + // Finally return the map + return inputs; + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index dc4c83d4e..bca1c8856 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -16,6 +16,7 @@ import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorAction.ActionType; @@ -33,8 +34,8 @@ import java.util.Locale; import java.util.Map; import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; @@ -120,57 +121,44 @@ public void onFailure(Exception e) { } }; - String name = null; - String description = null; - String version = null; - String protocol = null; - Map parameters = Collections.emptyMap(); - Map credentials = Collections.emptyMap(); - List actions = Collections.emptyList(); - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); + Set requiredKeys = Set.of( + NAME_FIELD, + DESCRIPTION_FIELD, + VERSION_FIELD, + PROTOCOL_FIELD, + PARAMETERS_FIELD, + CREDENTIAL_FIELD, + ACTIONS_FIELD + ); + Set optionalKeys = Collections.emptySet(); try { - for (WorkflowData workflowData : data) { - for (Entry entry : workflowData.getContent().entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - name = (String) entry.getValue(); - break; - case DESCRIPTION_FIELD: - description = (String) entry.getValue(); - break; - case VERSION_FIELD: - version = (String) entry.getValue(); - break; - case PROTOCOL_FIELD: - protocol = (String) entry.getValue(); - break; - case PARAMETERS_FIELD: - parameters = getParameterMap(entry.getValue()); - break; - case CREDENTIAL_FIELD: - credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD); - break; - case ACTIONS_FIELD: - actions = getConnectorActionList(entry.getValue()); - break; - } - } + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + String name = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String version = (String) inputs.get(VERSION_FIELD); + String protocol = (String) inputs.get(PROTOCOL_FIELD); + Map parameters; + Map credentials; + List actions; + + try { + parameters = getParameterMap(inputs.get(PARAMETERS_FIELD)); + credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD); + actions = getConnectorActionList(inputs.get(ACTIONS_FIELD)); + } catch (IllegalArgumentException iae) { + throw new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST); + } catch (PrivilegedActionException pae) { + throw new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED); } - } catch (IllegalArgumentException iae) { - createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST)); - return createConnectorFuture; - } catch (PrivilegedActionException pae) { - createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED)); - return createConnectorFuture; - } - if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) { MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder() .name(name) .description(description) @@ -182,12 +170,9 @@ public void onFailure(Exception e) { .build(); mlClient.createConnector(mlInput, actionListener); - } else { - createConnectorFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + createConnectorFuture.completeExceptionally(e); } - return createConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index bf0fae33e..517e484a7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -13,13 +13,14 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import java.io.IOException; +import java.util.Collections; import java.util.Map; -import java.util.Optional; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; @@ -72,29 +73,23 @@ public void onFailure(Exception e) { } }; - String connectorId = null; - - // Previous Node inputs defines which step the connector ID came from - Optional previousNode = previousNodeInputs.entrySet() - .stream() - .filter(e -> CONNECTOR_ID.equals(e.getValue())) - .map(Map.Entry::getKey) - .findFirst(); - if (previousNode.isPresent()) { - WorkflowData previousNodeOutput = outputs.get(previousNode.get()); - if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) { - connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString(); - } - } + Set requiredKeys = Set.of(CONNECTOR_ID); + Set optionalKeys = Collections.emptySet(); - if (connectorId != null) { - mlClient.deleteConnector(connectorId, actionListener); - } else { - deleteConnectorFuture.completeExceptionally( - new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST) + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs ); - } + String connectorId = (String) inputs.get(CONNECTOR_ID); + mlClient.deleteConnector(connectorId, actionListener); + } catch (FlowFrameworkException e) { + deleteConnectorFuture.completeExceptionally(e); + } return deleteConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 81409ef77..f878fbdc2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -12,14 +12,14 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; @@ -71,27 +71,24 @@ public void onFailure(Exception e) { } }; - String modelId = null; + Set requiredKeys = Set.of(MODEL_ID); + Set optionalKeys = Collections.emptySet(); - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - for (WorkflowData workflowData : data) { - if (workflowData.getContent().containsKey(MODEL_ID)) { - modelId = (String) workflowData.getContent().get(MODEL_ID); - break; - } - } + String modelId = (String) inputs.get(MODEL_ID); - if (modelId != null) { mlClient.deploy(modelId, actionListener); - } else { - deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST)); + } catch (FlowFrameworkException e) { + deployModelFuture.completeExceptionally(e); } - return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java index 50ae30986..e2aea19df 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ModelGroupStep.java @@ -12,10 +12,10 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.AccessMode; import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupInput; @@ -23,10 +23,9 @@ import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse; import java.io.IOException; -import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES; @@ -113,49 +112,23 @@ public void onFailure(Exception e) { } }; - String modelGroupName = null; - String description = null; - List backendRoles = new ArrayList<>(); - AccessMode modelAccessMode = null; - Boolean isAddAllBackendRoles = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelGroupName = (String) content.get(NAME_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case BACKEND_ROLES_FIELD: - backendRoles = getBackendRoles(content); - break; - case MODEL_ACCESS_MODE: - modelAccessMode = (AccessMode) content.get(MODEL_ACCESS_MODE); - break; - case ADD_ALL_BACKEND_ROLES: - isAddAllBackendRoles = (Boolean) content.get(ADD_ALL_BACKEND_ROLES); - break; - default: - break; - } - } - } + Set requiredKeys = Set.of(NAME_FIELD); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, BACKEND_ROLES_FIELD, MODEL_ACCESS_MODE, ADD_ALL_BACKEND_ROLES); - if (modelGroupName == null) { - registerModelGroupFuture.completeExceptionally( - new FlowFrameworkException("Model group name is not provided", RestStatus.BAD_REQUEST) + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs ); - } else { + String modelGroupName = (String) inputs.get(NAME_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + List backendRoles = getBackendRoles(inputs); + AccessMode modelAccessMode = (AccessMode) inputs.get(MODEL_ACCESS_MODE); + Boolean isAddAllBackendRoles = (Boolean) inputs.get(ADD_ALL_BACKEND_ROLES); + MLRegisterModelGroupInputBuilder builder = MLRegisterModelGroupInput.builder(); builder.name(modelGroupName); if (description != null) { @@ -173,6 +146,8 @@ public void onFailure(Exception e) { MLRegisterModelGroupInput mlInput = builder.build(); mlClient.registerModelGroup(mlInput, actionListener); + } catch (FlowFrameworkException e) { + registerModelGroupFuture.completeExceptionally(e); } return registerModelGroupFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java index 3dc730b54..fb1d383b5 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterLocalModelStep.java @@ -19,6 +19,7 @@ import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.MLTaskState; import org.opensearch.ml.common.model.MLModelConfig; @@ -30,10 +31,8 @@ import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; @@ -91,12 +90,6 @@ public CompletableFuture execute( CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { @@ -115,76 +108,41 @@ public void onFailure(Exception e) { } }; - String modelName = null; - String modelVersion = null; - String description = null; - MLModelFormat modelFormat = null; - String modelGroupId = null; - String modelContentHashValue = null; - String modelType = null; - String embeddingDimension = null; - FrameworkType frameworkType = null; - String allConfig = null; - String url = null; - - for (WorkflowData workflowData : data) { - Map content = workflowData.getContent(); - - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case VERSION_FIELD: - modelVersion = (String) content.get(VERSION_FIELD); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case MODEL_FORMAT: - modelFormat = MLModelFormat.from((String) content.get(MODEL_FORMAT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case MODEL_TYPE: - modelType = (String) content.get(MODEL_TYPE); - break; - case EMBEDDING_DIMENSION: - embeddingDimension = (String) content.get(EMBEDDING_DIMENSION); - break; - case FRAMEWORK_TYPE: - frameworkType = FrameworkType.from((String) content.get(FRAMEWORK_TYPE)); - break; - case ALL_CONFIG: - allConfig = (String) content.get(ALL_CONFIG); - break; - case MODEL_CONTENT_HASH_VALUE: - modelContentHashValue = (String) content.get(MODEL_CONTENT_HASH_VALUE); - break; - case URL: - url = (String) content.get(URL); - break; - default: - break; + Set requiredKeys = Set.of( + NAME_FIELD, + VERSION_FIELD, + MODEL_FORMAT, + MODEL_GROUP_ID, + MODEL_TYPE, + EMBEDDING_DIMENSION, + FRAMEWORK_TYPE, + MODEL_CONTENT_HASH_VALUE, + URL + ); + Set optionalKeys = Set.of(DESCRIPTION_FIELD, ALL_CONFIG); - } - } - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - if (Stream.of( - modelName, - modelVersion, - modelFormat, - modelGroupId, - modelType, - embeddingDimension, - frameworkType, - modelContentHashValue, - url - ).allMatch(x -> x != null)) { + String modelName = (String) inputs.get(NAME_FIELD); + String modelVersion = (String) inputs.get(VERSION_FIELD); + String description = (String) inputs.get(DESCRIPTION_FIELD); + MLModelFormat modelFormat = MLModelFormat.from((String) inputs.get(MODEL_FORMAT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String modelContentHashValue = (String) inputs.get(MODEL_CONTENT_HASH_VALUE); + String modelType = (String) inputs.get(MODEL_TYPE); + String embeddingDimension = (String) inputs.get(EMBEDDING_DIMENSION); + FrameworkType frameworkType = FrameworkType.from((String) inputs.get(FRAMEWORK_TYPE)); + String allConfig = (String) inputs.get(ALL_CONFIG); + String url = (String) inputs.get(URL); - // Create Model configudation + // Create Model configuration TextEmbeddingModelConfigBuilder modelConfigBuilder = TextEmbeddingModelConfig.builder() .modelType(modelType) .embeddingDimension(Integer.valueOf(embeddingDimension)) @@ -210,12 +168,9 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = mlInputBuilder.build(); mlClient.register(mlInput, actionListener); - } else { - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); + } catch (FlowFrameworkException e) { + registerLocalModelFuture.completeExceptionally(e); } - return registerLocalModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 7e33937bc..cde546d32 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -12,23 +12,20 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.WorkflowResources; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.common.transport.register.MLRegisterModelInput.MLRegisterModelInputBuilder; import org.opensearch.ml.common.transport.register.MLRegisterModelResponse; -import java.util.ArrayList; -import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.Map.Entry; +import java.util.Set; import java.util.concurrent.CompletableFuture; -import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -115,48 +112,23 @@ public void onFailure(Exception e) { } }; - String modelName = null; - FunctionName functionName = null; - String modelGroupId = null; - String description = null; - String connectorId = null; - - // TODO: Recreating the list to get this compiling - // Need to refactor the below iteration to pull directly from the maps - List data = new ArrayList<>(); - data.add(currentNodeInputs); - data.addAll(outputs.values()); - - // TODO : Handle inline connector configuration : https://github.com/opensearch-project/flow-framework/issues/149 - - for (WorkflowData workflowData : data) { - - Map content = workflowData.getContent(); - for (Entry entry : content.entrySet()) { - switch (entry.getKey()) { - case NAME_FIELD: - modelName = (String) content.get(NAME_FIELD); - break; - case FUNCTION_NAME: - functionName = FunctionName.from(((String) content.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); - break; - case MODEL_GROUP_ID: - modelGroupId = (String) content.get(MODEL_GROUP_ID); - break; - case DESCRIPTION_FIELD: - description = (String) content.get(DESCRIPTION_FIELD); - break; - case CONNECTOR_ID: - connectorId = (String) content.get(CONNECTOR_ID); - break; - default: - break; + Set requiredKeys = Set.of(NAME_FIELD, FUNCTION_NAME, CONNECTOR_ID); + Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD); - } - } - } + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); - if (Stream.of(modelName, functionName, connectorId).allMatch(x -> x != null)) { + String modelName = (String) inputs.get(NAME_FIELD); + FunctionName functionName = FunctionName.from(((String) inputs.get(FUNCTION_NAME)).toUpperCase(Locale.ROOT)); + String modelGroupId = (String) inputs.get(MODEL_GROUP_ID); + String description = (String) inputs.get(DESCRIPTION_FIELD); + String connectorId = (String) inputs.get(CONNECTOR_ID); MLRegisterModelInputBuilder builder = MLRegisterModelInput.builder() .functionName(functionName) @@ -172,12 +144,10 @@ public void onFailure(Exception e) { MLRegisterModelInput mlInput = builder.build(); mlClient.register(mlInput, actionListener); - } else { - registerRemoteModelFuture.completeExceptionally( - new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST) - ); - } + } catch (FlowFrameworkException e) { + registerRemoteModelFuture.completeExceptionally(e); + } return registerRemoteModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index f106ee652..738a31497 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -21,8 +21,8 @@ public interface WorkflowStep { * Triggers the actual processing of the building block. * @param currentNodeId The id of the node executing this step * @param currentNodeInputs Input params and content for this node, from workflow parsing - * @param previousNodeInputs Input params for this node that come from previous steps * @param outputs WorkflowData content of previous steps. + * @param previousNodeInputs Input params for this node that come from previous steps * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. * @throws IOException on a failure. */ diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index a5c4253b3..94fe7b01e 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -9,12 +9,17 @@ package org.opensearch.flowframework.util; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.time.Instant; +import java.util.Map; +import java.util.Set; public class ParseUtilsTests extends OpenSearchTestCase { public void testToInstant() throws IOException { @@ -54,4 +59,50 @@ public void testToInstantWithNotValue() throws IOException { Instant instant = ParseUtils.parseInstant(parser); assertNull(instant); } + + public void testGetInputsFromPreviousSteps() { + WorkflowData currentNodeInputs = new WorkflowData( + Map.ofEntries(Map.entry("content1", 1), Map.entry("param1", 2), Map.entry("content3", "${{step1.output1}}")), + Map.of("param1", "value1"), + "workflowId", + "nodeId" + ); + Map outputs = Map.ofEntries( + Map.entry( + "step1", + new WorkflowData( + Map.ofEntries(Map.entry("output1", "outputvalue1"), Map.entry("output2", "step1outputvalue2")), + "workflowId", + "step1" + ) + ), + Map.entry("step2", new WorkflowData(Map.of("output2", "step2outputvalue2"), "workflowId", "step2")) + ); + Map previousNodeInputs = Map.of("step2", "output2"); + Set requiredKeys = Set.of("param1", "content1"); + Set optionalKeys = Set.of("output1", "output2", "content3", "no-output"); + + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + + assertEquals("value1", inputs.get("param1")); + assertEquals(1, inputs.get("content1")); + assertEquals("outputvalue1", inputs.get("output1")); + assertEquals("step2outputvalue2", inputs.get("output2")); + assertEquals("outputvalue1", inputs.get("content3")); + assertNull(inputs.get("no-output")); + + Set missingRequiredKeys = Set.of("not-here"); + FlowFrameworkException e = assertThrows( + FlowFrameworkException.class, + () -> ParseUtils.getInputsFromPreviousSteps(missingRequiredKeys, optionalKeys, currentNodeInputs, outputs, previousNodeInputs) + ); + assertEquals("Missing required inputs [not-here] in workflow [workflowId] node [nodeId]", e.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, e.getRestStatus()); + } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index 3c997a02e..a766d51c9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -13,7 +13,6 @@ import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; import org.opensearch.core.rest.RestStatus; -import org.opensearch.flowframework.common.CommonValue; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; @@ -44,7 +43,7 @@ public void setUp() throws Exception { MockitoAnnotations.openMocks(this); - inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id"); + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); } public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { @@ -86,7 +85,7 @@ public void testNoConnectorIdInOutput() throws IOException { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required field connector_id is not provided", ex.getCause().getMessage()); + assertEquals("Missing required inputs [connector_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } public void testDeleteConnectorFailure() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index d78a97e8a..cc5acbc30 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -41,8 +41,8 @@ import static org.mockito.Mockito.verify; public class ModelGroupStepTests extends OpenSearchTestCase { - private WorkflowData inputData = WorkflowData.EMPTY; - private WorkflowData inputDataWithNoName = WorkflowData.EMPTY; + private WorkflowData inputData; + private WorkflowData inputDataWithNoName; @Mock MachineLearningNodeClient machineLearningNodeClient; @@ -65,7 +65,7 @@ public void setUp() throws Exception { "test-id", "test-node-id" ); - + inputDataWithNoName = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); } public void testRegisterModelGroup() throws ExecutionException, InterruptedException, IOException { @@ -146,7 +146,7 @@ public void testRegisterModelGroupWithNoName() throws IOException { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Model group name is not provided", ex.getCause().getMessage()); + assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java index c38f8a120..ffa6d82d1 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalModelStepTests.java @@ -237,7 +237,7 @@ public void testRegisterLocalModelTaskFailure() { public void testMissingInputs() { CompletableFuture future = registerLocalModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -245,7 +245,19 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { + "model_format", + "name", + "model_type", + "embedding_dimension", + "framework_type", + "model_group_id", + "version", + "url", + "model_content_hash_value" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } - } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index a83443f05..865526b79 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -127,7 +127,7 @@ public void testRegisterRemoteModelFailure() { public void testMissingInputs() { CompletableFuture future = this.registerRemoteModelStep.execute( "nodeId", - WorkflowData.EMPTY, + new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); @@ -135,7 +135,11 @@ public void testMissingInputs() { assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); - assertEquals("Required fields are not provided", ex.getCause().getMessage()); + assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); + for (String s : new String[] { "name", "function_name", "connector_id" }) { + assertTrue(ex.getCause().getMessage().contains(s)); + } + assertTrue(ex.getCause().getMessage().endsWith("] in workflow [test-id] node [test-node-id]")); } }