Skip to content

Commit

Permalink
Change tree walk
Browse files Browse the repository at this point in the history
Signed-off-by: Harsha Vamsi Kalluri <[email protected]>
  • Loading branch information
harshavamsi committed Jul 19, 2024
1 parent 8db8bfb commit a135cca
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

package org.opensearch.search.approximate;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PointValues;
Expand All @@ -25,6 +27,7 @@
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.DocIdSetBuilder;
import org.apache.lucene.util.IntsRef;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -47,13 +50,15 @@ public abstract class ApproximatePointRangeQuery extends Query {

private int size;

private SortOrder sortOrder;

private long[] docCount = { 0 };

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims) {
this(field, lowerPoint, upperPoint, numDims, 10_000);
this(field, lowerPoint, upperPoint, numDims, 10_000, SortOrder.ASC);
}

protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size) {
protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upperPoint, int numDims, int size, SortOrder sortOrder) {
checkArgs(field, lowerPoint, upperPoint);
this.field = field;
if (numDims <= 0) {
Expand All @@ -76,6 +81,7 @@ protected ApproximatePointRangeQuery(String field, byte[] lowerPoint, byte[] upp
this.lowerPoint = lowerPoint;
this.upperPoint = upperPoint;
this.size = size;
this.sortOrder = sortOrder;
}

public int getSize() {
Expand All @@ -86,6 +92,14 @@ public void setSize(int size) {
this.size = size;
}

public SortOrder getSortOrder() {
return this.sortOrder;
}

public void setSortOrder(SortOrder sortOrder) {
this.sortOrder = sortOrder;
}

@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
Expand Down Expand Up @@ -210,12 +224,17 @@ private boolean checkValidPointValues(PointValues values) throws IOException {
return true;
}

private void intersect(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException {
intersect(visitor, pointTree, count);
private void intersectLeft(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException {
intersectLeft(visitor, pointTree, count);
assert pointTree.moveToParent() == false;
}

private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count) throws IOException {
private void intersectRight(PointValues.PointTree pointTree, PointValues.IntersectVisitor visitor, int count) throws IOException {
intersectRight(visitor, pointTree, count);
assert pointTree.moveToParent() == false;
}

private long intersectLeft(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count) throws IOException {
PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
if (docCount[0] >= count) {
return 0;
Expand All @@ -225,16 +244,80 @@ private long intersect(PointValues.IntersectVisitor visitor, PointValues.PointTr
// This cell is fully outside the query shape: stop recursing
break;
case CELL_INSIDE_QUERY:
// This cell is fully inside the query shape: recursively add all points in this cell
// without filtering
pointTree.visitDocIDs(visitor);
return pointTree.size();
// If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when we have sufficient doc count
if (pointTree.moveToChild()) {
do {
docCount[0] += intersectLeft(visitor, pointTree, count);
} while (pointTree.moveToSibling() && docCount[0] <= count);
pointTree.moveToParent();
} else {
// TODO: we can assert that the first value here in fact matches what the pointTree
// claimed?
// Leaf node; scan and filter all points in this block:
if (docCount[0] <= count) {
pointTree.visitDocIDs(visitor);
docCount[0] += pointTree.size();
return docCount[0];
} else break;
}
break;
case CELL_CROSSES_QUERY:
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
// through and do full filtering:
if (pointTree.moveToChild()) {
do {
docCount[0] += intersect(visitor, pointTree, count);
docCount[0] += intersectLeft(visitor, pointTree, count);
} while (pointTree.moveToSibling() && docCount[0] <= count);
pointTree.moveToParent();
} else {
// TODO: we can assert that the first value here in fact matches what the pointTree
// claimed?
// Leaf node; scan and filter all points in this block:
if (docCount[0] <= count) {
pointTree.visitDocValues(visitor);
} else break;
}
break;
default:
throw new IllegalArgumentException("Unreachable code");
}
// docCount can be updated by the local visitor so we ensure that we return docCount after pointTree.visitDocValues(visitor)
return docCount[0] > 0 ? docCount[0] : 0;
}

private long intersectRight(PointValues.IntersectVisitor visitor, PointValues.PointTree pointTree, int count) throws IOException {
PointValues.Relation r = visitor.compare(pointTree.getMinPackedValue(), pointTree.getMaxPackedValue());
if (docCount[0] >= count) {
return 0;
}
switch (r) {
case CELL_OUTSIDE_QUERY:
// This cell is fully outside the query shape: stop recursing
break;
case CELL_INSIDE_QUERY:
// If the cell is fully inside, we keep moving to child until we reach a point where we can no longer move or when we have sufficient doc count
if (pointTree.moveToChild()) {
while (pointTree.moveToSibling() && docCount[0] <= count){
docCount[0] += intersectRight(visitor, pointTree, count);
}
pointTree.moveToParent();
} else {
// TODO: we can assert that the first value here in fact matches what the pointTree
// claimed?
// Leaf node; scan and filter all points in this block:
if (docCount[0] <= count) {
pointTree.visitDocIDs(visitor);
docCount[0] += pointTree.size();
return docCount[0];
} else break;
}
break;
case CELL_CROSSES_QUERY:
// The cell crosses the shape boundary, or the cell fully contains the query, so we fall
// through and do full filtering:
if (pointTree.moveToChild()) {
do {
docCount[0] += intersectRight(visitor, pointTree, count);
} while (pointTree.moveToSibling() && docCount[0] <= count);
pointTree.moveToParent();
} else {
Expand Down Expand Up @@ -312,6 +395,31 @@ public long cost() {
}
};
} else {
if (sortOrder.equals(SortOrder.ASC)) {
return new ScorerSupplier() {

final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
final PointValues.IntersectVisitor visitor = getIntersectVisitor(result);
long cost = -1;

@Override
public Scorer get(long leadCost) throws IOException {
intersectLeft(values.getPointTree(), visitor, size);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
}

@Override
public long cost() {
if (cost == -1) {
// Computing the cost may be expensive, so only do it if necessary
cost = values.estimateDocCount(visitor);
assert cost >= 0;
}
return cost;
}
};
}
return new ScorerSupplier() {

final DocIdSetBuilder result = new DocIdSetBuilder(reader.maxDoc(), values, field);
Expand All @@ -320,7 +428,7 @@ public long cost() {

@Override
public Scorer get(long leadCost) throws IOException {
intersect(values.getPointTree(), visitor, size);
intersectRight(values.getPointTree(), visitor, size);
DocIdSetIterator iterator = result.build().iterator();
return new ConstantScoreScorer(weight, score(), scoreMode, iterator);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.sort.FieldSortBuilder;
import org.opensearch.search.sort.MinAndMax;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -328,6 +329,15 @@ private void searchLeaf(LeafReaderContext ctx, Weight weight, Collector collecto
if (isApproximateableRangeQuery()) {
ApproximateableQuery query = ((ApproximateableQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery());
if (searchContext.size() > 10_000) ((ApproximatePointRangeQuery) query.getApproximationQuery()).setSize(searchContext.size());
if (searchContext.request() != null && searchContext.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(searchContext.request().source());
if (primarySortField != null
&& primarySortField.missing() == null) {
if(primarySortField.order() == SortOrder.DESC){
((ApproximatePointRangeQuery) query.getApproximationQuery()).setSortOrder(SortOrder.DESC);
}
}
}
weight = query.getApproximationQueryWeight();
}
if (liveDocsBitSet == null) {
Expand Down Expand Up @@ -422,23 +432,10 @@ private static BitSet getSparseBitSetOrNull(Bits liveDocs) {
}

private boolean isApproximateableRangeQuery() {
boolean isTopLevelRangeQuery = searchContext.query() instanceof IndexOrDocValuesQuery
return searchContext.query() instanceof IndexOrDocValuesQuery
&& ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery() instanceof ApproximateableQuery
&& ((ApproximateableQuery) ((IndexOrDocValuesQuery) searchContext.query()).getIndexQuery())
.getOriginalQuery() instanceof PointRangeQuery;

boolean hasSort = false;

if (searchContext.request() != null && searchContext.request().source() != null) {
FieldSortBuilder primarySortField = FieldSortBuilder.getPrimaryFieldSortOrNull(searchContext.request().source());
if (primarySortField != null
&& primarySortField.missing() == null
&& Objects.equals(searchContext.trackTotalHitsUpTo(), SearchContext.TRACK_TOTAL_HITS_DISABLED)) {
hasSort = true;
}
}

return isTopLevelRangeQuery && !hasSort;
}

static void intersectScorerAndBitSet(Scorer scorer, BitSet acceptDocs, LeafCollector collector, Runnable checkCancelled)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.opensearch.search.sort.SortOrder;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -100,7 +101,8 @@ public void testApproximateRangeWithSize() throws IOException {
pack(lower).bytes,
pack(upper).bytes,
dims,
10
10,
SortOrder.ASC
) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
Expand All @@ -111,7 +113,8 @@ protected String toString(int dimension, byte[] value) {
pack(lower).bytes,
pack(upper).bytes,
dims,
100
100,
SortOrder.ASC
) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
Expand Down Expand Up @@ -156,7 +159,7 @@ public void testApproximateRangeShortCircuit() throws IOException {
try {
long lower = 0;
long upper = 100;
Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10) {
Query approximateQuery = new ApproximatePointRangeQuery("point", pack(lower).bytes, pack(upper).bytes, dims, 10, SortOrder.ASC) {
protected String toString(int dimension, byte[] value) {
return Long.toString(LongPoint.decodeDimension(value, 0));
}
Expand Down

0 comments on commit a135cca

Please sign in to comment.