From e8fe2847a5237a03edd414a333799f7a5d2d8c7d Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 18 Dec 2024 10:52:25 +0800 Subject: [PATCH] [Enhancement] Implement pruning for neural sparse search (#988) * add impl Signed-off-by: zhichao-aws * add UT Signed-off-by: zhichao-aws * rename pruneType; UT Signed-off-by: zhichao-aws * changelog Signed-off-by: zhichao-aws * ut Signed-off-by: zhichao-aws * add it Signed-off-by: zhichao-aws * change on 2-phase Signed-off-by: zhichao-aws * UT Signed-off-by: zhichao-aws * it Signed-off-by: zhichao-aws * rename Signed-off-by: zhichao-aws * enhance: more detailed error message Signed-off-by: zhichao-aws * refactor to prune and split Signed-off-by: zhichao-aws * changelog Signed-off-by: zhichao-aws * fix UT cov Signed-off-by: zhichao-aws * address review comments Signed-off-by: zhichao-aws * enlarge score diff range Signed-off-by: zhichao-aws * address comments: check lowScores non null instead of flag Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws --- CHANGELOG.md | 1 + .../NeuralSparseTwoPhaseProcessor.java | 97 +++--- .../processor/SparseEncodingProcessor.java | 29 +- .../SparseEncodingProcessorFactory.java | 42 ++- .../query/NeuralSparseQueryBuilder.java | 23 +- .../neuralsearch/util/prune/PruneType.java | 47 +++ .../neuralsearch/util/prune/PruneUtils.java | 293 ++++++++++++++++++ .../NeuralSparseTwoPhaseProcessorIT.java | 32 +- .../NeuralSparseTwoPhaseProcessorTests.java | 83 +++-- .../processor/SparseEncodingProcessIT.java | 30 ++ .../SparseEncodingProcessorTests.java | 105 ++++++- ...ncodingEmbeddingProcessorFactoryTests.java | 182 +++++++++++ .../query/NeuralSparseQueryBuilderTests.java | 39 +++ .../util/prune/PruneTypeTests.java | 30 ++ .../util/prune/PruneUtilsTests.java | 266 ++++++++++++++++ ...ncodingPipelineConfigurationWithPrune.json | 21 ++ .../UploadSparseEncodingModelRequestBody.json | 10 +- .../neuralsearch/BaseNeuralSearchIT.java | 7 +- 18 files changed, 1197 insertions(+), 140 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java create mode 100644 src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java create mode 100644 src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 14441f3f7..5345d416f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Support new knn query parameter expand_nested ([#1013](https://github.com/opensearch-project/neural-search/pull/1013)) +- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java index 8d386e615..bc5971e3f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessor.java @@ -9,11 +9,12 @@ import lombok.Getter; import lombok.Setter; import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.collect.Tuple; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; @@ -21,11 +22,9 @@ import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.search.rescore.RescorerBuilder; -import java.util.Collections; import java.util.Locale; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; /** * A SearchRequestProcessor to generate two-phase NeuralSparseQueryBuilder, @@ -37,41 +36,37 @@ public class NeuralSparseTwoPhaseProcessor extends AbstractProcessor implements public static final String TYPE = "neural_sparse_two_phase_processor"; private boolean enabled; - private float ratio; + private float pruneRatio; + private PruneType pruneType; private float windowExpansion; private int maxWindowSize; private static final String PARAMETER_KEY = "two_phase_parameter"; - private static final String RATIO_KEY = "prune_ratio"; private static final String ENABLE_KEY = "enabled"; private static final String EXPANSION_KEY = "expansion_rate"; private static final String MAX_WINDOW_SIZE_KEY = "max_window_size"; private static final boolean DEFAULT_ENABLED = true; private static final float DEFAULT_RATIO = 0.4f; + private static final PruneType DEFAULT_PRUNE_TYPE = PruneType.MAX_RATIO; private static final float DEFAULT_WINDOW_EXPANSION = 5.0f; private static final int DEFAULT_MAX_WINDOW_SIZE = 10000; private static final int DEFAULT_BASE_QUERY_SIZE = 10; private static final int MAX_WINDOWS_SIZE_LOWER_BOUND = 50; private static final float WINDOW_EXPANSION_LOWER_BOUND = 1.0f; - private static final float RATIO_LOWER_BOUND = 0f; - private static final float RATIO_UPPER_BOUND = 1f; protected NeuralSparseTwoPhaseProcessor( String tag, String description, boolean ignoreFailure, boolean enabled, - float ratio, + float pruneRatio, + PruneType pruneType, float windowExpansion, int maxWindowSize ) { super(tag, description, ignoreFailure); this.enabled = enabled; - if (ratio < RATIO_LOWER_BOUND || ratio > RATIO_UPPER_BOUND) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "The two_phase_parameter.prune_ratio must be within [0, 1]. Received: %f", ratio) - ); - } - this.ratio = ratio; + this.pruneRatio = pruneRatio; + this.pruneType = pruneType; if (windowExpansion < WINDOW_EXPANSION_LOWER_BOUND) { throw new IllegalArgumentException( String.format(Locale.ROOT, "The two_phase_parameter.expansion_rate must >= 1.0. Received: %f", windowExpansion) @@ -93,7 +88,7 @@ protected NeuralSparseTwoPhaseProcessor( */ @Override public SearchRequest processRequest(final SearchRequest request) { - if (!enabled || ratio == 0f) { + if (!enabled || pruneRatio == 0f) { return request; } QueryBuilder queryBuilder = request.source().query(); @@ -117,43 +112,6 @@ public String getType() { return TYPE; } - /** - * Based on ratio, split a Map into two map by the value. - * - * @param queryTokens the queryTokens map, key is the token String, value is the score. - * @param thresholdRatio The ratio that control how tokens map be split. - * @return A tuple has two element, { token map whose value above threshold, token map whose value below threshold } - */ - public static Tuple, Map> splitQueryTokensByRatioedMaxScoreAsThreshold( - final Map queryTokens, - final float thresholdRatio - ) { - if (Objects.isNull(queryTokens)) { - throw new IllegalArgumentException("Query tokens cannot be null or empty."); - } - float max = 0f; - for (Float value : queryTokens.values()) { - max = Math.max(value, max); - } - float threshold = max * thresholdRatio; - - Map> queryTokensByScore = queryTokens.entrySet() - .stream() - .collect( - Collectors.partitioningBy(entry -> entry.getValue() >= threshold, Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)) - ); - - Map highScoreTokens = queryTokensByScore.get(Boolean.TRUE); - Map lowScoreTokens = queryTokensByScore.get(Boolean.FALSE); - if (Objects.isNull(highScoreTokens)) { - highScoreTokens = Collections.emptyMap(); - } - if (Objects.isNull(lowScoreTokens)) { - lowScoreTokens = Collections.emptyMap(); - } - return Tuple.tuple(highScoreTokens, lowScoreTokens); - } - private QueryBuilder getNestedQueryBuilderFromNeuralSparseQueryBuilderMap( final Multimap queryBuilderFloatMap ) { @@ -201,7 +159,10 @@ private Multimap collectNeuralSparseQueryBuilde * - Docs besides TopDocs: Score = HighScoreToken's score * - Final TopDocs: Score = HighScoreToken's score + LowScoreToken's score */ - NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase(ratio); + NeuralSparseQueryBuilder modifiedQueryBuilder = neuralSparseQueryBuilder.getCopyNeuralSparseQueryBuilderForTwoPhase( + pruneRatio, + pruneType + ); result.put(modifiedQueryBuilder, updatedBoost); } // We only support BoostQuery, BooleanQuery and NeuralSparseQuery now. For other compound query type which are not support now, will @@ -248,16 +209,40 @@ public NeuralSparseTwoPhaseProcessor create( boolean enabled = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, ENABLE_KEY, DEFAULT_ENABLED); Map twoPhaseConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, PARAMETER_KEY); - float ratio = DEFAULT_RATIO; + float pruneRatio = DEFAULT_RATIO; float windowExpansion = DEFAULT_WINDOW_EXPANSION; int maxWindowSize = DEFAULT_MAX_WINDOW_SIZE; + PruneType pruneType = DEFAULT_PRUNE_TYPE; if (Objects.nonNull(twoPhaseConfigMap)) { - ratio = ((Number) twoPhaseConfigMap.getOrDefault(RATIO_KEY, ratio)).floatValue(); + pruneRatio = ((Number) twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_RATIO_FIELD, pruneRatio)).floatValue(); windowExpansion = ((Number) twoPhaseConfigMap.getOrDefault(EXPANSION_KEY, windowExpansion)).floatValue(); maxWindowSize = ((Number) twoPhaseConfigMap.getOrDefault(MAX_WINDOW_SIZE_KEY, maxWindowSize)).intValue(); + pruneType = PruneType.fromString( + twoPhaseConfigMap.getOrDefault(PruneUtils.PRUNE_TYPE_FIELD, pruneType.getValue()).toString() + ); + } + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); } - return new NeuralSparseTwoPhaseProcessor(tag, description, ignoreFailure, enabled, ratio, windowExpansion, maxWindowSize); + return new NeuralSparseTwoPhaseProcessor( + tag, + description, + ignoreFailure, + enabled, + pruneRatio, + pruneType, + windowExpansion, + maxWindowSize + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index e01840fbb..9250c8d64 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -9,14 +9,17 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; +import lombok.Getter; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.prune.PruneType; import org.opensearch.neuralsearch.util.TokenWeightUtil; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.prune.PruneUtils; /** * This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use, @@ -27,6 +30,10 @@ public final class SparseEncodingProcessor extends InferenceProcessor { public static final String TYPE = "sparse_encoding"; public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; + @Getter + private final PruneType pruneType; + @Getter + private final float pruneRatio; public SparseEncodingProcessor( String tag, @@ -34,11 +41,15 @@ public SparseEncodingProcessor( int batchSize, String modelId, Map fieldMap, + PruneType pruneType, + float pruneRatio, MLCommonsClientAccessor clientAccessor, Environment environment, ClusterService clusterService ) { super(tag, description, batchSize, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, clusterService); + this.pruneType = pruneType; + this.pruneRatio = pruneRatio; } @Override @@ -49,17 +60,23 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); + setVectorFieldsToDocument(ingestDocument, ProcessMap, sparseVectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @Override public void doBatchExecute(List inferenceList, Consumer> handler, Consumer onException) { - mlCommonsClientAccessor.inferenceSentencesWithMapResult( - this.modelId, - inferenceList, - ActionListener.wrap(resultMaps -> handler.accept(TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)), onException) - ); + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { + List> sparseVectors = TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps) + .stream() + .map(vector -> PruneUtils.pruneSparseVector(pruneType, pruneRatio, vector)) + .toList(); + handler.accept(sparseVectors); + }, onException)); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 46055df16..7a7d7dfde 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -6,10 +6,13 @@ import static org.opensearch.ingest.ConfigurationUtils.readMap; import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.TYPE; +import static org.opensearch.ingest.ConfigurationUtils.readOptionalStringProperty; +import static org.opensearch.ingest.ConfigurationUtils.readDoubleProperty; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import java.util.Locale; import java.util.Map; import org.opensearch.cluster.service.ClusterService; @@ -19,6 +22,8 @@ import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.util.prune.PruneUtils; +import org.opensearch.neuralsearch.util.prune.PruneType; /** * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. @@ -40,7 +45,40 @@ public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, En protected AbstractBatchingProcessor newProcessor(String tag, String description, int batchSize, Map config) { String modelId = readStringProperty(TYPE, tag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, tag, config, FIELD_MAP_FIELD); + // if the field is miss, will return PruneType.None + PruneType pruneType = PruneType.fromString(readOptionalStringProperty(TYPE, tag, config, PruneUtils.PRUNE_TYPE_FIELD)); + float pruneRatio = 0; + if (pruneType != PruneType.NONE) { + // if we have prune type, then prune ratio field must have value + // readDoubleProperty will throw exception if value is not present + pruneRatio = readDoubleProperty(TYPE, tag, config, PruneUtils.PRUNE_RATIO_FIELD).floatValue(); + if (!PruneUtils.isValidPruneRatio(pruneType, pruneRatio)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Illegal prune_ratio %f for prune_type: %s. %s", + pruneRatio, + pruneType.getValue(), + PruneUtils.getValidPruneRatioDescription(pruneType) + ) + ); + } + } else if (config.containsKey(PruneUtils.PRUNE_RATIO_FIELD)) { + // if we don't have prune type, then prune ratio field must not have value + throw new IllegalArgumentException("prune_ratio field is not supported when prune_type is not provided"); + } - return new SparseEncodingProcessor(tag, description, batchSize, modelId, fieldMap, clientAccessor, environment, clusterService); + return new SparseEncodingProcessor( + tag, + description, + batchSize, + modelId, + fieldMap, + pruneType, + pruneRatio, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f46997d5e..be9719452 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -47,8 +47,8 @@ import lombok.NoArgsConstructor; import lombok.Setter; import lombok.experimental.Accessors; - -import static org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; /** * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model @@ -90,6 +90,7 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder tokens = queryTokensSupplier.get(); // Splitting tokens based on a threshold value: tokens greater than the threshold are stored in v1, // while those less than or equal to the threshold are stored in v2. - Tuple, Map> splitTokens = splitQueryTokensByRatioedMaxScoreAsThreshold(tokens, ratio); + Tuple, Map> splitTokens = PruneUtils.splitSparseVector(pruneType, pruneRatio, tokens); this.queryTokensSupplier(() -> splitTokens.v1()); copy.queryTokensSupplier(() -> splitTokens.v2()); } else { @@ -346,9 +348,10 @@ private BiConsumer> getModelInferenceAsync(SetOnce { Map queryTokens = TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0); if (Objects.nonNull(twoPhaseSharedQueryToken)) { - Tuple, Map> splitQueryTokens = splitQueryTokensByRatioedMaxScoreAsThreshold( - queryTokens, - twoPhasePruneRatio + Tuple, Map> splitQueryTokens = PruneUtils.splitSparseVector( + twoPhasePruneType, + twoPhasePruneRatio, + queryTokens ); setOnce.set(splitQueryTokens.v1()); twoPhaseSharedQueryToken = splitQueryTokens.v2(); diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java new file mode 100644 index 000000000..5f8e62b7c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneType.java @@ -0,0 +1,47 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.prune; + +import org.apache.commons.lang.StringUtils; + +import java.util.Locale; + +/** + * Enum representing different types of prune methods for sparse vectors + */ +public enum PruneType { + NONE("none"), + TOP_K("top_k"), + ALPHA_MASS("alpha_mass"), + MAX_RATIO("max_ratio"), + ABS_VALUE("abs_value"); + + private final String value; + + PruneType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + /** + * Get PruneType from string value + * + * @param value string representation of prune type + * @return corresponding PruneType enum + * @throws IllegalArgumentException if value doesn't match any prune type + */ + public static PruneType fromString(final String value) { + if (StringUtils.isEmpty(value)) return NONE; + for (PruneType type : PruneType.values()) { + if (type.value.equals(value)) { + return type; + } + } + throw new IllegalArgumentException(String.format(Locale.ROOT, "Unknown prune type: %s", value)); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java new file mode 100644 index 000000000..a4c35adcc --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/prune/PruneUtils.java @@ -0,0 +1,293 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.prune; + +import org.opensearch.common.collect.Tuple; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.PriorityQueue; + +/** + * Utility class providing methods for prune sparse vectors using different strategies. + * Prune helps reduce the dimensionality of sparse vectors by removing less significant elements + * based on various criteria. + */ +public class PruneUtils { + public static final String PRUNE_TYPE_FIELD = "prune_type"; + public static final String PRUNE_RATIO_FIELD = "prune_ratio"; + + /** + * Prunes a sparse vector by keeping only the top K elements with the highest values. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param k The number of top elements to keep + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with top K elements, the second with remaining elements (or null) + */ + private static Tuple, Map> pruneByTopK( + Map sparseVector, + float k, + boolean requiresPrunedEntries + ) { + PriorityQueue> pq = new PriorityQueue<>((a, b) -> Float.compare(a.getValue(), b.getValue())); + + for (Map.Entry entry : sparseVector.entrySet()) { + if (pq.size() < (int) k) { + pq.offer(entry); + } else if (entry.getValue() > pq.peek().getValue()) { + pq.poll(); + pq.offer(entry); + } + } + + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>(sparseVector) : null; + + while (!pq.isEmpty()) { + Map.Entry entry = pq.poll(); + highScores.put(entry.getKey(), entry.getValue()); + if (Objects.nonNull(lowScores)) { + lowScores.remove(entry.getKey()); + } + } + + return new Tuple<>(highScores, lowScores); + } + + /** + * Prunes a sparse vector by keeping only elements whose values are within a certain ratio + * of the maximum value in the vector. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param ratio The minimum ratio relative to the maximum value for elements to be kept + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements meeting the ratio threshold, + * the second with elements below the threshold (or null) + */ + private static Tuple, Map> pruneByMaxRatio( + Map sparseVector, + float ratio, + boolean requiresPrunedEntries + ) { + float maxValue = sparseVector.values().stream().max(Float::compareTo).orElse(0f); + + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() >= ratio * maxValue) { + highScores.put(entry.getKey(), entry.getValue()); + } else if (Objects.nonNull(lowScores)) { + lowScores.put(entry.getKey(), entry.getValue()); + } + } + + return new Tuple<>(highScores, lowScores); + } + + /** + * Prunes a sparse vector by removing elements with values below a certain threshold. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param thresh The minimum absolute value for elements to be kept + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements above the threshold, + * the second with elements below the threshold (or null) + */ + private static Tuple, Map> pruneByValue( + Map sparseVector, + float thresh, + boolean requiresPrunedEntries + ) { + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() >= thresh) { + highScores.put(entry.getKey(), entry.getValue()); + } else if (Objects.nonNull(lowScores)) { + lowScores.put(entry.getKey(), entry.getValue()); + } + } + + return new Tuple<>(highScores, lowScores); + } + + /** + * Prunes a sparse vector by keeping only elements whose cumulative sum of values + * is within a certain ratio of the total sum. + * + * @param sparseVector The input sparse vector as a map of string keys to float values + * @param alpha The minimum ratio relative to the total sum for elements to be kept + * @param requiresPrunedEntries Whether to return pruned entries + * @return A tuple containing two maps: the first with elements meeting the alpha mass threshold, + * the second with remaining elements (or null) + */ + private static Tuple, Map> pruneByAlphaMass( + Map sparseVector, + float alpha, + boolean requiresPrunedEntries + ) { + List> sortedEntries = new ArrayList<>(sparseVector.entrySet()); + sortedEntries.sort(Map.Entry.comparingByValue(Comparator.reverseOrder())); + + float sum = (float) sparseVector.values().stream().mapToDouble(Float::doubleValue).sum(); + float topSum = 0f; + + Map highScores = new HashMap<>(); + Map lowScores = requiresPrunedEntries ? new HashMap<>() : null; + + for (Map.Entry entry : sortedEntries) { + float value = entry.getValue(); + topSum += value; + + if (topSum <= alpha * sum) { + highScores.put(entry.getKey(), value); + } else if (Objects.nonNull(lowScores)) { + lowScores.put(entry.getKey(), value); + } + } + + return new Tuple<>(highScores, lowScores); + } + + /** + * Split a sparse vector using the specified prune type and ratio. + * + * @param pruneType The type of prune strategy to use + * @param pruneRatio The ratio or threshold for prune + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A tuple containing two maps: the first with high-scoring elements, + * the second with low-scoring elements + */ + public static Tuple, Map> splitSparseVector( + PruneType pruneType, + float pruneRatio, + Map sparseVector + ) { + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); + } + + if (Objects.isNull(sparseVector)) { + throw new IllegalArgumentException("Sparse vector must be provided"); + } + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruneType) { + case TOP_K: + return pruneByTopK(sparseVector, pruneRatio, true); + case ALPHA_MASS: + return pruneByAlphaMass(sparseVector, pruneRatio, true); + case MAX_RATIO: + return pruneByMaxRatio(sparseVector, pruneRatio, true); + case ABS_VALUE: + return pruneByValue(sparseVector, pruneRatio, true); + default: + return new Tuple<>(new HashMap<>(sparseVector), new HashMap<>()); + } + } + + /** + * Prune a sparse vector using the specified prune type and ratio. + * + * @param pruneType The type of prune strategy to use + * @param pruneRatio The ratio or threshold for prune + * @param sparseVector The input sparse vector as a map of string keys to float values + * @return A map with high-scoring elements + */ + public static Map pruneSparseVector( + final PruneType pruneType, + final float pruneRatio, + final Map sparseVector + ) { + if (Objects.isNull(pruneType)) { + throw new IllegalArgumentException("Prune type must be provided"); + } + + if (Objects.isNull(sparseVector)) { + throw new IllegalArgumentException("Sparse vector must be provided"); + } + + for (Map.Entry entry : sparseVector.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException("Pruned values must be positive"); + } + } + + switch (pruneType) { + case TOP_K: + return pruneByTopK(sparseVector, pruneRatio, false).v1(); + case ALPHA_MASS: + return pruneByAlphaMass(sparseVector, pruneRatio, false).v1(); + case MAX_RATIO: + return pruneByMaxRatio(sparseVector, pruneRatio, false).v1(); + case ABS_VALUE: + return pruneByValue(sparseVector, pruneRatio, false).v1(); + default: + return sparseVector; + } + } + + /** + * Validates whether a prune ratio is valid for a given prune type. + * + * @param pruneType The type of prune strategy + * @param pruneRatio The ratio or threshold to validate + * @return true if the ratio is valid for the given prune type, false otherwise + * @throws IllegalArgumentException if prune type is null + */ + public static boolean isValidPruneRatio(final PruneType pruneType, final float pruneRatio) { + if (pruneType == null) { + throw new IllegalArgumentException("Prune type cannot be null"); + } + + switch (pruneType) { + case TOP_K: + return pruneRatio > 0 && pruneRatio == Math.floor(pruneRatio); + case ALPHA_MASS: + case MAX_RATIO: + return pruneRatio >= 0 && pruneRatio < 1; + case ABS_VALUE: + return pruneRatio >= 0; + default: + return true; + } + } + + /** + * Get description of valid prune ratio for a given prune type. + * + * @param pruneType The type of prune strategy + * @throws IllegalArgumentException if prune type is null + */ + public static String getValidPruneRatioDescription(final PruneType pruneType) { + if (pruneType == null) { + throw new IllegalArgumentException("Prune type cannot be null"); + } + + switch (pruneType) { + case TOP_K: + return "prune_ratio should be positive integer."; + case MAX_RATIO: + case ALPHA_MASS: + return "prune_ratio should be in the range [0, 1)."; + case ABS_VALUE: + return "prune_ratio should be non-negative."; + default: + return "prune_ratio field is not supported when prune_type is none"; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java index 3e4ed8844..bc61f7c29 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorIT.java @@ -43,19 +43,12 @@ public class NeuralSparseTwoPhaseProcessorIT extends BaseNeuralSearchIT { private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); - private static final Float DELTA = 1e-5f; + private static final Float DELTA = 1e-4f; private final Map testRankFeaturesDoc = createRandomTokenWeightMap(TEST_TOKENS); private static final List TWO_PHASE_TEST_TOKEN = List.of("hello", "world"); - private static final Map testFixedQueryTokens = new HashMap<>(); + private static final Map testFixedQueryTokens = Map.of("hello", 5.0f, "world", 4.0f, "a", 3.0f, "b", 2.0f, "c", 1.0f); private static final Supplier> testFixedQueryTokenSupplier = () -> testFixedQueryTokens; - static { - testFixedQueryTokens.put("hello", 5.0f); - testFixedQueryTokens.put("world", 4.0f); - testFixedQueryTokens.put("a", 3.0f); - testFixedQueryTokens.put("b", 2.0f); - testFixedQueryTokens.put("c", 1.0f); - } @Before public void setUp() throws Exception { @@ -82,7 +75,6 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries_whenTwoPhaseEnabl NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) .queryTokensSupplier(randomTokenWeightSupplier); NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(randomTokenWeightSupplier); boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); @@ -116,7 +108,7 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "TEST_QUERY_TEXT", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -127,13 +119,12 @@ private void setDefaultSearchPipelineForIndex(String indexName) { * } */ @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled_thenGetExpectedScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -148,14 +139,13 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabled_thenGetExpectedScor } @SneakyThrows - public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { + public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabledAndDisabled_thenGetSameScore() { try { initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); initializeTwoPhaseProcessor(); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -164,7 +154,6 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS float scoreWithoutTwoPhase = objectToFloat(firstInnerHit.get("_score")); sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); @@ -190,7 +179,7 @@ public void testBasicQueryUsingQueryText_whenTwoPhaseEnabledAndDisabled_thenGetS * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -209,7 +198,6 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); QueryBuilder queryBuilder = new MatchAllQueryBuilder(); @@ -232,7 +220,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -240,7 +228,7 @@ public void testNeuralSparseQueryAsRescoreQuery_whenTwoPhase_thenGetExpectedScor * { * "neural_sparse": { * "field": "test-sparse-encoding-1", - * "query_text": "Hello world a b", + * "query_tokens": testFixedQueryTokens, * "model_id": "dcsdcasd", * "boost": 2.0 * } @@ -316,7 +304,6 @@ public void testMultiNeuralSparseQuery_whenTwoPhaseAndFilter_thenGetExpectedScor setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(2.0f); boolQueryBuilder.should(sparseEncodingQueryBuilder); @@ -401,7 +388,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInConstantScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(1.0f); ConstantScoreQueryBuilder constantScoreQueryBuilder = new ConstantScoreQueryBuilder(sparseEncodingQueryBuilder); @@ -421,7 +407,6 @@ public void testNeuralSParseQuery_whenTwoPhaseAndNestedInDisjunctionMaxQuery_the createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); DisMaxQueryBuilder disMaxQueryBuilder = new DisMaxQueryBuilder(); @@ -444,7 +429,6 @@ public void testNeuralSparseQuery_whenTwoPhaseAndNestedInFunctionScoreQuery_then createNeuralSparseTwoPhaseSearchProcessor(search_pipeline, 0.6f, 5f, 10000); setDefaultSearchPipelineForIndex(TEST_BASIC_INDEX_NAME); NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1) - .queryText(TEST_QUERY_TEXT) .queryTokensSupplier(testFixedQueryTokenSupplier) .boost(5.0f); FunctionScoreQueryBuilder functionScoreQueryBuilder = new FunctionScoreQueryBuilder(sparseEncodingQueryBuilder); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java index 2ce7c7b96..40230a618 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralSparseTwoPhaseProcessorTests.java @@ -6,10 +6,11 @@ import lombok.SneakyThrows; import org.opensearch.action.search.SearchRequest; -import org.opensearch.common.collect.Tuple; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.neuralsearch.util.prune.PruneUtils; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.test.OpenSearchTestCase; @@ -20,7 +21,6 @@ public class NeuralSparseTwoPhaseProcessorTests extends OpenSearchTestCase { static final private String PARAMETER_KEY = "two_phase_parameter"; - static final private String RATIO_KEY = "prune_ratio"; static final private String ENABLE_KEY = "enabled"; static final private String EXPANSION_KEY = "expansion_rate"; static final private String MAX_WINDOW_SIZE_KEY = "max_window_size"; @@ -28,9 +28,10 @@ public class NeuralSparseTwoPhaseProcessorTests extends OpenSearchTestCase { public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory); - assertEquals(0.3f, processor.getRatio(), 1e-3); + assertEquals(0.3f, processor.getPruneRatio(), 1e-3); assertEquals(4.0f, processor.getWindowExpansion(), 1e-3); assertEquals(10000, processor.getMaxWindowSize()); + assertEquals(PruneType.MAX_RATIO, processor.getPruneType()); NeuralSparseTwoPhaseProcessor defaultProcessor = factory.create( Collections.emptyMap(), @@ -40,14 +41,26 @@ public void testFactory_whenCreateDefaultPipeline_thenSuccess() throws Exception Collections.emptyMap(), null ); - assertEquals(0.4f, defaultProcessor.getRatio(), 1e-3); + assertEquals(0.4f, defaultProcessor.getPruneRatio(), 1e-3); assertEquals(5.0f, defaultProcessor.getWindowExpansion(), 1e-3); assertEquals(10000, defaultProcessor.getMaxWindowSize()); + assertEquals(PruneType.MAX_RATIO, processor.getPruneType()); + } + + public void testFactory_whenCreatePipelineWithCustomPruneType_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 5f, "top_k", true, 5f, 1000); + assertEquals(5f, processor.getPruneRatio(), 1e-6); + assertEquals(PruneType.TOP_K, processor.getPruneType()); } public void testFactory_whenRatioOutOfRange_thenThrowException() { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "max_ratio", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 0f, "top_k", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, 1.1f, "alpha_mass", true, 5.0f, 10000)); + expectThrows(IllegalArgumentException.class, () -> createTestProcessor(factory, -1f, "abs_value", true, 5.0f, 10000)); } public void testFactory_whenWindowExpansionOutOfRange_thenThrowException() { @@ -73,6 +86,19 @@ public void testProcessRequest_whenTwoPhaseEnabled_thenSuccess() throws Exceptio assertNotNull(searchRequest.source().rescores()); } + public void testProcessRequest_whenUseCustomPruneType_thenSuccess() throws Exception { + NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); + NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory, 0.5f, "alpha_mass", true, 4.0f, 10000); + processor.processRequest(searchRequest); + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) searchRequest.source().query(); + assertEquals(queryBuilder.twoPhasePruneRatio(), 0.5f, 1e-3); + assertEquals(queryBuilder.twoPhasePruneType(), PruneType.ALPHA_MASS); + assertNotNull(searchRequest.source().rescores()); + } + public void testProcessRequest_whenTwoPhaseEnabledAndNestedBoolean_thenSuccess() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseQueryBuilder neuralQueryBuilder = new NeuralSparseQueryBuilder(); @@ -140,32 +166,6 @@ public void testProcessRequest_whenTwoPhaseEnabledAndWithOutNeuralSparseQuery_th assertNull(returnRequest.source().rescores()); } - @SneakyThrows - public void testGetSplitSetOnceByScoreThreshold() { - Map queryTokens = new HashMap<>(); - for (int i = 0; i < 10; i++) { - queryTokens.put(String.valueOf(i), (float) i); - } - Tuple, Map> splitQueryTokens = NeuralSparseTwoPhaseProcessor - .splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f); - assertNotNull(splitQueryTokens); - Map upSet = splitQueryTokens.v1(); - Map downSet = splitQueryTokens.v2(); - assertNotNull(upSet); - assertEquals(6, upSet.size()); - assertNotNull(downSet); - assertEquals(4, downSet.size()); - } - - @SneakyThrows - public void testGetSplitSetOnceByScoreThreshold_whenNullQueryToken_thenThrowException() { - Map queryTokens = null; - expectThrows( - IllegalArgumentException.class, - () -> NeuralSparseTwoPhaseProcessor.splitQueryTokensByRatioedMaxScoreAsThreshold(queryTokens, 0.4f) - ); - } - public void testType() throws Exception { NeuralSparseTwoPhaseProcessor.Factory factory = new NeuralSparseTwoPhaseProcessor.Factory(); NeuralSparseTwoPhaseProcessor processor = createTestProcessor(factory); @@ -182,9 +182,28 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor( Map configMap = new HashMap<>(); configMap.put(ENABLE_KEY, enabled); Map twoPhaseParaMap = new HashMap<>(); - twoPhaseParaMap.put(RATIO_KEY, ratio); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio); + twoPhaseParaMap.put(EXPANSION_KEY, expand); + twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window); + configMap.put(PARAMETER_KEY, twoPhaseParaMap); + return factory.create(Collections.emptyMap(), null, null, false, configMap, null); + } + + private NeuralSparseTwoPhaseProcessor createTestProcessor( + NeuralSparseTwoPhaseProcessor.Factory factory, + float ratio, + String type, + boolean enabled, + float expand, + int max_window + ) throws Exception { + Map configMap = new HashMap<>(); + configMap.put(ENABLE_KEY, enabled); + Map twoPhaseParaMap = new HashMap<>(); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, ratio); twoPhaseParaMap.put(EXPANSION_KEY, expand); twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, max_window); + twoPhaseParaMap.put(PruneUtils.PRUNE_TYPE_FIELD, type); configMap.put(PARAMETER_KEY, twoPhaseParaMap); return factory.create(Collections.emptyMap(), null, null, false, configMap, null); } @@ -193,7 +212,7 @@ private NeuralSparseTwoPhaseProcessor createTestProcessor(NeuralSparseTwoPhasePr Map configMap = new HashMap<>(); configMap.put(ENABLE_KEY, true); Map twoPhaseParaMap = new HashMap<>(); - twoPhaseParaMap.put(RATIO_KEY, 0.3f); + twoPhaseParaMap.put(PruneUtils.PRUNE_RATIO_FIELD, 0.3f); twoPhaseParaMap.put(EXPANSION_KEY, 4.0f); twoPhaseParaMap.put(MAX_WINDOW_SIZE_KEY, 10000); configMap.put(PARAMETER_KEY, twoPhaseParaMap); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 349da1033..83b680d19 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -18,6 +18,7 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import com.google.common.collect.ImmutableList; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; public class SparseEncodingProcessIT extends BaseNeuralSearchIT { @@ -39,6 +40,35 @@ public void testSparseEncodingProcessor() throws Exception { createSparseEncodingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("title_sparse"); + neuralSparseQueryBuilder.queryTokensSupplier(() -> Map.of("good", 1.0f, "a", 2.0f)); + Map searchResponse = search(INDEX_NAME, neuralSparseQueryBuilder, 2); + assertFalse(searchResponse.isEmpty()); + double maxScore = (Double) ((Map) searchResponse.get("hits")).get("max_score"); + assertEquals(4.4433594, maxScore, 1e-3); + } finally { + wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); + } + } + + public void testSparseEncodingProcessorWithPrune() throws Exception { + String modelId = null; + try { + modelId = prepareSparseEncodingModel(); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.SPARSE_ENCODING_PRUNE); + createSparseEncodingIndex(); + ingestDocument(); + assertEquals(1, getDocCount(INDEX_NAME)); + + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("title_sparse"); + neuralSparseQueryBuilder.queryTokensSupplier(() -> Map.of("good", 1.0f, "a", 2.0f)); + Map searchResponse = search(INDEX_NAME, neuralSparseQueryBuilder, 2); + assertFalse(searchResponse.isEmpty()); + double maxScore = (Double) ((Map) searchResponse.get("hits")).get("max_score"); + assertEquals(3.640625, maxScore, 1e-3); } finally { wipeOfTestResources(INDEX_NAME, PIPELINE_NAME, modelId, null); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java index 9486ee2ca..8d512cc4c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -14,10 +14,12 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.when; import static org.mockito.Mockito.verify; +import java.util.Arrays; import java.util.Map; import java.util.ArrayList; import java.util.Collections; @@ -49,6 +51,7 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.opensearch.neuralsearch.util.prune.PruneType; public class SparseEncodingProcessorTests extends InferenceProcessorTestCase { @Mock @@ -90,6 +93,17 @@ private SparseEncodingProcessor createInstance(int batchSize) { return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } + @SneakyThrows + private SparseEncodingProcessor createInstance(PruneType pruneType, float pruneRatio) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + config.put("prune_type", pruneType.getValue()); + config.put("prune_ratio", pruneRatio); + return (SparseEncodingProcessor) sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + public void testExecute_successful() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); @@ -260,9 +274,98 @@ public void test_batchExecute_exception() { } } + @SuppressWarnings("unchecked") + public void testExecute_withPruneConfig_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> dataAsMapList = Collections.singletonList( + Map.of("response", Arrays.asList(ImmutableMap.of("hello", 1.0f, "world", 0.1f), ImmutableMap.of("test", 0.8f, "low", 0.4f))) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + + ArgumentCaptor docCaptor = ArgumentCaptor.forClass(IngestDocument.class); + verify(handler).accept(docCaptor.capture(), isNull()); + + IngestDocument processedDoc = docCaptor.getValue(); + Map first = (Map) processedDoc.getFieldValue("key1Mapped", Map.class); + Map second = (Map) processedDoc.getFieldValue("key2Mapped", Map.class); + + assertNotNull(first); + assertNotNull(second); + + assertTrue(first.containsKey("hello")); + assertFalse(first.containsKey("world")); + assertEquals(1.0f, first.get("hello"), 0.001f); + + assertTrue(second.containsKey("test")); + assertTrue(second.containsKey("low")); + assertEquals(0.8f, second.get("test"), 0.001f); + assertEquals(0.4f, second.get("low"), 0.001f); + } + + public void test_batchExecute_withPrune_successful() { + SparseEncodingProcessor processor = createInstance(PruneType.MAX_RATIO, 0.5f); + + List> mockMLResponse = Collections.singletonList( + Map.of( + "response", + Arrays.asList( + ImmutableMap.of("token1", 1.0f, "token2", 0.3f, "token3", 0.8f), + ImmutableMap.of("token4", 0.9f, "token5", 0.2f, "token6", 0.7f) + ) + ) + ); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(mockMLResponse); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + Consumer> resultHandler = mock(Consumer.class); + Consumer exceptionHandler = mock(Consumer.class); + + List inferenceList = Arrays.asList("test1", "test2"); + processor.doBatchExecute(inferenceList, resultHandler, exceptionHandler); + + ArgumentCaptor>> resultCaptor = ArgumentCaptor.forClass(List.class); + verify(resultHandler).accept(resultCaptor.capture()); + verify(exceptionHandler, never()).accept(any()); + + List> processedResults = resultCaptor.getValue(); + + assertEquals(2, processedResults.size()); + + Map firstResult = processedResults.get(0); + assertEquals(2, firstResult.size()); + assertTrue(firstResult.containsKey("token1")); + assertTrue(firstResult.containsKey("token3")); + assertFalse(firstResult.containsKey("token2")); + + Map secondResult = processedResults.get(1); + assertEquals(2, secondResult.size()); + assertTrue(secondResult.containsKey("token4")); + assertTrue(secondResult.containsKey("token6")); + assertFalse(secondResult.containsKey("token5")); + } + private List> createMockMapResult(int number) { List> mockSparseEncodingResult = new ArrayList<>(); - IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f, "world", 0.1f))); List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); return mockMapResult; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..5d098e77e --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingEmbeddingProcessorFactoryTests.java @@ -0,0 +1,182 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.SparseEncodingProcessor.TYPE; +import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_TYPE_FIELD; +import static org.opensearch.neuralsearch.util.prune.PruneUtils.PRUNE_RATIO_FIELD; + +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; +import org.opensearch.neuralsearch.util.prune.PruneType; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class SparseEncodingEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + private static final String MODEL_ID = "testModelId"; + private static final int BATCH_SIZE = 1; + + private MLCommonsClientAccessor clientAccessor; + private Environment environment; + private ClusterService clusterService; + private SparseEncodingProcessorFactory sparseEncodingProcessorFactory; + + @Before + public void setup() { + clientAccessor = mock(MLCommonsClientAccessor.class); + environment = mock(Environment.class); + clusterService = mock(ClusterService.class); + sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(clientAccessor, environment, clusterService); + } + + @SneakyThrows + public void testCreateProcessor_whenAllRequiredParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.NONE, processor.getPruneType()); + assertEquals(0f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenPruneParamsPassed_thenSuccessful() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 2f); + + SparseEncodingProcessor processor = (SparseEncodingProcessor) sparseEncodingProcessorFactory.create( + Map.of(), + PROCESSOR_TAG, + DESCRIPTION, + config + ); + + assertNotNull(processor); + assertEquals(TYPE, processor.getType()); + assertEquals(PROCESSOR_TAG, processor.getTag()); + assertEquals(DESCRIPTION, processor.getDescription()); + assertEquals(PruneType.TOP_K, processor.getPruneType()); + assertEquals(2f, processor.getPruneRatio(), 1e-6); + } + + @SneakyThrows + public void testCreateProcessor_whenEmptyFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of()); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unable to create the processor as field_map has invalid key or value", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingModelIdField_thenFail() { + Map config = new HashMap<>(); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[model_id] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingFieldMapField_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[field_map] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "invalid_prune_type"); + config.put(PRUNE_RATIO_FIELD, 2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Unknown prune type: invalid_prune_type", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenInvalidPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "top_k"); + config.put(PRUNE_RATIO_FIELD, 0.2f); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("Illegal prune_ratio 0.200000 for prune_type: top_k. prune_ratio should be positive integer.", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneRatio_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_TYPE_FIELD, "alpha_mass"); + + OpenSearchParseException exception = assertThrows( + OpenSearchParseException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("[prune_ratio] required property is missing", exception.getMessage()); + } + + @SneakyThrows + public void testCreateProcessor_whenMissingPruneType_thenFail() { + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, MODEL_ID); + config.put(FIELD_MAP_FIELD, Map.of("a", "b")); + config.put(PRUNE_RATIO_FIELD, 0.1); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> sparseEncodingProcessorFactory.create(Map.of(), PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("prune_ratio field is not supported when prune_type is not provided", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 7509efd42..2c4c88871 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -52,6 +52,7 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; +import org.opensearch.neuralsearch.util.prune.PruneType; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -649,6 +650,44 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); } + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierNull_andPruneSet_thenSuceessPrune() { + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .twoPhaseSharedQueryToken(Map.of()) + .twoPhasePruneRatio(3.0f) + .twoPhasePruneType(PruneType.ABS_VALUE); + Map expectedMap = Map.of("1", 1f, "2", 5f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.queryTokensSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertEquals(Map.of("2", 5f), queryBuilder.queryTokensSupplier().get()); + assertEquals(Map.of("1", 1f), queryBuilder.twoPhaseSharedQueryToken()); + } + @SneakyThrows public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME) diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java new file mode 100644 index 000000000..f8ba5b604 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneTypeTests.java @@ -0,0 +1,30 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.prune; + +import org.opensearch.test.OpenSearchTestCase; + +public class PruneTypeTests extends OpenSearchTestCase { + public void testGetValue() { + assertEquals("none", PruneType.NONE.getValue()); + assertEquals("top_k", PruneType.TOP_K.getValue()); + assertEquals("alpha_mass", PruneType.ALPHA_MASS.getValue()); + assertEquals("max_ratio", PruneType.MAX_RATIO.getValue()); + assertEquals("abs_value", PruneType.ABS_VALUE.getValue()); + } + + public void testFromString() { + assertEquals(PruneType.NONE, PruneType.fromString("none")); + assertEquals(PruneType.NONE, PruneType.fromString(null)); + assertEquals(PruneType.NONE, PruneType.fromString("")); + assertEquals(PruneType.TOP_K, PruneType.fromString("top_k")); + assertEquals(PruneType.ALPHA_MASS, PruneType.fromString("alpha_mass")); + assertEquals(PruneType.MAX_RATIO, PruneType.fromString("max_ratio")); + assertEquals(PruneType.ABS_VALUE, PruneType.fromString("abs_value")); + + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneType.fromString("test_value")); + assertEquals("Unknown prune type: test_value", exception.getMessage()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java new file mode 100644 index 000000000..536125152 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/prune/PruneUtilsTests.java @@ -0,0 +1,266 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.util.prune; + +import org.opensearch.common.collect.Tuple; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +public class PruneUtilsTests extends OpenSearchTestCase { + + public void testPruneByTopK() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input); + + assertEquals(2, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.TOP_K, 2, input); + + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(4.0f, tupleResult.v1().get("c"), 0.001); + assertEquals(3.0f, tupleResult.v2().get("b"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); + } + + public void testPruneByMaxRatio() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 8.0f); + input.put("c", 5.0f); + input.put("d", 2.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.MAX_RATIO, 0.7f, input); + + assertEquals(2, result.size()); + assertEquals(10.0f, result.get("a"), 0.001); + assertEquals(8.0f, result.get("b"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.MAX_RATIO, 0.7f, input); + + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(10.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(8.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(5.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(2.0f, tupleResult.v2().get("d"), 0.001); + } + + public void testPruneByValue() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 2.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.ABS_VALUE, 3.0f, input); + + assertEquals(2, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(3.0f, result.get("b"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.ABS_VALUE, 3.0f, input); + + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(3.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(2.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); + } + + public void testPruneByAlphaMass() { + Map input = new HashMap<>(); + input.put("a", 10.0f); + input.put("b", 6.0f); + input.put("c", 3.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.ALPHA_MASS, 0.8f, input); + + assertEquals(2, result.size()); + assertEquals(10.0f, result.get("a"), 0.001); + assertEquals(6.0f, result.get("b"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.ALPHA_MASS, 0.8f, input); + + assertEquals(2, tupleResult.v1().size()); + assertEquals(2, tupleResult.v2().size()); + assertEquals(10.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(6.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(3.0f, tupleResult.v2().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v2().get("d"), 0.001); + } + + public void testNonePrune() { + Map input = new HashMap<>(); + input.put("a", 5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + input.put("d", 1.0f); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, result.size()); + assertEquals(5.0f, result.get("a"), 0.001); + assertEquals(3.0f, result.get("b"), 0.001); + assertEquals(4.0f, result.get("c"), 0.001); + assertEquals(1.0f, result.get("d"), 0.001); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.NONE, 2, input); + + assertEquals(4, tupleResult.v1().size()); + assertEquals(0, tupleResult.v2().size()); + assertEquals(5.0f, tupleResult.v1().get("a"), 0.001); + assertEquals(3.0f, tupleResult.v1().get("b"), 0.001); + assertEquals(4.0f, tupleResult.v1().get("c"), 0.001); + assertEquals(1.0f, tupleResult.v1().get("d"), 0.001); + } + + public void testEmptyInput() { + Map input = new HashMap<>(); + + // Test prune + Map result = PruneUtils.pruneSparseVector(PruneType.TOP_K, 5, input); + assertTrue(result.isEmpty()); + + // Test split + Tuple, Map> tupleResult = PruneUtils.splitSparseVector(PruneType.TOP_K, 5, input); + assertTrue(tupleResult.v1().isEmpty()); + assertTrue(tupleResult.v2().isEmpty()); + } + + public void testNegativeValues() { + Map input = new HashMap<>(); + input.put("a", -5.0f); + input.put("b", 3.0f); + input.put("c", 4.0f); + + // Test prune + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, input) + ); + assertEquals("Pruned values must be positive", exception1.getMessage()); + + // Test split + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, input) + ); + assertEquals("Pruned values must be positive", exception2.getMessage()); + } + + public void testInvalidPruneType() { + Map input = new HashMap<>(); + input.put("a", 1.0f); + input.put("b", 2.0f); + + // Test prune + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruneSparseVector(null, 2, input) + ); + assertEquals(exception1.getMessage(), "Prune type must be provided"); + + // Test split + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(null, 2, input) + ); + assertEquals(exception2.getMessage(), "Prune type must be provided"); + } + + public void testNullSparseVector() { + IllegalArgumentException exception1 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.pruneSparseVector(PruneType.TOP_K, 2, null) + ); + assertEquals(exception1.getMessage(), "Sparse vector must be provided"); + + IllegalArgumentException exception2 = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.splitSparseVector(PruneType.TOP_K, 2, null) + ); + assertEquals(exception2.getMessage(), "Sparse vector must be provided"); + } + + public void testIsValidPruneRatio() { + // Test TOP_K validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 100)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, -1)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.TOP_K, 1.5f)); + + // Test ALPHA_MASS validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 0)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, 1.1f)); + + // Test MAX_RATIO validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 0.5f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, -0.1f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, 1.1f)); + + // Test ABS_VALUE validation + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 0.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 1.0f)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, 100.0f)); + assertFalse(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, -0.1f)); + + // Test with extreme cases + assertTrue(PruneUtils.isValidPruneRatio(PruneType.TOP_K, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ABS_VALUE, Float.MAX_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.ALPHA_MASS, Float.MIN_VALUE)); + assertTrue(PruneUtils.isValidPruneRatio(PruneType.MAX_RATIO, Float.MIN_VALUE)); + } + + public void testIsValidPruneRatioWithNullType() { + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, () -> PruneUtils.isValidPruneRatio(null, 1.0f)); + assertEquals("Prune type cannot be null", exception.getMessage()); + } + + public void testGetValidPruneRatioDescription() { + assertEquals("prune_ratio should be positive integer.", PruneUtils.getValidPruneRatioDescription(PruneType.TOP_K)); + assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.MAX_RATIO)); + assertEquals("prune_ratio should be in the range [0, 1).", PruneUtils.getValidPruneRatioDescription(PruneType.ALPHA_MASS)); + assertEquals("prune_ratio should be non-negative.", PruneUtils.getValidPruneRatioDescription(PruneType.ABS_VALUE)); + assertEquals( + "prune_ratio field is not supported when prune_type is none", + PruneUtils.getValidPruneRatioDescription(PruneType.NONE) + ); + + IllegalArgumentException exception = assertThrows( + IllegalArgumentException.class, + () -> PruneUtils.getValidPruneRatioDescription(null) + ); + assertEquals(exception.getMessage(), "Prune type cannot be null"); + } +} diff --git a/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json b/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json new file mode 100644 index 000000000..642228e06 --- /dev/null +++ b/src/test/resources/processor/SparseEncodingPipelineConfigurationWithPrune.json @@ -0,0 +1,21 @@ +{ + "description": "An example sparse Encoding pipeline", + "processors" : [ + { + "sparse_encoding": { + "model_id": "%s", + "batch_size": "%d", + "prune_type": "max_ratio", + "prune_ratio": 0.8, + "field_map": { + "title": "title_sparse", + "favor_list": "favor_list_sparse", + "favorites": { + "game": "game_sparse", + "movie": "movie_sparse" + } + } + } + } + ] +} diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json index 5c2a73f6e..6bdac87c5 100644 --- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -1,10 +1,6 @@ { - "name": "tokenize-idf-0915", - "version": "1.0.0", - "function_name": "SPARSE_TOKENIZE", - "description": "test model", - "model_format": "TORCH_SCRIPT", + "name": "amazon/neural-sparse/opensearch-neural-sparse-tokenizer-v1", + "version": "1.0.1", "model_group_id": "%s", - "model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a", - "url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip" + "model_format": "TORCH_SCRIPT" } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 4f154e78b..08628c247 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -87,7 +87,9 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json", ProcessorType.TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, - "processor/PipelineConfigurationWithNestedFieldsMapping.json" + "processor/PipelineConfigurationWithNestedFieldsMapping.json", + ProcessorType.SPARSE_ENCODING_PRUNE, + "processor/SparseEncodingPipelineConfigurationWithPrune.json" ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; @@ -1466,6 +1468,7 @@ protected enum ProcessorType { TEXT_EMBEDDING, TEXT_EMBEDDING_WITH_NESTED_FIELDS_MAPPING, TEXT_IMAGE_EMBEDDING, - SPARSE_ENCODING + SPARSE_ENCODING, + SPARSE_ENCODING_PRUNE } }