From a46beb19f4a862878a4638773d2e5921b63238a9 Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Tue, 22 Oct 2024 11:36:47 -0700 Subject: [PATCH] ByFieldRerank Processor (ReRankProcessor enhancement) (#932) * Implements initial By Field re rank Signed-off-by: Brian Flores --- DEVELOPER_GUIDE.md | 2 +- ...ch-neural-search.release-notes-2.18.0.0.md | 3 + .../factory/RerankProcessorFactory.java | 74 +- .../rerank/ByFieldRerankProcessor.java | 189 +++ .../processor/rerank/RerankProcessor.java | 21 + .../processor/rerank/RerankType.java | 3 +- .../processor/util/ProcessorUtils.java | 183 +++ .../factory/RerankProcessorFactoryTests.java | 72 ++ .../rerank/ByFieldRerankProcessorIT.java | 201 +++ .../rerank/ByFieldRerankProcessorTests.java | 1075 +++++++++++++++++ .../util/ProcessorUtilsTests.java | 216 ++++ .../ReRankByFieldPipelineConfiguration.json | 14 + 12 files changed, 2033 insertions(+), 20 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java create mode 100644 src/test/resources/processor/ReRankByFieldPipelineConfiguration.json diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 37d78f9b1..5ee672da5 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -313,7 +313,7 @@ merged to main, the workflow will create a backport PR to the `2.x` branch. ## Building On Lucene Version Updates There may be a Lucene version update that can affect your workflow causing errors like -`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or +`java.lang.NoClassDefFoundError: org/apache/lucene/codecs/lucene99/Lucene99Codec` or `Provider org.opensearch.knn.index.codec.KNN910Codec.KNN910Codec could not be instantiated`. In this case we can observe there may be an issue with a dependency with [K-NN](https://github.com/opensearch-project/k-NN). This results in having issues with not being able to do `./gradlew run` or `./gradlew build`. diff --git a/release-notes/opensearch-neural-search.release-notes-2.18.0.0.md b/release-notes/opensearch-neural-search.release-notes-2.18.0.0.md index 298bd704b..2b6ef6a66 100644 --- a/release-notes/opensearch-neural-search.release-notes-2.18.0.0.md +++ b/release-notes/opensearch-neural-search.release-notes-2.18.0.0.md @@ -3,6 +3,9 @@ Compatible with OpenSearch 2.18.0 +### Features +- Introduces ByFieldRerankProcessor for second level reranking on documents ([#932](https://github.com/opensearch-project/neural-search/pull/932)) + ### Enhancements - Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907)) - Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917)) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 9b9715df5..e88892359 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -4,16 +4,12 @@ */ package org.opensearch.neuralsearch.processor.factory; -import java.util.ArrayList; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Set; -import java.util.StringJoiner; - +import com.google.common.collect.Sets; +import lombok.AllArgsConstructor; import org.opensearch.cluster.service.ClusterService; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; @@ -22,9 +18,17 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; -import com.google.common.collect.Sets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.StringJoiner; -import lombok.AllArgsConstructor; +import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_KEEP_PREVIOUS_SCORE; +import static org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor.DEFAULT_REMOVE_TARGET_FIELD; +import static org.opensearch.neuralsearch.processor.rerank.RerankProcessor.processorRequiresContext; /** * Factory for rerank processors. Must: @@ -51,15 +55,17 @@ public SearchResponseProcessor create( ) { RerankType type = findRerankType(config); boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type); - List contextFetchers = ContextFetcherFactory.createFetchers( - config, - includeQueryContextFetcher, - tag, - clusterService - ); + + // Currently the createFetchers method requires that you provide a context map, this branch makes sure we can ignore this on + // processors that don't need the context map + List contextFetchers = processorRequiresContext(type) + ? ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag, clusterService) + : Collections.emptyList(); + + Map rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); + switch (type) { case ML_OPENSEARCH: - Map rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel()); String modelId = ConfigurationUtils.readStringProperty( RERANK_PROCESSOR_TYPE, tag, @@ -67,6 +73,37 @@ public SearchResponseProcessor create( MLOpenSearchRerankProcessor.MODEL_ID_FIELD ); return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor); + case BY_FIELD: + String targetField = ConfigurationUtils.readStringProperty( + RERANK_PROCESSOR_TYPE, + tag, + rerankerConfig, + ByFieldRerankProcessor.TARGET_FIELD + ); + boolean removeTargetField = ConfigurationUtils.readBooleanProperty( + RERANK_PROCESSOR_TYPE, + tag, + rerankerConfig, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + DEFAULT_REMOVE_TARGET_FIELD + ); + boolean keepPreviousScore = ConfigurationUtils.readBooleanProperty( + RERANK_PROCESSOR_TYPE, + tag, + rerankerConfig, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + DEFAULT_KEEP_PREVIOUS_SCORE + ); + + return new ByFieldRerankProcessor( + description, + tag, + ignoreFailure, + targetField, + removeTargetField, + keepPreviousScore, + contextFetchers + ); default: throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel())); } @@ -100,6 +137,7 @@ private static class ContextFetcherFactory { /** * Map rerank types to whether they should include the query context source fetcher + * * @param type the constructing RerankType * @return does this RerankType depend on the QueryContextSourceFetcher? */ @@ -109,8 +147,8 @@ public static boolean shouldIncludeQueryContextFetcher(RerankType type) { /** * Create necessary queryContextFetchers for this processor - * @param config processor config object. Look for "context" field to find fetchers - * @param includeQueryContextFetcher should I include the queryContextFetcher? + * @param config Processor config object. Look for "context" field to find fetchers + * @param includeQueryContextFetcher Should I include the queryContextFetcher? * @return list of contextFetchers for the processor to use */ public static List createFetchers( diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java new file mode 100644 index 000000000..28bf7866f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessor.java @@ -0,0 +1,189 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher; +import org.opensearch.neuralsearch.processor.util.ProcessorUtils.SearchHitValidator; +import org.opensearch.search.SearchHit; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; + +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getScoreFromSourceMap; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.removeTargetFieldFromSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria; + +/** + * A reranking processor that reorders search results based on the content of a specified field. + *

+ * The ByFieldRerankProcessor allows for reordering of search results by considering the content of a + * designated target field within each document. This processor will update the _score field with what has been provided + * by {@code target_field}. When {@code keep_previous_score} is enabled a new field is appended called previous_score which was the score prior to reranking. + *

+ * Key features: + *

    + *
  • Reranks search results based on a specified target field
  • + *
  • Optionally removes the target field from the final search results
  • + *
  • Supports nested field structures using dot notation
  • + *
+ *

+ * The processor uses the following configuration parameters: + *

    + *
  • {@code target_field}: The field to be used for reranking (required)
  • + *
  • {@code remove_target_field}: Whether to remove the target field from the final results (optional, default: false)
  • + *
  • {@code keep_previous_score}: Whether to append the previous score in a field called previous_score (optional, default: false)
  • + *
+ *

+ * Usage example: + *

+ * {
+ *   "rerank": {
+ *     "by_field": {
+ *       "target_field": "document.relevance_score",
+ *       "remove_target_field": true,
+ *       "keep_previous_score": false
+ *     }
+ *   }
+ * }
+ * 
+ *

+ * This processor is useful in scenarios where additional, document-specific + * information stored in a field can be used to improve the relevance of search results + * beyond the initial scoring. + */ +@Log4j2 +public class ByFieldRerankProcessor extends RescoringRerankProcessor { + + public static final String TARGET_FIELD = "target_field"; + public static final String REMOVE_TARGET_FIELD = "remove_target_field"; + public static final String KEEP_PREVIOUS_SCORE = "keep_previous_score"; + + public static final boolean DEFAULT_REMOVE_TARGET_FIELD = false; + public static final boolean DEFAULT_KEEP_PREVIOUS_SCORE = false; + + protected final String targetField; + protected final boolean removeTargetField; + protected final boolean keepPreviousScore; + + /** + * Constructor to pass values to the RerankProcessor constructor. + * + * @param description The description of the processor + * @param tag The processor's identifier + * @param ignoreFailure If true, OpenSearch ignores any failure of this processor and + * continues to run the remaining processors in the search pipeline. + * @param targetField The field you want to replace your _score with + * @param removeTargetField A flag to let you delete the target_field for better visualization (i.e. removes a duplicate value) + * @param keepPreviousScore A flag to let you decide to stash your previous _score in a field called previous_score (i.e. for debugging purposes) + * @param contextSourceFetchers Context from some source and puts it in a map for a reranking processor to use (Unused in ByFieldRerankProcessor) + */ + public ByFieldRerankProcessor( + final String description, + final String tag, + final boolean ignoreFailure, + final String targetField, + final boolean removeTargetField, + final boolean keepPreviousScore, + final List contextSourceFetchers + ) { + super(RerankType.BY_FIELD, description, tag, ignoreFailure, contextSourceFetchers); + this.targetField = targetField; + this.removeTargetField = removeTargetField; + this.keepPreviousScore = keepPreviousScore; + } + + @Override + public void rescoreSearchResponse( + final SearchResponse response, + final Map rerankingContext, + final ActionListener> listener + ) { + SearchHit[] searchHits = response.getHits().getHits(); + + SearchHitValidator searchHitValidator = this::byFieldSearchHitValidator; + + if (!validateRerankCriteria(searchHits, searchHitValidator, listener)) { + return; + } + + List scores = new ArrayList<>(searchHits.length); + + for (SearchHit hit : searchHits) { + Map sourceAsMap = hit.getSourceAsMap(); + + float score = getScoreFromSourceMap(sourceAsMap, targetField); + scores.add(score); + + if (keepPreviousScore) { + sourceAsMap.put("previous_score", hit.getScore()); + } + + if (removeTargetField) { + removeTargetFieldFromSource(sourceAsMap, targetField); + } + + try { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + BytesReference sourceMapAsBytes = BytesReference.bytes(builder.map(sourceAsMap)); + hit.sourceRef(sourceMapAsBytes); + } catch (IOException e) { + log.error(e.getMessage()); + listener.onFailure(new RuntimeException(e)); + return; + } + } + + listener.onResponse(scores); + } + + /** + * Implements the behavior of the SearchHit validator {@code SearchHitValidator} + * It checks all the following + *

    + *
  • Checks the search hit has a source mapping
  • + *
  • Checks that the mapping exists in the source mapping using the target_field
  • + *
  • Checks that the mapping has a numerical score for it to rerank
  • + *
+ * @param hit A search hit to validate + */ + public void byFieldSearchHitValidator(final SearchHit hit) { + if (!hit.hasSource()) { + log.error(String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId())); + throw new IllegalArgumentException( + String.format(Locale.ROOT, "There is no source field to be able to perform rerank on hit [%d]", hit.docId()) + ); + } + + Map sourceMap = hit.getSourceAsMap(); + if (!mappingExistsInSource(sourceMap, targetField)) { + log.error(String.format(Locale.ROOT, "The field to rerank [%s] is not found at hit [%d]", targetField, hit.docId())); + + throw new IllegalArgumentException(String.format(Locale.ROOT, "The field to rerank by is not found at hit [%d]", hit.docId())); + } + + Optional val = getValueFromSource(sourceMap, targetField); + + if (!(val.get() instanceof Number)) { + log.error(String.format(Locale.ROOT, "The field mapping to rerank [%s: %s] is not Numerical", targetField, val.orElse(null))); + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "The field mapping to rerank by [%s] is not Numerical", val.orElse(null)) + ); + } + + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 93a2c8416..42d7d56ee 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -35,6 +35,7 @@ public abstract class RerankProcessor implements SearchResponseProcessor { @Getter private final boolean ignoreFailure; protected List contextSourceFetchers; + static final protected List processorsWithNoContext = List.of(RerankType.BY_FIELD); /** * Generate the information that this processor needs in order to rerank. @@ -48,6 +49,11 @@ public void generateRerankingContext( final SearchResponse searchResponse, final ActionListener> listener ) { + // Processors that don't require context, result on a listener infinitely waiting for a response without this check + if (!processorRequiresContext(subType)) { + listener.onResponse(Map.of()); + } + Map overallContext = new ConcurrentHashMap<>(); AtomicInteger successfulContexts = new AtomicInteger(contextSourceFetchers.size()); for (ContextSourceFetcher csf : contextSourceFetchers) { @@ -102,4 +108,19 @@ public void processResponseAsync( responseListener.onFailure(e); } } + + /** + * There are scenarios where ranking occurs without needing context. Currently, these are the processors don't require + * the context mapping + *
    + *
  • + * ByFieldRerankProcessor - Uses the search response to get value to rescore by + *
  • + *
+ * @param subType The kind of rerank processor + * @return Whether a rerank subtype needs context to perform the rescore search response action. + */ + public static boolean processorRequiresContext(RerankType subType) { + return !processorsWithNoContext.contains(subType); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java index 2063242dd..60a70e766 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankType.java @@ -16,7 +16,8 @@ */ public enum RerankType { - ML_OPENSEARCH("ml_opensearch"); + ML_OPENSEARCH("ml_opensearch"), + BY_FIELD("by_field"); @Getter private final String label; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java new file mode 100644 index 000000000..a6a377843 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/util/ProcessorUtils.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.util; + +import org.opensearch.common.collect.Tuple; +import org.opensearch.core.action.ActionListener; +import org.opensearch.search.SearchHit; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Stack; + +/** + * Utility class for evaluating SearchResponse data. This is useful when you want + * to see that the searchResponse is in correct form or if the data you want to extract/edit + * from the SearchResponse + */ +public class ProcessorUtils { + + /** + * Represents a function used to validate a SearchHit based on the provided implementation + * When it is incorrect an Exception is expected to be thrown. Otherwise, no return value is given + * to the caller indicating that the SearchHit is valid. + *
+ *

This is a functional interface + * whose functional method is {@link #validate(SearchHit)}}. + */ + @FunctionalInterface + public interface SearchHitValidator { + + /** + * Performs the validation for the SearchHit and takes in metadata of what happened when the error occurred. + *


+ * When the SearchHit is not in correct form, an exception is thrown + * @param hit The specific SearchHit were the invalidation occurred + * @throws IllegalArgumentException if the validation for the hit fails + */ + void validate(final SearchHit hit) throws IllegalArgumentException; + } + + /** + * This is the preflight check for Reranking. It checks that + * every Search Hit in the array from a given search Response has all the following + * for each SearchHit follows the correct form as specified by the validator. + * When just one of the conditions fail (as specified by the validator) the exception will be thrown to the listener. + * @param searchHits from the SearchResponse + * @param listener returns an error to the listener in case on of the conditions fail + * @return The status indicating that the SearchHits are in correct form to perform the Rerank + */ + public static boolean validateRerankCriteria( + final SearchHit[] searchHits, + final SearchHitValidator validator, + final ActionListener> listener + ) { + for (SearchHit hit : searchHits) { + try { + validator.validate(hit); + } catch (IllegalArgumentException e) { + listener.onFailure(e); + return false; + } + } + return true; + } + + /** + * Used to get the numeric mapping from the sourcemap using the target_field + *
+ * This method assumes that the path to the mapping exists (and is numerical) as checked by {@link #validateRerankCriteria(SearchHit[], SearchHitValidator, ActionListener)} + * As such no error checking is done in the methods implementing this functionality + * @param sourceAsMap the map of maps that contains the targetField + * @param targetField the path to take to get the score to replace by + * @return The numerical score found using the target_field + */ + public static float getScoreFromSourceMap(final Map sourceAsMap, final String targetField) { + Object val = getValueFromSource(sourceAsMap, targetField).get(); + return ((Number) val).floatValue(); + } + + /** + * This method performs the deletion of the targetField and emptyMaps in 3 phases + *
    + *
  1. Collect the maps and the respective keys (the key is used to get the inner map) in a stack. It will be used + * to delete empty maps and the target field
  2. + *
  3. Delete the top most entry, this is guaranteed even when the source mapping is non nested. This is the + * mapping containing the targetField
  4. + *
  5. Iteratively delete the rest of the maps that have (possibly been) emptied as the result of deleting the targetField
  6. + *
+ *
+ * This method assumes that the path to the mapping exists as checked by {@link #validateRerankCriteria(SearchHit[], SearchHitValidator, ActionListener)} + * As such no error checking is done in the methods implementing this functionality + *
+ * You can think of this algorithm as a recursive one the base case is deleting the targetField. The recursive case + * is going to the next map along with the respective key. Along the way if it finds a map is empty it will delete it + * @param sourceAsMap the map of maps that contains the targetField + * @param targetField The path to take to remove the targetField + */ + public static void removeTargetFieldFromSource(final Map sourceAsMap, final String targetField) { + Stack, String>> parentMapChildrenKeyTupleStack = new Stack<>(); + String[] keys = targetField.split("\\."); + + Map currentMap = sourceAsMap; + String lastKey = keys[keys.length - 1]; + + // Collect the parent maps with respective children to use them inside out + for (String key : keys) { + parentMapChildrenKeyTupleStack.add(new Tuple<>(currentMap, key)); + if (key.equals(lastKey)) { + break; + } + currentMap = (Map) currentMap.get(key); + } + + // Remove the last key this is guaranteed + Tuple, String> currentParentMapWithChild = parentMapChildrenKeyTupleStack.pop(); + Map parentMap = currentParentMapWithChild.v1(); + String key = currentParentMapWithChild.v2(); + parentMap.remove(key); + + // Delete the empty maps inside out using the stack to mock a recursive solution + while (!parentMapChildrenKeyTupleStack.isEmpty()) { + currentParentMapWithChild = parentMapChildrenKeyTupleStack.pop(); + parentMap = currentParentMapWithChild.v1(); + key = currentParentMapWithChild.v2(); + + @SuppressWarnings("unchecked") + Map innerMap = (Map) parentMap.get(key); + + if (innerMap != null && innerMap.isEmpty()) { + parentMap.remove(key); + } + } + } + + /** + * Returns the mapping associated with a path to a value, otherwise + * returns an empty optional when it encounters a dead end. + *
+ * When the targetField has the form (key[.key]) it will iterate through + * the map to see if a mapping exists. + * + * @param sourceAsMap The Source map (a map of maps) to iterate through + * @param targetField The path to take to get the desired mapping + * @return A possible result within an optional + */ + public static Optional getValueFromSource(final Map sourceAsMap, final String targetField) { + String[] keys = targetField.split("\\."); + Optional currentValue = Optional.of(sourceAsMap); + + for (String key : keys) { + currentValue = currentValue.flatMap(value -> { + if (!(value instanceof Map)) { + return Optional.empty(); + } + Map currentMap = (Map) value; + return Optional.ofNullable(currentMap.get(key)); + }); + + if (currentValue.isEmpty()) { + return Optional.empty(); + } + } + + return currentValue; + } + + /** + * Determines whether there exists a value that has a mapping according to the pathToValue. This is particularly + * useful when the source map is a map of maps and when the pathToValue is of the form key[.key]. + *
+ * To Exist in a map it must have a mapping that is not null or the key-value pair does not exist + * @param sourceAsMap the source field converted to a map + * @param pathToValue A string of the form key[.key] indicating what keys to apply to the sourceMap + * @return Whether the mapping using the pathToValue exists + */ + public static boolean mappingExistsInSource(final Map sourceAsMap, final String pathToValue) { + return getValueFromSource(sourceAsMap, pathToValue).isPresent(); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java index c464f2826..af5b374f3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactoryTests.java @@ -21,6 +21,7 @@ import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.ByFieldRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.processor.rerank.RerankType; @@ -72,8 +73,16 @@ public void testRerankProcessorFactory_whenNonExistentType_thenFail() { IllegalArgumentException.class, () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); + + Map config2 = new HashMap<>(Map.of("key", Map.of(ByFieldRerankProcessor.TARGET_FIELD, "path.to.target_field"))); + assertThrows( + "no rerank type found", + IllegalArgumentException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config2, pipelineContext) + ); } + // Start of MLOpenSearchRerankProcessor Tests public void testCrossEncoder_whenCorrectParams_thenSuccessful() { Map config = new HashMap<>( Map.of( @@ -218,5 +227,68 @@ public void testCrossEncoder_whenTooManyDocFields_thenFail() { () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) ); } + // End of MLOpenSearchRerankProcessor Tests + + // Start of ByFieldRerankProcessor Tests + public void testByFieldCreation_whenTargetFieldSpecifiedWithDefaultRemoveTargetFieldAndDefaultPreviousScore_thenSuccessful() { + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, "path.to.target_field"))) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof ByFieldRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testByFieldCreation_whenTargetFieldSpecifiedWithManualRemoveTargetFieldAndPreviousKeptScore_thenSuccessful() { + boolean removeTargetField = true; + boolean keepPreviousScore = true; + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + "path.to.target_field", + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + assert (processor instanceof RerankProcessor); + assert (processor instanceof ByFieldRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testByFieldCreation_WithContext_thenSucceed() { + // You can pass context but, it won't ever be used by ByFieldRerank + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, "path.to.target_field")), + RerankProcessorFactory.CONTEXT_CONFIG_FIELD, + new HashMap<>(Map.of(DocumentContextSourceFetcher.NAME, new ArrayList<>(List.of("text_representation")))) + ) + ); + SearchResponseProcessor processor = factory.create(Map.of(), TAG, DESC, false, config, pipelineContext); + + assert (processor instanceof RerankProcessor); + assert (processor instanceof ByFieldRerankProcessor); + assert (processor.getType().equals(RerankProcessor.TYPE)); + } + + public void testByField_whenEmptySubConfig_thenFail() { + Map config = new HashMap<>(Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>())); + assertThrows( + String.format(Locale.ROOT, "[%s] required property is missing", ByFieldRerankProcessor.TARGET_FIELD), + OpenSearchParseException.class, + () -> factory.create(Map.of(), TAG, DESC, false, config, pipelineContext) + ); + } + // End of ByFieldRerankProcessor Tests } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorIT.java new file mode 100644 index 000000000..93cf182c8 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorIT.java @@ -0,0 +1,201 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import com.google.common.collect.ImmutableList; +import lombok.SneakyThrows; +import lombok.extern.log4j.Log4j2; +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.ParseException; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.search.SearchHit; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Map; + +import static org.opensearch.ml.repackage.com.google.common.net.HttpHeaders.USER_AGENT; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_USER_AGENT; + +@Log4j2 +public class ByFieldRerankProcessorIT extends BaseNeuralSearchIT { + + private final static String PIPELINE_NAME = "rerank-byfield-pipeline"; + private final static String INDEX_NAME = "diary_index"; + private final static String INDEX_CONFIG = """ + { + "mappings" : { + "properties" : { + "diary" : { "type" : "text" }, + "similarity_score" : { "type" : "float" } + } + } + } + """.replace("\n", ""); + private final static List> sampleCrossEncoderData = List.of( + Map.entry("how are you", -11.055182f), + Map.entry("today is sunny", 8.969885f), + Map.entry("today is july fifth", -5.736348f), + Map.entry("it is winter", -10.045217f) + ); + private final static String SAMPLE_CROSS_ENCODER_DATA_FORMAT = """ + { + "diary" : "%s", + "similarity_score" : %s + } + """.replace("\n", ""); + + private final static String PATH_TO_BY_FIELD_RERANK_PIPELINE_TEMPLATE = "processor/ReRankByFieldPipelineConfiguration.json"; + private final static String POST = "POST"; + private final static String TARGET_FIELD = "similarity_score"; + private final static String REMOVE_TARGET_FIELD = "true"; + private final static String KEEP_PREVIOUS_FIELD = "true"; + private SearchResponse searchResponse; + + /** + * This test creates a simple index with as many documents that + * {@code sampleCrossEncoderData} has. It will then onboard a search pipeline + * with the byFieldRerankProcessor. When it applies the search pipeline it will + * capture the response string into a SearchResponse processor, which is tested like + * the Unit Tests. + *
+ * In this scenario the target_field is found within the first level and the + * target_field will be removed. + * + */ + @SneakyThrows + public void testByFieldRerankProcessor() throws IOException { + try { + createAndPopulateIndex(); + createPipeline(); + applyPipeLine(); + testSearchResponse(); + } finally { + wipeOfTestResources(INDEX_NAME, null, null, PIPELINE_NAME); + } + } + + private void createAndPopulateIndex() throws Exception { + createIndexWithConfiguration(INDEX_NAME, INDEX_CONFIG, PIPELINE_NAME); + for (int i = 0; i < sampleCrossEncoderData.size(); i++) { + String diary = sampleCrossEncoderData.get(i).getKey(); + String similarity = sampleCrossEncoderData.get(i).getValue() + ""; + + Response responseI = makeRequest( + client(), + POST, + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(String.format(LOCALE, SAMPLE_CROSS_ENCODER_DATA_FORMAT, diary, similarity)), + ImmutableList.of(new BasicHeader(USER_AGENT, DEFAULT_USER_AGENT)) + ); + + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(responseI.getEntity()), + false + ); + + assertEquals("The index has not been `created` instead was " + map.get("result"), "created", map.get("result")); + } + } + + private void createPipeline() throws URISyntaxException, IOException, ParseException { + String pipelineConfiguration = String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource(PATH_TO_BY_FIELD_RERANK_PIPELINE_TEMPLATE).toURI())), + TARGET_FIELD, + REMOVE_TARGET_FIELD, + KEEP_PREVIOUS_FIELD + ).replace("\"true\"", "true").replace("\"false\"", "false"); + + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + PIPELINE_NAME, + null, + toHttpEntity(pipelineConfiguration), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("Could not create the pipeline with node:" + node, "true", node.get("acknowledged").toString()); + } + + private void applyPipeLine() throws IOException, ParseException { + Request request = new Request(POST, "/" + INDEX_NAME + "/_search"); + request.addParameter("search_pipeline", PIPELINE_NAME); + // Filter out index metaData and only get document data. This gives search hits a score of 1 because of match all + request.setJsonEntity(""" + { + "query": { + "match_all": {} + } + } + """.replace("\n", "")); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + String responseBody = EntityUtils.toString(response.getEntity()); + this.searchResponse = stringToSearchResponse(responseBody); + } + + private void testSearchResponse() { + List> sortedDescendingSampleData = sampleCrossEncoderData.stream() + .sorted(Map.Entry.comparingByKey().reversed()) + .toList(); + + SearchHit[] searchHits = this.searchResponse.getHits().getHits(); + assertEquals("The sample data size should match the search response hits", sampleCrossEncoderData.size(), searchHits.length); + + for (int i = 0; i < searchHits.length; i++) { + float currentSimilarityScore = sortedDescendingSampleData.get(i).getValue(); + String currentDiary = sortedDescendingSampleData.get(i).getKey(); + SearchHit hit = this.searchResponse.getHits().getAt(i); + + assertEquals( + "The new score at hit[" + i + "] should match the current sampleScore", + currentSimilarityScore, + hit.getScore(), + 0.01 + ); + + Map sourceMap = hit.getSourceAsMap(); + assertEquals("The source map at hit[" + i + "] should be 2 keys `previous_score` and `diary`", 2, sourceMap.size()); + + float previousScore = (((Number) sourceMap.get("previous_score")).floatValue()); + String diary = (String) sourceMap.get("diary"); + + assertEquals("The `previous_score` should be 1.0f", 1.0f, previousScore, 0.01); + assertEquals("The `diary` fields should match based on the score", currentDiary, diary); + } + } + + // This assumes that the response is in the shape of a SearchResponse Object + private SearchResponse stringToSearchResponse(String response) throws IOException { + XContentParser parser = XContentType.JSON.xContent() + .createParser(NamedXContentRegistry.EMPTY, LoggingDeprecationHandler.INSTANCE, response); + + return SearchResponse.fromXContent(parser); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java new file mode 100644 index 000000000..a2555663d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/rerank/ByFieldRerankProcessorTests.java @@ -0,0 +1,1075 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.rerank; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.document.DocumentField; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.AbstractMap; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ByFieldRerankProcessorTests extends OpenSearchTestCase { + private SearchRequest request; + + private SearchResponse response; + + @Mock + private Processor.PipelineContext pipelineContext; + + @Mock + private PipelineProcessingContext ppctx; + + @Mock + private ClusterService clusterService; + + private RerankProcessorFactory factory; + + private ByFieldRerankProcessor processor; + + private final List> sampleIndexMLScorePairs = List.of( + Map.entry(1, 12.0f), + Map.entry(2, 5.2f), + Map.entry(3, 18.0f), + Map.entry(4, 1.0f) + ); + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + doReturn(Settings.EMPTY).when(clusterService).getSettings(); + factory = new RerankProcessorFactory(null, clusterService); + } + + public void testBasics() throws IOException { + setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping(); + String targetField = "ml_score"; + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field", + false, + config, + pipelineContext + ); + + assert (processor.getTag().equals("rerank processor")); + assert (processor.getDescription().equals("processor for 2nd level reranking based on provided field")); + assert (!processor.isIgnoreFailure()); + assertThrows( + "Use asyncProcessResponse unless you can guarantee to not deadlock yourself", + UnsupportedOperationException.class, + () -> processor.processResponse(request, response) + ); + } + + /** + * This test checks that the ByField successfully extracts the values using the targetField, this is + * the responsibility of extending the RescoreRerankProcessor. + * In this scenario it checks that the targetField is within the first level of the _source mapping. + *
+ * The expected behavior is to check that the sample ML Scores are returned from the rescoreSearchResponse. + * The target field is ml_score + */ + public void testRescoreSearchResponse_returnsScoresSuccessfully_WhenResponseHasTargetValueFirstLevelOfSourceMapping() + throws IOException { + String targetField = "ml_score"; + setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping(); + + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field", + false, + config, + pipelineContext + ); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.rescoreSearchResponse(response, Map.of(), listener); + + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + + assert (argCaptor.getValue().size() == sampleIndexMLScorePairs.size()); + for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) { + float mlScore = sampleIndexMLScorePairs.get(i).getValue(); + assert (argCaptor.getValue().get(i) == mlScore); + } + } + + /** + * This test checks that the ByField successfully extracts the values using the targetField (where the + * target value is within a nested map), this is the responsibility of extending the RescoreRerankProcessor. + * In this scenario it checks that the targetField is within a nested map. + *
+ * The expected behavior is to check that the sample ML Scores are returned from the rescoreSearchResponse. + * the targetField is ml.info.score + */ + public void testRescoreSearchResponse_returnsScoresSuccessfully_WhenResponseHasTargetValueInNestedMapping() throws IOException { + String targetField = "ml.info.score"; + setUpValidSearchResultsWithNestedTargetValue(); + + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + + @SuppressWarnings("unchecked") + ActionListener> listener = mock(ActionListener.class); + processor.rescoreSearchResponse(response, Map.of(), listener); + + ArgumentCaptor> argCaptor = ArgumentCaptor.forClass(List.class); + verify(listener, times(1)).onResponse(argCaptor.capture()); + + assertEquals(sampleIndexMLScorePairs.size(), argCaptor.getValue().size()); + for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) { + float mlScore = sampleIndexMLScorePairs.get(i).getValue(); + assertEquals(mlScore, argCaptor.getValue().get(i), 0.01); + } + } + + /** + * In this scenario the reRanking is being tested i.e. making sure that the search response has + * updated _score fields. This also tests that they are returned in sorted order as + * specified by sortedScoresDescending + */ + public void testReRank_SortsDescendingWithNewScores_WhenResponseHasNestedField() throws IOException { + String targetField = "ml.info.score"; + setUpValidSearchResultsWithNestedTargetValue(); + List> sortedScoresDescending = sampleIndexMLScorePairs.stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .toList(); + + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + assertEquals(sortedScoresDescending.getFirst().getValue(), searchResponse.getHits().getMaxScore(), 0.0001); + + for (int i = 0; i < sortedScoresDescending.size(); i++) { + int docId = sortedScoresDescending.get(i).getKey(); + float ml_score = sortedScoresDescending.get(i).getValue(); + assertEquals(docId, searchResponse.getHits().getAt(i).docId()); + assertEquals(ml_score, searchResponse.getHits().getAt(i).getScore(), 0.001); + } + } + + /** + * This scenario adds the remove_target_field to be able to test that _source mapping + * has been modified. It also asserts that the previous_score has been aggregated by keep_previous_score + *

+ * In this scenario the object will start off like this + *

+     * {
+     *    "my_field" : "%s",
+     *    "ml": {
+     *         "model" : "myModel",
+     *         "info"  : {
+     *          "score": %s
+     *         }
+     *    }
+     *  }
+     * 
+ * and then be transformed into + *
+     * {
+     *     "my_field" : "%s",
+     *     "ml": {
+     *         "model" : "myModel"
+     *      },
+     *      "previous_score" : float
+     * }
+     * 
+ * The reason for this was to delete any empty maps as the result of deleting score. + * This test also checks that previous score was added as a result of keep_previous_score being true + */ + public void testReRank_deletesEmptyMapsAndKeepsPreviousScore_WhenResponseHasNestedField() throws IOException { + String targetField = "ml.info.score"; + boolean removeTargetField = true; + boolean keepPreviousScore = true; + + setUpValidSearchResultsWithNestedTargetValue(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + SearchHit searchHit = searchResponse.getHits().getAt(i); + Map sourceMap = searchHit.getSourceAsMap(); + + assertTrue("The source mapping now has `previous_score` entry", sourceMap.containsKey("previous_score")); + assertEquals("The first level of the map is the containing `my_field`, `ml`, and `previous_score`", 3, sourceMap.size()); + + @SuppressWarnings("unchecked") + Map innerMLMap = (Map) sourceMap.get("ml"); + + assertEquals("The ml map now only has 1 mapping `model` instead of 2", 1, innerMLMap.size()); + assertTrue("The ml map has `model` as a mapping", innerMLMap.containsKey("model")); + assertFalse("The ml map no longer has the score `info` mapping ", innerMLMap.containsKey("info")); + + } + } + + /** + * This scenario tests the rerank functionality when the response has a nested field. + * It adds the remove_target_field to verify that the _source mapping + * has been modified. It also asserts that empty maps are deleted and no previous score is retained. + *

+ * In this scenario the object will start off like this: + *

+     * {
+     *    "my_field" : "%s",
+     *    "ml": {
+     *         "model" : "myModel",
+     *         "info"  : {
+     *          "score": %s
+     *         }
+     *    }
+     * }
+     * 
+ * and then be transformed into: + *
+     * {
+     *     "my_field" : "%s",
+     *     "ml": {
+     *         "model" : "myModel"
+     *     }
+     * }
+     * 
+ * The reason for this transformation is to delete any empty maps resulting from removing the target_field. + * This test also verifies that the nested structure is properly handled and the target field is removed. + */ + public void testReRank_deletesEmptyMapsAndHasNoPreviousScore_WhenResponseHasNestedField() throws IOException { + String targetField = "ml.info.score"; + boolean removeTargetField = true; + boolean keepPreviousScore = false; + + setUpValidSearchResultsWithNestedTargetValue(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + SearchHit searchHit = searchResponse.getHits().getAt(i); + Map sourceMap = searchHit.getSourceAsMap(); + + assertTrue( + "The source mapping does ot have `previous_score` entry because " + + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE + + " is " + + keepPreviousScore, + !sourceMap.containsKey("previous_score") + ); + assertEquals("The first level of the map is the containing `my_field` and `ml`", 2, sourceMap.size()); + + @SuppressWarnings("unchecked") + Map innerMLMap = (Map) sourceMap.get("ml"); + + assertEquals("The ml map now only has 1 mapping `model` instead of 2", 1, innerMLMap.size()); + assertTrue("The ml map has `model` as a mapping", innerMLMap.containsKey("model")); + assertFalse("The ml map no longer has the score `info` mapping ", innerMLMap.containsKey("info")); + + } + } + + /** + * This scenario adds the remove_target_field to be able to test that _source mapping + * has been modified. It also enables keep_previous_score to test that previous_score is appended. + *

+ * In this scenario the object will start off like this + *

+     * {
+     *  "my_field" : "%s",
+     *  "ml_score" : %s,
+     *   "info"    : {
+     *          "model" : "myModel"
+     *    }
+     * }
+     * 
+ * and then be transformed into + *
+     * {
+     *  "my_field" : "%s",
+     *   "info"    : {
+     *          "model" : "myModel"
+     *    },
+     *    "previous_score" : float
+     * }
+     * 
+ */ + public void testReRank_deletesEmptyMapsAndKeepsPreviousScore_WhenResponseHasNonNestedField() throws IOException { + String targetField = "ml_score"; + boolean removeTargetField = true; + boolean keepPreviousScore = true; + setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a NON-nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + SearchHit searchHit = searchResponse.getHits().getAt(i); + Map sourceMap = searchHit.getSourceAsMap(); + + assertTrue("The source mapping now has `previous_score` entry", sourceMap.containsKey("previous_score")); + assertEquals("The first level of the map is the containing `my_field`, `info`, and `previous_score`", 3, sourceMap.size()); + + @SuppressWarnings("unchecked") + Map innerInfoMap = (Map) sourceMap.get("info"); + + assertEquals("The info map has 1 mapping", 1, innerInfoMap.size()); + assertTrue("The info map has the model as the only mapping", innerInfoMap.containsKey("model")); + + } + } + + /** + * This scenario tests the rerank functionality when the response has a non-nested field. + * It adds the remove_target_field to verify that the _source mapping + * has been modified. It also disables keep_previous_score to test that previous_score is not appended. + *

+ * In this scenario the object will start off like this: + *

+    * {
+    *  "my_field" : "%s",
+    *  "ml_score" : %s,
+    *   "info"    : {
+    *          "model" : "myModel"
+    *    }
+    * }
+    * 
+ * and then be transformed into: + *
+    * {
+    *  "my_field" : "%s",
+    *   "info"    : {
+    *          "model" : "myModel"
+    *    }
+    * }
+    * 
+ * This test verifies that the target field is removed, empty maps are deleted, and no previous score is retained + * when dealing with a non-nested field structure. + */ + public void testReRank_deletesEmptyMapsAndHasNoPreviousScore_WhenResponseHasNonNestedField() throws IOException { + String targetField = "ml_score"; + boolean removeTargetField = true; + boolean keepPreviousScore = false; + setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a NON-nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + SearchHit searchHit = searchResponse.getHits().getAt(i); + Map sourceMap = searchHit.getSourceAsMap(); + + assertTrue( + "The source mapping does ot have `previous_score` entry because " + + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE + + " is " + + keepPreviousScore, + !sourceMap.containsKey("previous_score") + ); + assertEquals("The first level of the map is the containing `my_field` and `info`", 2, sourceMap.size()); + + @SuppressWarnings("unchecked") + Map innerInfoMap = (Map) sourceMap.get("info"); + + assertEquals("The info map has 1 mapping", 1, innerInfoMap.size()); + assertTrue("The info map has the model as the only mapping", innerInfoMap.containsKey("model")); + + } + } + + /** + * This scenario makes sure turning on keep_previous_score, updates the contents of the nested + * mapping by checking that a new field previous_score was added along with the correct values in which they came from + * and that the targetField has been deleted (along with other empty maps as a result of deleting this entry). + *
+ * In order to check that previous_score is valid it will check that the docIds and scores match from + * the resulting rerank and original sample data + */ + public void testReRank_storesPreviousScoresInSourceMap_WhenResponseHasNestedField() throws IOException { + String targetField = "ml.info.score"; + boolean removeTargetField = true; + boolean keepPreviousScore = true; + setUpValidSearchResultsWithNestedTargetValue(); + + List> previousDocIdScorePair = IntStream.range( + 0, + response.getHits().getHits().length + ) + .boxed() + .map(i -> new AbstractMap.SimpleImmutableEntry<>(response.getHits().getAt(i).docId(), response.getHits().getAt(i).getScore()) { + }) + .toList(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + float currentPreviousScore = ((Number) searchResponse.getHits().getAt(i).getSourceAsMap().get("previous_score")).floatValue(); + int currentDocId = searchResponse.getHits().getAt(i).docId(); + + // to access the corresponding document id it does so by counting at 0 + float trackedPreviousScore = previousDocIdScorePair.get(currentDocId - 1).getValue(); + int trackedDocId = previousDocIdScorePair.get(currentDocId - 1).getKey(); + + assertEquals("The document Ids need to match to compare previous scores", trackedDocId, currentDocId); + assertEquals( + "The scores for the search response previoiusly need to match to the score in the source map", + trackedPreviousScore, + currentPreviousScore, + 0.01 + ); + + } + } + + /** + * This scenario makes sure turning on keep_previous_score, updates the contents of the NON-nested + * mapping by checking that a new field previous_score was added along with the correct values in which they came from + * and that the targetField has been deleted (along with other empty maps as a result of deleting this entry). + *
+ * In order to check that previous_score is valid it will check that the docIds and scores match from + * the resulting rerank and original sample data + */ + public void testReRank_storesPreviousScoresInSourceMap_WhenResponseHasNonNestedField() throws IOException { + String targetField = "ml_score"; + boolean removeTargetField = true; + boolean keepPreviousScore = true; + setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping(); + + List> previousDocIdScorePair = IntStream.range( + 0, + response.getHits().getHits().length + ) + .boxed() + .map(i -> new AbstractMap.SimpleImmutableEntry<>(response.getHits().getAt(i).docId(), response.getHits().getAt(i).getScore()) { + }) + .toList(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of( + ByFieldRerankProcessor.TARGET_FIELD, + targetField, + ByFieldRerankProcessor.REMOVE_TARGET_FIELD, + removeTargetField, + ByFieldRerankProcessor.KEEP_PREVIOUS_SCORE, + keepPreviousScore + ) + ) + ) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a Non-nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + for (int i = 0; i < searchResponse.getHits().getHits().length; i++) { + float currentPreviousScore = ((Number) searchResponse.getHits().getAt(i).getSourceAsMap().get("previous_score")).floatValue(); + int currentDocId = searchResponse.getHits().getAt(i).docId(); + + // to access the corresponding document id it does so by counting at 0 + float trackedPreviousScore = previousDocIdScorePair.get(currentDocId - 1).getValue(); + int trackedDocId = previousDocIdScorePair.get(currentDocId - 1).getKey(); + + assertEquals("The document Ids need to match to compare previous scores", trackedDocId, currentDocId); + assertEquals( + "The scores for the search response previously need to match to the score in the source map", + trackedPreviousScore, + currentPreviousScore, + 0.01 + ); + + } + } + + public void testRerank_keepsTargetFieldAndHasNoPreviousScore_WhenByFieldHasDefaultValues() throws IOException { + String targetField = "ml.info.score"; + setUpValidSearchResultsWithNestedTargetValue(); + List> sortedScoresDescending = sampleIndexMLScorePairs.stream() + .sorted(Map.Entry.comparingByValue().reversed()) + .toList(); + + Map config = new HashMap<>( + Map.of(RerankType.BY_FIELD.getLabel(), new HashMap<>(Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField))) + ); + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field, This will check a nested field", + false, + config, + pipelineContext + ); + ActionListener listener = mock(ActionListener.class); + processor.rerank(response, Map.of(), listener); + + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(SearchResponse.class); + + verify(listener, times(1)).onResponse(argCaptor.capture()); + SearchResponse searchResponse = argCaptor.getValue(); + + assertEquals(sampleIndexMLScorePairs.size(), searchResponse.getHits().getHits().length); + assertEquals(sortedScoresDescending.getFirst().getValue(), searchResponse.getHits().getMaxScore(), 0.0001); + + for (int i = 0; i < sortedScoresDescending.size(); i++) { + int docId = sortedScoresDescending.get(i).getKey(); + float ml_score = sortedScoresDescending.get(i).getValue(); + assertEquals(docId, searchResponse.getHits().getAt(i).docId()); + assertEquals(ml_score, searchResponse.getHits().getAt(i).getScore(), 0.001); + + // Test that the path to targetField is valid + Map currentMap = searchResponse.getHits().getAt(i).getSourceAsMap(); + String[] keys = targetField.split("\\."); + String lastKey = keys[keys.length - 1]; + for (int keyIndex = 0; keyIndex < keys.length - 1; keyIndex++) { + String key = keys[keyIndex]; + assertTrue("The key:" + key + "does not exist in" + currentMap, currentMap.containsKey(key)); + currentMap = (Map) currentMap.get(key); + } + assertTrue("The key:" + lastKey + "does not exist in" + currentMap, currentMap.containsKey(lastKey)); + + } + } + + /** + * Creates a searchResponse where the value to reRank by is Nested. + * The location where the target is within a map of size 1 meaning after + * Using ByFieldReRank the expected behavior is to delete the info mapping + * as it is only has one mapping i.e. the duplicate value. + *
+ * The targetField for this scenario is ml.info.score + */ + private void setUpValidSearchResultsWithNestedTargetValue() throws IOException { + SearchHit[] hits = new SearchHit[sampleIndexMLScorePairs.size()]; + + String templateString = """ + { + "my_field" : "%s", + "ml": { + "model" : "myModel", + "info" : { + "score": %s + } + } + } + """.replace("\n", ""); + + for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) { + int docId = sampleIndexMLScorePairs.get(i).getKey(); + String mlScore = sampleIndexMLScorePairs.get(i).getValue() + ""; + + String sourceMap = templateString.formatted(i, mlScore); + + hits[i] = new SearchHit(docId, docId + "", Collections.emptyMap(), Collections.emptyMap()); + hits[i].sourceRef(new BytesArray(sourceMap)); + hits[i].score(1); + } + + TotalHits totalHits = new TotalHits(sampleIndexMLScorePairs.size(), TotalHits.Relation.EQUAL_TO); + + SearchHits searchHits = new SearchHits(hits, totalHits, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new SearchResponse.Clusters(1, 1, 0), null); + } + + /** + * Creates a searchResponse where the value to reRank is not Nested. + * The location where the target is within the first level of the _source mapping. + * There will be other fields as well (this is a dense map), the expected behavior is to leave the _source mapping + * without the targetField and leave the other fields intact. + *
+ * The targetField for this scenario is ml_score + */ + private void setUpValidSearchResultsWithNonNestedTargetValueWithDenseSourceMapping() throws IOException { + SearchHit[] hits = new SearchHit[sampleIndexMLScorePairs.size()]; + + String templateString = """ + { + "my_field" : "%s", + "ml_score" : %s, + "info" : { + "model" : "myModel" + } + } + """.replace("\n", ""); + + for (int i = 0; i < sampleIndexMLScorePairs.size(); i++) { + int docId = sampleIndexMLScorePairs.get(i).getKey(); + String mlScore = sampleIndexMLScorePairs.get(i).getValue() + ""; + + String sourceMap = templateString.formatted(i, mlScore); + + hits[i] = new SearchHit(docId, docId + "", Collections.emptyMap(), Collections.emptyMap()); + hits[i].sourceRef(new BytesArray(sourceMap)); + hits[i].score(1); + } + + TotalHits totalHits = new TotalHits(sampleIndexMLScorePairs.size(), TotalHits.Relation.EQUAL_TO); + + SearchHits searchHits = new SearchHits(hits, totalHits, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + response = new SearchResponse(internal, null, 1, 1, 0, 1, new ShardSearchFailure[0], new SearchResponse.Clusters(1, 1, 0), null); + } + + /** + * This scenario checks the byField rerank is able to check when a search hit has no source mapping. + * It is always required to have a source mapping if you want to use this processor. + */ + public void testRerank_throwsExceptionOnNoSource_WhenSearchResponseHasNoSourceMapping() { + String targetField = "similarity_score"; + boolean removeTargetField = true; + setUpInvalidSearchResultsWithNonSourceMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField, ByFieldRerankProcessor.REMOVE_TARGET_FIELD, removeTargetField) + ) + ) + ); + + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field. ", + false, + config, + pipelineContext + ); + + ActionListener listener = mock(ActionListener.class); + + processor.rerank(response, config, listener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argumentCaptor.capture()); + + assertEquals("There is no source field to be able to perform rerank on hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assert (argumentCaptor.getValue() instanceof IllegalArgumentException); + } + + /** + * The scenario checks that the search response has a source mapping for each search hit and verifies that the target field exists. + * In this case the test will see that the target field has no entry inside the source mapping. + */ + public void testRerank_throwsExceptionOnMappingNotExistingInSource_WhenSearchResponseHasAMissingMapping() { + String targetField = "similarity_score"; + boolean removeTargetField = true; + setUpInvalidSearchResultsWithMissingTargetFieldMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField, ByFieldRerankProcessor.REMOVE_TARGET_FIELD, removeTargetField) + ) + ) + ); + + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field. ", + false, + config, + pipelineContext + ); + + ActionListener listener = mock(ActionListener.class); + + processor.rerank(response, config, listener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argumentCaptor.capture()); + + assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assert (argumentCaptor.getValue() instanceof IllegalArgumentException); + } + + /** + * The scenario checks that the search response has source mapping within each search hit + * and the entry for the target field exists. However, the value for the target value target field is null the + * expected behavior is to return a message that it is not found, similar to the test case where there's no + * entry mapping for this target field + */ + public void testRerank_throwsExceptionOnHavingEmptyMapping_WhenTargetFieldHasNullMapping() { + String targetField = "similarity_score"; + boolean removeTargetField = true; + setUpInvalidSearchResultsWithTargetFieldHavingNullMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField, ByFieldRerankProcessor.REMOVE_TARGET_FIELD, removeTargetField) + ) + ) + ); + + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field. ", + false, + config, + pipelineContext + ); + + ActionListener listener = mock(ActionListener.class); + + processor.rerank(response, config, listener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argumentCaptor.capture()); + + assertEquals("The field to rerank by is not found at hit [" + 1 + "]", argumentCaptor.getValue().getMessage()); + assert (argumentCaptor.getValue() instanceof IllegalArgumentException); + } + + /** + * the scenario checks that the search response has source mapping within each search it and the entry for the target field exist. + * however, the value for the target field is non-numeric the expected behaviors is to throw an exception that the value is not numeric. + */ + public void testRerank_throwsExceptionOnHavingNonNumericValue_WhenTargetFieldHasNonNumericMapping() { + String targetField = "similarity_score"; + boolean removeTargetField = true; + setUpInvalidSearchResultsWithTargetFieldHavingNonNumericMapping(); + + Map config = new HashMap<>( + Map.of( + RerankType.BY_FIELD.getLabel(), + new HashMap<>( + Map.of(ByFieldRerankProcessor.TARGET_FIELD, targetField, ByFieldRerankProcessor.REMOVE_TARGET_FIELD, removeTargetField) + ) + ) + ); + + processor = (ByFieldRerankProcessor) factory.create( + Map.of(), + "rerank processor", + "processor for 2nd level reranking based on provided field. ", + false, + config, + pipelineContext + ); + + ActionListener listener = mock(ActionListener.class); + + processor.rerank(response, config, listener); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argumentCaptor.capture()); + + assertEquals("The field mapping to rerank by [hello world] is not Numerical", argumentCaptor.getValue().getMessage()); + assert (argumentCaptor.getValue() instanceof IllegalArgumentException); + + } + + /** + * This creates a search response with two hits, the first hit being in the correct form. + * While, the second search hit has a non-numeric target field mapping. + */ + private void setUpInvalidSearchResultsWithTargetFieldHavingNonNumericMapping() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":777}")); + hits[0].score(1.0F); + + Map dummyMap = new HashMap<>(); + dummyMap.put("test", new DocumentField("test", Collections.singletonList("test-field-mapping"))); + hits[1] = new SearchHit(1, "2", dummyMap, Collections.emptyMap()); + hits[1].sourceRef(new BytesArray("{\"diary\" : \"how do you do\",\"similarity_score\":\"hello world\"}")); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.response = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + /** + * This creates a search response with two hits, the first hit being in the correct form. + * While, the second search hit has a null target field mapping. + */ + private void setUpInvalidSearchResultsWithTargetFieldHavingNullMapping() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":-11.055182}")); + hits[0].score(1.0F); + + Map dummyMap = new HashMap<>(); + dummyMap.put("test", new DocumentField("test", Collections.singletonList("test-field-mapping"))); + hits[1] = new SearchHit(1, "2", dummyMap, Collections.emptyMap()); + hits[1].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":null}")); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.response = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + /** + * This creates a search response with two hits, the first hit being in the correct form. + * While, the second search hit has having a missing entry that is needed to perform reranking + */ + private void setUpInvalidSearchResultsWithMissingTargetFieldMapping() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":-11.055182}")); + hits[0].score(1.0F); + + hits[1] = new SearchHit(1, "2", Collections.emptyMap(), Collections.emptyMap()); + hits[1].sourceRef(new BytesArray("{\"diary\" : \"how are you\" }")); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.response = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + + } + + /** + * This creates a search response with two hits, the first hit being in correct form. + * While, the second search hit has no source mapping. + */ + private void setUpInvalidSearchResultsWithNonSourceMapping() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":-11.055182}")); + hits[0].score(1.0F); + + Map dummyMap = new HashMap<>(); + dummyMap.put("test", new DocumentField("test", Collections.singletonList("test-field-mapping"))); + hits[1] = new SearchHit(1, "2", dummyMap, Collections.emptyMap()); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.response = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java new file mode 100644 index 000000000..101060ca1 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/ProcessorUtilsTests.java @@ -0,0 +1,216 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util; + +import org.apache.lucene.search.TotalHits; +import org.mockito.ArgumentCaptor; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.common.document.DocumentField; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.neuralsearch.processor.util.ProcessorUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.getValueFromSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.mappingExistsInSource; +import static org.opensearch.neuralsearch.processor.util.ProcessorUtils.validateRerankCriteria; + +public class ProcessorUtilsTests extends OpenSearchTestCase { + + private Map sourceMap; + private SearchResponse searchResponse; + private float expectedScore; + + /** + * SourceMap of the form + *
+     *  {
+     *    "my_field" : "test_value",
+     *    "ml": {
+     *         "model" : "myModel",
+     *         "info"  : {
+     *                   "score": 0.95
+     *         }
+     *    }
+     * }
+     * 
+ */ + public void setUpValidSourceMap() { + expectedScore = 0.95f; + + sourceMap = new HashMap<>(); + sourceMap.put("my_field", "test_value"); + + Map mlMap = new HashMap<>(); + mlMap.put("model", "myModel"); + + Map infoMap = new HashMap<>(); + infoMap.put("score", expectedScore); + + mlMap.put("info", infoMap); + sourceMap.put("ml", mlMap); + } + + /** + * This creates a search response with two hits, the first hit being in the correct form. + * While, the second search hit has a non-numeric target field mapping. + */ + private void setUpInvalidSearchResultsWithTargetFieldHavingNonNumericMapping() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":777}")); + hits[0].score(1.0F); + + Map dummyMap = new HashMap<>(); + dummyMap.put("test", new DocumentField("test", Collections.singletonList("test-field-mapping"))); + hits[1] = new SearchHit(1, "2", dummyMap, Collections.emptyMap()); + hits[1].sourceRef(new BytesArray("{\"diary\" : \"how do you do\",\"similarity_score\":\"hello world\"}")); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + /** + * This creates a search response with two hits, Both are in correct form + */ + private void setUpValidSearchResults() { + SearchHit[] hits = new SearchHit[2]; + hits[0] = new SearchHit(0, "1", Collections.emptyMap(), Collections.emptyMap()); + hits[0].sourceRef(new BytesArray("{\"diary\" : \"how are you\",\"similarity_score\":777}")); + hits[0].score(1.0F); + + hits[1] = new SearchHit(1, "2", Collections.emptyMap(), Collections.emptyMap()); + hits[1].sourceRef(new BytesArray("{\"diary\" : \"how do you do\",\"similarity_score\":888}")); + hits[1].score(1.0F); + + SearchHits searchHits = new SearchHits(hits, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + this.searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + } + + public void testValidateRerankCriteria_throwsException_OnSearchResponseHavingNonNumericalScore() { + String targetField = "similarity_score"; + setUpInvalidSearchResultsWithTargetFieldHavingNonNumericMapping(); + ActionListener> listener = mock(ActionListener.class); + // Check that the mapping has non-numerical mapping + ProcessorUtils.SearchHitValidator searchHitValidator = (hit) -> { + Map sourceAsMap = hit.getSourceAsMap(); + Optional val = getValueFromSource(sourceAsMap, targetField); + + if (!(val.get() instanceof Number)) { + throw new IllegalArgumentException( + "The field mapping to rerank [" + targetField + ": " + val.get() + "] is a not Numerical" + ); + } + + }; + + boolean validRerankCriteria = validateRerankCriteria(searchResponse.getHits().getHits(), searchHitValidator, listener); + + assertFalse("This search response has invalid reranking criteria", validRerankCriteria); + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(argumentCaptor.capture()); + + assertEquals( + "The field mapping to rerank [" + targetField + ": hello world] is a not Numerical", + argumentCaptor.getValue().getMessage() + ); + assert (argumentCaptor.getValue() instanceof IllegalArgumentException); + } + + public void testValidateRerankCriteria_returnTrue_OnSearchResponseHavingCorrectForm() { + String targetField = "similarity_score"; + setUpValidSearchResults(); + ActionListener> listener = mock(ActionListener.class); + // This check is emulating the byField SearchHit Validation + ProcessorUtils.SearchHitValidator searchHitValidator = (hit) -> { + if (!hit.hasSource()) { + throw new IllegalArgumentException("There is no source field to be able to perform rerank on hit [" + hit.docId() + "]"); + } + + Map sourceAsMap = hit.getSourceAsMap(); + if (!mappingExistsInSource(sourceAsMap, targetField)) { + throw new IllegalArgumentException("The field to rerank [" + targetField + "] is not found at hit [" + hit.docId() + "]"); + } + Optional val = getValueFromSource(sourceAsMap, targetField); + + if (!(val.get() instanceof Number)) { + throw new IllegalArgumentException( + "The field mapping to rerank [" + targetField + ": " + val.get() + "] is a not Numerical" + ); + } + + }; + + boolean validRerankCriteria = validateRerankCriteria(searchResponse.getHits().getHits(), searchHitValidator, listener); + + assertTrue("This search response has valid reranking criteria", validRerankCriteria); + } + + public void testGetValueFromSource_returnsExpectedScore_WithExistingKeys() { + String targetField = "ml.info.score"; + setUpValidSourceMap(); + Optional result = getValueFromSource(sourceMap, targetField); + assertTrue(result.isPresent()); + assertEquals(expectedScore, (Float) result.get(), 0.01); + } + + public void testGetScoreFromSource_returnsExpectedScore_WithExistingKeys() { + String targetField = "ml.info.score"; + setUpValidSourceMap(); + float result = ProcessorUtils.getScoreFromSourceMap(sourceMap, targetField); + assertEquals(expectedScore, result, 0.01); + } + + public void testGetValueFromSource_returnsEmptyValue_WithNonExistingKeys() { + String targetField = "ml.info.score.wrong"; + setUpValidSourceMap(); + Optional result = getValueFromSource(sourceMap, targetField); + assertTrue(result.isEmpty()); + } + + public void testMappingExistsInSource_returnsTrue_withExistingKeys() { + String targetField = "ml.info.score"; + setUpValidSourceMap(); + boolean result = mappingExistsInSource(sourceMap, targetField); + assertTrue(result); + } + + public void testMappingExistsInSource_returnsFalse_withNonExistingKeys() { + String targetField = "ml.info.score.wrong"; + setUpValidSourceMap(); + boolean result = mappingExistsInSource(sourceMap, targetField); + assertFalse(result); + } + + public void testRemoveTargetFieldFromSource_successfullyDeletesTargetField_WithExistingKeys() { + String targetField = "ml.info.score"; + setUpValidSourceMap(); + ProcessorUtils.removeTargetFieldFromSource(sourceMap, targetField); + assertEquals("The first level of the map is the containing `my_field` and `ml`", 2, sourceMap.size()); + @SuppressWarnings("unchecked") + Map innerMLMap = (Map) sourceMap.get("ml"); + + assertEquals("The ml map now only has 1 mapping `model` instead of 2", 1, innerMLMap.size()); + assertTrue("The ml map has `model` as a mapping", innerMLMap.containsKey("model")); + assertFalse("The ml map no longer has the score `info` mapping ", innerMLMap.containsKey("info")); + } + +} diff --git a/src/test/resources/processor/ReRankByFieldPipelineConfiguration.json b/src/test/resources/processor/ReRankByFieldPipelineConfiguration.json new file mode 100644 index 000000000..37c8b1129 --- /dev/null +++ b/src/test/resources/processor/ReRankByFieldPipelineConfiguration.json @@ -0,0 +1,14 @@ +{ + "description": "Pipeline for reranking ByField", + "response_processors": [ + { + "rerank": { + "by_field": { + "target_field": "%s", + "remove_target_field": "%s", + "keep_previous_score": "%s" + } + } + } + ] +}