Skip to content

Commit

Permalink
Refactor adaptive BF and UT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Mar 9, 2024
1 parent 21d5434 commit b0cbe34
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,53 +11,90 @@
import java.io.InputStream;
import java.io.OutputStream;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import java.util.function.Function;
import org.opensearch.flint.core.field.bloomfilter.BloomFilter;
import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter;

/**
* Adaptive bloom filter implementation that generates a series of bloom filter candidate
* Adaptive BloomFilter implementation that generates a series of bloom filter candidate
* with different expected number of item (NDV) and at last choose the best one.
*/
public class AdaptiveBloomFilter implements BloomFilter {

/**
* Initial expected number of items for the first candidate.
*/
private static final int INITIAL_EXPECTED_NUM_ITEMS = 1024;

/**
* Total number of distinct items seen so far.
*/
private int total = 0;
private int cardinality = 0;

/**
* Bloom filter candidates.
* BloomFilter candidates.
*/
final BloomFilterCandidate[] candidates;

public AdaptiveBloomFilter(int numCandidate, double fpp) {
this.candidates = new BloomFilterCandidate[numCandidate];
/**
* Construct adaptive BloomFilter instance with the given algorithm parameters.
*
* @param numCandidates number of candidate
* @param fpp false positive probability
*/
public AdaptiveBloomFilter(int numCandidates, double fpp) {
this.candidates = initializeCandidates(numCandidates, expectedNumItems -> new ClassicBloomFilter(expectedNumItems, fpp));
}

int expectedNumItems = 1024;
for (int i = 0; i < candidates.length; i++) {
candidates[i] =
new BloomFilterCandidate(
expectedNumItems,
new ClassicBloomFilter(expectedNumItems, fpp));
expectedNumItems *= 2;
}
/**
* Construct adaptive BloomFilter instance from deserialized content.
*
* @param cardinality total number of distinct items
* @param candidates BloomFilter candidates
*/
AdaptiveBloomFilter(int cardinality, BloomFilter[] candidates) {
this.cardinality = cardinality;
Iterator<BloomFilter> it = Arrays.stream(candidates).iterator();
this.candidates = initializeCandidates(candidates.length, expectedNumItems -> it.next());
}

AdaptiveBloomFilter(int total, BloomFilter[] candidates) {
this.total = total;
this.candidates = new BloomFilterCandidate[candidates.length];
/**
* Deserialize adaptive BloomFilter instance from input stream.
*
* @param numCandidates number of candidates
* @param in input stream of serialized adaptive BloomFilter instance
* @return adaptive BloomFilter instance
*/
public static BloomFilter readFrom(int numCandidates, InputStream in) {
try {
// Read total distinct counter
int cardinality = new DataInputStream(in).readInt();

int expectedNumItems = 1024;
for (int i = 0; i < candidates.length; i++) {
this.candidates[i] = new BloomFilterCandidate(expectedNumItems, candidates[i]);
expectedNumItems *= 2;
// Read BloomFilter candidate array
BloomFilter[] candidates = new BloomFilter[numCandidates];
for (int i = 0; i < numCandidates; i++) {
candidates[i] = ClassicBloomFilter.readFrom(in);
}
return new AdaptiveBloomFilter(cardinality, candidates);
} catch (IOException e) {
throw new RuntimeException("Failed to deserialize adaptive BloomFilter", e);
}
}

/**
* @return best BloomFilter candidate which has expected number of item right above total distinct counter.
*/
public BloomFilterCandidate bestCandidate() {
return candidates[bestCandidateIndex()];
}

@Override
public long bitSize() {
return Arrays.stream(candidates).map(candidate -> candidate.bloomFilter.bitSize()).count();
return Arrays.stream(candidates)
.mapToLong(c -> c.bloomFilter.bitSize())
.sum();
}

@Override
Expand All @@ -68,22 +105,22 @@ public boolean put(long item) {
bitChanged = candidates[i].bloomFilter.put(item);
}

// Use the last candidate's put result which is most accurate
// Use the last candidate's put result which is the most accurate
if (bitChanged) {
total++;
cardinality++;
}
return bitChanged;
}

@Override
public BloomFilter merge(BloomFilter other) {
AdaptiveBloomFilter otherBf = (AdaptiveBloomFilter) other;
total += otherBf.total;
cardinality += otherBf.cardinality;

for (int i = 0; i < candidates.length; i++) {
candidates[i].bloomFilter.merge(otherBf.candidates[i].bloomFilter);
}
return null;
return this;
}

@Override
Expand All @@ -93,39 +130,36 @@ public boolean mightContain(long item) {

@Override
public void writeTo(OutputStream out) throws IOException {
new DataOutputStream(out).writeInt(total);
// Serialized cardinality counter first
new DataOutputStream(out).writeInt(cardinality);

// Serialize classic BloomFilter array
for (BloomFilterCandidate candidate : candidates) {
candidate.bloomFilter.writeTo(out);
}
}

public static BloomFilter readFrom(int numCandidates, InputStream in) {
try {
int total = new DataInputStream(in).readInt();
BloomFilter[] candidates = new BloomFilter[numCandidates];
for (int i = 0; i < numCandidates; i++) {
candidates[i] = ClassicBloomFilter.readFrom(in);
}
return new AdaptiveBloomFilter(total, candidates);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
private BloomFilterCandidate[] initializeCandidates(int numCandidates,
Function<Integer, BloomFilter> initializer) {
BloomFilterCandidate[] candidates = new BloomFilterCandidate[numCandidates];
int ndv = INITIAL_EXPECTED_NUM_ITEMS;

public BloomFilterCandidate bestCandidate() {
return candidates[bestCandidateIndex()];
// Initialize candidate with NDV doubled in each iteration
for (int i = 0; i < numCandidates; i++, ndv *= 2) {
candidates[i] = new BloomFilterCandidate(ndv, initializer.apply(ndv));
}
return candidates;
}

private int bestCandidateIndex() {
int index = Arrays.binarySearch(candidates, new BloomFilterCandidate(total, null));
int index = Arrays.binarySearch(candidates, new BloomFilterCandidate(cardinality, null));
if (index < 0) {
index = -(index + 1);
}
return Math.min(index, candidates.length - 1);
}

public static class BloomFilterCandidate implements Comparable<BloomFilterCandidate> {

int expectedNumItems;
BloomFilter bloomFilter;

Expand Down Expand Up @@ -166,12 +200,12 @@ public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
AdaptiveBloomFilter that = (AdaptiveBloomFilter) o;
return total == that.total && Arrays.equals(candidates, that.candidates);
return cardinality == that.cardinality && Arrays.equals(candidates, that.candidates);
}

@Override
public int hashCode() {
int result = Objects.hash(total);
int result = Objects.hash(cardinality);
result = 31 * result + Arrays.hashCode(candidates);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import static org.junit.Assert.assertTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter.*;
import static org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter.readFrom;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
Expand All @@ -18,32 +18,29 @@

public class AdaptiveBloomFilterTest {

private final int numCandidates = 10;
private final int numCandidates = 5;

private final AdaptiveBloomFilter bloomFilter = new AdaptiveBloomFilter(numCandidates, 0.01);
private final AdaptiveBloomFilter bloomFilter = new AdaptiveBloomFilter(numCandidates, 0.03);

@Test
public void shouldChooseBestCandidateAdaptively() {
// Insert 500 items
// Insert 500 items should choose 1st candidate
for (int i = 0; i < 500; i++) {
bloomFilter.put(i);
}
BloomFilterCandidate candidate1 = bloomFilter.bestCandidate();
assertEquals(1024, candidate1.expectedNumItems);
assertEquals(1024, bloomFilter.bestCandidate().expectedNumItems);

// Insert 1000 (total 1500)
// Insert 1000 (total 1500) should choose 2nd candidate
for (int i = 500; i < 1500; i++) {
bloomFilter.put(i);
}
BloomFilterCandidate candidate2 = bloomFilter.bestCandidate();
assertEquals(2048, candidate2.expectedNumItems);
assertEquals(2048, bloomFilter.bestCandidate().expectedNumItems);

// Insert 4000 (total 5500)
// Insert 4000 (total 5500) should choose 4th candidate
for (int i = 1500; i < 5500; i++) {
bloomFilter.put(i);
}
BloomFilterCandidate candidate3 = bloomFilter.bestCandidate();
assertEquals(8192, candidate3.expectedNumItems);
assertEquals(8192, bloomFilter.bestCandidate().expectedNumItems);
}

@Test
Expand All @@ -52,10 +49,35 @@ public void shouldBeTheSameAfterWriteToAndReadFrom() throws IOException {
bloomFilter.put(456L);
bloomFilter.put(789L);

// Serialize and deserialize and assert the equality
ByteArrayOutputStream out = new ByteArrayOutputStream();
bloomFilter.writeTo(out);
InputStream in = new ByteArrayInputStream(out.toByteArray());
BloomFilter newBloomFilter = readFrom(numCandidates, in);
assertEquals(bloomFilter, newBloomFilter);
}

@Test
public void shouldMergeTwoFiltersCorrectly() {
AdaptiveBloomFilter bloomFilter2 = new AdaptiveBloomFilter(numCandidates, 0.03);

// Insert items into the first filter
for (int i = 0; i < 1000; i++) {
bloomFilter.put(i);
}

// Insert different items into the second filter
for (int i = 1000; i < 2000; i++) {
bloomFilter2.put(i);
}

// Merge the second filter into the first one
bloomFilter.merge(bloomFilter2);

// Check if the merged filter contains items from both filters
for (int i = 0; i < 2000; i++) {
assertTrue(bloomFilter.mightContain(i));
}
assertEquals(2048, bloomFilter.bestCandidate().expectedNumItems);
}
}

0 comments on commit b0cbe34

Please sign in to comment.