Skip to content

Commit

Permalink
Forcess ann only search with delete docs filter
Browse files Browse the repository at this point in the history
  • Loading branch information
shatejas committed Jan 7, 2025
1 parent 9fb7a5a commit 7ba420d
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.FixedBitSet;
import org.opensearch.common.StopWatch;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.common.KNNConstants;
Expand All @@ -40,6 +41,7 @@
import org.opensearch.knn.indices.ModelUtil;
import org.opensearch.knn.jni.JNIService;
import org.opensearch.knn.plugin.stats.KNNCounter;
import org.opensearch.knn.plugin.stats.KNNTimer;

import java.io.IOException;
import java.util.Arrays;
Expand Down Expand Up @@ -128,7 +130,12 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
* @return A Map of docId to scores for top k results
*/
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
StopWatch stopWatch = new StopWatch().start();
KNNTimer.FILTER_SCORER_TIME.start();
final BitSet filterBitSet = getFilteredDocsBitSet(context);
KNNTimer.FILTER_SCORER_TIME.stop();
log.debug("Filter Query execution time {} ms", stopWatch.stop().totalTime().millis());

final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
Expand All @@ -137,41 +144,38 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
if (filterWeight != null && cardinality == 0) {
return PerLeafResult.EMPTY_RESULT;
}
/*
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}

/*
* If filters match all docs in this segment, then null should be passed as filterBitSet
* so that it will not do a bitset look up in bottom search layer.
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, cardinality) : null;
Map<Integer, Float> result = doExactSearch(context, docs, cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap);
stopWatch = new StopWatch().start();
KNNTimer.ANN_TIME.start();
Map<Integer, Float> result = doANNSearch(context, annFilter, cardinality, k);
KNNTimer.ANN_TIME.stop();
log.debug("Ann search time {} ms", stopWatch.stop().totalTime().millis());

return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
if (this.filterWeight == null) {
final Bits liveDocs = ctx.reader().getLiveDocs();
if (this.filterWeight == null && liveDocs == null) {
return new FixedBitSet(0);
}

final Bits liveDocs = ctx.reader().getLiveDocs();
final int maxDoc = ctx.reader().maxDoc();
if (filterWeight == null) {
// done only to not do refactor, we can always pass bits
final FixedBitSet fixedBitSet = new FixedBitSet(maxDoc);
for (int index = 0; index < liveDocs.length(); index++) {
if (liveDocs.get(index)) {
fixedBitSet.set(index);
}
}
return fixedBitSet;
}

final Scorer scorer = filterWeight.scorer(ctx);
if (scorer == null) {
Expand Down
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNStats.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.knn.plugin.stats.suppliers.KNNCircuitBreakerSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNCounterSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNInnerCacheStatsSupplier;
import org.opensearch.knn.plugin.stats.suppliers.KNNTimerSupplier;
import org.opensearch.knn.plugin.stats.suppliers.LibraryInitializedSupplier;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexStatusSupplier;
import org.opensearch.knn.plugin.stats.suppliers.ModelIndexingDegradingSupplier;
Expand Down Expand Up @@ -116,6 +117,10 @@ private void addQueryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
StatNames.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS.getName(),
new KNNStat<>(false, new KNNCounterSupplier(KNNCounter.MAX_DISTANCE_QUERY_WITH_FILTER_REQUESTS))
);

for (KNNTimer timer : KNNTimer.values()) {
builder.put(timer.getName(), new KNNStat<>(false, new KNNTimerSupplier(timer)));
}
}

private void addNativeMemoryStats(ImmutableMap.Builder<String, KNNStat<?>> builder) {
Expand Down
39 changes: 39 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/KNNTimer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.stats;

import lombok.Getter;
import lombok.RequiredArgsConstructor;
import org.opensearch.search.profile.Timer;

@Getter
@RequiredArgsConstructor
public enum KNNTimer {

FILTER_WEIGHT_TIME("filter_weight", new Timer()),
FILTER_SCORER_TIME("filter_scorer", new Timer()),
EXACT_SEARCH_TIME("exact_search", new Timer()),
FILTER_ID_SELECTOR_TIME("filter_id_selector", new Timer()),
ANN_TIME("ann_search", new Timer());

private final String name;
private final Timer timer;

public void start() {
timer.start();
}

public void stop() {
timer.stop();
}

public long average() {
if (timer.getCount() > 1) {
return timer.getApproximateTiming() / timer.getCount();
}
return timer.getApproximateTiming();
}
}
5 changes: 5 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/stats/StatNames.java
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ public static Set<String> getNames() {
for (StatNames statName : StatNames.values()) {
names.add(statName.getName());
}

for (KNNTimer knnTimer : KNNTimer.values()) {
names.add(knnTimer.getName());
}

return names;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin.stats.suppliers;

import lombok.RequiredArgsConstructor;
import org.opensearch.knn.plugin.stats.KNNTimer;

import java.util.function.Supplier;

@RequiredArgsConstructor
public class KNNTimerSupplier implements Supplier<Long> {

private final KNNTimer knnTimer;

@Override
public Long get() {
return knnTimer.average();
}
}

0 comments on commit 7ba420d

Please sign in to comment.