diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java new file mode 100644 index 0000000000..1fb3b1471e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponse.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import java.io.IOException; +import java.util.Map; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.XContentBuilder; + +public class MLInferenceSearchResponse extends SearchResponse { + private static final String EXT_SECTION_NAME = "ext"; + + private Map params; + + public MLInferenceSearchResponse( + Map params, + SearchResponseSections internalResponse, + String scrollId, + int totalShards, + int successfulShards, + int skippedShards, + long tookInMillis, + ShardSearchFailure[] shardFailures, + Clusters clusters + ) { + super(internalResponse, scrollId, totalShards, successfulShards, skippedShards, tookInMillis, shardFailures, clusters); + this.params = params; + } + + public void setParams(Map params) { + this.params = params; + } + + public Map getParams() { + return this.params; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + innerToXContent(builder, params); + + if (this.params != null) { + builder.startObject(EXT_SECTION_NAME); + builder.field(MLInferenceSearchResponseProcessor.TYPE, this.params); + + builder.endObject(); + } + builder.endObject(); + return builder; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index 2164877b9f..56e98474c7 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -20,6 +20,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; @@ -84,6 +85,9 @@ public class MLInferenceSearchResponseProcessor extends AbstractProcessor implem // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results"; + // allow to write to the extension of the search response, the path to point to search extension + // is prefix with ext.ml_inference + public static final String EXTENSION_PREFIX = "ext.ml_inference"; protected MLInferenceSearchResponseProcessor( String modelId, @@ -158,7 +162,28 @@ public void processResponseAsync( // if many to one, run rewriteResponseDocuments if (!oneToOne) { - rewriteResponseDocuments(response, responseListener); + // use MLInferenceSearchResponseProcessor to allow writing to extension + // check if the search response is in the type of MLInferenceSearchResponse + // if not, initiate a new one MLInferenceSearchResponse + MLInferenceSearchResponse mlInferenceSearchResponse; + + if (response instanceof MLInferenceSearchResponse) { + mlInferenceSearchResponse = (MLInferenceSearchResponse) response; + } else { + mlInferenceSearchResponse = new MLInferenceSearchResponse( + null, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + } + + rewriteResponseDocuments(mlInferenceSearchResponse, responseListener); } else { // if one to one, make one hit search response and run rewriteResponseDocuments GroupedActionListener combineResponseListener = getCombineResponseGroupedActionListener( @@ -545,22 +570,37 @@ public void onResponse(Map multipleMLOutputs) { } else { modelOutputValuePerDoc = modelOutputValue; } - - if (sourceAsMap.containsKey(newDocumentFieldName)) { - if (override) { - sourceAsMapWithInference.remove(newDocumentFieldName); - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + // writing to search response extension + if (newDocumentFieldName.startsWith(EXTENSION_PREFIX)) { + Map params = ((MLInferenceSearchResponse) response).getParams(); + String paramsName = newDocumentFieldName.replaceFirst(EXTENSION_PREFIX + ".", ""); + + if (params != null) { + params.put(paramsName, modelOutputValuePerDoc); + ((MLInferenceSearchResponse) response).setParams(params); } else { - logger - .debug( - "{} already exists in the search response hit. Skip processing this field.", - newDocumentFieldName - ); - // TODO when the response has the same field name, should it throw exception? currently, - // ingest processor quietly skip it + Map newParams = new HashMap<>(); + newParams.put(paramsName, modelOutputValuePerDoc); + ((MLInferenceSearchResponse) response).setParams(newParams); } } else { - sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + // writing to search response hits + if (sourceAsMap.containsKey(newDocumentFieldName)) { + if (override) { + sourceAsMapWithInference.remove(newDocumentFieldName); + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } else { + logger + .debug( + "{} already exists in the search response hit. Skip processing this field.", + newDocumentFieldName + ); + // TODO when the response has the same field name, should it throw exception? currently, + // ingest processor quietly skip it + } + } else { + sourceAsMapWithInference.put(newDocumentFieldName, modelOutputValuePerDoc); + } } } } @@ -774,6 +814,19 @@ public MLInferenceSearchResponseProcessor create( + ". Please adjust mappings." ); } + boolean writeToSearchExtension = false; + + if (outputMaps != null) { + writeToSearchExtension = outputMaps + .stream() + .filter(Objects::nonNull) // To avoid potential NullPointerExceptions from null outputMaps + .flatMap(outputMap -> outputMap.keySet().stream()) + .anyMatch(key -> key.startsWith(EXTENSION_PREFIX)); + } + + if (writeToSearchExtension & oneToOne) { + throw new IllegalArgumentException("Write model response to search extension does not support when one_to_one is true."); + } return new MLInferenceSearchResponseProcessor( modelId, diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 8f04cab9d4..f462408943 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -21,6 +21,7 @@ import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FULL_RESPONSE_PATH; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.FUNCTION_NAME; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.MODEL_INPUT; +import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.ONE_TO_ONE; import static org.opensearch.ml.processor.MLInferenceSearchResponseProcessor.TYPE; import java.util.ArrayList; @@ -33,6 +34,7 @@ import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; @@ -60,6 +62,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.SearchModule; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.test.AbstractBuilderTestCase; @@ -85,6 +88,7 @@ public void setup() { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseException() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( @@ -113,6 +117,7 @@ public void testProcessResponseException() throws Exception { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseSuccess() throws Exception { String modelInputField = "inputs"; String originalDocumentField = "text"; @@ -172,6 +177,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { String newDocumentField = "context"; @@ -261,6 +267,7 @@ public void onFailure(Exception e) { * with many to one prediction, 5 documents in hits are calling 1 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { String documentField = "text"; @@ -358,6 +365,7 @@ public void onFailure(Exception e) { * with full response path false and no output mapping is provided * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPromptFullResponsePathFalse() throws Exception { String documentField = "text"; @@ -434,6 +442,7 @@ public void onFailure(Exception e) { * with full response path true and no output mapping is provided * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseManyToOneWithCustomPromptFullResponsePathTrue() throws Exception { String documentField = "text"; @@ -503,12 +512,89 @@ public void onFailure(Exception e) { verify(client, times(1)).execute(any(), any(), any()); } + /** + * Tests the successful processing of a response with a single pair of input and output mappings. + * read the query text into model config + * with query extensions + * @throws Exception if an error occurs during the test + */ + @Test + public void testProcessResponseSuccessWriteToExt() throws Exception { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.llm_response"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(1)).execute(any(), any(), any()); + } + /** * Tests create processor with one_to_one is true * with no mapping provided * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithNoMappings() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( @@ -589,6 +675,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithEmptyMappings() throws Exception { List> outputMap = new ArrayList<>(); List> inputMap = new ArrayList<>(); @@ -670,6 +757,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappings() throws Exception { String newDocumentField = "text_embedding"; @@ -759,6 +847,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerFail() throws Exception { String newDocumentField = "text_embedding"; @@ -820,6 +909,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerException() throws Exception { String newDocumentField = "text_embedding"; @@ -876,6 +966,7 @@ public void onFailure(Exception e) { * when there is one document and ignoreFailure, should return the original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsCombineResponseListenerExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -932,6 +1023,7 @@ public void onFailure(Exception e) { * when there is one document and ignoreFailure, should return the original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -978,14 +1070,18 @@ public void testProcessResponseCreateRewriteResponseListenerExceptionIgnoreFailu SearchResponse mockResponse = mock(SearchResponse.class); SearchHits searchHits = response.getHits(); + + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, null, 1); + when(mockResponse.getInternalResponse()).thenReturn(internalSearchResponse); + RuntimeException mockException = new RuntimeException("Mock exception"); AtomicInteger callCount = new AtomicInteger(0); - ; + when(mockResponse.getHits()).thenAnswer(invocation -> { int count = callCount.getAndIncrement(); - if (count == 2) { + if (count == 6) { // throw exception when it reaches createRewriteResponseListener throw mockException; } else { @@ -1011,13 +1107,14 @@ public void onFailure(Exception e) { } /** - * Tests create processor with one_to_one is true + * Tests create processor with one_to_one is false * with output_maps * createRewriteResponseListener throw Exceptions * expect to run one prediction task * createRewriteResponseListener should reach on Failure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseCreateRewriteResponseListenerException() throws Exception { String newDocumentField = "text_embedding"; @@ -1066,7 +1163,10 @@ public void testProcessResponseCreateRewriteResponseListenerException() throws E SearchHits searchHits = response.getHits(); RuntimeException mockException = new RuntimeException("Mock exception"); AtomicInteger callCount = new AtomicInteger(0); - ; + + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, null, 1); + when(mockResponse.getInternalResponse()).thenReturn(internalSearchResponse); + when(mockResponse.getHits()).thenAnswer(invocation -> { int count = callCount.getAndIncrement(); @@ -1101,6 +1201,7 @@ public void onFailure(Exception e) { * test throwing OpenSearchStatusException * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOpenSearchStatusException() throws Exception { String newDocumentField = "text_embedding"; @@ -1184,6 +1285,7 @@ public void onFailure(Exception e) { * test throwing MLResourceNotFoundException * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseMLResourceNotFoundException() throws Exception { String newDocumentField = "text_embedding"; @@ -1269,6 +1371,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1330,6 +1433,7 @@ public void onFailure(Exception e) { * when there is one document, the combinedResponseListener calls onFailure * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsMLTaskResponseExceptionIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1392,6 +1496,7 @@ public void onFailure(Exception e) { * expect to run one prediction task and the rest 4 predictions tasks are not created * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictException() throws Exception { String newDocumentField = "text_embedding"; @@ -1448,6 +1553,7 @@ public void onFailure(Exception e) { * expect to run one prediction task and the rest 4 predictions tasks are not created * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictFail() throws Exception { String newDocumentField = "text_embedding"; @@ -1510,6 +1616,7 @@ public void onFailure(Exception e) { * then return original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneWithOutputMappingsPredictFailIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -1569,6 +1676,7 @@ public void onFailure(Exception e) { * with one to one prediction, 5 documents in hits are calling 10 prediction tasks * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictions() throws Exception { String modelInputField = "inputs"; @@ -1699,6 +1807,7 @@ public void onFailure(Exception e) { * expect to throw exception without further processing * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneException() throws Exception { String modelInputField = "inputs"; @@ -1799,6 +1908,7 @@ public void onFailure(Exception e) { * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreMissing() throws Exception { String modelInputField = "inputs"; @@ -1909,6 +2019,7 @@ public void onFailure(Exception e) { * expect to return document with second round prediction results * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneTwoRoundsPredictionsOneExceptionIgnoreFailure() throws Exception { String modelInputField = "inputs"; @@ -2002,6 +2113,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseNoMappingSuccess() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2079,6 +2191,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseEmptyMappingSuccess() throws Exception { List> inputMap = new ArrayList<>(); Map input = new HashMap<>(); @@ -2159,6 +2272,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsSuccess() throws Exception { /** * sample response before inference @@ -2246,6 +2360,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOverrideSameField() throws Exception { /** * sample response before inference @@ -2332,6 +2447,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOverrideSameFieldFalse() throws Exception { /** * sample response before inference @@ -2420,6 +2536,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsMissingOneInputIgnoreMissingSuccess() throws Exception { /** * sample response before inference @@ -2502,6 +2619,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseListOfEmbeddingsMissingOneInputException() throws Exception { /** * sample response before inference @@ -2586,6 +2704,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseTwoRoundsOfPredictionSuccess() throws Exception { String modelInputField = "inputs"; String modelOutputField = "response"; @@ -2683,6 +2802,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneModelInputMultipleModelOutputs() throws Exception { // one model input String modelInputField = "inputs"; @@ -2769,6 +2889,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionException() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2814,6 +2935,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionFailed() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2864,6 +2986,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponsePredictionExceptionIgnoreFailure() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2914,6 +3037,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseEmptyHit() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -2958,6 +3082,7 @@ public void onFailure(Exception e) { * * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseHitWithNoSource() throws Exception { MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( "model1", @@ -3003,6 +3128,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits to be one Hit Response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneMadeOneHitResponseExceptions() throws Exception { String newDocumentField = "text_embedding"; @@ -3070,6 +3196,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits and ignoreFailure return original response * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneMadeOneHitResponseExceptionsIgnoreFailure() throws Exception { String newDocumentField = "text_embedding"; @@ -3136,6 +3263,7 @@ public void onFailure(Exception e) { * Exceptions happen when replaceHits * @throws Exception if an error occurs during the test */ + @Test public void testProcessResponseOneToOneCombinedHitsExceptions() throws Exception { String newDocumentField = "text_embedding"; @@ -3199,6 +3327,246 @@ public void onFailure(Exception e) { responseProcessor.processResponseAsync(request, mockResponse, responseContext, listener); } + /** + * Tests the processResponseAsync method when the input is a regular SearchResponse. + * + * This test verifies that when a regular SearchResponse is passed to the method, + * it attempts to create a new MLInferenceSearchResponse object. + */ + @Test + public void testProcessResponseAsync_WithRegularSearchResponse() { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + assertEquals(responseAfterProcessor.getParams(), params); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + + /** + * Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse. + * + * This test verifies that when an MLInferenceSearchResponse is passed to the method, + * and the params is being passed over + */ + @Test + public void testProcessResponseAsync_WithMLInferenceSearchResponse() { + String modelInputField = "inputs"; + String originalDocumentField = "text"; + String newDocumentField = "text_embedding"; + String modelOutputField = "response"; + + SearchResponse response = getSearchResponse(5, true, originalDocumentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + MLInferenceSearchResponseProcessor responseProcessor = getMlInferenceSearchResponseProcessorSinglePairMapping( + modelOutputField, + modelInputField, + originalDocumentField, + newDocumentField, + false, + false, + false + ); + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap(ImmutableMap.of("response", Arrays.asList(0.0, 1.0, 2.0, 3.0, 4.0))) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + assertEquals(responseAfterProcessor.getHits().getHits()[0].getSourceAsMap().get("text_embedding"), 0.0); + assertEquals(responseAfterProcessor.getHits().getHits()[1].getSourceAsMap().get("text_embedding"), 1.0); + assertEquals(responseAfterProcessor.getHits().getHits()[2].getSourceAsMap().get("text_embedding"), 2.0); + assertEquals(responseAfterProcessor.getHits().getHits()[3].getSourceAsMap().get("text_embedding"), 3.0); + assertEquals(responseAfterProcessor.getHits().getHits()[4].getSourceAsMap().get("text_embedding"), 4.0); + assertEquals(responseAfterProcessor.getParams(), params); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + + /** + * Tests the processResponseAsync method when the input is already an MLInferenceSearchResponse. + * + * This test verifies that when an MLInferenceSearchResponse is passed to the method, + * and the params is being passed over and new params is added + */ + @Test + public void testProcessResponseAsync_WriteExtensionToMLInferenceSearchResponse() { + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "ext.ml_inference.summary"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, + false + ); + SearchResponse response = getSearchResponse(5, true, documentField); + Map params = new HashMap<>(); + params.put("llm_response", "answer"); + MLInferenceSearchResponse mLInferenceSearchResponse = new MLInferenceSearchResponse( + params, + response.getInternalResponse(), + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getSuccessfulShards(), + response.getShardFailures(), + response.getClusters() + ); + + SearchRequest request = getSearchRequest(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + MLInferenceSearchResponse responseAfterProcessor = (MLInferenceSearchResponse) newSearchResponse; + assertEquals(responseAfterProcessor.getHits().getHits().length, 5); + Map newParams = new HashMap<>(); + newParams.put("llm_response", "answer"); + newParams.put("summary", "there is 1 value"); + assertEquals(responseAfterProcessor.getParams(), newParams); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + }; + + responseProcessor.processResponseAsync(request, mLInferenceSearchResponse, responseContext, listener); + + } + private static SearchRequest getSearchRequest() { QueryBuilder incomingQuery = new TermQueryBuilder("text", "foo"); SearchSourceBuilder source = new SearchSourceBuilder().query(incomingQuery); @@ -3538,7 +3906,7 @@ public void testOutputMapsExceedInputMaps() throws Exception { output2.put("hashtag_embedding", "response"); outputMap.add(output2); Map output3 = new HashMap<>(); - output2.put("hashtvg_embedding", "response"); + output3.put("hashtvg_embedding", "response"); outputMap.add(output3); config.put(OUTPUT_MAP, outputMap); config.put(MAX_PREDICTION_TASKS, 2); @@ -3587,4 +3955,40 @@ public void testCreateOptionalFields() throws Exception { assertEquals(MLInferenceSearchResponseProcessor.getTag(), processorTag); assertEquals(MLInferenceSearchResponseProcessor.getType(), MLInferenceSearchResponseProcessor.TYPE); } + + /** + * Tests the case where output map try to write to extension and one to one inference is true + * and an exception is expected. + * + * @throws Exception if an error occurs during the test + */ + public void testWriteToExtensionAndOneToOne() throws Exception { + Map config = new HashMap<>(); + config.put(MODEL_ID, "model2"); + List> inputMap = new ArrayList<>(); + Map input0 = new HashMap<>(); + input0.put("inputs", "text"); + inputMap.add(input0); + Map input1 = new HashMap<>(); + input1.put("inputs", "hashtag"); + inputMap.add(input1); + config.put(INPUT_MAP, inputMap); + List> outputMap = new ArrayList<>(); + Map output1 = new HashMap<>(); + output1.put("text_embedding", "response"); + outputMap.add(output1); + Map output2 = new HashMap<>(); + output2.put("ext.inference.hashtag_embedding", "response"); + outputMap.add(output2); + config.put(OUTPUT_MAP, outputMap); + config.put(ONE_TO_ONE, true); + String processorTag = randomAlphaOfLength(10); + + try { + factory.create(Collections.emptyMap(), processorTag, null, false, config, null); + } catch (IllegalArgumentException e) { + assertEquals(e.getMessage(), ""); + + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java new file mode 100644 index 0000000000..a50467e261 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseTests.java @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.ml.processor; + +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.io.OutputStream; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentGenerator; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.test.OpenSearchTestCase; + +public class MLInferenceSearchResponseTests extends OpenSearchTestCase { + + @Rule + public ExpectedException exceptionRule = ExpectedException.none(); + + /** + * Tests the toXContent method of MLInferenceSearchResponse with non-null parameters. + * This test ensures that the method correctly serializes the response when parameters are present. + * + * @throws IOException if an I/O error occurs during the test + */ + @Test + public void testToXContent() throws IOException { + Map params = new HashMap<>(); + params.put("key1", "value1"); + params.put("key2", "value2"); + + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + params, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } + + /** + * Tests the toXContent method of MLInferenceSearchResponse with null parameters. + * This test verifies that the method handles null parameters correctly during serialization. + * + * @throws IOException if an I/O error occurs during the test + */ + @Test + public void testToXContentWithNullParams() throws IOException { + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + XContent xc = mock(XContent.class); + OutputStream os = mock(OutputStream.class); + XContentGenerator generator = mock(XContentGenerator.class); + when(xc.createGenerator(any(), any(), any())).thenReturn(generator); + XContentBuilder builder = new XContentBuilder(xc, os); + XContentBuilder actual = searchResponse.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertNotNull(actual); + } + + /** + * Tests the getParams method of MLInferenceSearchResponse. + * This test ensures that the method correctly returns the parameters that were set during object creation. + */ + @Test + public void testGetParams() { + Map params = new HashMap<>(); + params.put("key1", "value1"); + params.put("key2", "value2"); + + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + params, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + assertEquals(params, searchResponse.getParams()); + } + + /** + * Tests the setParams method of MLInferenceSearchResponse. + * This test verifies that the method correctly updates the parameters of the response object. + */ + @Test + public void testSetParams() { + SearchResponseSections internal = new SearchResponseSections( + new SearchHits(new SearchHit[0], null, 0), + null, + null, + false, + false, + null, + 0 + ); + MLInferenceSearchResponse searchResponse = new MLInferenceSearchResponse( + null, + internal, + null, + 0, + 0, + 0, + 0, + new ShardSearchFailure[0], + MLInferenceSearchResponse.Clusters.EMPTY + ); + + Map newParams = new HashMap<>(); + newParams.put("key3", "value3"); + searchResponse.setParams(newParams); + + assertEquals(newParams, searchResponse.getParams()); + } +}