diff --git a/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java new file mode 100644 index 000000000..a1001c455 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/query/ModelInferenceQueryBuilder.java @@ -0,0 +1,27 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +/** + * Query builders which calls ml-commons API to do model inference. + * The model inference result is used for search on target field. + */ + +public interface ModelInferenceQueryBuilder { + /** + * Get the model id used by ml-commons model inference. Return null if the model id is absent. + */ + public String modelId(); + + /** + * Set a new model id for the query builder. + */ + public ModelInferenceQueryBuilder modelId(String modelId); + + /** + * Get the field name for search. + */ + public String fieldName(); +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index cda01767e..d74378617 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -58,7 +58,7 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class NeuralQueryBuilder extends AbstractQueryBuilder { +public class NeuralQueryBuilder extends AbstractQueryBuilder implements ModelInferenceQueryBuilder { public static final String NAME = "neural"; diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 226594a87..48c722011 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -55,7 +55,7 @@ @Accessors(chain = true, fluent = true) @NoArgsConstructor @AllArgsConstructor -public class NeuralSparseQueryBuilder extends AbstractQueryBuilder { +public class NeuralSparseQueryBuilder extends AbstractQueryBuilder implements ModelInferenceQueryBuilder { public static final String NAME = "neural_sparse"; @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); diff --git a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java index 9dab0a695..6fd4d0708 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java +++ b/src/main/java/org/opensearch/neuralsearch/query/visitor/NeuralSearchQueryVisitor.java @@ -9,12 +9,12 @@ import org.apache.lucene.search.BooleanClause; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilderVisitor; -import org.opensearch.neuralsearch.query.NeuralQueryBuilder; +import org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder; import lombok.AllArgsConstructor; /** - * Neural Search Query Visitor. It visits each and every component of query buikder tree. + * Neural Search Query Visitor. It visits each and every component of query builder tree. */ @AllArgsConstructor public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { @@ -28,16 +28,16 @@ public class NeuralSearchQueryVisitor implements QueryBuilderVisitor { */ @Override public void accept(QueryBuilder queryBuilder) { - if (queryBuilder instanceof NeuralQueryBuilder) { - NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder; - if (neuralQueryBuilder.modelId() == null) { + if (queryBuilder instanceof ModelInferenceQueryBuilder) { + ModelInferenceQueryBuilder modelInferenceQueryBuilder = (ModelInferenceQueryBuilder) queryBuilder; + if (modelInferenceQueryBuilder.modelId() == null) { if (neuralFieldMap != null - && neuralQueryBuilder.fieldName() != null - && neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) { - String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName()); - neuralQueryBuilder.modelId(fieldDefaultModelId); + && modelInferenceQueryBuilder.fieldName() != null + && neuralFieldMap.get(modelInferenceQueryBuilder.fieldName()) != null) { + String fieldDefaultModelId = (String) neuralFieldMap.get(modelInferenceQueryBuilder.fieldName()); + modelInferenceQueryBuilder.modelId(fieldDefaultModelId); } else if (modelId != null) { - neuralQueryBuilder.modelId(modelId); + modelInferenceQueryBuilder.modelId(modelId); } else { throw new IllegalArgumentException( "model id must be provided in neural query or a default model id must be set in search request processor"