diff --git a/.idea/runConfigurations/Debug_OpenSearch.xml b/.idea/runConfigurations/Debug_OpenSearch.xml
index 0d8bf59823acf..c18046f873477 100644
--- a/.idea/runConfigurations/Debug_OpenSearch.xml
+++ b/.idea/runConfigurations/Debug_OpenSearch.xml
@@ -6,6 +6,10 @@
+
+
+
+
-
+
\ No newline at end of file
diff --git a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java
index 91d96b3e2c2f7..90a188c9d6743 100644
--- a/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java
+++ b/server/src/main/java/org/opensearch/index/mapper/DateFieldMapper.java
@@ -38,7 +38,6 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.BoostQuery;
-import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSortSortedNumericDocValuesRangeQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
@@ -62,8 +61,8 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
+import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery;
import org.opensearch.search.approximate.ApproximatePointRangeQuery;
-import org.opensearch.search.approximate.ApproximateScoreQuery;
import org.opensearch.search.lookup.SearchLookup;
import java.io.IOException;
@@ -470,43 +469,42 @@ public Query rangeQuery(
Query pointRangeQuery = isSearchable() ? createPointRangeQuery(l, u) : null;
Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null;
if (isSearchable() && hasDocValues()) {
- Query query = new IndexOrDocValuesQuery(pointRangeQuery, dvQuery);
-
+ Query query = new ApproximateIndexOrDocValuesQuery(
+ pointRangeQuery,
+ new ApproximatePointRangeQuery(
+ name(),
+ pack(new long[] { l }).bytes,
+ pack(new long[] { u }).bytes,
+ new long[] { l }.length
+ ) {
+ @Override
+ protected String toString(int dimension, byte[] value) {
+ return Long.toString(LongPoint.decodeDimension(value, 0));
+ }
+ },
+ dvQuery
+ );
if (context.indexSortedOnField(name())) {
query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query);
}
return query;
}
if (hasDocValues()) {
- Query query = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u);
if (context.indexSortedOnField(name())) {
- query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query);
+ dvQuery = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, dvQuery);
}
- return query;
+ return dvQuery;
}
return pointRangeQuery;
});
}
private Query createPointRangeQuery(long l, long u) {
- return new ApproximateScoreQuery(
- new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
- protected String toString(int dimension, byte[] value) {
- return Long.toString(LongPoint.decodeDimension(value, 0));
- }
- },
- new ApproximatePointRangeQuery(
- name(),
- pack(new long[] { l }).bytes,
- pack(new long[] { u }).bytes,
- new long[] { l }.length
- ) {
- @Override
- protected String toString(int dimension, byte[] value) {
- return Long.toString(LongPoint.decodeDimension(value, 0));
- }
+ return new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
+ protected String toString(int dimension, byte[] value) {
+ return Long.toString(LongPoint.decodeDimension(value, 0));
}
- );
+ };
}
public static Query dateRangeQuery(
diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java
new file mode 100644
index 0000000000000..f646216054dc4
--- /dev/null
+++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateIndexOrDocValuesQuery.java
@@ -0,0 +1,69 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+package org.opensearch.search.approximate;
+
+import org.apache.lucene.search.IndexOrDocValuesQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Weight;
+
+import java.io.IOException;
+
+public final class ApproximateIndexOrDocValuesQuery extends ApproximateScoreQuery {
+
+ private final ApproximateableQuery approximateIndexQuery;
+ private final IndexOrDocValuesQuery indexOrDocValuesQuery;
+
+ public ApproximateIndexOrDocValuesQuery(Query indexQuery, ApproximateableQuery approximateIndexQuery, Query dvQuery) {
+ super(new IndexOrDocValuesQuery(indexQuery, dvQuery), approximateIndexQuery);
+ this.approximateIndexQuery = approximateIndexQuery;
+ this.indexOrDocValuesQuery = new IndexOrDocValuesQuery(indexQuery, dvQuery);
+ }
+
+ @Override
+ public String toString(String field) {
+ return "ApproximateIndexOrDocValuesQuery(indexQuery="
+ + indexOrDocValuesQuery.getIndexQuery().toString(field)
+ + ", approximateIndexQuery="
+ + approximateIndexQuery.toString(field)
+ + ", dvQuery="
+ + indexOrDocValuesQuery.getRandomAccessQuery().toString(field)
+ + ")";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ indexOrDocValuesQuery.visit(visitor);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (sameClassAs(obj) == false) {
+ return false;
+ }
+ ApproximateIndexOrDocValuesQuery that = (ApproximateIndexOrDocValuesQuery) obj;
+ return indexOrDocValuesQuery.getIndexQuery().equals(that.indexOrDocValuesQuery.getIndexQuery())
+ && indexOrDocValuesQuery.getRandomAccessQuery().equals(that.indexOrDocValuesQuery.getRandomAccessQuery());
+ }
+
+ @Override
+ public int hashCode() {
+ return indexOrDocValuesQuery.hashCode();
+ }
+
+ @Override
+ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+ if (approximateIndexQuery.canApproximate(this.getContext())) {
+ return approximateIndexQuery.createWeight(searcher, scoreMode, boost);
+ }
+ return indexOrDocValuesQuery.createWeight(searcher, scoreMode, boost);
+ }
+}
diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java
index 3932d2d04767b..2b7b4f576b240 100644
--- a/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java
+++ b/server/src/main/java/org/opensearch/search/approximate/ApproximatePointRangeQuery.java
@@ -14,7 +14,6 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.QueryVisitor;
@@ -430,21 +429,17 @@ public boolean canApproximate(SearchContext context) {
if (context.aggregations() != null) {
return false;
}
- if (!(context.query() instanceof IndexOrDocValuesQuery
- && ((IndexOrDocValuesQuery) context.query()).getIndexQuery() instanceof ApproximateScoreQuery
- && ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery())
- .getOriginalQuery() instanceof PointRangeQuery)) {
+ if (!(context.query() instanceof ApproximateIndexOrDocValuesQuery)) {
return false;
}
- ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery());
- ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
+ this.setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
if (context.request() != null && context.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) {
if (primarySortField.order() == SortOrder.DESC) {
- ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC);
+ this.setSortOrder(SortOrder.DESC);
}
}
}
diff --git a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java
index 1673d345ac1bf..bfe6fc2c59c40 100644
--- a/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java
+++ b/server/src/main/java/org/opensearch/search/approximate/ApproximateScoreQuery.java
@@ -8,16 +8,11 @@
package org.opensearch.search.approximate;
-import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
-import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
-import org.apache.lucene.search.Matches;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
-import org.apache.lucene.search.Scorer;
-import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.search.internal.SearchContext;
@@ -28,29 +23,16 @@
*
* This class is heavily inspired by {@link org.apache.lucene.search.IndexOrDocValuesQuery}. It acts as a wrapper that consumer two queries, a regular query and an approximate version of the same. By default, it executes the regular query and returns {@link Weight#scorer} for the original query. At run-time, depending on certain constraints, we can re-write the {@code Weight} to use the approximate weight instead.
*/
-public final class ApproximateScoreQuery extends Query {
+public class ApproximateScoreQuery extends Query {
private final Query originalQuery;
private final ApproximateableQuery approximationQuery;
- private Weight originalQueryWeight, approximationQueryWeight;
-
private SearchContext context;
public ApproximateScoreQuery(Query originalQuery, ApproximateableQuery approximationQuery) {
- this(originalQuery, approximationQuery, null, null);
- }
-
- public ApproximateScoreQuery(
- Query originalQuery,
- ApproximateableQuery approximationQuery,
- Weight originalQueryWeight,
- Weight approximationQueryWeight
- ) {
this.originalQuery = originalQuery;
this.approximationQuery = approximationQuery;
- this.originalQueryWeight = originalQueryWeight;
- this.approximationQueryWeight = approximationQueryWeight;
}
public Query getOriginalQuery() {
@@ -63,72 +45,20 @@ public ApproximateableQuery getApproximationQuery() {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
- originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost);
- approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost);
-
- return new Weight(this) {
- @Override
- public Explanation explain(LeafReaderContext leafReaderContext, int doc) throws IOException {
- return originalQueryWeight.explain(leafReaderContext, doc);
- }
-
- @Override
- public Matches matches(LeafReaderContext leafReaderContext, int doc) throws IOException {
- return originalQueryWeight.matches(leafReaderContext, doc);
- }
-
- @Override
- public ScorerSupplier scorerSupplier(LeafReaderContext leafReaderContext) throws IOException {
- final ScorerSupplier originalQueryScoreSupplier = originalQueryWeight.scorerSupplier(leafReaderContext);
- final ScorerSupplier approximationQueryScoreSupplier = approximationQueryWeight.scorerSupplier(leafReaderContext);
- if (originalQueryScoreSupplier == null || approximationQueryScoreSupplier == null) {
- return null;
- }
-
- return new ScorerSupplier() {
- @Override
- public Scorer get(long l) throws IOException {
- if (approximationQuery.canApproximate(context)) {
- return approximationQueryScoreSupplier.get(l);
- }
- return originalQueryScoreSupplier.get(l);
- }
-
- @Override
- public long cost() {
- return originalQueryScoreSupplier.cost();
- }
- };
- }
-
- @Override
- public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
- ScorerSupplier scorerSupplier = scorerSupplier(leafReaderContext);
- if (scorerSupplier == null) {
- return null;
- }
- return scorerSupplier.get(Long.MAX_VALUE);
- }
-
- @Override
- public boolean isCacheable(LeafReaderContext leafReaderContext) {
- return originalQueryWeight.isCacheable(leafReaderContext);
- }
-
- @Override
- public int count(LeafReaderContext leafReaderContext) throws IOException {
- if (approximationQuery.canApproximate(context)) {
- return approximationQueryWeight.count(leafReaderContext);
- }
- return originalQueryWeight.count(leafReaderContext);
- }
- };
+ if (approximationQuery.canApproximate(context)) {
+ return approximationQuery.createWeight(searcher, scoreMode, boost);
+ }
+ return originalQuery.createWeight(searcher, scoreMode, boost);
}
public void setContext(SearchContext context) {
this.context = context;
};
+ public SearchContext getContext() {
+ return context;
+ };
+
@Override
public String toString(String s) {
return "ApproximateScoreQuery(originalQuery="
diff --git a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java
index 023cdd8320a1f..6942b3dedabe6 100644
--- a/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java
+++ b/server/src/main/java/org/opensearch/search/internal/ContextIndexSearcher.java
@@ -46,7 +46,6 @@
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
@@ -220,6 +219,9 @@ public Weight createWeight(Query query, ScoreMode scoreMode, float boost) throws
profiler.pollLastElement();
}
return new ProfileWeight(query, weight, profile);
+ } else if (query instanceof ApproximateScoreQuery) {
+ ((ApproximateScoreQuery) query).setContext(searchContext);
+ return query.createWeight(this, scoreMode, boost);
} else {
return super.createWeight(query, scoreMode, boost);
}
@@ -329,11 +331,6 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto
// catch early terminated exception and rethrow?
Bits liveDocs = ctx.reader().getLiveDocs();
BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs);
- if (searchContext.query() instanceof IndexOrDocValuesQuery
- && ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateScoreQuery) {
- ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery());
- query.setContext(searchContext);
- }
if (liveDocsBitSet == null) {
BulkScorer bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer != null) {
diff --git a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java
index 6460d69aea4d4..bb3a1ebaa697d 100644
--- a/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java
+++ b/server/src/test/java/org/opensearch/search/approximate/ApproximatePointRangeQueryTests.java
@@ -21,12 +21,15 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
+import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.test.OpenSearchTestCase;
import java.io.IOException;
import static org.apache.lucene.document.LongPoint.pack;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
public class ApproximatePointRangeQueryTests extends OpenSearchTestCase {
@@ -303,4 +306,11 @@ protected String toString(int dimension, byte[] value) {
}
}
}
+
+ public void testCanApproximate() throws IOException {
+ SearchContext context = mock(SearchContext.class);
+ ApproximatePointRangeQuery approximatePointRangeQuery = mock(ApproximatePointRangeQuery.class);
+ when(approximatePointRangeQuery.canApproximate(null)).thenReturn(false);
+ when(approximatePointRangeQuery.canApproximate(context)).thenReturn(false);
+ }
}