Skip to content

Commit

Permalink
iter
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Aug 27, 2024
1 parent ff8a2fd commit b1d5dd6
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 118 deletions.
6 changes: 5 additions & 1 deletion .idea/runConfigurations/Debug_OpenSearch.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.PointValues;
import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSortSortedNumericDocValuesRangeQuery;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.Query;
Expand All @@ -62,8 +61,8 @@
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.search.DocValueFormat;
import org.opensearch.search.approximate.ApproximateIndexOrDocValuesQuery;
import org.opensearch.search.approximate.ApproximatePointRangeQuery;
import org.opensearch.search.approximate.ApproximateScoreQuery;
import org.opensearch.search.lookup.SearchLookup;

import java.io.IOException;
Expand Down Expand Up @@ -470,43 +469,42 @@ public Query rangeQuery(
Query pointRangeQuery = isSearchable() ? createPointRangeQuery(l, u) : null;
Query dvQuery = hasDocValues() ? SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u) : null;
if (isSearchable() && hasDocValues()) {
Query query = new IndexOrDocValuesQuery(pointRangeQuery, dvQuery);

Query query = new ApproximateIndexOrDocValuesQuery(
pointRangeQuery,
new ApproximatePointRangeQuery(
name(),
pack(new long[] { l }).bytes,
pack(new long[] { u }).bytes,
new long[] { l }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
},
dvQuery
);
if (context.indexSortedOnField(name())) {
query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query);
}
return query;
}
if (hasDocValues()) {
Query query = SortedNumericDocValuesField.newSlowRangeQuery(name(), l, u);
if (context.indexSortedOnField(name())) {
query = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, query);
dvQuery = new IndexSortSortedNumericDocValuesRangeQuery(name(), l, u, dvQuery);
}
return query;
return dvQuery;
}
return pointRangeQuery;
});
}

private Query createPointRangeQuery(long l, long u) {
return new ApproximateScoreQuery(
new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
},
new ApproximatePointRangeQuery(
name(),
pack(new long[] { l }).bytes,
pack(new long[] { u }).bytes,
new long[] { l }.length
) {
@Override
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
return new PointRangeQuery(name(), pack(new long[] { l }).bytes, pack(new long[] { u }).bytes, new long[] { l }.length) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
);
};
}

public static Query dateRangeQuery(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.approximate;

import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Weight;

import java.io.IOException;

public final class ApproximateIndexOrDocValuesQuery extends ApproximateScoreQuery {

private final ApproximateableQuery approximateIndexQuery;
private final IndexOrDocValuesQuery indexOrDocValuesQuery;

public ApproximateIndexOrDocValuesQuery(Query indexQuery, ApproximateableQuery approximateIndexQuery, Query dvQuery) {
super(new IndexOrDocValuesQuery(indexQuery, dvQuery), approximateIndexQuery);
this.approximateIndexQuery = approximateIndexQuery;
this.indexOrDocValuesQuery = new IndexOrDocValuesQuery(indexQuery, dvQuery);
}

@Override
public String toString(String field) {
return "ApproximateIndexOrDocValuesQuery(indexQuery="
+ indexOrDocValuesQuery.getIndexQuery().toString(field)
+ ", approximateIndexQuery="
+ approximateIndexQuery.toString(field)
+ ", dvQuery="
+ indexOrDocValuesQuery.getRandomAccessQuery().toString(field)
+ ")";
}

@Override
public void visit(QueryVisitor visitor) {
indexOrDocValuesQuery.visit(visitor);
}

@Override
public boolean equals(Object obj) {
if (sameClassAs(obj) == false) {
return false;
}
ApproximateIndexOrDocValuesQuery that = (ApproximateIndexOrDocValuesQuery) obj;
return indexOrDocValuesQuery.getIndexQuery().equals(that.indexOrDocValuesQuery.getIndexQuery())
&& indexOrDocValuesQuery.getRandomAccessQuery().equals(that.indexOrDocValuesQuery.getRandomAccessQuery());
}

@Override
public int hashCode() {
return indexOrDocValuesQuery.hashCode();
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
if (approximateIndexQuery.canApproximate(this.getContext())) {
return approximateIndexQuery.createWeight(searcher, scoreMode, boost);
}
return indexOrDocValuesQuery.createWeight(searcher, scoreMode, boost);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.PointRangeQuery;
import org.apache.lucene.search.QueryVisitor;
Expand Down Expand Up @@ -430,21 +429,17 @@ public boolean canApproximate(SearchContext context) {
if (context.aggregations() != null) {
return false;
}
if (!(context.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) context.query()).getIndexQuery() instanceof ApproximateScoreQuery
&& ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery())
.getOriginalQuery() instanceof PointRangeQuery)) {
if (!(context.query() instanceof ApproximateIndexOrDocValuesQuery)) {
return false;
}
ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) context.query()).getIndexQuery());
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
this.setSize(Math.max(context.size(), context.trackTotalHitsUpTo()));
if (context.request() != null && context.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(context.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& primarySortField.getFieldName().equals(((RangeQueryBuilder) context.request().source().query()).fieldName())) {
if (primarySortField.order() == SortOrder.DESC) {
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC);
this.setSortOrder(SortOrder.DESC);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,11 @@

package org.opensearch.search.approximate;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Matches;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Weight;
import org.opensearch.search.internal.SearchContext;

Expand All @@ -28,29 +23,16 @@
*
* This class is heavily inspired by {@link org.apache.lucene.search.IndexOrDocValuesQuery}. It acts as a wrapper that consumer two queries, a regular query and an approximate version of the same. By default, it executes the regular query and returns {@link Weight#scorer} for the original query. At run-time, depending on certain constraints, we can re-write the {@code Weight} to use the approximate weight instead.
*/
public final class ApproximateScoreQuery extends Query {
public class ApproximateScoreQuery extends Query {

private final Query originalQuery;
private final ApproximateableQuery approximationQuery;

private Weight originalQueryWeight, approximationQueryWeight;

private SearchContext context;

public ApproximateScoreQuery(Query originalQuery, ApproximateableQuery approximationQuery) {
this(originalQuery, approximationQuery, null, null);
}

public ApproximateScoreQuery(
Query originalQuery,
ApproximateableQuery approximationQuery,
Weight originalQueryWeight,
Weight approximationQueryWeight
) {
this.originalQuery = originalQuery;
this.approximationQuery = approximationQuery;
this.originalQueryWeight = originalQueryWeight;
this.approximationQueryWeight = approximationQueryWeight;
}

public Query getOriginalQuery() {
Expand All @@ -63,72 +45,20 @@ public ApproximateableQuery getApproximationQuery() {

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
originalQueryWeight = originalQuery.createWeight(searcher, scoreMode, boost);
approximationQueryWeight = approximationQuery.createWeight(searcher, scoreMode, boost);

return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext leafReaderContext, int doc) throws IOException {
return originalQueryWeight.explain(leafReaderContext, doc);
}

@Override
public Matches matches(LeafReaderContext leafReaderContext, int doc) throws IOException {
return originalQueryWeight.matches(leafReaderContext, doc);
}

@Override
public ScorerSupplier scorerSupplier(LeafReaderContext leafReaderContext) throws IOException {
final ScorerSupplier originalQueryScoreSupplier = originalQueryWeight.scorerSupplier(leafReaderContext);
final ScorerSupplier approximationQueryScoreSupplier = approximationQueryWeight.scorerSupplier(leafReaderContext);
if (originalQueryScoreSupplier == null || approximationQueryScoreSupplier == null) {
return null;
}

return new ScorerSupplier() {
@Override
public Scorer get(long l) throws IOException {
if (approximationQuery.canApproximate(context)) {
return approximationQueryScoreSupplier.get(l);
}
return originalQueryScoreSupplier.get(l);
}

@Override
public long cost() {
return originalQueryScoreSupplier.cost();
}
};
}

@Override
public Scorer scorer(LeafReaderContext leafReaderContext) throws IOException {
ScorerSupplier scorerSupplier = scorerSupplier(leafReaderContext);
if (scorerSupplier == null) {
return null;
}
return scorerSupplier.get(Long.MAX_VALUE);
}

@Override
public boolean isCacheable(LeafReaderContext leafReaderContext) {
return originalQueryWeight.isCacheable(leafReaderContext);
}

@Override
public int count(LeafReaderContext leafReaderContext) throws IOException {
if (approximationQuery.canApproximate(context)) {
return approximationQueryWeight.count(leafReaderContext);
}
return originalQueryWeight.count(leafReaderContext);
}
};
if (approximationQuery.canApproximate(context)) {
return approximationQuery.createWeight(searcher, scoreMode, boost);
}
return originalQuery.createWeight(searcher, scoreMode, boost);
}

public void setContext(SearchContext context) {
this.context = context;
};

public SearchContext getContext() {
return context;
};

@Override
public String toString(String s) {
return "ApproximateScoreQuery(originalQuery="
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.lucene.search.ConjunctionUtils;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexOrDocValuesQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
Expand Down Expand Up @@ -220,6 +219,9 @@ public Weight createWeight(Query query, ScoreMode scoreMode, float boost) throws
profiler.pollLastElement();
}
return new ProfileWeight(query, weight, profile);
} else if (query instanceof ApproximateScoreQuery) {
((ApproximateScoreQuery) query).setContext(searchContext);
return query.createWeight(this, scoreMode, boost);
} else {
return super.createWeight(query, scoreMode, boost);
}
Expand Down Expand Up @@ -329,11 +331,6 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto
// catch early terminated exception and rethrow?
Bits liveDocs = ctx.reader().getLiveDocs();
BitSet liveDocsBitSet = getSparseBitSetOrNull(liveDocs);
if (searchContext.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateScoreQuery) {
ApproximateScoreQuery query = ((ApproximateScoreQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery());
query.setContext(searchContext);
}
if (liveDocsBitSet == null) {
BulkScorer bulkScorer = weight.bulkScorer(ctx);
if (bulkScorer != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;

import static org.apache.lucene.document.LongPoint.pack;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class ApproximatePointRangeQueryTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -303,4 +306,11 @@ protected String toString(int dimension, byte[] value) {
}
}
}

public void testCanApproximate() throws IOException {
SearchContext context = mock(SearchContext.class);
ApproximatePointRangeQuery approximatePointRangeQuery = mock(ApproximatePointRangeQuery.class);
when(approximatePointRangeQuery.canApproximate(null)).thenReturn(false);
when(approximatePointRangeQuery.canApproximate(context)).thenReturn(false);
}
}

0 comments on commit b1d5dd6

Please sign in to comment.