diff --git a/CHANGELOG.md b/CHANGELOG.md index f57bc71fe..8291a4953 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317) * Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367) * Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403) +* Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder [#1397](https://github.com/opensearch-project/k-NN/pull/1397) ### Infrastructure * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) * Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307) diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java index 74a289994..9c78b18a1 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQuery.java @@ -5,6 +5,8 @@ package org.opensearch.knn.index.query; +import java.util.Arrays; +import java.util.Objects; import lombok.Getter; import lombok.Setter; import org.apache.lucene.search.BooleanClause; @@ -127,7 +129,7 @@ public String toString(String field) { @Override public int hashCode() { - return field.hashCode() ^ queryVector.hashCode() ^ k; + return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery); } @Override @@ -136,6 +138,10 @@ public boolean equals(Object other) { } private boolean equalsTo(KNNQuery other) { - return this.field.equals(other.getField()) && this.queryVector.equals(other.getQueryVector()) && this.k == other.getK(); + return Objects.equals(field, other.field) + && Arrays.equals(queryVector, other.queryVector) + && Objects.equals(k, other.k) + && Objects.equals(indexName, other.indexName) + && Objects.equals(filterQuery, other.filterQuery); } -}; +} diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 096c2e30b..eb43f67f7 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.query; +import java.util.Arrays; import lombok.extern.log4j.Log4j2; import org.apache.lucene.search.MatchNoDocsQuery; import org.opensearch.core.common.Strings; @@ -46,7 +47,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder { public static final ParseField K_FIELD = new ParseField("k"); public static final ParseField FILTER_FIELD = new ParseField("filter"); public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped"); - public static int K_MAX = 10000; + public static final int K_MAX = 10000; /** * The name for the knn query */ @@ -346,12 +347,16 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie @Override protected boolean doEquals(KNNQueryBuilder other) { - return Objects.equals(fieldName, other.fieldName) && Objects.equals(vector, other.vector) && Objects.equals(k, other.k); + return Objects.equals(fieldName, other.fieldName) + && Arrays.equals(vector, other.vector) + && Objects.equals(k, other.k) + && Objects.equals(filter, other.filter) + && Objects.equals(ignoreUnmapped, other.ignoreUnmapped); } @Override protected int doHashCode() { - return Objects.hash(fieldName, vector, k); + return Objects.hash(fieldName, Arrays.hashCode(vector), k, filter, ignoreUnmapped); } @Override diff --git a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java index a981c684e..3ea469ada 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNQueryBuilderTests.java @@ -90,7 +90,7 @@ public void testEmptyVector() { expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K)); } - public void testFromXcontent() throws Exception { + public void testFromXContent() throws Exception { float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f }; KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K); XContentBuilder builder = XContentFactory.jsonBuilder(); @@ -103,10 +103,10 @@ public void testFromXcontent() throws Exception { XContentParser contentParser = createParser(builder); contentParser.nextToken(); KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - actualBuilder.equals(knnQueryBuilder); + assertEquals(knnQueryBuilder, actualBuilder); } - public void testFromXcontent_WithFilter() throws Exception { + public void testFromXContent_WithFilter() throws Exception { final ClusterService clusterService = mockClusterService(Version.CURRENT); final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance(); @@ -125,7 +125,7 @@ public void testFromXcontent_WithFilter() throws Exception { XContentParser contentParser = createParser(builder); contentParser.nextToken(); KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser); - actualBuilder.equals(knnQueryBuilder); + assertEquals(knnQueryBuilder, actualBuilder); } public void testFromXContent_invalidQueryVectorType() throws Exception {