Skip to content

Commit

Permalink
feature: implement default model id for neural sparse
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Feb 28, 2024
1 parent 38eda5f commit 9b25049
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.query;

/**
* Query builders which calls ml-commons API to do model inference.
* The model inference result is used for search on target field.
*/

public interface ModelInferenceQueryBuilder {
/**
* Get the model id used by ml-commons model inference. Return null if the model id is absent.
*/
public String modelId();

/**
* Set a new model id for the query builder.
*/
public ModelInferenceQueryBuilder modelId(String modelId);

/**
* Get the field name for search.
*/
public String fieldName();
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> {
public class NeuralQueryBuilder extends AbstractQueryBuilder<NeuralQueryBuilder> implements ModelInferenceQueryBuilder {

public static final String NAME = "neural";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQueryBuilder> {
public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQueryBuilder> implements ModelInferenceQueryBuilder {
public static final String NAME = "neural_sparse";
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
import org.apache.lucene.search.BooleanClause;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilderVisitor;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.ModelInferenceQueryBuilder;

import lombok.AllArgsConstructor;

/**
* Neural Search Query Visitor. It visits each and every component of query buikder tree.
* Neural Search Query Visitor. It visits each and every component of query builder tree.
*/
@AllArgsConstructor
public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {
Expand All @@ -28,16 +28,16 @@ public class NeuralSearchQueryVisitor implements QueryBuilderVisitor {
*/
@Override
public void accept(QueryBuilder queryBuilder) {
if (queryBuilder instanceof NeuralQueryBuilder) {
NeuralQueryBuilder neuralQueryBuilder = (NeuralQueryBuilder) queryBuilder;
if (neuralQueryBuilder.modelId() == null) {
if (queryBuilder instanceof ModelInferenceQueryBuilder) {
ModelInferenceQueryBuilder modelInferenceQueryBuilder = (ModelInferenceQueryBuilder) queryBuilder;
if (modelInferenceQueryBuilder.modelId() == null) {
if (neuralFieldMap != null
&& neuralQueryBuilder.fieldName() != null
&& neuralFieldMap.get(neuralQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(neuralQueryBuilder.fieldName());
neuralQueryBuilder.modelId(fieldDefaultModelId);
&& modelInferenceQueryBuilder.fieldName() != null
&& neuralFieldMap.get(modelInferenceQueryBuilder.fieldName()) != null) {
String fieldDefaultModelId = (String) neuralFieldMap.get(modelInferenceQueryBuilder.fieldName());
modelInferenceQueryBuilder.modelId(fieldDefaultModelId);
} else if (modelId != null) {
neuralQueryBuilder.modelId(modelId);
modelInferenceQueryBuilder.modelId(modelId);
} else {
throw new IllegalArgumentException(
"model id must be provided in neural query or a default model id must be set in search request processor"
Expand Down

0 comments on commit 9b25049

Please sign in to comment.