From f172279432eedf2d05584a402f0f6bdf4cc8fb6e Mon Sep 17 00:00:00 2001 From: HenryL27 Date: Thu, 16 Nov 2023 15:41:22 -0800 Subject: [PATCH] add unittests Signed-off-by: HenryL27 --- .../factory/RerankProcessorFactory.java | 11 +- .../rerank/CrossEncoderRerankProcessor.java | 26 +- .../processor/rerank/RerankType.java | 10 +- .../rerank/RescoringRerankProcessor.java | 7 +- .../query/ext/RerankSearchExtBuilder.java | 9 +- .../common/BaseNeuralSearchIT.java | 17 + .../ml/MLCommonsClientAccessorTests.java | 82 +++++ .../factory/RerankProcessorFactoryTests.java | 144 +++++++++ .../CrossEncoderRerankProcessorTests.java | 304 ++++++++++++++++++ .../ext/RerankSearchExtBuilderTests.java | 102 ++++++ 10 files changed, 691 insertions(+), 21 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index ed1d56b4b..03e7c8154 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -27,6 +27,8 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +import com.google.common.annotations.VisibleForTesting; + @AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { @@ -49,14 +51,21 @@ public SearchResponseProcessor create( @SuppressWarnings("unchecked") Map rerankerConfig = (Map) config.get(type.getLabel()); String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); + if (modelId == null) { + throw new IllegalArgumentException(CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified"); + } String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); + if (rerankContext == null) { + throw new IllegalArgumentException(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified"); + } return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); default: throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel()); } } - private RerankType findRerankType(final Map config) throws IllegalArgumentException { + @VisibleForTesting + RerankType findRerankType(final Map config) throws IllegalArgumentException { for (String key : config.keySet()) { try { RerankType attempt = RerankType.from(key); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index 61193ea36..3c60e6570 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -17,13 +17,15 @@ */ package org.opensearch.neuralsearch.processor.rerank; -import java.io.PipedInputStream; -import java.io.PipedOutputStream; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import lombok.extern.log4j.Log4j2; + import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.common.xcontent.XContentType; @@ -38,6 +40,7 @@ import org.opensearch.search.SearchExtBuilder; import org.opensearch.search.SearchHit; +@Log4j2 public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { public static final String MODEL_ID_FIELD = "model_id"; @@ -76,26 +79,29 @@ public void generateScoringContext( Map scoringContext = new HashMap<>(); if (params.containsKey(QUERY_TEXT_FIELD)) { if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { - throw new IllegalArgumentException("Cannot specify both \"query_text\" and \"query_text_path\""); + throw new IllegalArgumentException( + "Cannot specify both \"" + QUERY_TEXT_FIELD + "\" and \"" + QUERY_TEXT_PATH_FIELD + "\"" + ); } scoringContext.put(QUERY_TEXT_FIELD, (String) params.get(QUERY_TEXT_FIELD)); } else if (params.containsKey(QUERY_TEXT_PATH_FIELD)) { String path = (String) params.get(QUERY_TEXT_PATH_FIELD); // Convert query to a map with io/xcontent shenanigans - PipedOutputStream os = new PipedOutputStream(); - XContentBuilder builder = XContentType.CBOR.contentBuilder(os); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + XContentBuilder builder = XContentType.CBOR.contentBuilder(baos); searchRequest.source().toXContent(builder, ToXContent.EMPTY_PARAMS); - PipedInputStream is = new PipedInputStream(os); - XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, is); + builder.close(); + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + XContentParser parser = XContentType.CBOR.xContent().createParser(NamedXContentRegistry.EMPTY, null, bais); Map map = parser.map(); // Get the text at the path Object queryText = ObjectPath.eval(path, map); if (!(queryText instanceof String)) { - throw new IllegalArgumentException("query_text_path must point to a string field"); + throw new IllegalArgumentException(QUERY_TEXT_PATH_FIELD + " must point to a string field"); } scoringContext.put(QUERY_TEXT_FIELD, (String) queryText); } else { - throw new IllegalArgumentException("Must specify either \"query_text\" or \"query_text_path\""); + throw new IllegalArgumentException("Must specify either \"" + QUERY_TEXT_FIELD + "\" or \"" + QUERY_TEXT_PATH_FIELD + "\""); } listener.onResponse(scoringContext); } catch (Exception e) { @@ -115,7 +121,7 @@ public void rescoreSearchResponse(SearchResponse response, Map s private String contextFromSearchHit(final SearchHit hit) { if (hit.getFields().containsKey(this.rerankContext)) { return (String) hit.field(this.rerankContext).getValue(); - } else if (hit.getSourceAsMap().containsKey(this.rerankContext)) { + } else if (hit.hasSource() && hit.getSourceAsMap().containsKey(this.rerankContext)) { return (String) hit.getSourceAsMap().get(this.rerankContext); } else { return null; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index 6bfb9feed..e474c4b11 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -17,6 +17,9 @@ */ package org.opensearch.neuralsearch.processor.rerank; +import java.util.Arrays; +import java.util.Optional; + import lombok.Getter; /** @@ -39,9 +42,10 @@ private RerankType(String label) { * @return RerankType represented by the label */ public static RerankType from(String label) { - try { - return RerankType.valueOf(label); - } catch (Exception e) { + Optional typeMaybe = Arrays.stream(RerankType.values()).filter(rrt -> rrt.label.equals(label)).findFirst(); + if (typeMaybe.isPresent()) { + return typeMaybe.get(); + } else { throw new IllegalArgumentException("Wrong rerank type name: " + label); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index 907c26c5d..c1479b2c9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -84,7 +84,9 @@ public void rerank(SearchResponse searchResponse, Map scoringCon rescoreSearchResponse(searchResponse, scoringContext, ActionListener.wrap(scores -> { // Assign new scores SearchHit[] hits = searchResponse.getHits().getHits(); - assert (hits.length == scores.size()); + if (hits.length != scores.size()) { + throw new Exception("scores and hits are not the same length"); + } for (int i = 0; i < hits.length; i++) { hits[i].score(scores.get(i)); } @@ -92,7 +94,8 @@ public void rerank(SearchResponse searchResponse, Map scoringCon Collections.sort(Arrays.asList(hits), new Comparator() { @Override public int compare(SearchHit hit1, SearchHit hit2) { - return Float.compare(hit1.getScore(), hit2.getScore()); + // backwards to sort DESC + return Float.compare(hit2.getScore(), hit1.getScore()); } }); // Reconstruct the search response, replacing the max score diff --git a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java index ad3756aa8..56623f768 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilder.java @@ -55,10 +55,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(PARAM_FIELD_NAME, this.params); - builder.endObject(); - return builder; + return builder.field(PARAM_FIELD_NAME, this.params); } @Override @@ -92,7 +89,9 @@ public static RerankSearchExtBuilder fromExtBuilderList(List b * @throws IOException if problems parsing */ public static RerankSearchExtBuilder parse(XContentParser parser) throws IOException { - return new RerankSearchExtBuilder(parser.map()); + @SuppressWarnings("unchecked") + RerankSearchExtBuilder ans = new RerankSearchExtBuilder((Map) parser.map().get(PARAM_FIELD_NAME)); + return ans; } } diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 33cdff9a0..91bd1dd04 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -315,6 +315,23 @@ protected void createSearchRequestProcessor(String modelId, String pipelineName) assertEquals("true", node.get("acknowledged").toString()); } + protected void createSearchPipelineViaConfig(String modelId, String pipelineName, String configPath) throws Exception { + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity( + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource(configPath).toURI())), + modelId + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } + /** * Get the number of documents in a particular index * diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 68d5f79eb..32d382e4f 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -326,6 +326,71 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } + public void testInferenceSimilarity_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createManyModelTensorOutputs(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferencesSimilarity_whenExceptionFromMLClient_ThenFail() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceSimilarity_whenNodeNotConnectedException_ThenTryThreeTimes() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSimilarity( + TestCommonConstants.MODEL_ID, + "is it sunny", + List.of("it is sunny today", "roses are red"), + singleSentenceResultListener + ); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -353,4 +418,21 @@ private ModelTensorOutput createModelTensorOutput(final Map map) tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createManyModelTensorOutputs(final Float[] output) { + final List tensorsList = new ArrayList<>(); + for (Float score : output) { + List tensorList = new ArrayList<>(); + String name = "logits"; + Number[] data = new Number[] { score }; + long[] shape = new long[] { 1 }; + MLResultDataType dataType = MLResultDataType.FLOAT32; + MLResultDataType mlResultDataType = MLResultDataType.valueOf(dataType.name()); + ModelTensor tensor = ModelTensor.builder().name(name).data(data).shape(shape).dataType(mlResultDataType).build(); + tensorList.add(tensor); + tensorsList.add(new ModelTensors(tensorList)); + } + ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + return modelTensorOutput; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java new file mode 100644 index 000000000..b663b3c93 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -0,0 +1,144 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; + +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.mockito.Mock; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class RerankProcessorFactoryTests extends OpenSearchTestCase { + + final String TAG = "default-tag"; + final String DESC = "processor description"; + + RerankProcessorFactory factory; + + @Mock + MLCommonsClientAccessor clientAccessor; + + @Mock + PipelineContext pipelineContext; + + @Before + public void setup() { + pipelineContext = mock(PipelineContext.class); + clientAccessor = mock(MLCommonsClientAccessor.class); + factory = new RerankProcessorFactory(clientAccessor); + } + + public void testRerankProcessorFactory_EmptyConfig_ThenFail() { + Map config = Map.of(); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_NonExistentType_ThenFail() { + Map config = Map.of("jpeo rvgh we iorgn", Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id")); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_HappyPath() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testRerankProcessorFactory_CrossEncoder_MessyConfig_ThenHappy() { + Map config = Map.of( + "poafn aorr;anv", + Map.of(";oawhls", "aowirhg "), + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation", + "pqiohg rpowierhg", + "pw;oith4pt3ih go" + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof CrossEncoderRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testRerankProcessorFactory_CrossEncoder_EmptySubConfig_ThenFail() { + Map config = Map.of(RerankType.CROSS_ENCODER.getLabel(), Map.of()); + assertThrows( + CrossEncoderRerankProcessor.MODEL_ID_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_NoContextField_ThenFail() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of(CrossEncoderRerankProcessor.MODEL_ID_FIELD, "model-id") + ); + assertThrows( + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + + public void testRerankProcessorFactory_CrossEncoder_NoModelId_ThenFail() { + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, "text_representation") + ); + assertThrows( + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD + " must be specified", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java new file mode 100644 index 000000000..5bbb5c38c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessorTests.java @@ -0,0 +1,304 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.processor.rerank; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyList; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponse.Clusters; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.xcontent.json.JsonXContent; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.Processor.PipelineContext; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class CrossEncoderRerankProcessorTests extends OpenSearchTestCase { + + @Mock + SearchRequest request; + + SearchResponse response; + + @Mock + MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + PipelineContext pipelineContext; + + RerankProcessorFactory factory; + + CrossEncoderRerankProcessor processor; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + factory = new RerankProcessorFactory(mlCommonsClientAccessor); + Map config = Map.of( + RerankType.CROSS_ENCODER.getLabel(), + Map.of( + CrossEncoderRerankProcessor.MODEL_ID_FIELD, + "model-id", + CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD, + "text_representation" + ) + ); + processor = (CrossEncoderRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for reranking with a cross encoder", + false, + config, + pipelineContext + ); + } + + private void setupParams(Map params) { + SearchSourceBuilder ssb = new SearchSourceBuilder(); + NeuralQueryBuilder nqb = new NeuralQueryBuilder(); + nqb.fieldName("embedding").k(3).modelId("embedding_id").queryText("Question about dolphins"); + ssb.query(nqb); + List exts = List.of(new RerankSearchExtBuilder(params)); + ssb.ext(exts); + doReturn(ssb).when(request).source(); + } + + private void setupSimilarityRescoring() { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f, 3f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + } + + private void setupSearchResults() throws IOException { + XContentBuilder sourceContent = JsonXContent.contentBuilder() + .startObject() + .field("text_representation", "source passage") + .endObject(); + SearchHit sourceHit = new SearchHit(0, "0", Map.of(), Map.of()); + sourceHit.sourceRef(BytesReference.bytes(sourceContent)); + sourceHit.score(1.5f); + + DocumentField field = new DocumentField("text_representation", List.of("field passage")); + SearchHit fieldHit = new SearchHit(1, "1", Map.of("text_representation", field), Map.of()); + fieldHit.score(1.7f); + + SearchHit nullHit = new SearchHit(2, "2", Map.of(), Map.of()); + nullHit.score(0f); + + SearchHit[] hitArray = new SearchHit[] { fieldHit, sourceHit, nullHit }; + + SearchHits searchHits = new SearchHits(hitArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new Clusters(1, 1, 0), null); + } + + public void testScoringContext_QueryText_ThenSucceed() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("query text")); + } + + public void testScoringContext_QueryTextPath_ThenSucceed() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding.query_text")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(Map.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().containsKey(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD)); + assert (argCaptor.getValue().get(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD).equals("Question about dolphins")); + } + + public void testScoringContext_QueryTextAndPath_ThenFail() { + setupParams( + Map.of( + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, + "query.neural.embedding.query_text", + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, + "query text" + ) + ); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Cannot specify both \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + "\" and \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testScoringContext_NoQueryInfo_ThenFail() { + setupParams(Map.of()); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals( + "Must specify either \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_FIELD + + "\" or \"" + + CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + + "\"" + )); + } + + public void testScoringContext_QueryTextPath_BadPointer_ThenFail() { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD, "query.neural.embedding")); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.generateScoringContext(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue() instanceof IllegalArgumentException); + assert (argCaptor.getValue() + .getMessage() + .equals(CrossEncoderRerankProcessor.QUERY_TEXT_PATH_FIELD + " must point to a string field")); + } + + public void testRescoreSearchResponse_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rescoreSearchResponse(response, scoringContext, listener); + @SuppressWarnings("unchecked") + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + assert (argCaptor.getValue().size() == 3); + assert (argCaptor.getValue().get(0) == 1f); + assert (argCaptor.getValue().get(1) == 2f); + assert (argCaptor.getValue().get(2) == 3f); + } + + public void testRerank_HappyPath() throws IOException { + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } + + public void testRerank_ScoresAndHitsHaveDiffLengths() throws IOException { + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(3); + List scores = List.of(1f, 2f); + listener.onResponse(scores); + return null; + }).when(mlCommonsClientAccessor).inferenceSimilarity(anyString(), anyString(), anyList(), any()); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + Map scoringContext = Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text"); + processor.rerank(response, scoringContext, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argCaptor.capture()); + assert (argCaptor.getValue().getMessage().equals("scores and hits are not the same length")); + } + + public void testBasics() throws IOException { + assert (processor.getTag().equals("rerank processor")); + assert (processor.getDescription().equals("processor for reranking with a cross encoder")); + assert (!processor.isIgnoreFailure()); + assertThrows( + "Use asyncProcessResponse unless you can guarantee to not deadlock yourself", + UnsupportedOperationException.class, + () -> processor.processResponse(request, response) + ); + } + + public void testProcessResponseAsync() throws IOException { + setupParams(Map.of(CrossEncoderRerankProcessor.QUERY_TEXT_FIELD, "query text")); + setupSimilarityRescoring(); + setupSearchResults(); + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + processor.processResponseAsync(request, response, listener); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse rsp = argCaptor.getValue(); + assert (rsp.getHits().getAt(0).docId() == 2); + assert (rsp.getHits().getAt(0).getScore() == 3f); + assert (rsp.getHits().getAt(1).docId() == 0); + assert (rsp.getHits().getAt(1).getScore() == 2f); + assert (rsp.getHits().getAt(2).docId() == 1); + assert (rsp.getHits().getAt(2).getScore() == 1f); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java new file mode 100644 index 000000000..f6a22b675 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/ext/RerankSearchExtBuilderTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2023 Aryn + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.opensearch.neuralsearch.query.ext; + +import static org.mockito.Mockito.mock; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.junit.Before; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.BytesStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.search.SearchExtBuilder; +import org.opensearch.test.OpenSearchTestCase; + +@Log4j2 +public class RerankSearchExtBuilderTests extends OpenSearchTestCase { + + Map params; + + @Before + public void setup() { + params = Map.of("query_text", "question about the meaning of life, the universe, and everything"); + } + + public void testStreaming() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + BytesStreamOutput outbytes = new BytesStreamOutput(); + StreamOutput osso = new OutputStreamStreamOutput(outbytes); + b1.writeTo(osso); + StreamInput in = new BytesStreamInput(BytesReference.toBytes(outbytes.bytes())); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(in); + assert (b2.getParams().equals(params)); + assert (b1.equals(b2)); + } + + public void testToXContent() throws IOException { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + XContentBuilder builder = XContentType.JSON.contentBuilder(); + builder.startObject(); + b1.toXContent(builder, ToXContentObject.EMPTY_PARAMS); + builder.endObject(); + String extString = builder.toString(); + log.info(extString); + XContentParser parser = this.createParser(XContentType.JSON.xContent(), extString); + RerankSearchExtBuilder b2 = RerankSearchExtBuilder.parse(parser); + assert (b2.getParams().equals(params)); + assert (b1.equals(b2)); + } + + public void testPullFromListOfExtBuilders() { + RerankSearchExtBuilder builder = new RerankSearchExtBuilder(params); + SearchExtBuilder otherBuilder = mock(SearchExtBuilder.class); + assert (!builder.equals(otherBuilder)); + List builders1 = List.of(otherBuilder, builder); + List builders2 = List.of(otherBuilder); + List builders3 = List.of(); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders1).equals(builder)); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders2) == null); + assert (RerankSearchExtBuilder.fromExtBuilderList(builders3) == null); + } + + public void testHash() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b2 = new RerankSearchExtBuilder(params); + RerankSearchExtBuilder b3 = new RerankSearchExtBuilder(Map.of()); + assert (b1.hashCode() == b2.hashCode()); + assert (b1.hashCode() != b3.hashCode()); + assert (!b1.equals(b3)); + } + + public void testWriteableName() { + RerankSearchExtBuilder b1 = new RerankSearchExtBuilder(params); + assert (b1.getWriteableName().equals(RerankSearchExtBuilder.PARAM_FIELD_NAME)); + } +}