diff --git a/CHANGELOG.md b/CHANGELOG.md index f240141ee..972c571c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.12...2.x) ### Features +- Enabled support for applying default modelId in neural sparse query ([#614](https://github.com/opensearch-project/neural-search/pull/614) ### Enhancements - Adding aggregations in hybrid query ([#630](https://github.com/opensearch-project/neural-search/pull/630)) - Support for post filter in hybrid query ([#633](https://github.com/opensearch-project/neural-search/pull/633)) diff --git a/qa/restart-upgrade/build.gradle b/qa/restart-upgrade/build.gradle index a4ca03b65..1a6d0a104 100644 --- a/qa/restart-upgrade/build.gradle +++ b/qa/restart-upgrade/build.gradle @@ -65,12 +65,21 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } @@ -98,12 +107,21 @@ task testAgainstNewCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.is_old_cluster', 'false' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java new file mode 100644 index 000000000..d27b9d0f3 --- /dev/null +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; + +import org.opensearch.common.settings.Settings; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +public class NeuralQueryEnricherProcessorIT extends AbstractRestartUpgradeRestTestCase { + // add prefix to avoid conflicts with other IT class, since we don't wipe resources after first round + private static final String SPARSE_INGEST_PIPELINE_NAME = "nqep-nlp-ingest-pipeline-sparse"; + private static final String DENSE_INGEST_PIPELINE_NAME = "nqep-nlp-ingest-pipeline-dense"; + private static final String SPARSE_SEARCH_PIPELINE_NAME = "nqep-nlp-search-pipeline-sparse"; + private static final String DENSE_SEARCH_PIPELINE_NAME = "nqep-nlp-search-pipeline-dense"; + private static final String TEST_ENCODING_FIELD = "passage_embedding"; + private static final String TEST_TEXT_FIELD = "passage_text"; + private static final String TEXT_1 = "Hello world a b"; + + // Test restart-upgrade neural_query_enricher in restart-upgrade scenario + public void testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + NeuralSparseQueryBuilder sparseEncodingQueryBuilderWithoutModelId = new NeuralSparseQueryBuilder().fieldName(TEST_ENCODING_FIELD) + .queryText(TEXT_1); + // will set the model_id after we obtain the id + NeuralSparseQueryBuilder sparseEncodingQueryBuilderWithModelId = new NeuralSparseQueryBuilder().fieldName(TEST_ENCODING_FIELD) + .queryText(TEXT_1); + + if (isRunningAgainstOldCluster()) { + String modelId = uploadSparseEncodingModel(); + loadModel(modelId); + sparseEncodingQueryBuilderWithModelId.modelId(modelId); + createPipelineForSparseEncodingProcessor(modelId, SPARSE_INGEST_PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())), + SPARSE_INGEST_PIPELINE_NAME + ); + + addSparseEncodingDoc(getIndexNameForTest(), "0", List.of(), List.of(), List.of(TEST_TEXT_FIELD), List.of(TEXT_1)); + + createSearchRequestProcessor(modelId, SPARSE_SEARCH_PIPELINE_NAME); + updateIndexSettings( + getIndexNameForTest(), + Settings.builder().put("index.search.default_pipeline", SPARSE_SEARCH_PIPELINE_NAME) + ); + } else { + String modelId = null; + try { + modelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_INGEST_PIPELINE_NAME), SPARSE_ENCODING_PROCESSOR); + loadModel(modelId); + sparseEncodingQueryBuilderWithModelId.modelId(modelId); + assertEquals( + search(getIndexNameForTest(), sparseEncodingQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), sparseEncodingQueryBuilderWithModelId, 1).get("hits") + ); + } finally { + wipeOfTestResources(getIndexNameForTest(), SPARSE_INGEST_PIPELINE_NAME, modelId, SPARSE_SEARCH_PIPELINE_NAME); + } + } + } + + public void testNeuralQueryEnricherProcessor_NeuralSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + NeuralQueryBuilder neuralQueryBuilderWithoutModelId = new NeuralQueryBuilder().fieldName(TEST_ENCODING_FIELD).queryText(TEXT_1); + NeuralQueryBuilder neuralQueryBuilderWithModelId = new NeuralQueryBuilder().fieldName(TEST_ENCODING_FIELD).queryText(TEXT_1); + + if (isRunningAgainstOldCluster()) { + String modelId = uploadTextEmbeddingModel(); + loadModel(modelId); + neuralQueryBuilderWithModelId.modelId(modelId); + createPipelineProcessor(modelId, DENSE_INGEST_PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/IndexMappingMultipleShard.json").toURI())), + DENSE_INGEST_PIPELINE_NAME + ); + + addDocument(getIndexNameForTest(), "0", TEST_TEXT_FIELD, TEXT_1, null, null); + + createSearchRequestProcessor(modelId, DENSE_SEARCH_PIPELINE_NAME); + updateIndexSettings(getIndexNameForTest(), Settings.builder().put("index.search.default_pipeline", DENSE_SEARCH_PIPELINE_NAME)); + assertEquals( + search(getIndexNameForTest(), neuralQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), neuralQueryBuilderWithModelId, 1).get("hits") + ); + } else { + String modelId = null; + try { + modelId = TestUtils.getModelId(getIngestionPipeline(DENSE_INGEST_PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); + loadModel(modelId); + neuralQueryBuilderWithModelId.modelId(modelId); + + assertEquals( + search(getIndexNameForTest(), neuralQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), neuralQueryBuilderWithModelId, 1).get("hits") + ); + } finally { + wipeOfTestResources(getIndexNameForTest(), DENSE_INGEST_PIPELINE_NAME, modelId, DENSE_SEARCH_PIPELINE_NAME); + } + } + } +} diff --git a/qa/restart-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json b/qa/restart-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json new file mode 100644 index 000000000..5cd9988cb --- /dev/null +++ b/qa/restart-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json @@ -0,0 +1,11 @@ +{ + "request_processors": [ + { + "neural_query_enricher": { + "tag": "tag1", + "description": "This processor is going to restrict to publicly visible documents", + "default_model_id": "%s" + } + } + ] +} diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 175924523..591e83d58 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -65,12 +65,21 @@ task testAgainstOldCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version systemProperty 'tests.skip_delete_model_index', 'true' - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } @@ -99,12 +108,21 @@ task testAgainstOneThirdUpgradedCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } @@ -132,12 +150,21 @@ task testAgainstTwoThirdsUpgradedCluster(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } @@ -165,12 +192,21 @@ task testRollingUpgrade(type: StandaloneRestIntegTestTask) { systemProperty 'tests.skip_delete_model_index', 'true' systemProperty 'tests.plugin_bwc_version', ext.neural_search_bwc_version - //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT tests from neural search version 2.9 and 2.10 because these features were released in 2.11 version. + //Excluding MultiModalSearchIT, HybridSearchIT, NeuralSparseSearchIT, NeuralQueryEnricherProcessorIT tests from neural search version 2.9 and 2.10 + // because these features were released in 2.11 version. if (ext.neural_search_bwc_version.startsWith("2.9") || ext.neural_search_bwc_version.startsWith("2.10")){ filter { excludeTestsMatching "org.opensearch.neuralsearch.bwc.MultiModalSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.HybridSearchIT.*" excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralSparseSearchIT.*" + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.*" + } + } + + // Excluding the test because we introduce this feature in 2.13 + if (ext.neural_search_bwc_version.startsWith("2.11") || ext.neural_search_bwc_version.startsWith("2.12")){ + filter { + excludeTestsMatching "org.opensearch.neuralsearch.bwc.NeuralQueryEnricherProcessorIT.testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow" } } diff --git a/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java new file mode 100644 index 000000000..7a98b7322 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/NeuralQueryEnricherProcessorIT.java @@ -0,0 +1,148 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.bwc; + +import org.opensearch.common.settings.Settings; +import org.opensearch.neuralsearch.TestUtils; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; + +import static org.opensearch.neuralsearch.TestUtils.NODES_BWC_CLUSTER; +import static org.opensearch.neuralsearch.TestUtils.SPARSE_ENCODING_PROCESSOR; +import static org.opensearch.neuralsearch.TestUtils.TEXT_EMBEDDING_PROCESSOR; + +public class NeuralQueryEnricherProcessorIT extends AbstractRollingUpgradeTestCase { + // add prefix to avoid conflicts with other IT class, since we don't wipe resources after first round + private static final String SPARSE_INGEST_PIPELINE_NAME = "nqep-nlp-ingest-pipeline-sparse"; + private static final String DENSE_INGEST_PIPELINE_NAME = "nqep-nlp-ingest-pipeline-dense"; + private static final String SPARSE_SEARCH_PIPELINE_NAME = "nqep-nlp-search-pipeline-sparse"; + private static final String DENSE_SEARCH_PIPELINE_NAME = "nqep-nlp-search-pipeline-dense"; + private static final String TEST_ENCODING_FIELD = "passage_embedding"; + private static final String TEST_TEXT_FIELD = "passage_text"; + private static final String TEXT_1 = "Hello world a b"; + private String sparseModelId = ""; + private String denseModelId = ""; + + // test of NeuralQueryEnricherProcessor supports neural_sparse query default model_id + // the feature is introduced from 2.13 + public void testNeuralQueryEnricherProcessor_NeuralSparseSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + NeuralSparseQueryBuilder sparseEncodingQueryBuilderWithoutModelId = new NeuralSparseQueryBuilder().fieldName(TEST_ENCODING_FIELD) + .queryText(TEXT_1); + // will set the model_id after we obtain the id + NeuralSparseQueryBuilder sparseEncodingQueryBuilderWithModelId = new NeuralSparseQueryBuilder().fieldName(TEST_ENCODING_FIELD) + .queryText(TEXT_1); + + switch (getClusterType()) { + case OLD: + sparseModelId = uploadSparseEncodingModel(); + loadModel(sparseModelId); + sparseEncodingQueryBuilderWithModelId.modelId(sparseModelId); + createPipelineForSparseEncodingProcessor(sparseModelId, SPARSE_INGEST_PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/SparseIndexMappings.json").toURI())), + SPARSE_INGEST_PIPELINE_NAME + ); + + addSparseEncodingDoc(getIndexNameForTest(), "0", List.of(), List.of(), List.of(TEST_TEXT_FIELD), List.of(TEXT_1)); + createSearchRequestProcessor(sparseModelId, SPARSE_SEARCH_PIPELINE_NAME); + updateIndexSettings( + getIndexNameForTest(), + Settings.builder().put("index.search.default_pipeline", SPARSE_SEARCH_PIPELINE_NAME) + ); + break; + case MIXED: + sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_INGEST_PIPELINE_NAME), SPARSE_ENCODING_PROCESSOR); + loadModel(sparseModelId); + sparseEncodingQueryBuilderWithModelId.modelId(sparseModelId); + break; + case UPGRADED: + try { + sparseModelId = TestUtils.getModelId(getIngestionPipeline(SPARSE_INGEST_PIPELINE_NAME), SPARSE_ENCODING_PROCESSOR); + loadModel(sparseModelId); + sparseEncodingQueryBuilderWithModelId.modelId(sparseModelId); + assertEquals( + search(getIndexNameForTest(), sparseEncodingQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), sparseEncodingQueryBuilderWithModelId, 1).get("hits") + ); + } finally { + wipeOfTestResources(getIndexNameForTest(), SPARSE_INGEST_PIPELINE_NAME, sparseModelId, SPARSE_SEARCH_PIPELINE_NAME); + } + break; + default: + throw new IllegalStateException("Unexpected value: " + getClusterType()); + } + } + + // test of NeuralQueryEnricherProcessor supports neural query default model_id + // the feature is introduced from 2.11 + public void testNeuralQueryEnricherProcessor_NeuralSearch_E2EFlow() throws Exception { + waitForClusterHealthGreen(NODES_BWC_CLUSTER); + NeuralQueryBuilder neuralQueryBuilderWithoutModelId = new NeuralQueryBuilder().fieldName(TEST_ENCODING_FIELD).queryText(TEXT_1); + NeuralQueryBuilder neuralQueryBuilderWithModelId = new NeuralQueryBuilder().fieldName(TEST_ENCODING_FIELD).queryText(TEXT_1); + + switch (getClusterType()) { + case OLD: + denseModelId = uploadTextEmbeddingModel(); + loadModel(denseModelId); + neuralQueryBuilderWithModelId.modelId(denseModelId); + createPipelineProcessor(denseModelId, DENSE_INGEST_PIPELINE_NAME); + createIndexWithConfiguration( + getIndexNameForTest(), + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + DENSE_INGEST_PIPELINE_NAME + ); + + addDocument(getIndexNameForTest(), "0", TEST_TEXT_FIELD, TEXT_1, null, null); + + createSearchRequestProcessor(denseModelId, DENSE_SEARCH_PIPELINE_NAME); + updateIndexSettings( + getIndexNameForTest(), + Settings.builder().put("index.search.default_pipeline", DENSE_SEARCH_PIPELINE_NAME) + ); + assertEquals( + search(getIndexNameForTest(), neuralQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), neuralQueryBuilderWithModelId, 1).get("hits") + ); + break; + case MIXED: + denseModelId = TestUtils.getModelId(getIngestionPipeline(DENSE_INGEST_PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); + loadModel(denseModelId); + neuralQueryBuilderWithModelId.modelId(denseModelId); + + createSearchRequestProcessor(denseModelId, DENSE_SEARCH_PIPELINE_NAME); + updateIndexSettings( + getIndexNameForTest(), + Settings.builder().put("index.search.default_pipeline", DENSE_SEARCH_PIPELINE_NAME) + ); + assertEquals( + search(getIndexNameForTest(), neuralQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), neuralQueryBuilderWithModelId, 1).get("hits") + ); + break; + case UPGRADED: + try { + denseModelId = TestUtils.getModelId(getIngestionPipeline(DENSE_INGEST_PIPELINE_NAME), TEXT_EMBEDDING_PROCESSOR); + loadModel(denseModelId); + neuralQueryBuilderWithModelId.modelId(denseModelId); + + assertEquals( + search(getIndexNameForTest(), neuralQueryBuilderWithoutModelId, 1).get("hits"), + search(getIndexNameForTest(), neuralQueryBuilderWithModelId, 1).get("hits") + ); + } finally { + wipeOfTestResources(getIndexNameForTest(), DENSE_INGEST_PIPELINE_NAME, denseModelId, DENSE_SEARCH_PIPELINE_NAME); + } + break; + default: + throw new IllegalStateException("Unexpected value: " + getClusterType()); + } + } +} diff --git a/qa/rolling-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json b/qa/rolling-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json new file mode 100644 index 000000000..5cd9988cb --- /dev/null +++ b/qa/rolling-upgrade/src/test/resources/processor/SearchRequestPipelineConfiguration.json @@ -0,0 +1,11 @@ +{ + "request_processors": [ + { + "neural_query_enricher": { + "tag": "tag1", + "description": "This processor is going to restrict to publicly visible documents", + "default_model_id": "%s" + } + } + ] +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java new file mode 100644 index 000000000..23db567d4 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +/** + * Query builders which calls ml-commons API to do model inference. + * The model inference result is used for search on target field. + */ +public interface ModelInferenceQueryBuilder { + /** + * Get the model id used by ml-commons model inference. Return null if the model id is absent. + */ + public String modelId(); + + /** + * Set a new model id for the query builder. + */ + public ModelInferenceQueryBuilder modelId(String modelId); + + /** + * Get the field name for search. + */ + public String fieldName(); +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index cda01767e..d27061e36 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -12,6 +12,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; import org.apache.commons.lang.StringUtils; @@ -58,7 +59,7 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class NeuralQueryBuilder extends AbstractQueryBuilder { +public class NeuralQueryBuilder extends AbstractQueryBuilder implements ModelInferenceQueryBuilder { public static final String NAME = "neural"; @@ -133,11 +134,11 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); - if (modelId != null) { + if (Objects.nonNull(modelId)) { xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); } xContentBuilder.field(K_FIELD.getPreferredName(), k); - if (filter != null) { + if (Objects.nonNull(filter)) { xContentBuilder.field(FILTER_FIELD.getPreferredName(), filter); } printBoostAndQueryName(xContentBuilder); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 42498f1fd..542024fc4 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -8,6 +8,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; import org.apache.commons.lang.StringUtils; @@ -17,6 +18,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionListener; @@ -31,6 +33,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import com.google.common.annotations.VisibleForTesting; @@ -54,7 +57,7 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class NeuralSparseQueryBuilder extends AbstractQueryBuilder { +public class NeuralSparseQueryBuilder extends AbstractQueryBuilder implements ModelInferenceQueryBuilder { public static final String NAME = "neural_sparse"; @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); @@ -77,6 +80,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String modelId; private Float maxTokenScore; private Supplier> queryTokensSupplier; + private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; /** * Constructor from stream input @@ -88,7 +92,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.queryText = in.readString(); - this.modelId = in.readString(); + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + this.modelId = in.readOptionalString(); + } else { + this.modelId = in.readString(); + } this.maxTokenScore = in.readOptionalFloat(); if (in.readBoolean()) { Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); @@ -100,9 +108,13 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); - out.writeString(modelId); + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + out.writeOptionalString(this.modelId); + } else { + out.writeString(this.modelId); + } out.writeOptionalFloat(maxTokenScore); - if (queryTokensSupplier != null && queryTokensSupplier.get() != null) { + if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { out.writeBoolean(true); out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); } else { @@ -115,7 +127,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); - xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + if (Objects.nonNull(modelId)) { + xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + } if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore); printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); @@ -161,11 +175,12 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw sparseEncodingQueryBuilder.queryText(), String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) ); - requireValue( - sparseEncodingQueryBuilder.modelId(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) - ); - + if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + ); + } return sparseEncodingQueryBuilder; } @@ -304,4 +319,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java index 9dab0a695..b869db10a 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java +++ b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java @@ -5,16 +5,17 @@ package org.opensearch.neuralsearch.query.visitor; import java.util.Map; +import java.util.Objects; import org.apache.lucene.search.BooleanClause; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilderVisitor; -import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder; import lombok.AllArgsConstructor; /** - * Neural Search Query Visitor. It visits each and every component of query buikder tree. + * Neural Search Query Visitor. It visits each and every component of query builder tree. */ @AllArgsConstructor public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { @@ -22,22 +23,27 @@ public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { private final String modelId; private final Map neuralFieldMap; + private Boolean isFieldDefaultModelIdApplied(String fieldName) { + if (Objects.nonNull(neuralFieldMap) && Objects.nonNull(fieldName) && Objects.nonNull(neuralFieldMap.get(fieldName))) { + return true; + } + return false; + } + /** * Accept method accepts every query builder from the search request, * and processes it if the required conditions in accept method are satisfied. */ @Override public void accept(QueryBuilder queryBuilder) { - if (queryBuilder instanceof NeuralQueryBuilder) { - NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder; - if (neuralQueryBuilder.modelId() == null) { - if (neuralFieldMap != null - && neuralQueryBuilder.fieldName() != null - && neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { - String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName()); - neuralQueryBuilder.modelId(fieldDefaultModelId); - } else if (modelId != null) { - neuralQueryBuilder.modelId(modelId); + if (queryBuilder instanceof ModelInferenceQueryBuilder) { + ModelInferenceQueryBuilder modelInferenceQueryBuilder = (ModelInferenceQueryBuilder) queryBuilder; + if (modelInferenceQueryBuilder.modelId() == null) { + if (isFieldDefaultModelIdApplied(modelInferenceQueryBuilder.fieldName())) { + String fieldDefaultModelId = (String) neuralFieldMap.get(modelInferenceQueryBuilder.fieldName()); + modelInferenceQueryBuilder.modelId(fieldDefaultModelId); + } else if (Objects.nonNull(modelId)) { + modelInferenceQueryBuilder.modelId(modelId); } else { throw new IllegalArgumentException( "model id must be provided in neural query or a default model id must be set in search request processor" diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index a2fd41ac6..6d31ea6a6 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -328,7 +328,7 @@ public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenR .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); } - + public void testInferenceMultimodal_whenInvalidInputAndEmptyTensorOutput_thenFail() { List tensorsList = new ArrayList<>(); List mlModelTensorList = List.of( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java index e414b8c24..e54576583 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryEnricherProcessorIT.java @@ -9,6 +9,7 @@ import static org.opensearch.neuralsearch.TestUtils.createRandomVector; import java.util.Collections; +import java.util.List; import java.util.Map; import org.apache.http.util.EntityUtils; @@ -22,6 +23,7 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import com.google.common.primitives.Floats; @@ -30,9 +32,11 @@ public class NeuralQueryEnricherProcessorIT extends BaseNeuralSearchIT { private static final String index = "my-nlp-index"; + private static final String sparseIndex = "my-nlp-index-sparse"; private static final String search_pipeline = "search-pipeline"; private static final String ingest_pipeline = "nlp-pipeline"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_RANK_FEATURES_FIELD_NAME_1 = "test-rank-features-1"; private final float[] testVector = createRandomVector(TEST_DIMENSION); @Before @@ -61,6 +65,25 @@ public void testNeuralQueryEnricherProcessor_whenNoModelIdPassed_thenSuccess() { } } + @SneakyThrows + public void testNeuralQueryEnricherProcessor_whenNoModelIdPassedInNeuralSparseQuery_thenSuccess() { + String modelId = null; + try { + initializeIndexIfNotExist(sparseIndex); + modelId = prepareSparseEncodingModel(); + createSearchRequestProcessor(modelId, search_pipeline); + createPipelineProcessor(modelId, ingest_pipeline, ProcessorType.SPARSE_ENCODING); + updateIndexSettings(sparseIndex, Settings.builder().put("index.search.default_pipeline", search_pipeline)); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName(TEST_RANK_FEATURES_FIELD_NAME_1); + neuralSparseQueryBuilder.queryText("hello"); + Map response = search(sparseIndex, neuralSparseQueryBuilder, 2); + assertFalse(response.isEmpty()); + } finally { + wipeOfTestResources(sparseIndex, ingest_pipeline, modelId, search_pipeline); + } + } + @SneakyThrows public void testNeuralQueryEnricherProcessor_whenGetEmptyQueryBody_thenSuccess() { String modelId = null; @@ -120,5 +143,11 @@ private void initializeIndexIfNotExist(String indexName) { ); assertEquals(1, getDocCount(indexName)); } + + if (sparseIndex.equals(indexName) && !indexExists(indexName)) { + prepareSparseEncodingIndex(indexName, List.of(TEST_RANK_FEATURES_FIELD_NAME_1)); + addSparseEncodingDoc(indexName, "1", List.of(TEST_RANK_FEATURES_FIELD_NAME_1), List.of(Map.of("hi", 1.0f, "hello", 1.1f))); + assertEquals(1, getDocCount(indexName)); + } } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index d66f98f5b..90b635e7c 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -26,7 +26,10 @@ import org.apache.lucene.document.FeatureField; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.junit.Before; +import org.opensearch.Version; import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.SetOnce; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; @@ -44,6 +47,8 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -58,6 +63,11 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase { private static final Float MAX_TOKEN_SCORE = 123f; private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f); + @Before + public void setupClusterServiceToCurrentVersion() { + setUpClusterService(Version.CURRENT); + } + @SneakyThrows public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() { /* @@ -195,7 +205,7 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() { } @SneakyThrows - public void testFromXContent_whenBuildWithMissingModelId_thenFail() { + public void testFromXContent_whenBuildWithMissingModelIdInCurrentVersion_thenSuccess() { /* { "VECTOR_FIELD": { @@ -210,6 +220,30 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() { .endObject() .endObject(); + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser); + + assertNull(sparseEncodingQueryBuilder.modelId()); + } + + @SneakyThrows + public void testFromXContent_whenBuildWithMissingModelIdInOldVersion_thenFail() { + /* + { + "VECTOR_FIELD": { + "query_text": "string" + } + } + */ + setUpClusterService(Version.V_2_12_0); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .endObject() + .endObject(); + XContentParser contentParser = createParser(xContentBuilder); contentParser.nextToken(); expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser)); @@ -276,6 +310,11 @@ public void testToXContent() { assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0); } + public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() { + setUpClusterService(Version.V_2_12_0); + testStreams(); + } + @SneakyThrows public void testStreams() { NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder(); @@ -495,11 +534,16 @@ public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() { .modelId(MODEL_ID) .queryTokensSupplier(QUERY_TOKENS_SUPPLIER); QueryBuilder queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assertTrue(queryBuilder == sparseEncodingQueryBuilder); + assertSame(queryBuilder, sparseEncodingQueryBuilder); sparseEncodingQueryBuilder.queryTokensSupplier(() -> null); queryBuilder = sparseEncodingQueryBuilder.doRewrite(null); - assertTrue(queryBuilder == sparseEncodingQueryBuilder); + assertSame(queryBuilder, sparseEncodingQueryBuilder); + } + + private void setUpClusterService(Version version) { + ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); + NeuralSearchClusterUtil.instance().initialize(clusterService); } @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java index e513ab035..26dbb289b 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.BooleanClause; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; import org.opensearch.test.OpenSearchTestCase; public class NeuralSearchQueryVisitorTests extends OpenSearchTestCase { @@ -38,11 +39,49 @@ public void testAccept_whenNeuralQueryBuilderWithoutFieldModelId_thenSetFieldMod assertEquals("bdcvjkcdjvkddcjxdjsc", neuralQueryBuilder.modelId()); } + public void testAccept_whenNeuralSparseQueryBuilderWithModelId_theDoNothing() { + String modelId1 = "bdcvjkcdjvkddcjxdjsc"; + String modelId2 = "45dfsnfoiqwrjcjxdjsc"; + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("passage_text"); + neuralSparseQueryBuilder.modelId(modelId1); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(modelId2, null); + neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder); + + assertEquals(modelId1, neuralSparseQueryBuilder.modelId()); + } + + public void testAccept_whenNeuralSparseQueryBuilderWithoutModelId_thenSetModelId() { + String modelId = "bdcvjkcdjvkddcjxdjsc"; + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("passage_text"); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(modelId, null); + neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder); + + assertEquals(modelId, neuralSparseQueryBuilder.modelId()); + } + + public void testAccept_whenNeuralSparseQueryBuilderWithoutFieldModelId_thenSetFieldModelId() { + Map neuralInfoMap = new HashMap<>(); + neuralInfoMap.put("passage_text", "bdcvjkcdjvkddcjxdjsc"); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); + neuralSparseQueryBuilder.fieldName("passage_text"); + + NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, neuralInfoMap); + neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder); + + assertEquals("bdcvjkcdjvkddcjxdjsc", neuralSparseQueryBuilder.modelId()); + } + public void testAccept_whenNullValuesInVisitor_thenFail() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder(); NeuralSearchQueryVisitor neuralSearchQueryVisitor = new NeuralSearchQueryVisitor(null, null); expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralQueryBuilder)); + expectThrows(IllegalArgumentException.class, () -> neuralSearchQueryVisitor.accept(neuralSparseQueryBuilder)); } public void testGetChildVisitor() { diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 0c59b23a1..ea006db57 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -1185,7 +1185,13 @@ protected void wipeOfTestResources( deleteSearchPipeline(searchPipeline); } if (modelId != null) { - deleteModel(modelId); + try { + deleteModel(modelId); + } catch (AssertionError e) { + // sometimes we have flaky test that the model state doesn't change after call undeploy api + // for this case we can call undeploy api one more time + deleteModel(modelId); + } } if (indexName != null) { deleteIndex(indexName);