Skip to content

Commit

Permalink
Add Util method to fetch inputs from parameters, content, and previou…
Browse files Browse the repository at this point in the history
…s step output (#234)

* Util method to get required inputs

Signed-off-by: Daniel Widdis <[email protected]>

* Implement parsing in some of the steps

Signed-off-by: Daniel Widdis <[email protected]>

* Handle parsing exceptions in the future

Signed-off-by: Daniel Widdis <[email protected]>

* Improve exception handling

Signed-off-by: Daniel Widdis <[email protected]>

* More steps using the new input parsing

Signed-off-by: Daniel Widdis <[email protected]>

* Update Delete Connector Step with parsing util

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis authored Dec 2, 2023
1 parent c35a809 commit 54ab4c7
Show file tree
Hide file tree
Showing 13 changed files with 327 additions and 282 deletions.
102 changes: 102 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,24 @@
import org.opensearch.commons.ConfigConstants;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.ml.common.agent.LLMSpec;

import java.io.IOException;
import java.time.Instant;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand All @@ -39,6 +47,9 @@
public class ParseUtils {
private static final Logger logger = LogManager.getLogger(ParseUtils.class);

// Matches ${{ foo.bar }} (whitespace optional) with capturing groups 1=foo, 2=bar
private static final Pattern SUBSTITUTION_PATTERN = Pattern.compile("\\$\\{\\{\\s*(.+)\\.(.+?)\\s*\\}\\}");

private ParseUtils() {}

/**
Expand Down Expand Up @@ -161,4 +172,95 @@ public static Map<String, String> getStringToStringMap(Object map, String fieldN
throw new IllegalArgumentException("[" + fieldName + "] must be a key-value map.");
}

/**
* Creates a map containing the specified input keys, with values derived from template data or previous node
* output.
*
* @param requiredInputKeys A set of keys that must be present, or will cause an exception to be thrown
* @param optionalInputKeys A set of keys that may be present, or will be absent in the returned map
* @param currentNodeInputs Input params and content for this node, from workflow parsing
* @param outputs WorkflowData content of previous steps
* @param previousNodeInputs Input params for this node that come from previous steps
* @return A map containing the requiredInputKeys with their corresponding values,
* and optionalInputKeys with their corresponding values if present.
* Throws a {@link FlowFrameworkException} if a required key is not present.
*/
public static Map<String, Object> getInputsFromPreviousSteps(
Set<String> requiredInputKeys,
Set<String> optionalInputKeys,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) {
// Mutable set to ensure all required keys are used
Set<String> requiredKeys = new HashSet<>(requiredInputKeys);
// Merge input sets to add all requested keys
Set<String> keys = new HashSet<>(requiredInputKeys);
keys.addAll(optionalInputKeys);
// Initialize return map
Map<String, Object> inputs = new HashMap<>();
for (String key : keys) {
Object value = null;
// Priority 1: specifically named prior step inputs
// ... parse the previousNodeInputs map and fill in the specified keys
Optional<String> previousNodeForKey = previousNodeInputs.entrySet()
.stream()
.filter(e -> key.equals(e.getValue()))
.map(Map.Entry::getKey)
.findAny();
if (previousNodeForKey.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNodeForKey.get());
if (previousNodeOutput != null) {
value = previousNodeOutput.getContent().get(key);
}
}
// Priority 2: inputs specified in template
// ... fetch from currentNodeInputs (params take precedence)
if (value == null) {
value = currentNodeInputs.getParams().get(key);
}
if (value == null) {
value = currentNodeInputs.getContent().get(key);
}
// Priority 3: other inputs
if (value == null) {
Optional<Object> matchedValue = outputs.values()
.stream()
.map(WorkflowData::getContent)
.filter(m -> m.containsKey(key))
.map(m -> m.get(key))
.findAny();
if (matchedValue.isPresent()) {
value = matchedValue.get();
}
}
// Check for substitution
if (value != null) {
Matcher m = SUBSTITUTION_PATTERN.matcher(value.toString());
if (m.matches()) {
WorkflowData data = outputs.get(m.group(1));
if (data != null && data.getContent().containsKey(m.group(2))) {
value = data.getContent().get(m.group(2));
}
}
inputs.put(key, value);
requiredKeys.remove(key);
}
}
// After iterating is complete, throw exception if requiredKeys is not empty
if (!requiredKeys.isEmpty()) {
throw new FlowFrameworkException(
"Missing required inputs "
+ requiredKeys
+ " in workflow ["
+ currentNodeInputs.getWorkflowId()
+ "] node ["
+ currentNodeInputs.getNodeId()
+ "]",
RestStatus.BAD_REQUEST
);
}
// Finally return the map
return inputs;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.flowframework.common.WorkflowResources;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.ConnectorAction.ActionType;
Expand All @@ -33,8 +34,8 @@
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Stream;

import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD;
import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD;
Expand Down Expand Up @@ -120,57 +121,44 @@ public void onFailure(Exception e) {
}
};

String name = null;
String description = null;
String version = null;
String protocol = null;
Map<String, String> parameters = Collections.emptyMap();
Map<String, String> credentials = Collections.emptyMap();
List<ConnectorAction> actions = Collections.emptyList();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());
Set<String> requiredKeys = Set.of(
NAME_FIELD,
DESCRIPTION_FIELD,
VERSION_FIELD,
PROTOCOL_FIELD,
PARAMETERS_FIELD,
CREDENTIAL_FIELD,
ACTIONS_FIELD
);
Set<String> optionalKeys = Collections.emptySet();

try {
for (WorkflowData workflowData : data) {
for (Entry<String, Object> entry : workflowData.getContent().entrySet()) {
switch (entry.getKey()) {
case NAME_FIELD:
name = (String) entry.getValue();
break;
case DESCRIPTION_FIELD:
description = (String) entry.getValue();
break;
case VERSION_FIELD:
version = (String) entry.getValue();
break;
case PROTOCOL_FIELD:
protocol = (String) entry.getValue();
break;
case PARAMETERS_FIELD:
parameters = getParameterMap(entry.getValue());
break;
case CREDENTIAL_FIELD:
credentials = getStringToStringMap(entry.getValue(), CREDENTIAL_FIELD);
break;
case ACTIONS_FIELD:
actions = getConnectorActionList(entry.getValue());
break;
}
}
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String name = (String) inputs.get(NAME_FIELD);
String description = (String) inputs.get(DESCRIPTION_FIELD);
String version = (String) inputs.get(VERSION_FIELD);
String protocol = (String) inputs.get(PROTOCOL_FIELD);
Map<String, String> parameters;
Map<String, String> credentials;
List<ConnectorAction> actions;

try {
parameters = getParameterMap(inputs.get(PARAMETERS_FIELD));
credentials = getStringToStringMap(inputs.get(CREDENTIAL_FIELD), CREDENTIAL_FIELD);
actions = getConnectorActionList(inputs.get(ACTIONS_FIELD));
} catch (IllegalArgumentException iae) {
throw new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST);
} catch (PrivilegedActionException pae) {
throw new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED);
}
} catch (IllegalArgumentException iae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(iae.getMessage(), RestStatus.BAD_REQUEST));
return createConnectorFuture;
} catch (PrivilegedActionException pae) {
createConnectorFuture.completeExceptionally(new FlowFrameworkException(pae.getMessage(), RestStatus.UNAUTHORIZED));
return createConnectorFuture;
}

if (Stream.of(name, description, version, protocol, parameters, credentials, actions).allMatch(x -> x != null)) {
MLCreateConnectorInput mlInput = MLCreateConnectorInput.builder()
.name(name)
.description(description)
Expand All @@ -182,12 +170,9 @@ public void onFailure(Exception e) {
.build();

mlClient.createConnector(mlInput, actionListener);
} else {
createConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required fields are not provided", RestStatus.BAD_REQUEST)
);
} catch (FlowFrameworkException e) {
createConnectorFuture.completeExceptionally(e);
}

return createConnectorFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
import org.opensearch.ExceptionsHelper;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;

import java.io.IOException;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID;
Expand Down Expand Up @@ -72,29 +73,23 @@ public void onFailure(Exception e) {
}
};

String connectorId = null;

// Previous Node inputs defines which step the connector ID came from
Optional<String> previousNode = previousNodeInputs.entrySet()
.stream()
.filter(e -> CONNECTOR_ID.equals(e.getValue()))
.map(Map.Entry::getKey)
.findFirst();
if (previousNode.isPresent()) {
WorkflowData previousNodeOutput = outputs.get(previousNode.get());
if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) {
connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString();
}
}
Set<String> requiredKeys = Set.of(CONNECTOR_ID);
Set<String> optionalKeys = Collections.emptySet();

if (connectorId != null) {
mlClient.deleteConnector(connectorId, actionListener);
} else {
deleteConnectorFuture.completeExceptionally(
new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST)
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);
}
String connectorId = (String) inputs.get(CONNECTOR_ID);

mlClient.deleteConnector(connectorId, actionListener);
} catch (FlowFrameworkException e) {
deleteConnectorFuture.completeExceptionally(e);
}
return deleteConnectorFuture;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;

import java.util.ArrayList;
import java.util.List;
import java.util.Collections;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
Expand Down Expand Up @@ -71,27 +71,24 @@ public void onFailure(Exception e) {
}
};

String modelId = null;
Set<String> requiredKeys = Set.of(MODEL_ID);
Set<String> optionalKeys = Collections.emptySet();

// TODO: Recreating the list to get this compiling
// Need to refactor the below iteration to pull directly from the maps
List<WorkflowData> data = new ArrayList<>();
data.add(currentNodeInputs);
data.addAll(outputs.values());
try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

for (WorkflowData workflowData : data) {
if (workflowData.getContent().containsKey(MODEL_ID)) {
modelId = (String) workflowData.getContent().get(MODEL_ID);
break;
}
}
String modelId = (String) inputs.get(MODEL_ID);

if (modelId != null) {
mlClient.deploy(modelId, actionListener);
} else {
deployModelFuture.completeExceptionally(new FlowFrameworkException("Model ID is not provided", RestStatus.BAD_REQUEST));
} catch (FlowFrameworkException e) {
deployModelFuture.completeExceptionally(e);
}

return deployModelFuture;
}

Expand Down
Loading

0 comments on commit 54ab4c7

Please sign in to comment.