Skip to content

Commit

Permalink
Adds random filter bitset generator
Browse files Browse the repository at this point in the history
  • Loading branch information
shatejas committed Jan 9, 2025
1 parent 7ba420d commit a105827
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
17 changes: 16 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public class KNNSettings {
public static final String QUANTIZATION_STATE_CACHE_EXPIRY_TIME_MINUTES = "knn.quantization.cache.expiry.minutes";
public static final String KNN_FAISS_AVX512_DISABLED = "knn.faiss.avx512.disabled";
public static final String KNN_DISK_VECTOR_SHARD_LEVEL_RESCORING_DISABLED = "index.knn.disk.vector.shard_level_rescoring_disabled";

public static final String KNN_FILTER_PCT = "index.knn.filter.pct";
/**
* Default setting values
*
Expand Down Expand Up @@ -302,6 +302,8 @@ public class KNNSettings {
Dynamic
);

public static final Setting<Double> KNN_FILTER_PCT_SETTING = Setting.doubleSetting(KNN_FILTER_PCT, 0, 0, 100, IndexScope, Dynamic);

public static final Setting<Boolean> KNN_FAISS_AVX2_DISABLED_SETTING = Setting.boolSetting(
KNN_FAISS_AVX2_DISABLED,
KNN_DEFAULT_FAISS_AVX2_DISABLED_VALUE,
Expand Down Expand Up @@ -475,6 +477,10 @@ private Setting<?> getSetting(String key) {
return ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_SETTING;
}

if (KNN_FILTER_PCT.equals(key)) {
return KNN_FILTER_PCT_SETTING;
}

if (KNN_FAISS_AVX2_DISABLED.equals(key)) {
return KNN_FAISS_AVX2_DISABLED_SETTING;
}
Expand Down Expand Up @@ -504,6 +510,7 @@ private Setting<?> getSetting(String key) {

public List<Setting<?>> getSettings() {
List<Setting<?>> settings = Arrays.asList(
KNN_FILTER_PCT_SETTING,
INDEX_KNN_SPACE_TYPE,
INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD_SETTING,
INDEX_KNN_ALGO_PARAM_M_SETTING,
Expand Down Expand Up @@ -577,6 +584,14 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static Double getKnnFilterPercent(final String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
.getSettings()
.getAsDouble(KNN_FILTER_PCT, Double.parseDouble("0.0"));
}

public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
Expand Down
17 changes: 15 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ThreadLocalRandom;
import java.util.stream.Collectors;

import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE;
Expand Down Expand Up @@ -132,9 +133,9 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
StopWatch stopWatch = new StopWatch().start();
KNNTimer.FILTER_SCORER_TIME.start();
final BitSet filterBitSet = getFilteredDocsBitSet(context);
final BitSet filterBitSet = randomBitSetGenerator(context.reader().maxDoc());
KNNTimer.FILTER_SCORER_TIME.stop();
log.debug("Filter Query execution time {} ms", stopWatch.stop().totalTime().millis());
log.info("Filter Query execution time {} ms", stopWatch.stop().totalTime().millis());

final int maxDoc = context.reader().maxDoc();
int cardinality = filterBitSet.cardinality();
Expand Down Expand Up @@ -185,6 +186,18 @@ private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOExcep
return createBitSet(scorer.iterator(), liveDocs, maxDoc);
}

private FixedBitSet randomBitSetGenerator(int maxDoc) {
FixedBitSet bitset = new FixedBitSet(maxDoc);
int percent = (int) Math.floor(KNNSettings.getKnnFilterPercent(knnQuery.getIndexName()));
log.info("Percent {}", percent);
ThreadLocalRandom random = ThreadLocalRandom.current();
int numBitsToSet = percent == 0 ? maxDoc : (int) (maxDoc * (percent / 100.0));
for (int i = 0; i < numBitsToSet; i++) {
bitset.set(random.nextInt(maxDoc));
}
return bitset;
}

private BitSet createBitSet(final DocIdSetIterator filteredDocIdsIterator, final Bits liveDocs, int maxDoc) throws IOException {
if (liveDocs == null && filteredDocIdsIterator instanceof BitSetIterator) {
// If we already have a BitSet and no deletions, reuse the BitSet
Expand Down

0 comments on commit a105827

Please sign in to comment.