From 7bef7a0c955922d97cb025b65e4c8cbccbba283c Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 27 Sep 2023 13:47:02 +0800 Subject: [PATCH] Supporting sparse semantic retrieval in neural search (#333) * sparse mapper field and query builder Signed-off-by: zhichao-aws * fix typo Signed-off-by: zhichao-aws * Add map result support in neural search for non text embedding models Signed-off-by: zane-neo * Fix compilation failure issue Signed-off-by: zane-neo * Add more UTs Signed-off-by: zane-neo * add sparse encoding processor Signed-off-by: xinyual * add sparse encoding processor Signed-off-by: xinyual * remove guava in gradle Signed-off-by: xinyual * modify access control Signed-off-by: xinyual * Add map result support in neural search for non text embedding models Signed-off-by: zane-neo * Fix compilation failure issue Signed-off-by: zane-neo * change output logic Signed-off-by: xinyual * create abstract Signed-off-by: xinyual * create abstract proccesor Signed-off-by: xinyual * add abstract class Signed-off-by: xinyual * remove duplicate code Signed-off-by: xinyual * remove duplicate code Signed-off-by: xinyual * remove dl process Signed-off-by: xinyual * move static to abstract class Signed-off-by: xinyual * update query rewrite logic Signed-off-by: zhichao-aws * modify header Signed-off-by: zhichao-aws * merge conflict Signed-off-by: xinyual * delete index mapper, change to rank_features Signed-off-by: zhichao-aws * remove unused import Signed-off-by: zhichao-aws * list return result Signed-off-by: zhichao-aws * refactor type and listTypeNestedMapKey, tidy Signed-off-by: zhichao-aws * forbid nested input. tidy. Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * enable nested Signed-off-by: zhichao-aws * fix test Signed-off-by: zhichao-aws * Add ut it to sparse encoding processor (#6) * fix original UT problem Signed-off-by: xinyual * add UT IT Signed-off-by: xinyual * add more UT Signed-off-by: xinyual * add more ut Signed-off-by: xinyual * fix typo error Signed-off-by: xinyual --------- Signed-off-by: xinyual * utils, tidy Signed-off-by: zhichao-aws * rename to sparse_encoding query Signed-off-by: zhichao-aws * add validation and ut Signed-off-by: zhichao-aws * sparse encoding query builder ut Signed-off-by: zhichao-aws * rename Signed-off-by: zhichao-aws * UT for utils Signed-off-by: zhichao-aws * enrich sparse encoding IT mappings Signed-off-by: zhichao-aws * add it Signed-off-by: zhichao-aws * add it Signed-off-by: zhichao-aws * add integ test Signed-off-by: zhichao-aws * rename resource file Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * remove BoundedLinearQuery and TokenScoreUpperBound Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * add delta to loose the equal Signed-off-by: zhichao-aws * move SparseEncodingQueryBuilder to upper level path Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * add it Signed-off-by: zhichao-aws * Update src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java Co-authored-by: zane-neo Signed-off-by: zhichao-aws * Update src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java Co-authored-by: zane-neo Signed-off-by: zhichao-aws * restore gradle.propeties Signed-off-by: zhichao-aws * add release notes Signed-off-by: zhichao-aws * change field modifier to private for NLPProcessor Signed-off-by: zhichao-aws * add comments Signed-off-by: zhichao-aws * use StringUtils to check Signed-off-by: zhichao-aws * null check Signed-off-by: zhichao-aws * modify changelog Signed-off-by: zhichao-aws * nit Signed-off-by: zhichao-aws * nit Signed-off-by: zhichao-aws * remove query tokens from user interface Signed-off-by: zhichao-aws * fix test Signed-off-by: zhichao-aws * tidy Signed-off-by: zhichao-aws * update function name Signed-off-by: zhichao-aws * add javadoc Signed-off-by: zhichao-aws * remove debug log including inference result Signed-off-by: zhichao-aws * make query text and model id required Signed-off-by: zhichao-aws * minor changes based on comments Signed-off-by: zhichao-aws * add locale to String.format Signed-off-by: zhichao-aws * update mock model url Signed-off-by: zhichao-aws --------- Signed-off-by: zhichao-aws Signed-off-by: zane-neo Signed-off-by: xinyual Co-authored-by: zane-neo Co-authored-by: xinyual --- CHANGELOG.md | 1 + build.gradle | 2 + .../ml/MLCommonsClientAccessor.java | 55 ++- .../neuralsearch/plugin/NeuralSearch.java | 17 +- .../neuralsearch/processor/NLPProcessor.java | 325 ++++++++++++++ .../processor/SparseEncodingProcessor.java | 53 +++ .../processor/TextEmbeddingProcessor.java | 303 +------------ .../SparseEncodingProcessorFactory.java | 46 ++ .../TextEmbeddingProcessorFactory.java | 3 + .../query/SparseEncodingQueryBuilder.java | 275 ++++++++++++ .../neuralsearch/util/TokenWeightUtil.java | 78 ++++ .../opensearch/neuralsearch/TestUtils.java | 17 + .../common/BaseNeuralSearchIT.java | 16 +- .../common/BaseSparseEncodingIT.java | 139 ++++++ .../ml/MLCommonsClientAccessorTests.java | 127 ++++++ .../processor/SparseEncodingProcessIT.java | 94 ++++ .../SparseEncodingProcessorTests.java | 167 ++++++++ .../TextEmbeddingProcessorTests.java | 18 +- .../SparseEncodingQueryBuilderTests.java | 404 ++++++++++++++++++ .../query/SparseEncodingQueryIT.java | 286 +++++++++++++ .../util/TokenWeightUtilTests.java | 108 +++++ .../SparseEncodingIndexMappings.json | 26 ++ .../SparseEncodingPipelineConfiguration.json | 18 + .../UploadSparseEncodingModelRequestBody.json | 10 + 24 files changed, 2273 insertions(+), 315 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java create mode 100644 src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java create mode 100644 src/test/resources/processor/SparseEncodingIndexMappings.json create mode 100644 src/test/resources/processor/SparseEncodingPipelineConfiguration.json create mode 100644 src/test/resources/processor/UploadSparseEncodingModelRequestBody.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 1b3ea389e..da2ae9ec9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x) ### Features +Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/build.gradle b/build.gradle index 1d8eca483..853aa85e7 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,8 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' + runtimeOnly group: 'org.json', name: 'json', version: '20230227' } // In order to add the jar to the classpath, we need to unzip the diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 768584ec9..6f8b790bb 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import lombok.NonNull; @@ -15,6 +16,7 @@ import lombok.extern.log4j.Log4j2; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; @@ -100,10 +102,38 @@ public void inferenceSentences( @NonNull final List inputText, @NonNull final ActionListener>> listener ) { - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, 0, listener); } - private void inferenceSentencesWithRetry( + public void inferenceSentencesWithMapResult( + @NonNull final String modelId, + @NonNull final List inputText, + @NonNull final ActionListener>> listener + ) { + retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); + } + + private void retryableInferenceSentencesWithMapResult( + final String modelId, + final List inputText, + final int retryTime, + final ActionListener>> listener + ) { + MLInput mlInput = createMLInput(null, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> result = buildMapResultFromResponse(mlOutput); + listener.onResponse(result); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + retryableInferenceSentencesWithMapResult(modelId, inputText, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private void retryableInferenceSentencesWithVectorResult( final List targetResponseFilters, final String modelId, final List inputText, @@ -113,12 +143,11 @@ private void inferenceSentencesWithRetry( MLInput mlInput = createMLInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); listener.onResponse(vector); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { final int retryTimeAdd = retryTime + 1; - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); + retryableInferenceSentencesWithVectorResult(targetResponseFilters, modelId, inputText, retryTimeAdd, listener); } else { listener.onFailure(e); } @@ -144,4 +173,22 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } + private List> buildMapResultFromResponse(MLOutput mlOutput) { + final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; + final List tensorOutputList = modelTensorOutput.getMlModelOutputs(); + if (CollectionUtils.isEmpty(tensorOutputList) || CollectionUtils.isEmpty(tensorOutputList.get(0).getMlModelTensors())) { + throw new IllegalStateException( + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]" + ); + } + List> resultMaps = new ArrayList<>(); + for (ModelTensors tensors : tensorOutputList) { + List tensorList = tensors.getMlModelTensors(); + for (ModelTensor tensor : tensorList) { + resultMaps.add(tensor.getDataAsMap()); + } + } + return resultMaps; + } + } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e94a2957d..2ac8853e4 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -9,7 +9,6 @@ import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -31,15 +30,18 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ExtensiblePlugin; @@ -62,7 +64,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private MLCommonsClientAccessor clientAccessor; private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); - private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory();; + private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); @Override public Collection createComponents( @@ -79,6 +81,7 @@ public Collection createComponents( final Supplier repositoriesServiceSupplier ) { NeuralQueryBuilder.initialize(clientAccessor); + SparseEncodingQueryBuilder.initialize(clientAccessor); normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()); return List.of(clientAccessor); } @@ -87,14 +90,20 @@ public Collection createComponents( public List> getQueries() { return Arrays.asList( new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent), - new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent) + new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent), + new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent) ); } @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); + return Map.of( + TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + SparseEncodingProcessor.TYPE, + new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java new file mode 100644 index 000000000..4ac63d419 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java @@ -0,0 +1,325 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + +/** + * The abstract class for text processing use cases. Users provide a field name map and a model id. + * During ingestion, the processor will use the corresponding model to inference the input texts, + * and set the target fields according to the field name map. + */ +@Log4j2 +public abstract class NLPProcessor extends AbstractProcessor { + + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private final String type; + + // This field is used for nested knn_vector/rank_features field. The value of the field will be used as the + // default key for the nested object. + private final String listTypeNestedMapKey; + + protected final String modelId; + + private final Map fieldMap; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public NLPProcessor( + String tag, + String description, + String type, + String listTypeNestedMapKey, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + this.type = type; + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); + validateEmbeddingConfiguration(fieldMap); + + this.listTypeNestedMapKey = listTypeNestedMapKey; + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.size() == 0 + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); + } + } + + public abstract void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ); + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + try { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + doExecute(ingestDocument, ProcessMap, inferenceList, handler); + } + } catch (Exception e) { + handler.accept(null, e); + } + } + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + @VisibleForTesting + Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithProcessorKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + Map treeRes = new LinkedHashMap<>(); + buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); + } else { + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithProcessorKeys; + } + + private void buildMapWithProcessorKeyAndOriginalValueForMapType( + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (processorKey == null || sourceAndMetadataMap == null) return; + if (processorKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(processorKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + } + + protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { + Objects.requireNonNull(results, "embedding failed, inference returns null result!"); + log.debug("Model inference result fetched, starting build vector output!"); + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + nlpResult.forEach(ingestDocument::setFieldValue); + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { + NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : processorMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof String) { + result.put(knnKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + } else if (sourceValue instanceof Map) { + putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + } + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putNLPResultToSourceMapForMapType( + String processorKey, + Object sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putNLPResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + } + } + + private List> buildNLPResultForListType( + List sourceValue, + List results, + NLPProcessor.IndexWrapper indexWrapper + ) { + List> keyToResult = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + return keyToResult; + } + + @Override + public String getType() { + return type; + } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java new file mode 100644 index 000000000..275117809 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -0,0 +1,53 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.TokenWeightUtil; + +/** + * This processor is used for user input data text sparse encoding processing, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results. + */ +@Log4j2 +public final class SparseEncodingProcessor extends NLPProcessor { + + public static final String TYPE = "sparse_encoding"; + public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding"; + + public SparseEncodingProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); + } + + @Override + public void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ) { + mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index 878a410a8..1df60baea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -5,49 +5,26 @@ package org.opensearch.neuralsearch.processor; -import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Objects; import java.util.function.BiConsumer; -import java.util.function.Supplier; -import java.util.stream.IntStream; import lombok.extern.log4j.Log4j2; -import org.apache.commons.lang3.StringUtils; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; - /** * This processor is used for user input data text embedding processing, model_id can be used to indicate which model user use, * and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results. */ @Log4j2 -public class TextEmbeddingProcessor extends AbstractProcessor { +public final class TextEmbeddingProcessor extends NLPProcessor { public static final String TYPE = "text_embedding"; - public static final String MODEL_ID_FIELD = "model_id"; - public static final String FIELD_MAP_FIELD = "field_map"; - - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - - @VisibleForTesting - private final String modelId; - - private final Map fieldMap; - - private final MLCommonsClientAccessor mlCommonsClientAccessor; - - private final Environment environment; + public static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; public TextEmbeddingProcessor( String tag, @@ -57,275 +34,19 @@ public TextEmbeddingProcessor( MLCommonsClientAccessor clientAccessor, Environment environment ) { - super(tag, description); - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); - validateEmbeddingConfiguration(fieldMap); - - this.modelId = modelId; - this.fieldMap = fieldMap; - this.mlCommonsClientAccessor = clientAccessor; - this.environment = environment; - } - - private void validateEmbeddingConfiguration(Map fieldMap) { - if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() - .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { - throw new IllegalArgumentException("Unable to create the TextEmbedding processor as field_map has invalid key or value"); - } - } - - @Override - public IngestDocument execute(IngestDocument ingestDocument) { - return ingestDocument; + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); } - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - // When received a bulk indexing request, the pipeline will be executed in this method, (see - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). - // Before the pipeline execution, the pipeline will be marked as resolved (means executed), - // and then this overriding method will be invoked when executing the text embedding processor. - // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. - try { - validateEmbeddingFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(knnMap); - if (inferenceList.size() == 0) { - handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, knnMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); - } - } catch (Exception e) { - handler.accept(null, e); - } - - } - - void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { - Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); - log.debug("Text embedding result fetched, starting build vector output!"); - Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); - textEmbeddingResult.forEach(ingestDocument::setFieldValue); - } - - @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); - } - } - - @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithKnnKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); - } else { - mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); - } - } - return mapWithKnnKeys; - } - - @SuppressWarnings({ "unchecked" }) - private void buildMapWithKnnKeyAndOriginalValueForMapType( - String parentKey, - Object knnKey, - Map sourceAndMetadataMap, - Map treeRes + public void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler ) { - if (knnKey == null || sourceAndMetadataMap == null) return; - if (knnKey instanceof Map) { - Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { - buildMapWithKnnKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); - } - treeRes.put(parentKey, next); - } else { - String key = String.valueOf(knnKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); - } + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); } - - @SuppressWarnings({ "unchecked" }) - @VisibleForTesting - Map buildTextEmbeddingResult( - Map knnMap, - List> modelTensorList, - Map sourceAndMetadataMap - ) { - IndexWrapper indexWrapper = new IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : knnMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof String) { - List modelTensor = modelTensorList.get(indexWrapper.index++); - result.put(knnKey, modelTensor); - } else if (sourceValue instanceof List) { - result.put(knnKey, buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper)); - } else if (sourceValue instanceof Map) { - putTextEmbeddingResultToSourceMapForMapType(knnKey, sourceValue, modelTensorList, indexWrapper, sourceAndMetadataMap); - } - } - return result; - } - - @SuppressWarnings({ "unchecked" }) - private void putTextEmbeddingResultToSourceMapForMapType( - String knnKey, - Object sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putTextEmbeddingResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - modelTensorList, - indexWrapper, - (Map) sourceAndMetadataMap.get(knnKey) - ); - } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - knnKey, - buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) - ); - } - } - - private List>> buildTextEmbeddingResultForListType( - List sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper - ) { - List>> numbers = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); - return numbers; - } - - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); - } - } - } - - @Override - public String getType() { - return TYPE; - } - - /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase - * the index pointer during the recursive. - * index: the index pointer of the text embedding result. - */ - static class IndexWrapper { - private int index; - - protected IndexWrapper(int index) { - this.index = index; - } - } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java new file mode 100644 index 000000000..104418ec5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.*; + +import java.util.Map; + +import lombok.extern.log4j.Log4j2; + +import org.opensearch.env.Environment; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; + +/** + * Factory for sparse encoding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ +@Log4j2 +public class SparseEncodingProcessorFactory implements Processor.Factory { + private final MLCommonsClientAccessor clientAccessor; + private final Environment environment; + + public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public SparseEncodingProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); + Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + + return new SparseEncodingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index f805b29e1..0c9a6fa2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -16,6 +16,9 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +/** + * Factory for text embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java new file mode 100644 index 000000000..4b8b6f0d4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java @@ -0,0 +1,275 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.function.Supplier; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.Setter; +import lombok.experimental.Accessors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang.StringUtils; +import org.apache.commons.lang.builder.EqualsBuilder; +import org.apache.commons.lang.builder.HashCodeBuilder; +import org.apache.lucene.document.FeatureField; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Query; +import org.opensearch.common.SetOnce; +import org.opensearch.core.ParseField; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.mapper.MappedFieldType; +import org.opensearch.index.query.AbstractQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.TokenWeightUtil; + +import com.google.common.annotations.VisibleForTesting; + +/** + * SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model + * or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed + * to Lucene FeatureQuery wrapped by Lucene BooleanQuery. + */ + +@Log4j2 +@Getter +@Setter +@Accessors(chain = true, fluent = true) +@NoArgsConstructor +@AllArgsConstructor +public class SparseEncodingQueryBuilder extends AbstractQueryBuilder { + public static final String NAME = "sparse_encoding"; + @VisibleForTesting + static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting + static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); + + private static MLCommonsClientAccessor ML_CLIENT; + + public static void initialize(MLCommonsClientAccessor mlClient) { + SparseEncodingQueryBuilder.ML_CLIENT = mlClient; + } + + private String fieldName; + private String queryText; + private String modelId; + private Supplier> queryTokensSupplier; + + /** + * Constructor from stream input + * + * @param in StreamInput to initialize object from + * @throws IOException thrown if unable to read from input stream + */ + public SparseEncodingQueryBuilder(StreamInput in) throws IOException { + super(in); + this.fieldName = in.readString(); + this.queryText = in.readString(); + this.modelId = in.readString(); + } + + @Override + protected void doWriteTo(StreamOutput out) throws IOException { + out.writeString(fieldName); + out.writeString(queryText); + out.writeString(modelId); + } + + @Override + protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException { + xContentBuilder.startObject(NAME); + xContentBuilder.startObject(fieldName); + xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); + xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + printBoostAndQueryName(xContentBuilder); + xContentBuilder.endObject(); + xContentBuilder.endObject(); + } + + /** + * The expected parsing form looks like: + * "SAMPLE_FIELD": { + * "query_text": "string", + * "model_id": "string" + * } + * + * @param parser XContentParser + * @return NeuralQueryBuilder + * @throws IOException can be thrown by parser + */ + public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder(); + if (parser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT"); + } + parser.nextToken(); + sparseEncodingQueryBuilder.fieldName(parser.currentName()); + parser.nextToken(); + parseQueryParams(parser, sparseEncodingQueryBuilder); + if (parser.nextToken() != XContentParser.Token.END_OBJECT) { + throw new ParsingException( + parser.getTokenLocation(), + String.format( + Locale.ROOT, + "[%s] query doesn't support multiple fields, found [%s] and [%s]", + NAME, + sparseEncodingQueryBuilder.fieldName(), + parser.currentName() + ) + ); + } + + requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query"); + requireValue( + sparseEncodingQueryBuilder.queryText(), + String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) + ); + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + ); + + return sparseEncodingQueryBuilder; + } + + private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException { + XContentParser.Token token; + String currentFieldName = ""; + while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + currentFieldName = parser.currentName(); + } else if (token.isValue()) { + if (NAME_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.queryName(parser.text()); + } else if (BOOST_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.boost(parser.floatValue()); + } else if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.queryText(parser.text()); + } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + sparseEncodingQueryBuilder.modelId(parser.text()); + } else { + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName) + ); + } + } else { + throw new ParsingException( + parser.getTokenLocation(), + String.format(Locale.ROOT, "[%s] unknown token [%s] after [%s]", NAME, token, currentFieldName) + ); + } + } + } + + @Override + protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws IOException { + // We need to inference the sentence to get the queryTokens. The logic is similar to NeuralQueryBuilder + // If the inference is finished, then rewrite to self and call doToQuery, otherwise, continue doRewrite + if (null != queryTokensSupplier) { + return this; + } + + validateForRewrite(queryText, modelId); + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryRewriteContext.registerAsyncAction( + ((client, actionListener) -> ML_CLIENT.inferenceSentencesWithMapResult( + modelId(), + List.of(queryText), + ActionListener.wrap(mapResultList -> { + queryTokensSetOnce.set(TokenWeightUtil.fetchListOfTokenWeightMap(mapResultList).get(0)); + actionListener.onResponse(null); + }, actionListener::onFailure) + )) + ); + return new SparseEncodingQueryBuilder().fieldName(fieldName) + .queryText(queryText) + .modelId(modelId) + .queryTokensSupplier(queryTokensSetOnce::get); + } + + @Override + protected Query doToQuery(QueryShardContext context) throws IOException { + final MappedFieldType ft = context.fieldMapper(fieldName); + validateFieldType(ft); + + Map queryTokens = queryTokensSupplier.get(); + validateQueryTokens(queryTokens); + + BooleanQuery.Builder builder = new BooleanQuery.Builder(); + for (Map.Entry entry : queryTokens.entrySet()) { + builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD); + } + return builder.build(); + } + + private static void validateForRewrite(String queryText, String modelId) { + if (StringUtils.isBlank(queryText) || StringUtils.isBlank(modelId)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "%s and %s cannot be null", + QUERY_TEXT_FIELD.getPreferredName(), + MODEL_ID_FIELD.getPreferredName() + ) + ); + } + } + + private static void validateFieldType(MappedFieldType fieldType) { + if (null == fieldType || !fieldType.typeName().equals("rank_features")) { + throw new IllegalArgumentException("[" + NAME + "] query only works on [rank_features] fields"); + } + } + + private static void validateQueryTokens(Map queryTokens) { + if (null == queryTokens) { + throw new IllegalArgumentException("Query tokens cannot be null."); + } + for (Map.Entry entry : queryTokens.entrySet()) { + if (entry.getValue() <= 0) { + throw new IllegalArgumentException( + "Feature weight must be larger than 0, feature [" + entry.getValue() + "] has negative weight." + ); + } + } + } + + @Override + protected boolean doEquals(SparseEncodingQueryBuilder obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) + .append(queryText, obj.queryText) + .append(modelId, obj.modelId); + return equalsBuilder.isEquals(); + } + + @Override + protected int doHashCode() { + return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode(); + } + + @Override + public String getWriteableName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java new file mode 100644 index 000000000..76ce0fa16 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * Utility class for working with sparse_encoding queries and ingest processor. + * Used to fetch the (token, weight) Map from the response returned by {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} + * + */ + +public class TokenWeightUtil { + public static String RESPONSE_KEY = "response"; + + /** + * possible input data format + * case remote inference: + * [{ + * "response":{ + * [ + * { TOKEN_WEIGHT_MAP}, + * { TOKEN_WEIGHT_MAP} + * ] + * } + * }] + * case local deploy: + * [{"response":{ + * [ + * { TOKEN_WEIGHT_MAP} + * ] + * } + * },{"response":{ + * [ + * { TOKEN_WEIGHT_MAP} + * ] + * }] + * + * @param mapResultList {@link Map} which is the response from {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor} + */ + public static List> fetchListOfTokenWeightMap(List> mapResultList) { + if (null == mapResultList || mapResultList.isEmpty()) { + throw new IllegalArgumentException("The inference result can not be null or empty."); + } + List results = new ArrayList<>(); + for (Map map : mapResultList) { + if (!map.containsKey(RESPONSE_KEY)) { + throw new IllegalArgumentException("The inference result should be associated with the field [" + RESPONSE_KEY + "]."); + } + if (!List.class.isAssignableFrom(map.get(RESPONSE_KEY).getClass())) { + throw new IllegalArgumentException("The data object associated with field [" + RESPONSE_KEY + "] should be a list."); + } + results.addAll((List) map.get("response")); + } + return results.stream().map(TokenWeightUtil::buildTokenWeightMap).collect(Collectors.toList()); + } + + private static Map buildTokenWeightMap(Object uncastedMap) { + if (!Map.class.isAssignableFrom(uncastedMap.getClass())) { + throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); + } + Map result = new HashMap<>(); + for (Map.Entry entry : ((Map) uncastedMap).entrySet()) { + if (!String.class.isAssignableFrom(entry.getKey().getClass()) || !Number.class.isAssignableFrom(entry.getValue().getClass())) { + throw new IllegalArgumentException("The expected inference result is a Map with String keys and Float values."); + } + result.put((String) entry.getKey(), ((Number) entry.getValue()).floatValue()); + } + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index 3b131b886..385855a2e 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -13,6 +13,8 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -72,6 +74,21 @@ public static float[] createRandomVector(int dimension) { return vector; } + /** + * Create a map of provided tokens, the values will be random float numbers + * + * @param tokens of the created map keys + * @return token weight map with random weight > 0 + */ + public static Map createRandomTokenWeightMap(Collection tokens) { + Map resultMap = new HashMap<>(); + for (String token : tokens) { + // use a small shift to ensure value > 0 + resultMap.put(token, Math.abs(randomFloat()) + 1e-3f); + } + return resultMap; + } + /** * Assert results of hybrid query after score normalization and combination * @param querySearchResults collection of query search results after they processed by normalization processor diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index b144ade6c..84672d479 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -54,18 +54,24 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { - private static final Locale LOCALE = Locale.ROOT; + protected static final Locale LOCALE = Locale.ROOT; private static final int MAX_TASK_RESULT_QUERY_TIME_IN_SECOND = 60 * 5; private static final int DEFAULT_TASK_RESULT_QUERY_INTERVAL_IN_MILLISECOND = 1000; - private static final String DEFAULT_USER_AGENT = "Kibana"; + protected static final String DEFAULT_USER_AGENT = "Kibana"; protected static final String DEFAULT_NORMALIZATION_METHOD = "min_max"; protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; + protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; + protected final ClassLoader classLoader = this.getClass().getClassLoader(); + protected void setPipelineConfigurationName(String pipelineConfigurationName) { + this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; + } + @Before public void setupSettings() { if (isUpdateClusterSettings()) { @@ -237,11 +243,7 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro "/_ingest/pipeline/" + pipelineName, null, toHttpEntity( - String.format( - LOCALE, - Files.readString(Path.of(classLoader.getResource("processor/PipelineConfiguration.json").toURI())), - modelId - ) + String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), modelId) ), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) ); diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java new file mode 100644 index 000000000..d0231cfe6 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseSparseEncodingIT.java @@ -0,0 +1,139 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.common; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.opensearch.client.Request; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.neuralsearch.util.TokenWeightUtil; + +import com.google.common.collect.ImmutableList; + +public abstract class BaseSparseEncodingIT extends BaseNeuralSearchIT { + + @SneakyThrows + @Override + protected String prepareModel() { + String requestBody = Files.readString( + Path.of(classLoader.getResource("processor/UploadSparseEncodingModelRequestBody.json").toURI()) + ); + String modelId = uploadModel(requestBody); + loadModel(modelId); + return modelId; + } + + @SneakyThrows + protected void prepareSparseEncodingIndex(String indexName, List sparseEncodingFieldNames) { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder().startObject().startObject("mappings").startObject("properties"); + + for (String fieldName : sparseEncodingFieldNames) { + xContentBuilder.startObject(fieldName).field("type", "rank_features").endObject(); + } + + xContentBuilder.endObject().endObject().endObject(); + String indexMappings = xContentBuilder.toString(); + createIndexWithConfiguration(indexName, indexMappings, ""); + } + + @SneakyThrows + protected void addSparseEncodingDoc(String index, String docId, List fieldNames, List> docs) { + addSparseEncodingDoc(index, docId, fieldNames, docs, Collections.emptyList(), Collections.emptyList()); + } + + @SneakyThrows + protected void addSparseEncodingDoc( + String index, + String docId, + List fieldNames, + List> docs, + List textFieldNames, + List texts + ) { + Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + for (int i = 0; i < fieldNames.size(); i++) { + builder.field(fieldNames.get(i), docs.get(i)); + } + + for (int i = 0; i < textFieldNames.size(); i++) { + builder.field(textFieldNames.get(i), texts.get(i)); + } + builder.endObject(); + + request.setJsonEntity(builder.toString()); + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.CREATED, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + } + + protected float computeExpectedScore(String modelId, Map tokenWeightMap, String queryText) { + Map queryTokens = runSparseModelInference(modelId, queryText); + return computeExpectedScore(tokenWeightMap, queryTokens); + } + + protected float computeExpectedScore(Map tokenWeightMap, Map queryTokens) { + Float score = 0f; + for (Map.Entry entry : queryTokens.entrySet()) { + if (tokenWeightMap.containsKey(entry.getKey())) { + score += entry.getValue() * getFeatureFieldCompressedNumber(tokenWeightMap.get(entry.getKey())); + } + } + return score; + } + + @SneakyThrows + protected Map runSparseModelInference(String modelId, String queryText) { + Response inferenceResponse = makeRequest( + client(), + "POST", + String.format(LOCALE, "/_plugins/_ml/models/%s/_predict", modelId), + null, + toHttpEntity(String.format(LOCALE, "{\"text_docs\": [\"%s\"]}", queryText)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + + Map inferenceResJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(inferenceResponse.getEntity()), + false + ); + + Object inference_results = inferenceResJson.get("inference_results"); + assertTrue(inference_results instanceof List); + List inferenceResultsAsMap = (List) inference_results; + assertEquals(1, inferenceResultsAsMap.size()); + Map result = (Map) inferenceResultsAsMap.get(0); + List output = (List) result.get("output"); + assertEquals(1, output.size()); + Map map = (Map) output.get(0); + assertEquals(1, map.size()); + Map dataAsMap = (Map) map.get("dataAsMap"); + return TokenWeightUtil.fetchListOfTokenWeightMap(List.of(dataAsMap)).get(0); + } + + // rank_features use lucene FeatureField, which will compress the Float number to 16 bit + // this function simulate the encoding and decoding progress in lucene FeatureField + protected Float getFeatureFieldCompressedNumber(Float originNumber) { + int freqBits = Float.floatToIntBits(originNumber); + freqBits = freqBits >> 15; + freqBits = ((int) ((float) freqBits)) << 15; + return Float.intBitsToFloat(freqBits); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 3ef5431b3..295daa948 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -16,6 +16,7 @@ import java.util.Map; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.InjectMocks; import org.mockito.Mock; import org.mockito.Mockito; @@ -161,6 +162,122 @@ public void testInferenceSentences_whenNotConnectionException_thenNoRetry() { Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceSentencesWithMapResult_whenValidInput_thenSuccess() { + final Map map = Map.of("key", "value"); + final ActionListener>> resultListener = mock(ActionListener.class); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(map)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(List.of(map)); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void testInferenceSentencesWithMapResult_whenTensorOutputListEmpty_thenException() { + final ActionListener>> resultListener = mock(ActionListener.class); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(Collections.emptyList()); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", + argumentCaptor.getValue().getMessage() + ); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void testInferenceSentencesWithMapResult_whenModelTensorListEmpty_thenException() { + final ActionListener>> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class); + Mockito.verify(resultListener).onFailure(argumentCaptor.capture()); + assertEquals( + "Empty model result produced. Expected at least [1] tensor output and [1] model tensor, but got [0]", + argumentCaptor.getValue().getMessage() + ); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void testInferenceSentencesWithMapResult_whenModelTensorListSizeBiggerThan1_thenSuccess() { + final ActionListener>> resultListener = mock(ActionListener.class); + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor("response", null, null, null, null, null, Map.of("key", "value")); + mlModelTensorList.add(tensor); + mlModelTensorList.add(tensor); + tensorsList.add(new ModelTensors(mlModelTensorList)); + final ModelTensorOutput modelTensorOutput = new ModelTensorOutput(tensorsList); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(modelTensorOutput); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onResponse(List.of(Map.of("key", "value"), Map.of("key", "value"))); + Mockito.verifyNoMoreInteractions(resultListener); + } + + public void testInferenceSentencesWithMapResult_whenRetryableException_retry3Times() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener>> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(nodeNodeConnectedException); + } + + public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFail() { + final IllegalStateException illegalStateException = new IllegalStateException("Illegal state"); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(illegalStateException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + final ActionListener>> resultListener = mock(ActionListener.class); + accessor.inferenceSentencesWithMapResult(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_LIST, resultListener); + + Mockito.verify(client, times(1)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(resultListener).onFailure(illegalStateException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); @@ -178,4 +295,14 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { tensorsList.add(modelTensors); return new ModelTensorOutput(tensorsList); } + + private ModelTensorOutput createModelTensorOutput(final Map map) { + final List tensorsList = new ArrayList<>(); + final List mlModelTensorList = new ArrayList<>(); + final ModelTensor tensor = new ModelTensor("response", null, null, null, null, null, map); + mlModelTensorList.add(tensor); + final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); + tensorsList.add(modelTensors); + return new ModelTensorOutput(tensorsList); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java new file mode 100644 index 000000000..0312eaef7 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.apache.hc.core5.http.HttpHeaders; +import org.apache.hc.core5.http.io.entity.EntityUtils; +import org.apache.hc.core5.http.message.BasicHeader; +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; + +import com.google.common.collect.ImmutableList; + +public class SparseEncodingProcessIT extends BaseSparseEncodingIT { + + private static final String INDEX_NAME = "sparse_encoding_index"; + + private static final String PIPELINE_NAME = "pipeline-sparse-encoding"; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + + @Before + public void setPipelineName() { + this.setPipelineConfigurationName("processor/SparseEncodingPipelineConfiguration.json"); + } + + public void testSparseEncodingProcessor() throws Exception { + String modelId = prepareModel(); + createPipelineProcessor(modelId, PIPELINE_NAME); + createSparseEncodingIndex(); + ingestDocument(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + private void createSparseEncodingIndex() throws Exception { + createIndexWithConfiguration( + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/SparseEncodingIndexMappings.json").toURI())), + PIPELINE_NAME + ); + } + + private void ingestDocument() throws Exception { + String ingestDocument = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"favor_list\": [\n" + + " \"test\",\n" + + " \"hello\",\n" + + " \"mock\"\n" + + " ],\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " }\n" + + "}\n"; + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java new file mode 100644 index 000000000..209db58a8 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessorTests.java @@ -0,0 +1,167 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; +import static org.mockito.Mockito.verify; + +import java.util.*; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class SparseEncodingProcessorTests extends OpenSearchTestCase { + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private SparseEncodingProcessorFactory SparseEncodingProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private SparseEncodingProcessor createInstance() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(SparseEncodingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(SparseEncodingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + return SparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List> dataAsMapList = createMockMapResult(2); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + @SneakyThrows + public void testExecute_whenInferenceTextListEmpty_SuccessWithoutAnyMap() { + Map sourceAndMetadata = new HashMap<>(); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map registry = new HashMap<>(); + MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + SparseEncodingProcessorFactory sparseEncodingProcessorFactory = new SparseEncodingProcessorFactory(accessor, env); + + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, ImmutableMap.of("key1", "key1Mapped", "key2", "key2Mapped")); + SparseEncodingProcessor processor = sparseEncodingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyList(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_withListTypeInput_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", list1); + sourceAndMetadata.put("key2", list2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List> dataAsMapList = createMockMapResult(6); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_withMapTypeInput_successful() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test4", "test5"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("key2", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + SparseEncodingProcessor processor = createInstance(); + + List> dataAsMapList = createMockMapResult(2); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(dataAsMapList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(anyString(), anyList(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + + } + + private List> createMockMapResult(int number) { + List> mockSparseEncodingResult = new ArrayList<>(); + IntStream.range(0, number).forEachOrdered(x -> mockSparseEncodingResult.add(ImmutableMap.of("hello", 1.0f))); + + List> mockMapResult = Collections.singletonList(Map.of("response", mockSparseEncodingResult)); + return mockMapResult; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index da0a46954..399cd1eb8 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -75,9 +75,9 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapError_throwIllegalA fieldMap.put("key2", "key2Mapped"); config.put(TextEmbeddingProcessor.FIELD_MAP_FIELD, fieldMap); try { - textEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + textEmbeddingProcessorFactory.create(registry, TextEmbeddingProcessor.TYPE, DESCRIPTION, config); } catch (IllegalArgumentException e) { - assertEquals("Unable to create the TextEmbedding processor as field_map has invalid key or value", e.getMessage()); + assertEquals("Unable to create the processor as field_map has invalid key or value", e.getMessage()); } } @@ -309,7 +309,7 @@ private Map createMaxDepthLimitExceedMap(Supplier maxDe return innerMap; } - public void testExecute_hybridTypeInput_successful() { + public void testExecute_hybridTypeInput_successful() throws Exception { List list1 = ImmutableList.of("test1", "test2"); Map> map1 = ImmutableMap.of("test3", list1); Map sourceAndMetadata = new HashMap<>(); @@ -347,7 +347,7 @@ public void testProcessResponse_successful() throws Exception { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); @@ -360,7 +360,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); // To assert the order is not changed between config map and generated map. List configValueList = new LinkedList<>(config.values()); @@ -371,7 +371,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { assertEquals(knnKeyList.get(lastIndex), configValueList.get(lastIndex).toString()); List> modelTensorList = createMockVectorResult(); - Map result = processor.buildTextEmbeddingResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + Map result = processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); assertTrue(result.containsKey("oriKey1_knn")); assertTrue(result.containsKey("oriKey2_knn")); assertTrue(result.containsKey("oriKey3_knn")); @@ -386,9 +386,9 @@ public void testBuildVectorOutput_withNestedMap_successful() { Map config = createNestedMapConfiguration(); IngestDocument ingestDocument = createNestedMapIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.buildTextEmbeddingResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); assertNotNull(favoritesMap); Map favoriteGames = (Map) favoritesMap.get("favorite.games"); @@ -402,7 +402,7 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java new file mode 100644 index 000000000..6cb122c4f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java @@ -0,0 +1,404 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD; +import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD; +import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.SneakyThrows; + +import org.opensearch.client.Client; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.ParsingException; +import org.opensearch.core.common.io.stream.FilterStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableAwareStreamInput; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.test.OpenSearchTestCase; + +public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase { + + private static final String FIELD_NAME = "testField"; + private static final String QUERY_TEXT = "Hello world!"; + private static final String MODEL_ID = "mfgfgdsfgfdgsde"; + private static final float BOOST = 1.8f; + private static final String QUERY_NAME = "queryName"; + private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); + + @SneakyThrows + public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); + assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); + } + + @SneakyThrows + public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "boost": 10.0, + "_name": "something", + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser); + + assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName()); + assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText()); + assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId()); + assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0); + assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName()); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "model_id": "string", + "boost": 10.0, + "_name": "something", + }, + "invalid": 10 + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(BOOST_FIELD.getPreferredName(), BOOST) + .field(NAME_FIELD.getPreferredName(), QUERY_NAME) + .endObject() + .field("invalid", 10) + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingQuery_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "string" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingModelId_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string", + "query_text": "string", + "model_id": "string", + "model_id": "string" + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser)); + } + + @SuppressWarnings("unchecked") + @SneakyThrows + public void testToXContent() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .modelId(MODEL_ID) + .queryText(QUERY_TEXT); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); + Map out = xContentBuilderToMap(builder); + + Object outer = out.get(NAME); + if (!(outer instanceof Map)) { + fail("sparse encoding does not map to nested object"); + } + + Map outerMap = (Map) outer; + + assertEquals(1, outerMap.size()); + assertTrue(outerMap.containsKey(FIELD_NAME)); + + Object secondInner = outerMap.get(FIELD_NAME); + if (!(secondInner instanceof Map)) { + fail("field name does not map to nested object"); + } + + Map secondInnerMap = (Map) secondInner; + + assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName())); + assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName())); + } + + @SneakyThrows + public void testStreams() { + SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder(); + original.fieldName(FIELD_NAME); + original.queryText(QUERY_TEXT); + original.modelId(MODEL_ID); + original.boost(BOOST); + original.queryName(QUERY_NAME); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput); + assertEquals(original, copy); + } + + public void testHashAndEquals() { + String fieldName1 = "field 1"; + String fieldName2 = "field 2"; + String queryText1 = "query text 1"; + String queryText2 = "query text 2"; + String modelId1 = "model-1"; + String modelId2 = "model-2"; + float boost1 = 1.8f; + float boost2 = 3.8f; + String queryName1 = "query-1"; + String queryName2 = "query-2"; + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except default boost and query name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName( + fieldName1 + ).queryText(queryText1).modelId(modelId1); + + // Identical to sparseEncodingQueryBuilder_baseline except diff field name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except diff query text + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText2) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except diff model ID + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId2) + .boost(boost1) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except diff boost + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost2) + .queryName(queryName1); + + // Identical to sparseEncodingQueryBuilder_baseline except diff query name + SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .boost(boost1) + .queryName(queryName2); + + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); + assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); + + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baselineCopy); + assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baselineCopy.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_defaultBoostAndQueryName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_defaultBoostAndQueryName.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffFieldName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffFieldName.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryText); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryText.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffModelId); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffModelId.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffBoost); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffBoost.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode()); + } + + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID); + Map expectedMap = Map.of("1", 1f, "2", 2f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(List.of(Map.of("response", List.of(expectedMap)))); + return null; + }).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any()); + SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set query tokens supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.queryTokensSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get()); + } + + @SneakyThrows + public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); + QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + assertTrue(queryBuilder == sparseEncodingQueryBuilder); + + sparseEncodingQueryBuilder.queryTokensSupplier(() -> null); + queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); + assertTrue(queryBuilder == sparseEncodingQueryBuilder); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java new file mode 100644 index 000000000..54991d7e2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java @@ -0,0 +1,286 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.query; + +import static org.opensearch.neuralsearch.TestUtils.objectToFloat; + +import java.util.List; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.client.ResponseException; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.MatchAllQueryBuilder; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.common.BaseSparseEncodingIT; + +public class SparseEncodingQueryIT extends BaseSparseEncodingIT { + private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index"; + private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index"; + private static final String TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-text-and-field-index"; + private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index"; + private static final String TEST_QUERY_TEXT = "Hello world a b"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_1 = "test-sparse-encoding-1"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_2 = "test-sparse-encoding-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; + private static final String TEST_SPARSE_ENCODING_FIELD_NAME_NESTED = "nested.sparse_encoding.field"; + + private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c"); + + private static final Float DELTA = 1e-5f; + private final Map testRankFeaturesDoc = TestUtils.createRandomTokenWeightMap(TEST_TOKENS); + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + findDeployedModels().forEach(this::deleteModel); + } + + /** + * Tests basic query: + * { + * "query": { + * "sparse_encoding": { + * "text_sparse": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * } + * } + */ + @SneakyThrows + public void testBasicQueryUsingQueryText() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } + + /** + * Tests basic query: + * { + * "query": { + * "sparse_encoding": { + * "text_sparse": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd", + * "boost": 2 + * } + * } + * } + * } + */ + @SneakyThrows + public void testBoostQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId).boost(2.0f); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } + + /** + * Tests rescore query: + * { + * "query" : { + * "match_all": {} + * }, + * "rescore": { + * "query": { + * "rescore_query": { + * "sparse_encoding": { + * "text_sparse": { + * * "query_text": "Hello world a b", + * * "model_id": "dcsdcasd" + * * } + * } + * } + * } + * } + * } + */ + @SneakyThrows + public void testRescoreQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder(); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, matchAllQueryBuilder, sparseEncodingQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } + + /** + * Tests bool should query with query text: + * { + * "query": { + * "bool" : { + * "should": [ + * "sparse_encoding": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * }, + * "sparse_encoding": { + * "field2": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBooleanQuery_withMultipleSparseEncodingQueries() { + initializeIndexIfNotExist(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_2 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + + boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2); + + Map searchResponseAsMap = search(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = 2 * computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA); + } + + /** + * Tests bool should query with query text: + * { + * "query": { + * "bool" : { + * "should": [ + * "sparse_encoding": { + * "field1": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * }, + * "sparse_encoding": { + * "field2": { + * "query_text": "Hello world a b", + * "model_id": "dcsdcasd" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testBooleanQuery_withSparseEncodingAndBM25Queries() { + initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName( + TEST_SPARSE_ENCODING_FIELD_NAME_1 + ).queryText(TEST_QUERY_TEXT).modelId(modelId); + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); + boolQueryBuilder.should(sparseEncodingQueryBuilder).should(matchQueryBuilder); + + Map searchResponseAsMap = search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float minExpectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT); + assertTrue(minExpectedScore < objectToFloat(firstInnerHit.get("_score"))); + } + + @SneakyThrows + public void testBasicQueryUsingQueryText_whenQueryWrongFieldType_thenFail() { + initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME); + String modelId = getDeployedModelId(); + + SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1) + .queryText(TEST_QUERY_TEXT) + .modelId(modelId); + + expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1)); + } + + @SneakyThrows + protected void initializeIndexIfNotExist(String indexName) { + if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testRankFeaturesDoc)); + assertEquals(1, getDocCount(indexName)); + } + + if (TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2)); + addSparseEncodingDoc( + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2), + List.of(testRankFeaturesDoc, testRankFeaturesDoc) + ); + assertEquals(1, getDocCount(indexName)); + } + + if (TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1)); + addSparseEncodingDoc( + indexName, + "1", + List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), + List.of(testRankFeaturesDoc), + List.of(TEST_TEXT_FIELD_NAME_1), + List.of(TEST_QUERY_TEXT) + ); + assertEquals(1, getDocCount(indexName)); + } + + if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc)); + assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java new file mode 100644 index 000000000..a4bc2c495 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/TokenWeightUtilTests.java @@ -0,0 +1,108 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import java.util.List; +import java.util.Map; + +import org.opensearch.test.OpenSearchTestCase; + +public class TokenWeightUtilTests extends OpenSearchTestCase { + private static final Map MOCK_DATA = Map.of("hello", 1.f, "world", 2.f); + + public void testFetchListOfTokenWeightMap_singleObject() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA)); + } + + public void testFetchListOfTokenWeightMap_multipleObjectsInOneResponse() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0}, + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA, MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA, MOCK_DATA)); + } + + public void testFetchListOfTokenWeightMap_multipleObjectsInMultipleResponse() { + /* + [{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + },{ + "response": [ + {"hello": 1.0, "world": 2.0} + ] + }] + */ + List> inputData = List.of(Map.of("response", List.of(MOCK_DATA)), Map.of("response", List.of(MOCK_DATA))); + assertEquals(TokenWeightUtil.fetchListOfTokenWeightMap(inputData), List.of(MOCK_DATA, MOCK_DATA)); + } + + public void testFetchListOfTokenWeightMap_whenResponseValueNotList_thenFail() { + /* + [{ + "response": {"hello": 1.0, "world": 2.0} + }] + */ + List> inputData = List.of(Map.of("response", MOCK_DATA)); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenNotUseResponseKey_thenFail() { + /* + [{ + "some_key": [{"hello": 1.0, "world": 2.0}] + }] + */ + List> inputData = List.of(Map.of("some_key", List.of(MOCK_DATA))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputObjectIsNotMap_thenFail() { + /* + [{ + "response": [[{"hello": 1.0, "world": 2.0}]] + }] + */ + List> inputData = List.of(Map.of("response", List.of(List.of(MOCK_DATA)))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonStringKeys_thenFail() { + /* + [{ + "response": [{"hello": 1.0, 2.3: 2.0}] + }] + */ + Map mockData = Map.of("hello", 1.f, 2.3f, 2.f); + List> inputData = List.of(Map.of("response", List.of(mockData))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } + + public void testFetchListOfTokenWeightMap_whenInputTokenMapWithNonFloatValues_thenFail() { + /* + [{ + "response": [{"hello": 1.0, "world": "world"}] + }] + */ + Map mockData = Map.of("hello", 1.f, "world", "world"); + List> inputData = List.of(Map.of("response", List.of(mockData))); + expectThrows(IllegalArgumentException.class, () -> TokenWeightUtil.fetchListOfTokenWeightMap(inputData)); + } +} diff --git a/src/test/resources/processor/SparseEncodingIndexMappings.json b/src/test/resources/processor/SparseEncodingIndexMappings.json new file mode 100644 index 000000000..87dee278e --- /dev/null +++ b/src/test/resources/processor/SparseEncodingIndexMappings.json @@ -0,0 +1,26 @@ +{ + "settings":{ + "default_pipeline": "pipeline-sparse-encoding" + }, + "mappings": { + "properties": { + "title_sparse": { + "type": "rank_features" + }, + "favor_list_sparse": { + "type": "nested", + "properties":{ + "sparse_encoding":{ + "type": "rank_features" + } + } + }, + "favorites.game_sparse": { + "type": "rank_features" + }, + "favorites.movie_sparse": { + "type": "rank_features" + } + } + } +} \ No newline at end of file diff --git a/src/test/resources/processor/SparseEncodingPipelineConfiguration.json b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json new file mode 100644 index 000000000..82d13c8fe --- /dev/null +++ b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json @@ -0,0 +1,18 @@ +{ + "description": "An example sparse Encoding pipeline", + "processors" : [ + { + "sparse_encoding": { + "model_id": "%s", + "field_map": { + "title": "title_sparse", + "favor_list": "favor_list_sparse", + "favorites": { + "game": "game_sparse", + "movie": "movie_sparse" + } + } + } + } + ] +} \ No newline at end of file diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json new file mode 100644 index 000000000..c45334bae --- /dev/null +++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json @@ -0,0 +1,10 @@ +{ + "name": "tokenize-idf-0915", + "version": "1.0.0", + "function_name": "SPARSE_TOKENIZE", + "description": "test model", + "model_format": "TORCH_SCRIPT", + "model_group_id": "", + "model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a", + "url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip" +} \ No newline at end of file