From c06d07a0fa271b0cc02885bf92370fd348c85b38 Mon Sep 17 00:00:00 2001 From: Varun Jain Date: Mon, 25 Sep 2023 17:54:06 -0700 Subject: [PATCH] Support for default model Id Signed-off-by: Varun Jain --- .../neuralsearch/plugin/NeuralSearch.java | 12 ++--- .../processor/NeuralQueryProcessor.java | 36 ++++++------- .../query/NeuralQueryBuilder.java | 34 ++++++------- .../visitor/NeuralSearchQueryVisitor.java | 8 +-- .../util/NeuralSearchClusterUtil.java | 12 +++-- .../common/BaseNeuralSearchIT.java | 23 +++++++++ .../plugin/NeuralSearchTests.java | 12 +++++ .../processor/NeuralQueryProcessorTests.java | 50 ++++++++++++++++++ .../util/NeuralSearchClusterTestUtils.java | 32 ++++++++++++ .../util/NeuralSearchClusterUtilTests.java | 51 +++++++++++++++++++ .../SearchRequestPipelineConfiguration.json | 11 ++++ 11 files changed, 233 insertions(+), 48 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java create mode 100644 src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java create mode 100644 src/test/resources/processor/SearchRequestPipelineConfiguration.json diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index b01ef703d..57ac4feb5 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -5,8 +5,8 @@ package org.opensearch.neuralsearch.plugin; -import org.opensearch.neuralsearch.processor.NeuralQueryProcessor; import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; + import java.util.Arrays; import java.util.Collection; import java.util.Collections; @@ -29,6 +29,7 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.NeuralQueryProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; @@ -132,10 +133,9 @@ public List> getSettings() { } @Override - public Map> getRequestProcessors(Parameters parameters) { - return Map.of( - NeuralQueryProcessor.TYPE, - new NeuralQueryProcessor.Factory() - ); + public Map> getRequestProcessors( + Parameters parameters + ) { + return Map.of(NeuralQueryProcessor.TYPE, new NeuralQueryProcessor.Factory()); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java index 2b9906e84..59a8688ca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java @@ -20,11 +20,11 @@ public class NeuralQueryProcessor extends AbstractProcessor implements SearchReq /** * Key to reference this processor type from a search pipeline. */ - public static final String TYPE = "default_query"; + public static final String TYPE = "neural_query"; - private final String modelId; + final String modelId; - private final Map fieldInfoMap; + final Map neuralFieldMap; /** * Returns the type of the processor. @@ -37,39 +37,39 @@ public String getType() { } protected NeuralQueryProcessor( - String tag, - String description, - boolean ignoreFailure, - String modelId, - Map fieldInfoMap + String tag, + String description, + boolean ignoreFailure, + String modelId, + Map fieldInfoMap ) { super(tag, description, ignoreFailure); this.modelId = modelId; - this.fieldInfoMap = fieldInfoMap; + this.neuralFieldMap = fieldInfoMap; } @Override public SearchRequest processRequest(SearchRequest searchRequest) throws Exception { QueryBuilder queryBuilder = searchRequest.source().query(); - queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, fieldInfoMap)); + queryBuilder.visit(new NeuralSearchQueryVisitor(modelId, neuralFieldMap)); return searchRequest; } public static class Factory implements Processor.Factory { private static final String DEFAULT_MODEL_ID = "default_model_id"; - private static final String NEURAL_FIELD_MAP = "neural_field_map"; + private static final String NEURAL_FIELD_DEFAULT_ID = "neural_field_default_id"; @Override public NeuralQueryProcessor create( - Map> processorFactories, - String tag, - String description, - boolean ignoreFailure, - Map config, - PipelineContext pipelineContext + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext ) throws Exception { String modelId = (String) config.remove(DEFAULT_MODEL_ID); - Map neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_MAP); + Map neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID); if (modelId == null && neuralInfoMap == null) { throw new IllegalArgumentException("model Id or neural info map either of them should be provided"); diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index c59a1173c..799cc22de 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -160,14 +160,14 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx parseQueryParams(parser, neuralQueryBuilder); if (parser.nextToken() != XContentParser.Token.END_OBJECT) { throw new ParsingException( - parser.getTokenLocation(), - "[" - + NAME - + "] query doesn't support multiple fields, found [" - + neuralQueryBuilder.fieldName() - + "] and [" - + parser.currentName() - + "]" + parser.getTokenLocation(), + "[" + + NAME + + "] query doesn't support multiple fields, found [" + + neuralQueryBuilder.fieldName() + + "] and [" + + parser.currentName() + + "]" ); } requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); @@ -197,8 +197,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n neuralQueryBuilder.boost(parser.floatValue()); } else { throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] query does not support [" + currentFieldName + "]" + parser.getTokenLocation(), + "[" + NAME + "] query does not support [" + currentFieldName + "]" ); } } else if (token == XContentParser.Token.START_OBJECT) { @@ -207,8 +207,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } } else { throw new ParsingException( - parser.getTokenLocation(), - "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" + parser.getTokenLocation(), + "[" + NAME + "] unknown token [" + token + "] after [" + currentFieldName + "]" ); } } @@ -231,10 +231,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { SetOnce vectorSetOnce = new SetOnce<>(); queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { - vectorSetOnce.set(vectorAsListToArray(floatList)); - actionListener.onResponse(null); - }, actionListener::onFailure))) + ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { + vectorSetOnce.set(vectorAsListToArray(floatList)); + actionListener.onResponse(null); + }, actionListener::onFailure))) ); return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter()); } @@ -271,4 +271,4 @@ public String getWriteableName() { private static boolean isClusterOnOrAfterMinRequiredVersion() { return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java index c1f66ee9e..2d0b53eaa 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java +++ b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java @@ -27,15 +27,15 @@ public void accept(QueryBuilder queryBuilder) { if (queryBuilder instanceof NeuralQueryBuilder) { NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder; if (neuralFieldMap != null - && neuralQueryBuilder.fieldName() != null - && neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { + && neuralQueryBuilder.fieldName() != null + && neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName()); neuralQueryBuilder.modelId(fieldDefaultModelId); } else if (modelId != null) { neuralQueryBuilder.modelId(modelId); } else { throw new IllegalArgumentException( - "model id must be provided in neural query or a default model id must be set in search request processor" + "model id must be provided in neural query or a default model id must be set in search request processor" ); } } @@ -45,4 +45,4 @@ public void accept(QueryBuilder queryBuilder) { public QueryBuilderVisitor getChildVisitor(BooleanClause.Occur occur) { return this; } -} \ No newline at end of file +} diff --git a/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java index 1e0b22094..f9e044c53 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.util; +import java.util.Locale; + import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -47,11 +49,15 @@ public Version getClusterMinVersion() { return this.clusterService.state().getNodes().getMinNodeVersion(); } catch (Exception exception) { log.error( - String.format("Failed to get cluster minimum node version, returning current node version %s instead.", Version.CURRENT), - exception + String.format( + Locale.ROOT, + "Failed to get cluster minimum node version, returning current node version %s instead.", + Version.CURRENT + ), + exception ); return Version.CURRENT; } } -} \ No newline at end of file +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index b144ade6c..fac32b538 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -253,6 +253,29 @@ protected void createPipelineProcessor(String modelId, String pipelineName) thro assertEquals("true", node.get("acknowledged").toString()); } + protected void createSearchRequestProcessor(String modelId, String pipelineName) throws Exception { + Response pipelineCreateResponse = makeRequest( + client(), + "PUT", + "/_search/pipeline/" + pipelineName, + null, + toHttpEntity( + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource("processor/SearchRequestPipelineConfiguration.json").toURI())), + modelId + ) + ), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + Map node = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(pipelineCreateResponse.getEntity()), + false + ); + assertEquals("true", node.get("acknowledged").toString()); + } + /** * Get the number of documents in a particular index * diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 8918e174c..9c8a23d03 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -12,6 +12,7 @@ import java.util.Optional; import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.processor.NeuralQueryProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; @@ -22,6 +23,7 @@ import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.query.QueryPhaseSearcher; public class NeuralSearchTests extends OpenSearchQueryTestCase { @@ -73,4 +75,14 @@ public void testSearchPhaseResultsProcessors() { ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); } + + public void testRequestProcessors() { + NeuralSearch plugin = new NeuralSearch(); + SearchPipelinePlugin.Parameters parameters = mock(SearchPipelinePlugin.Parameters.class); + Map> processors = plugin.getRequestProcessors( + parameters + ); + assertNotNull(processors); + assertNotNull(processors.get(NeuralQueryProcessor.TYPE)); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java new file mode 100644 index 000000000..0cd9bea3a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.test.OpenSearchTestCase; + +public class NeuralQueryProcessorTests extends OpenSearchTestCase { + + public void testFactory() throws Exception { + NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory(); + NeuralQueryProcessor processor = createTestProcessor(factory); + assertEquals("vasdcvkcjkbldbjkd", processor.modelId); + assertEquals("bahbkcdkacb", processor.neuralFieldMap.get("fieldName").toString()); + + // Missing "query" parameter: + expectThrows( + IllegalArgumentException.class, + () -> factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null) + ); + } + + public void testProcessRequest() throws Exception { + NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder().query(neuralQueryBuilder)); + NeuralQueryProcessor processor = createTestProcessor(factory); + SearchRequest processSearchRequest = processor.processRequest(searchRequest); + assertEquals(processSearchRequest, searchRequest); + } + + public NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception { + Map configMap = new HashMap<>(); + configMap.put("default_model_id", "vasdcvkcjkbldbjkd"); + configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb")); + NeuralQueryProcessor processor = factory.create(Collections.emptyMap(), null, null, false, configMap, null); + return processor; + } + +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java new file mode 100644 index 000000000..30399cfea --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterTestUtils.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import org.opensearch.Version; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.node.DiscoveryNodes; +import org.opensearch.cluster.service.ClusterService; + +public class NeuralSearchClusterTestUtils { + + /** + * Create new mock for ClusterService + * @param version min version for cluster nodes + * @return + */ + public static ClusterService mockClusterService(final Version version) { + ClusterService clusterService = mock(ClusterService.class); + ClusterState clusterState = mock(ClusterState.class); + when(clusterService.state()).thenReturn(clusterState); + DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); + when(clusterState.getNodes()).thenReturn(discoveryNodes); + when(discoveryNodes.getMinNodeVersion()).thenReturn(version); + return clusterService; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java new file mode 100644 index 000000000..9f619d4c2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.util; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils.mockClusterService; + +import org.opensearch.Version; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.test.OpenSearchTestCase; + +public class NeuralSearchClusterUtilTests extends OpenSearchTestCase { + + public void testSingleNodeCluster() { + ClusterService clusterService = mockClusterService(Version.V_2_4_0); + + final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); + neuralSearchClusterUtil.initialize(clusterService); + + final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_4_0.equals(minVersion)); + } + + public void testMultipleNodesCluster() { + ClusterService clusterService = mockClusterService(Version.V_2_3_0); + + final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); + neuralSearchClusterUtil.initialize(clusterService); + + final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); + + assertTrue(Version.V_2_3_0.equals(minVersion)); + } + + public void testWhenErrorOnClusterStateDiscover() { + ClusterService clusterService = mock(ClusterService.class); + when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); + + final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); + neuralSearchClusterUtil.initialize(clusterService); + + final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); + + assertTrue(Version.CURRENT.equals(minVersion)); + } +} diff --git a/src/test/resources/processor/SearchRequestPipelineConfiguration.json b/src/test/resources/processor/SearchRequestPipelineConfiguration.json new file mode 100644 index 000000000..3f3675ce3 --- /dev/null +++ b/src/test/resources/processor/SearchRequestPipelineConfiguration.json @@ -0,0 +1,11 @@ +{ + "request_processors": [ + { + "default_query": { + "tag": "tag1", + "description": "This processor is going to restrict to publicly visible documents", + "default_model_id": "%s" + } + } + ] +} \ No newline at end of file