diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java index 9c479b48e..e38d96941 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessor.java @@ -7,7 +7,11 @@ import java.util.Map; +import lombok.Getter; +import lombok.Setter; + import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.Nullable; import org.opensearch.index.query.QueryBuilder; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.neuralsearch.query.visitor.NeuralSearchQueryVisitor; @@ -16,18 +20,21 @@ import org.opensearch.search.pipeline.SearchRequestProcessor; /** - * Neural Search Query Request Processor + * Neural Search Query Request Processor, It modifies the search request with neural query clause + * and adds model Id if not present in the search query. */ +@Setter +@Getter public class NeuralQueryProcessor extends AbstractProcessor implements SearchRequestProcessor { /** * Key to reference this processor type from a search pipeline. */ - public static final String TYPE = "neural_query"; + public static final String TYPE = "enriching_query_defaults"; - final String modelId; + private final String modelId; - final Map neuralFieldDefaultIdMap; + private final Map neuralFieldDefaultIdMap; /** * Returns the type of the processor. @@ -39,12 +46,12 @@ public String getType() { return TYPE; } - protected NeuralQueryProcessor( + private NeuralQueryProcessor( String tag, String description, boolean ignoreFailure, - String modelId, - Map neuralFieldDefaultIdMap + @Nullable String modelId, + @Nullable Map neuralFieldDefaultIdMap ) { super(tag, description, ignoreFailure); this.modelId = modelId; @@ -81,7 +88,12 @@ public NeuralQueryProcessor create( Map config, PipelineContext pipelineContext ) throws IllegalArgumentException { - String modelId = (String) config.remove(DEFAULT_MODEL_ID); + String modelId; + try { + modelId = (String) config.remove(DEFAULT_MODEL_ID); + } catch (ClassCastException e) { + throw new IllegalArgumentException("model Id must of String type"); + } Map neuralInfoMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, NEURAL_FIELD_DEFAULT_ID); if (modelId == null && neuralInfoMap == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 6701b5835..7b78be269 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -96,6 +96,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.queryText = in.readString(); + // If cluster version is on or after 2.11 then default model Id support is enabled if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { this.modelId = in.readOptionalString(); } else { @@ -109,6 +110,7 @@ public NeuralQueryBuilder(StreamInput in) throws IOException { protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(this.fieldName); out.writeString(this.queryText); + // If cluster version is on or after 2.11 then default model Id support is enabled if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { out.writeOptionalString(this.modelId); } else { diff --git a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java index 9d746190c..febb35294 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java +++ b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java @@ -20,8 +20,8 @@ @AllArgsConstructor public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { - private String modelId; - private Map neuralFieldMap; + private final String modelId; + private final Map neuralFieldMap; /** * Accept method accepts every query builder from the search request, diff --git a/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java index c91cb0fae..5a97120e0 100644 --- a/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtil.java @@ -5,8 +5,6 @@ package org.opensearch.neuralsearch.util; -import java.util.Locale; - import lombok.AccessLevel; import lombok.NoArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -48,19 +46,7 @@ public void initialize(final ClusterService clusterService) { * @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version */ public Version getClusterMinVersion() { - try { - return this.clusterService.state().getNodes().getMinNodeVersion(); - } catch (Exception exception) { - log.error( - String.format( - Locale.ROOT, - "Failed to get cluster minimum node version, returning current node version %s instead.", - Version.CURRENT - ), - exception - ); - return Version.CURRENT; - } + return this.clusterService.state().getNodes().getMinNodeVersion(); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java index 176ffbad3..0a90d5c70 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NeuralQueryProcessorTests.java @@ -19,8 +19,8 @@ public class NeuralQueryProcessorTests extends OpenSearchTestCase { public void testFactory() throws Exception { NeuralQueryProcessor.Factory factory = new NeuralQueryProcessor.Factory(); NeuralQueryProcessor processor = createTestProcessor(factory); - assertEquals("vasdcvkcjkbldbjkd", processor.modelId); - assertEquals("bahbkcdkacb", processor.neuralFieldDefaultIdMap.get("fieldName").toString()); + assertEquals("vasdcvkcjkbldbjkd", processor.getModelId()); + assertEquals("bahbkcdkacb", processor.getNeuralFieldDefaultIdMap().get("fieldName").toString()); // Missing "query" parameter: expectThrows( @@ -39,7 +39,7 @@ public void testProcessRequest() throws Exception { assertEquals(processSearchRequest, searchRequest); } - public NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception { + private NeuralQueryProcessor createTestProcessor(NeuralQueryProcessor.Factory factory) throws Exception { Map configMap = new HashMap<>(); configMap.put("default_model_id", "vasdcvkcjkbldbjkd"); configMap.put("neural_field_default_id", Map.of("fieldName", "bahbkcdkacb")); diff --git a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java index 548f651e8..f85be25d5 100644 --- a/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/util/NeuralSearchClusterUtilTests.java @@ -5,8 +5,6 @@ package org.opensearch.neuralsearch.util; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils.mockClusterService; import org.opensearch.Version; @@ -36,16 +34,4 @@ public void testMinNodeVersion_whenMultipleNodesCluster_thenSuccess() { assertTrue(Version.V_2_3_0.equals(minVersion)); } - - public void testMinNodeVersion_WhenErrorOnClusterState_thenMatchCurrentVersion() { - ClusterService clusterService = mock(ClusterService.class); - when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); - - final NeuralSearchClusterUtil neuralSearchClusterUtil = NeuralSearchClusterUtil.instance(); - neuralSearchClusterUtil.initialize(clusterService); - - final Version minVersion = neuralSearchClusterUtil.getClusterMinVersion(); - - assertTrue(Version.CURRENT.equals(minVersion)); - } }