From 96b49ed879ca4138f9fdc8bd55e00f092dee578f Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Mon, 26 Feb 2024 16:10:17 +0800 Subject: [PATCH] feature: implement default model id for neural sparse Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilder.java | 34 ++++++++++++++----- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 48c722011..6e87f4c38 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -18,6 +18,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Query; +import org.opensearch.Version; import org.opensearch.common.SetOnce; import org.opensearch.core.ParseField; import org.opensearch.core.action.ActionListener; @@ -32,6 +33,7 @@ import org.opensearch.index.query.QueryRewriteContext; import org.opensearch.index.query.QueryShardContext; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import com.google.common.annotations.VisibleForTesting; @@ -72,6 +74,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String queryText; private String modelId; private Supplier> queryTokensSupplier; + private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0; /** * Constructor from stream input @@ -83,7 +86,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { super(in); this.fieldName = in.readString(); this.queryText = in.readString(); - this.modelId = in.readString(); + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + this.modelId = in.readOptionalString(); + } else { + this.modelId = in.readString(); + } if (in.readBoolean()) { Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); this.queryTokensSupplier = () -> queryTokens; @@ -94,7 +101,11 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(fieldName); out.writeString(queryText); - out.writeString(modelId); + if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + out.writeOptionalString(this.modelId); + } else { + out.writeString(this.modelId); + } if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) { out.writeBoolean(true); out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); @@ -108,7 +119,9 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws xContentBuilder.startObject(NAME); xContentBuilder.startObject(fieldName); xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText); - xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + if (modelId != null) { + xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId); + } printBoostAndQueryName(xContentBuilder); xContentBuilder.endObject(); xContentBuilder.endObject(); @@ -152,11 +165,12 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw sparseEncodingQueryBuilder.queryText(), String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME) ); - requireValue( - sparseEncodingQueryBuilder.modelId(), - String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) - ); - + if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { + requireValue( + sparseEncodingQueryBuilder.modelId(), + String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME) + ); + } return sparseEncodingQueryBuilder; } @@ -291,4 +305,8 @@ protected int doHashCode() { public String getWriteableName() { return NAME; } + + private static boolean isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport() { + return NeuralSearchClusterUtil.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID); + } }