Skip to content

Commit

Permalink
Support ML Inference Search Processor Writing to Search Extension (#3061
Browse files Browse the repository at this point in the history
)

(cherry picked from commit d9a56cf)
  • Loading branch information
mingshl authored and github-actions[bot] committed Oct 17, 2024
1 parent 41dae61 commit c0c205d
Show file tree
Hide file tree
Showing 4 changed files with 714 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.ml.processor;

import java.io.IOException;
import java.util.Map;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.action.search.ShardSearchFailure;
import org.opensearch.core.xcontent.XContentBuilder;

public class MLInferenceSearchResponse extends SearchResponse {
private static final String EXT_SECTION_NAME = "ext";

private Map<String, Object> params;

public MLInferenceSearchResponse(
Map<String, Object> params,
SearchResponseSections internalResponse,
String scrollId,
int totalShards,
int successfulShards,
int skippedShards,
long tookInMillis,
ShardSearchFailure[] shardFailures,
Clusters clusters
) {
super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters);
this.params = params;
}

public void setParams(Map<String, Object> params) {
this.params = params;
}

public Map<String, Object> getParams() {
return this.params;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
innerToXContent(builder, params);

if (this.params != null) {
builder.startObject(EXT_SECTION_NAME);
builder.field(MLInferenceSearchResponseProcessor.TYPE, this.params);

builder.endObject();
}
builder.endObject();
return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;

Expand Down Expand Up @@ -84,6 +85,9 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
// allow to write to the extension of the search response, the path to point to search extension
// is prefix with ext.ml_inference
public static final String EXTENSION_PREFIX = "ext.ml_inference";

protected MLInferenceSearchResponseProcessor(
String modelId,
Expand Down Expand Up @@ -158,7 +162,28 @@ public void processResponseAsync(

// if many to one, run rewriteResponseDocuments
if (!oneToOne) {
rewriteResponseDocuments(response, responseListener);
// use MLInferenceSearchResponseProcessor to allow writing to extension
// check if the search response is in the type of MLInferenceSearchResponse
// if not, initiate a new one MLInferenceSearchResponse
MLInferenceSearchResponse mlInferenceSearchResponse;

if (response instanceof MLInferenceSearchResponse) {
mlInferenceSearchResponse = (MLInferenceSearchResponse) response;
} else {
mlInferenceSearchResponse = new MLInferenceSearchResponse(
null,
response.getInternalResponse(),
response.getScrollId(),
response.getTotalShards(),
response.getSuccessfulShards(),
response.getSkippedShards(),
response.getSuccessfulShards(),
response.getShardFailures(),
response.getClusters()
);
}

rewriteResponseDocuments(mlInferenceSearchResponse, responseListener);
} else {
// if one to one, make one hit search response and run rewriteResponseDocuments
GroupedActionListener<SearchResponse> combineResponseListener = getCombineResponseGroupedActionListener(
Expand Down Expand Up @@ -545,22 +570,37 @@ public void onResponse(Map<Integer, MLOutput> multipleMLOutputs) {
} else {
modelOutputValuePerDoc = modelOutputValue;
}

if (sourceAsMap.containsKey(newDocumentFieldName)) {
if (override) {
sourceAsMapWithInference.remove(newDocumentFieldName);
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
// writing to search response extension
if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) {
Map<String, Object> params = ((MLInferenceSearchResponse) response).getParams();
String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", "");

if (params != null) {
params.put(paramsName, modelOutputValuePerDoc);
((MLInferenceSearchResponse) response).setParams(params);
} else {
logger
.debug(
"{} already exists in the search response hit. Skip processing this field.",
newDocumentFieldName
);
// TODO when the response has the same field name, should it throw exception? currently,
// ingest processor quietly skip it
Map<String, Object> newParams = new HashMap<>();
newParams.put(paramsName, modelOutputValuePerDoc);
((MLInferenceSearchResponse) response).setParams(newParams);
}
} else {
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
// writing to search response hits
if (sourceAsMap.containsKey(newDocumentFieldName)) {
if (override) {
sourceAsMapWithInference.remove(newDocumentFieldName);
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
} else {
logger
.debug(
"{} already exists in the search response hit. Skip processing this field.",
newDocumentFieldName
);
// TODO when the response has the same field name, should it throw exception? currently,
// ingest processor quietly skip it
}
} else {
sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc);
}
}
}
}
Expand Down Expand Up @@ -774,6 +814,19 @@ public MLInferenceSearchResponseProcessor create(
+ ". Please adjust mappings."
);
}
boolean writeToSearchExtension = false;

if (outputMaps != null) {
writeToSearchExtension = outputMaps
.stream()
.filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps
.flatMap(outputMap -> outputMap.keySet().stream())
.anyMatch(key -> key.startsWith(EXTENSION_PREFIX));
}

if (writeToSearchExtension & oneToOne) {
throw new IllegalArgumentException("Write model response to search extension does not support when one_to_one is true.");
}

return new MLInferenceSearchResponseProcessor(
modelId,
Expand Down
Loading

0 comments on commit c0c205d

Please sign in to comment.