Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.14] Fix ml inference ingest processor always return list using JsonPath #3011

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dependencies {
exclude group: 'com.google.j2objc', module: 'j2objc-annotations'
exclude group: 'com.google.guava', module: 'listenablefuture'
}
compileOnly 'com.jayway.jsonpath:json-path:2.9.0'
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_ADDITIONAL_INFO_FIELD;
import static org.opensearch.ml.common.conversation.ConversationalIndexConstants.INTERACTIONS_RESPONSE_FIELD;
import com.jayway.jsonpath.JsonPath;


@Log4j2
public class StringUtils {
Expand All @@ -56,6 +58,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -239,4 +242,89 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

public static String obtainFieldNameFromJsonPath(String jsonPath) {
String[] parts = jsonPath.split("\\.");

// Get the last part which is the field name
return parts[parts.length - 1];
}

public static String getJsonPath(String jsonPathWithSource) {
// Find the index of the first occurrence of "$."
int startIndex = jsonPathWithSource.indexOf("$.");

// Extract the substring from the startIndex to the end of the input string
return (startIndex != -1) ? jsonPathWithSource.substring(startIndex) : jsonPathWithSource;
}

/**
* Checks if the given input string matches the JSONPath format.
*
* <p>The JSONPath format is a way to navigate and extract data from JSON documents.
* It uses a syntax similar to XPath for XML documents. This method attempts to compile
* the input string as a JSONPath expression using the {@link com.jayway.jsonpath.JsonPath}
* library. If the compilation succeeds, it means the input string is a valid JSONPath
* expression.
*
* @param input the input string to be checked for JSONPath format validity
* @return true if the input string is a valid JSONPath expression, false otherwise
*/
public static boolean isValidJSONPath(String input) {
if (input == null || input.isBlank()) {
return false;
}
try {
JsonPath.compile(input); // This will throw an exception if the path is invalid
return true;
} catch (Exception e) {
return false;
}
}


/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix);
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
Expand All @@ -31,9 +30,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.script.TemplateScript;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;

/**
* MLInferenceIngestProcessor requires a modelId string to call model inferences
Expand All @@ -57,11 +54,6 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL, Option.ALWAYS_RETURN_LIST)
.build();

protected MLInferenceIngestProcessor(
String modelId,
List<Map<String, String>> inputMaps,
Expand Down Expand Up @@ -243,24 +235,29 @@ private void getMappedModelInputFromDocuments(
Object documentFieldValue = ingestDocument.getFieldValue(originalFieldPath, Object.class);
String documentFieldValueAsString = toString(documentFieldValue);
updateModelParameters(modelInputFieldName, documentFieldValueAsString, modelParameters);
return;
}
// else when cannot find field path in document, try check for nested array using json path
else {
if (documentFieldName.contains(DOT_SYMBOL)) {

Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
ArrayList<Object> fieldValueList = JsonPath
.using(suppressExceptionConfiguration)
.parse(sourceObject)
.read(documentFieldName);
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
// If the standard dot path fails, try to check for a nested array using JSON path
if (StringUtils.isValidJSONPath(documentFieldName)) {
Map<String, Object> sourceObject = ingestDocument.getSourceAndMetadata();
Object fieldValue = JsonPath.using(suppressExceptionConfiguration).parse(sourceObject).read(documentFieldName);

if (fieldValue != null) {
if (fieldValue instanceof List) {
List<?> fieldValueList = (List<?>) fieldValue;
if (!fieldValueList.isEmpty()) {
updateModelParameters(modelInputFieldName, toString(fieldValueList), modelParameters);
} else if (!ignoreMissing) {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
updateModelParameters(modelInputFieldName, toString(fieldValue), modelParameters);
}
} else if (!ignoreMissing) {
throw new IllegalArgumentException("cannot find field name defined from input map: " + documentFieldName);
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
} else {
throw new IllegalArgumentException("Cannot find field name defined from input map: " + documentFieldName);
}
}

Expand Down
Loading
Loading