diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index d284015af..99d683a2e 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -24,6 +24,7 @@ public class KNNConstants { public static final String KNN = "knn"; public static final String VECTOR = "vector"; public static final String K = "k"; + public static final String EF_SEARCH = "ef_Search"; public static final String TYPE_KNN_VECTOR = "knn_vector"; public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search"; public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction"; diff --git a/src/main/java/org/opensearch/knn/index/IndexUtil.java b/src/main/java/org/opensearch/knn/index/IndexUtil.java index 208da2ea9..145fb6edd 100644 --- a/src/main/java/org/opensearch/knn/index/IndexUtil.java +++ b/src/main/java/org/opensearch/knn/index/IndexUtil.java @@ -51,6 +51,7 @@ public class IndexUtil { put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT); put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT); put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH); + put(KNNConstants.EF_SEARCH, Version.V_2_15_0); } }; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 0a1f2d513..57a338e1e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -71,12 +71,17 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { */ private final String fieldName; private final float[] vector; + @Getter private int k; @Getter private Integer efSearch; + @Getter private Float maxDistance; + @Getter private Float minScore; + @Getter private QueryBuilder filter; + @Getter private boolean ignoreUnmapped; /** @@ -300,7 +305,6 @@ public KNNQueryBuilder(StreamInput in) throws IOException { vector = in.readFloatArray(); k = in.readInt(); filter = in.readOptionalNamedWriteable(QueryBuilder.class); - efSearch = in.readOptionalInt(); if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { ignoreUnmapped = in.readOptionalBoolean(); } @@ -310,6 +314,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { minScore = in.readOptionalFloat(); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.EF_SEARCH)) { + efSearch = in.readOptionalInt(); + } } catch (IOException ex) { throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex); } @@ -404,7 +411,6 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeFloatArray(vector); out.writeInt(k); out.writeOptionalNamedWriteable(filter); - out.writeOptionalInt(efSearch); if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) { out.writeOptionalBoolean(ignoreUnmapped); } @@ -414,6 +420,9 @@ protected void doWriteTo(StreamOutput out) throws IOException { if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) { out.writeOptionalFloat(minScore); } + if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.EF_SEARCH)) { + out.writeOptionalInt(efSearch); + } } /** @@ -430,26 +439,6 @@ public Object vector() { return this.vector; } - public int getK() { - return this.k; - } - - public float getMaxDistance() { - return this.maxDistance; - } - - public float getMinScore() { - return this.minScore; - } - - public QueryBuilder getFilter() { - return this.filter; - } - - public boolean getIgnoreUnmapped() { - return this.ignoreUnmapped; - } - @Override public void doXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(NAME); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index 3d1c0be8f..d20f3e6bb 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -20,16 +20,13 @@ import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.index.IndexSettings; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.index.query.QueryBuilders; -import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.index.query.*; import org.opensearch.knn.KNNTestCase; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentParser; import org.opensearch.core.index.Index; import org.opensearch.index.mapper.NumberFieldMapper; -import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.KNNClusterUtil; import org.opensearch.knn.index.KNNMethodContext; import org.opensearch.knn.index.MethodComponentContext; @@ -970,27 +967,38 @@ public void testDoToQuery_InvalidZeroByteVector() { public void testSerialization() throws Exception { // For k-NN search - assertSerialization(Version.CURRENT, Optional.empty(), K, null, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null); - assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null); + assertSerialization(Version.CURRENT, Optional.empty(), K, EF_SEARCH, null, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, EF_SEARCH, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, EF_SEARCH, null, null); + assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null); // For distance threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, MAX_DISTANCE, null); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MAX_DISTANCE, null); // For score threshold search - assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE); - assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE); + assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE); } private void assertSerialization( final Version version, final Optional queryBuilderOptional, Integer k, + Integer efSearch, Float distance, Float score ) throws Exception { - final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score); + final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(FIELD_NAME) + .vector(QUERY_VECTOR) + .maxDistance(distance) + .minScore(score) + .k(k) + .efSearch(efSearch) + .filter(queryBuilderOptional.orElse(null)) + .build(); final ClusterService clusterService = mockClusterService(version); @@ -1011,6 +1019,12 @@ private void assertSerialization( assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f); if (k != null) { assertEquals(k.intValue(), deserializedKnnQueryBuilder.getK()); + // Verifies efSearch + if (version.onOrAfter(Version.V_2_15_0)) { + assertEquals(efSearch, deserializedKnnQueryBuilder.getEfSearch()); + } else { + assertNull(deserializedKnnQueryBuilder.getEfSearch()); + } } else if (distance != null) { assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getMaxDistance(), 0.0f); } else { @@ -1026,36 +1040,6 @@ private void assertSerialization( } } - private static KNNQueryBuilder getKnnQueryBuilder(Optional queryBuilderOptional, Integer k, Float distance, Float score) { - final KNNQueryBuilder knnQueryBuilder; - if (k != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get()) - : new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k); - } else if (distance != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(QUERY_VECTOR) - .maxDistance(distance) - .filter(queryBuilderOptional.get()) - .build() - : KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(distance).build(); - } else if (score != null) { - knnQueryBuilder = queryBuilderOptional.isPresent() - ? KNNQueryBuilder.builder() - .fieldName(FIELD_NAME) - .vector(QUERY_VECTOR) - .minScore(score) - .filter(queryBuilderOptional.get()) - .build() - : KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(score).build(); - } else { - throw new IllegalArgumentException("Either k or distance must be provided"); - } - return knnQueryBuilder; - } - public void testIgnoreUnmapped() throws IOException { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder() @@ -1063,7 +1047,7 @@ public void testIgnoreUnmapped() throws IOException { .vector(queryVector) .k(K) .ignoreUnmapped(true); - assertTrue(knnQueryBuilder.build().getIgnoreUnmapped()); + assertTrue(knnQueryBuilder.build().isIgnoreUnmapped()); Query query = knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class)); assertNotNull(query); assertThat(query, instanceOf(MatchNoDocsQuery.class));