Skip to content

Commit

Permalink
use standard config in ingest processor intead of always return list (#…
Browse files Browse the repository at this point in the history
…3011)

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl authored Oct 1, 2024
1 parent 69f6515 commit 0fe4c15
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 27 deletions.
1 change: 1 addition & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
compileOnly 'com.jayway.jsonpath:json-path:2.9.0'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
import com.jayway.jsonpath.JsonPath;


@Log4j2
public class StringUtils {
Expand All @@ -56,6 +58,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -239,4 +242,89 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

public static String obtainFieldNameFromJsonPath(String jsonPath) {
String[] parts = jsonPath.split("\\.");

// Get the last part which is the field name
return parts[parts.length - 1];
}

public static String getJsonPath(String jsonPathWithSource) {
// Find the index of the first occurrence of "$."
int startIndex = jsonPathWithSource.indexOf("$.");

// Extract the substring from the startIndex to the end of the input string
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
}

/**
* Checks if the given input string matches the JSONPath format.
*
* <p>The JSONPath format is a way to navigate and extract data from JSON documents.
* It uses a syntax similar to XPath for XML documents. This method attempts to compile
* the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath}
* library. If the compilation succeeds, it means the input string is a valid JSONPath
* expression.
*
* @param input the input string to be checked for JSONPath format validity
* @return true if the input string is a valid JSONPath expression, false otherwise
*/
public static boolean isValidJSONPath(String input) {
if (input == null || input.isBlank()) {
return false;
}
try {
JsonPath.compile(input); // This will throw an exception if the path is invalid
return true;
} catch (Exception e) {
return false;
}
}


/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix);
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
Expand All @@ -31,9 +30,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;

/**
* MLInferenceIngestProcessor requires a modelId string to call model inferences
Expand All @@ -57,11 +54,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST)
.build();

protected MLInferenceIngestProcessor(
String modelId,
List<Map<String, String>> inputMaps,
Expand Down Expand Up @@ -243,24 +235,29 @@ private void getMappedModelInputFromDocuments(
Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
String documentFieldValueAsString = toString(documentFieldValue);
updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
return;
}
// else when cannot find field path in document, try check for nested array using json path
else {
if (documentFieldName.contains(DOT_SYMBOL)) {

Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
ArrayList<Object> fieldValueList = JsonPath
.using(suppressExceptionConfiguration)
.parse(sourceObject)
.read(documentFieldName);
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
// If the standard dot path fails, try to check for a nested array using JSON path
if (StringUtils.isValidJSONPath(documentFieldName)) {
Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName);

if (fieldValue != null) {
if (fieldValue instanceof List) {
List<?> fieldValueList = (List<?>) fieldValue;
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters);
}
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
}

Expand Down
Loading

0 comments on commit 0fe4c15

Please sign in to comment.