Skip to content

Commit

Permalink
Gates efSearch behind 2.15 version to make sure there is backward
Browse files Browse the repository at this point in the history
compatibility during upgrades

Signed-off-by: Tejas Shah <[email protected]>
  • Loading branch information
shatejas committed May 20, 2024
1 parent 5c82673 commit 5065709
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 65 deletions.
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class KNNConstants {
public static final String KNN = "knn";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String EF_SEARCH = "ef_Search";
public static final String TYPE_KNN_VECTOR = "knn_vector";
public static final String METHOD_PARAMETER_EF_SEARCH = "ef_search";
public static final String METHOD_PARAMETER_EF_CONSTRUCTION = "ef_construction";
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/IndexUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public class IndexUtil {
put(MODEL_NODE_ASSIGNMENT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_NODE_ASSIGNMENT);
put(MODEL_METHOD_COMPONENT_CONTEXT_KEY, MINIMAL_SUPPORTED_VERSION_FOR_MODEL_METHOD_COMPONENT_CONTEXT);
put(KNNConstants.RADIAL_SEARCH_KEY, MINIMAL_SUPPORTED_VERSION_FOR_RADIAL_SEARCH);
put(KNNConstants.EF_SEARCH, Version.V_2_15_0);
}
};

Expand Down
33 changes: 11 additions & 22 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,17 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
*/
private final String fieldName;
private final float[] vector;
@Getter
private int k;
@Getter
private Integer efSearch;
@Getter
private Float maxDistance;
@Getter
private Float minScore;
@Getter
private QueryBuilder filter;
@Getter
private boolean ignoreUnmapped;

/**
Expand Down Expand Up @@ -300,7 +305,6 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
vector = in.readFloatArray();
k = in.readInt();
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
efSearch = in.readOptionalInt();
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
ignoreUnmapped = in.readOptionalBoolean();
}
Expand All @@ -310,6 +314,9 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
minScore = in.readOptionalFloat();
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.EF_SEARCH)) {
efSearch = in.readOptionalInt();
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
Expand Down Expand Up @@ -404,7 +411,6 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeFloatArray(vector);
out.writeInt(k);
out.writeOptionalNamedWriteable(filter);
out.writeOptionalInt(efSearch);
if (isClusterOnOrAfterMinRequiredVersion("ignore_unmapped")) {
out.writeOptionalBoolean(ignoreUnmapped);
}
Expand All @@ -414,6 +420,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.RADIAL_SEARCH_KEY)) {
out.writeOptionalFloat(minScore);
}
if (isClusterOnOrAfterMinRequiredVersion(KNNConstants.EF_SEARCH)) {
out.writeOptionalInt(efSearch);
}
}

/**
Expand All @@ -430,26 +439,6 @@ public Object vector() {
return this.vector;
}

public int getK() {
return this.k;
}

public float getMaxDistance() {
return this.maxDistance;
}

public float getMinScore() {
return this.minScore;
}

public QueryBuilder getFilter() {
return this.filter;
}

public boolean getIgnoreUnmapped() {
return this.ignoreUnmapped;
}

@Override
public void doXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject(NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.IndexSettings;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.index.query.*;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.index.Index;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.MethodComponentContext;
Expand Down Expand Up @@ -970,27 +967,38 @@ public void testDoToQuery_InvalidZeroByteVector() {

public void testSerialization() throws Exception {
// For k-NN search
assertSerialization(Version.CURRENT, Optional.empty(), K, null, null);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, null, null);
assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null);
assertSerialization(Version.CURRENT, Optional.empty(), K, null, null, null);
assertSerialization(Version.CURRENT, Optional.empty(), K, EF_SEARCH, null, null);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), K, EF_SEARCH, null, null);
assertSerialization(Version.V_2_3_0, Optional.empty(), K, EF_SEARCH, null, null);
assertSerialization(Version.V_2_3_0, Optional.empty(), K, null, null, null);

// For distance threshold search
assertSerialization(Version.CURRENT, Optional.empty(), null, MAX_DISTANCE, null);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, MAX_DISTANCE, null);
assertSerialization(Version.CURRENT, Optional.empty(), null, null, MAX_DISTANCE, null);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MAX_DISTANCE, null);

// For score threshold search
assertSerialization(Version.CURRENT, Optional.empty(), null, null, MIN_SCORE);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, MIN_SCORE);
assertSerialization(Version.CURRENT, Optional.empty(), null, null, null, MIN_SCORE);
assertSerialization(Version.CURRENT, Optional.of(TERM_QUERY), null, null, null, MIN_SCORE);
}

private void assertSerialization(
final Version version,
final Optional<QueryBuilder> queryBuilderOptional,
Integer k,
Integer efSearch,
Float distance,
Float score
) throws Exception {
final KNNQueryBuilder knnQueryBuilder = getKnnQueryBuilder(queryBuilderOptional, k, distance, score);
final KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(QUERY_VECTOR)
.maxDistance(distance)
.minScore(score)
.k(k)
.efSearch(efSearch)
.filter(queryBuilderOptional.orElse(null))
.build();

final ClusterService clusterService = mockClusterService(version);

Expand All @@ -1011,6 +1019,12 @@ private void assertSerialization(
assertArrayEquals(QUERY_VECTOR, (float[]) deserializedKnnQueryBuilder.vector(), 0.0f);
if (k != null) {
assertEquals(k.intValue(), deserializedKnnQueryBuilder.getK());
// Verifies efSearch
if (version.onOrAfter(Version.V_2_15_0)) {
assertEquals(efSearch, deserializedKnnQueryBuilder.getEfSearch());
} else {
assertNull(deserializedKnnQueryBuilder.getEfSearch());
}
} else if (distance != null) {
assertEquals(distance.floatValue(), deserializedKnnQueryBuilder.getMaxDistance(), 0.0f);
} else {
Expand All @@ -1026,44 +1040,14 @@ private void assertSerialization(
}
}

private static KNNQueryBuilder getKnnQueryBuilder(Optional<QueryBuilder> queryBuilderOptional, Integer k, Float distance, Float score) {
final KNNQueryBuilder knnQueryBuilder;
if (k != null) {
knnQueryBuilder = queryBuilderOptional.isPresent()
? new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k, queryBuilderOptional.get())
: new KNNQueryBuilder(FIELD_NAME, QUERY_VECTOR, k);
} else if (distance != null) {
knnQueryBuilder = queryBuilderOptional.isPresent()
? KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(QUERY_VECTOR)
.maxDistance(distance)
.filter(queryBuilderOptional.get())
.build()
: KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).maxDistance(distance).build();
} else if (score != null) {
knnQueryBuilder = queryBuilderOptional.isPresent()
? KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(QUERY_VECTOR)
.minScore(score)
.filter(queryBuilderOptional.get())
.build()
: KNNQueryBuilder.builder().fieldName(FIELD_NAME).vector(QUERY_VECTOR).minScore(score).build();
} else {
throw new IllegalArgumentException("Either k or distance must be provided");
}
return knnQueryBuilder;
}

public void testIgnoreUnmapped() throws IOException {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder.Builder knnQueryBuilder = KNNQueryBuilder.builder()
.fieldName(FIELD_NAME)
.vector(queryVector)
.k(K)
.ignoreUnmapped(true);
assertTrue(knnQueryBuilder.build().getIgnoreUnmapped());
assertTrue(knnQueryBuilder.build().isIgnoreUnmapped());
Query query = knnQueryBuilder.build().doToQuery(mock(QueryShardContext.class));
assertNotNull(query);
assertThat(query, instanceOf(MatchNoDocsQuery.class));
Expand Down

0 comments on commit 5065709

Please sign in to comment.