Skip to content

Commit

Permalink
Enable pass query string to input_map in ml inference search response…
Browse files Browse the repository at this point in the history
… processor (#2899) (#3129)

* enable add query_text to model_config

Signed-off-by: Mingshi Liu <[email protected]>

* change javadoc

Signed-off-by: Mingshi Liu <[email protected]>

* add more tests

Signed-off-by: Mingshi Liu <[email protected]>

* use standard json path config

Signed-off-by: Mingshi Liu <[email protected]>

* add example in javadoc

Signed-off-by: Mingshi Liu <[email protected]>

* read query mapping  from input_map

Signed-off-by: Mingshi Liu <[email protected]>

* recognize query mapping by prefix _request.

Signed-off-by: Mingshi Liu <[email protected]>

---------

Signed-off-by: Mingshi Liu <[email protected]>
(cherry picked from commit 083abad)

Co-authored-by: Mingshi Liu <[email protected]>
  • Loading branch information
opensearch-trigger-bot[bot] and mingshl authored Oct 21, 2024
1 parent 09aa6ea commit 4f01193
Show file tree
Hide file tree
Showing 5 changed files with 508 additions and 47 deletions.
1 change: 0 additions & 1 deletion common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies {
compileOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0'
compileOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1'
compileOnly group: 'org.json', name: 'json', version: '20231013'

implementation('com.google.guava:guava:32.1.2-jre') {
exclude group: 'com.google.guava', module: 'failureaccess'
exclude group: 'com.google.code.findbugs', module: 'jsr305'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
import static org.opensearch.ml.common.utils.StringUtils.getJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.isValidJSONPath;
import static org.opensearch.ml.common.utils.StringUtils.obtainFieldNameFromJsonPath;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
Expand Down Expand Up @@ -457,4 +458,53 @@ public void testGetJsonPath_ValidJsonPathWithoutSource() {
String result = getJsonPath(input);
assertEquals("$.response.body.data[*].embedding", result);
}

@Test
public void testisValidJSONPath_InvalidInputs() {
Assert.assertFalse(isValidJSONPath("..bar"));
Assert.assertFalse(isValidJSONPath("."));
Assert.assertFalse(isValidJSONPath(".."));
Assert.assertFalse(isValidJSONPath("foo.bar."));
Assert.assertFalse(isValidJSONPath(".foo.bar."));
}

@Test
public void testisValidJSONPath_NullInput() {
Assert.assertFalse(isValidJSONPath(null));
}

@Test
public void testisValidJSONPath_EmptyInput() {
Assert.assertFalse(isValidJSONPath(""));
}

@Test
public void testisValidJSONPath_ValidInputs() {
Assert.assertTrue(isValidJSONPath("foo"));
Assert.assertTrue(isValidJSONPath("foo.bar"));
Assert.assertTrue(isValidJSONPath("foo.bar.baz"));
Assert.assertTrue(isValidJSONPath("foo.bar.baz.qux"));
Assert.assertTrue(isValidJSONPath(".foo"));
Assert.assertTrue(isValidJSONPath("$.foo"));
Assert.assertTrue(isValidJSONPath(".foo.bar"));
Assert.assertTrue(isValidJSONPath("$.foo.bar"));
}

@Test
public void testisValidJSONPath_WithFilter() {
Assert.assertTrue(isValidJSONPath("$.store['book']"));
Assert.assertTrue(isValidJSONPath("$['store']['book'][0]['title']"));
Assert.assertTrue(isValidJSONPath("$.store.book[0]"));
Assert.assertTrue(isValidJSONPath("$.store.book[1,2]"));
Assert.assertTrue(isValidJSONPath("$.store.book[-1:] "));
Assert.assertTrue(isValidJSONPath("$.store.book[0:2]"));
Assert.assertTrue(isValidJSONPath("$.store.book[*]"));
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.price < 10)]"));
Assert.assertTrue(isValidJSONPath("$.store.book[?(@.author == 'J.K. Rowling')]"));
Assert.assertTrue(isValidJSONPath("$..author"));
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 15)]"));
Assert.assertTrue(isValidJSONPath("$.store.book[0,1]"));
Assert.assertTrue(isValidJSONPath("$['store','warehouse']"));
Assert.assertTrue(isValidJSONPath("$..book[?(@.price > 20)].title"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.processor;

import static java.lang.Math.max;
import static org.opensearch.ml.common.utils.StringUtils.toJson;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.INPUT_MAP;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MAX_PREDICTION_TASKS;
import static org.opensearch.ml.processor.InferenceProcessorAttributes.MODEL_CONFIG;
Expand Down Expand Up @@ -55,12 +56,11 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.jayway.jsonpath.Configuration;
import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Option;

public class MLInferenceSearchResponseProcessor extends AbstractProcessor implements SearchResponseProcessor, ModelExecutor {

public static final String REQUEST_PREFIX = "_request.";
private final NamedXContentRegistry xContentRegistry;
private static final Logger logger = LogManager.getLogger(MLInferenceSearchResponseProcessor.class);
private final InferenceProcessorAttributes inferenceProcessorAttributes;
Expand Down Expand Up @@ -155,6 +155,8 @@ public void processResponseAsync(
try {
SearchHit[] hits = response.getHits().getHits();
// skip processing when there is no hit

String queryString = request.source().toString();
if (hits.length == 0) {
responseListener.onResponse(response);
return;
Expand Down Expand Up @@ -183,7 +185,7 @@ public void processResponseAsync(
);
}

rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
rewriteResponseDocuments(mlInferenceSearchResponse, responseListener, queryString);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
Expand All @@ -198,7 +200,7 @@ public void processResponseAsync(
newHits[0] = hit;
SearchResponse oneHitResponse = SearchResponseUtil.replaceHits(newHits, response);
ActionListener<SearchResponse> oneHitListener = getOneHitListener(combineResponseListener, isOneHitListenerFailed);
rewriteResponseDocuments(oneHitResponse, oneHitListener);
rewriteResponseDocuments(oneHitResponse, oneHitListener, queryString);
// if any OneHitListener failure, try stop the rest of the predictions
if (isOneHitListenerFailed.get()) {
break;
Expand Down Expand Up @@ -305,9 +307,11 @@ public void onFailure(Exception e) {
*
* @param response the search response
* @param responseListener the listener to be notified when the response is processed
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
* @throws IOException if an I/O error occurs during the rewriting process
*/
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener) throws IOException {
private void rewriteResponseDocuments(SearchResponse response, ActionListener<SearchResponse> responseListener, String queryString)
throws IOException {
List<Map<String, String>> processInputMap = inferenceProcessorAttributes.getInputMaps();
List<Map<String, String>> processOutputMap = inferenceProcessorAttributes.getOutputMaps();
int inputMapSize = (processInputMap == null) ? 0 : processInputMap.size();
Expand All @@ -329,7 +333,7 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
);
SearchHit[] hits = response.getHits().getHits();
for (int inputMapIndex = 0; inputMapIndex < max(inputMapSize, 1); inputMapIndex++) {
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions);
processPredictions(hits, processInputMap, inputMapIndex, batchPredictionListener, hitCountInPredictions, queryString);
}
}

Expand All @@ -341,56 +345,80 @@ private void rewriteResponseDocuments(SearchResponse response, ActionListener<Se
* @param inputMapIndex the index of the input mapping to process
* @param batchPredictionListener the listener to be notified when the predictions are processed
* @param hitCountInPredictions a map to keep track of the count of hits that have the required input fields for each round of prediction
* @param queryString the query body in string format, for example, "{ \"query\": { \"match_all\": {} } }\n"
* @throws IOException if an I/O error occurs during the prediction process
*/
private void processPredictions(
SearchHit[] hits,
List<Map<String, String>> processInputMap,
int inputMapIndex,
GroupedActionListener<Map<Integer, MLOutput>> batchPredictionListener,
Map<Integer, Integer> hitCountInPredictions
Map<Integer, Integer> hitCountInPredictions,
String queryString
) throws IOException {

Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
Map<String, String> modelConfigMapsInput = inferenceProcessorAttributes.getModelConfigMaps();

modelParameters.putAll(modelConfigMapsInput);
modelConfigs.putAll(modelConfigMapsInput);

}

Map<String, Object> modelInputParameters = new HashMap<>();

Map<String, String> inputMapping;
if (processInputMap != null && !processInputMap.isEmpty()) {
inputMapping = processInputMap.get(inputMapIndex);
boolean isRequestInputMissing = checkIsRequestInputMissing(queryString, inputMapping);
if (isRequestInputMissing) {
if (!ignoreMissing) {
throw new IllegalArgumentException(
"Missing required input field in query body. input_map: " + inputMapping.values() + ", query body:" + queryString
);
}
}

for (SearchHit hit : hits) {
Map<String, Object> document = hit.getSourceAsMap();
boolean isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
if (!isModelInputMissing) {
boolean isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
if (!isDocumentFieldMissing) {
MapUtils.incrementCounter(hitCountInPredictions, inputMapIndex);
for (Map.Entry<String, String> entry : inputMapping.entrySet()) {
// model field as key, document field name as value
String modelInputFieldName = entry.getKey();
String documentFieldName = entry.getValue();

Object documentJson = JsonPath.parse(document).read("$");
Configuration configuration = Configuration
.builder()
.options(Option.SUPPRESS_EXCEPTIONS, Option.DEFAULT_PATH_LEAF_TO_NULL)
.build();

Object documentValue = JsonPath.using(configuration).parse(documentJson).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
// read the query string when the mapping field name starts with "$._request." or "_request."
// skip when modelInputParameters already has this modelInputFieldName to avoid duplicate read
if (StringUtils.isValidJSONPath(documentFieldName)
&& (documentFieldName.startsWith("$." + REQUEST_PREFIX) || documentFieldName.startsWith(REQUEST_PREFIX))
&& !modelInputParameters.containsKey(modelInputFieldName)) {
String requestFieldName = documentFieldName.replaceFirst(REQUEST_PREFIX, "");

Object queryText = JsonPath.using(suppressExceptionConfiguration).parse(queryString).read(requestFieldName);
if (queryText != null) {
modelInputParameters.put(modelInputFieldName, toJson(queryText));
}
} else {
Object documentValue = JsonPath.using(suppressExceptionConfiguration).parse(document).read(documentFieldName);
if (documentValue != null) {
// when not existed in the map, add into the modelInputParameters map
updateModelInputParameters(modelInputParameters, modelInputFieldName, documentValue);
}
}
}
} else { // when document does not contain the documentFieldName, skip when ignoreMissing
if (!ignoreMissing) {
throw new IllegalArgumentException(
"cannot find all required input fields: " + inputMapping.values() + " in hit:" + hit
"cannot find all required input fields: "
+ inputMapping.values()
+ " in hit:"
+ hit
+ " and query body:"
+ queryString
);
}
}
Expand Down Expand Up @@ -542,11 +570,11 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
Map<String, String> inputMapping = getDefaultInputMapping(sourceAsMap, mappingIndex, processInputMap);
Map<String, String> outputMapping = getDefaultOutputMapping(mappingIndex, processOutputMap);

boolean isModelInputMissing = false;
boolean isDocumentFieldMissing = false;
if (processInputMap != null && !processInputMap.isEmpty()) {
isModelInputMissing = checkIsModelInputMissing(document, inputMapping);
isDocumentFieldMissing = checkIsDocumentFieldMissing(document, inputMapping);
}
if (!isModelInputMissing) {
if (!isDocumentFieldMissing) {
// Iterate over outputMapping
for (Map.Entry<String, String> outputMapEntry : outputMapping.entrySet()) {

Expand Down Expand Up @@ -637,22 +665,45 @@ public void onFailure(Exception e) {

/**
* Checks if the document is missing any of the required input fields specified in the input mapping.
* When model config contains the default model_input value, it's not considered as missing model input.
*
* @param document the document map
* @param inputMapping the input mapping
* @return true if the document is missing any of the required input fields, false otherwise
*/
private boolean checkIsModelInputMissing(Map<String, Object> document, Map<String, String> inputMapping) {
boolean isModelInputMissing = false;
for (Map.Entry<String, String> inputMapEntry : inputMapping.entrySet()) {
String oldDocumentFieldName = inputMapEntry.getValue();
boolean checkSingleModelInputPresent = hasField(document, oldDocumentFieldName);
if (!checkSingleModelInputPresent) {
isModelInputMissing = true;
break;
}
}
return isModelInputMissing;
private boolean checkIsDocumentFieldMissing(Map<String, Object> document, Map<String, String> inputMapping) {
return inputMapping
.values()
.stream()
.filter(fieldName -> !(fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX)))
.anyMatch(fieldName -> {
boolean isFieldPresentInDocument = document != null && hasField(document, fieldName);
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(fieldName);
return !isFieldPresentInDocument && !isFieldPresentInModelConfig;
});
}

/**
* Checks if the request is missing any of the required input fields specified in the input mapping.
* When model config contains the default model_input value, it's not considered as missing model input.
*
* @param queryString the query body in string format, e.g., "{ \"query\": { \"match_all\": {} } }\n"
* @param inputMapping the input mapping
* @return true if the document is missing any of the required input fields, false otherwise
*/
private boolean checkIsRequestInputMissing(String queryString, Map<String, String> inputMapping) {
return inputMapping
.values()
.stream()
.filter(fieldName -> fieldName.startsWith("$." + REQUEST_PREFIX) || fieldName.startsWith(REQUEST_PREFIX))
.map(fieldName -> fieldName.replaceFirst(REQUEST_PREFIX, ""))
.anyMatch(requestFieldName -> {
boolean isFieldPresentInQuery = queryString != null && hasField(queryString, requestFieldName);
boolean isFieldPresentInModelConfig = this.inferenceProcessorAttributes.modelConfigMaps != null
&& this.inferenceProcessorAttributes.modelConfigMaps.containsKey(requestFieldName);
return !isFieldPresentInQuery && !isFieldPresentInModelConfig;
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,12 @@ default String toString(Object originalFieldValue) {
}

default boolean hasField(Object json, String path) {
Object value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);

Object value;
if (json instanceof String) {
value = JsonPath.using(suppressExceptionConfiguration).parse((String) json).read(path);
} else {
value = JsonPath.using(suppressExceptionConfiguration).parse(json).read(path);
}
if (value != null) {
return true;
}
Expand Down
Loading

0 comments on commit 4f01193

Please sign in to comment.