From b0cbe34bdd3f3e3feb205df25cce1804c0ad13c4 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Fri, 8 Mar 2024 16:04:51 -0800 Subject: [PATCH] Refactor adaptive BF and UT Signed-off-by: Chen Dai --- .../adaptive/AdaptiveBloomFilter.java | 122 +++++++++++------- .../adaptive/AdaptiveBloomFilterTest.java | 46 +++++-- 2 files changed, 112 insertions(+), 56 deletions(-) diff --git a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java index 9274eca64..923f74b6b 100644 --- a/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java +++ b/flint-core/src/main/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilter.java @@ -11,53 +11,90 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; +import java.util.Iterator; import java.util.Objects; +import java.util.function.Function; import org.opensearch.flint.core.field.bloomfilter.BloomFilter; import org.opensearch.flint.core.field.bloomfilter.classic.ClassicBloomFilter; /** - * Adaptive bloom filter implementation that generates a series of bloom filter candidate + * Adaptive BloomFilter implementation that generates a series of bloom filter candidate * with different expected number of item (NDV) and at last choose the best one. */ public class AdaptiveBloomFilter implements BloomFilter { + /** + * Initial expected number of items for the first candidate. + */ + private static final int INITIAL_EXPECTED_NUM_ITEMS = 1024; + /** * Total number of distinct items seen so far. */ - private int total = 0; + private int cardinality = 0; /** - * Bloom filter candidates. + * BloomFilter candidates. */ final BloomFilterCandidate[] candidates; - public AdaptiveBloomFilter(int numCandidate, double fpp) { - this.candidates = new BloomFilterCandidate[numCandidate]; + /** + * Construct adaptive BloomFilter instance with the given algorithm parameters. + * + * @param numCandidates number of candidate + * @param fpp false positive probability + */ + public AdaptiveBloomFilter(int numCandidates, double fpp) { + this.candidates = initializeCandidates(numCandidates, expectedNumItems -> new ClassicBloomFilter(expectedNumItems, fpp)); + } - int expectedNumItems = 1024; - for (int i = 0; i < candidates.length; i++) { - candidates[i] = - new BloomFilterCandidate( - expectedNumItems, - new ClassicBloomFilter(expectedNumItems, fpp)); - expectedNumItems *= 2; - } + /** + * Construct adaptive BloomFilter instance from deserialized content. + * + * @param cardinality total number of distinct items + * @param candidates BloomFilter candidates + */ + AdaptiveBloomFilter(int cardinality, BloomFilter[] candidates) { + this.cardinality = cardinality; + Iterator it = Arrays.stream(candidates).iterator(); + this.candidates = initializeCandidates(candidates.length, expectedNumItems -> it.next()); } - AdaptiveBloomFilter(int total, BloomFilter[] candidates) { - this.total = total; - this.candidates = new BloomFilterCandidate[candidates.length]; + /** + * Deserialize adaptive BloomFilter instance from input stream. + * + * @param numCandidates number of candidates + * @param in input stream of serialized adaptive BloomFilter instance + * @return adaptive BloomFilter instance + */ + public static BloomFilter readFrom(int numCandidates, InputStream in) { + try { + // Read total distinct counter + int cardinality = new DataInputStream(in).readInt(); - int expectedNumItems = 1024; - for (int i = 0; i < candidates.length; i++) { - this.candidates[i] = new BloomFilterCandidate(expectedNumItems, candidates[i]); - expectedNumItems *= 2; + // Read BloomFilter candidate array + BloomFilter[] candidates = new BloomFilter[numCandidates]; + for (int i = 0; i < numCandidates; i++) { + candidates[i] = ClassicBloomFilter.readFrom(in); + } + return new AdaptiveBloomFilter(cardinality, candidates); + } catch (IOException e) { + throw new RuntimeException("Failed to deserialize adaptive BloomFilter", e); } } + /** + * @return best BloomFilter candidate which has expected number of item right above total distinct counter. + */ + public BloomFilterCandidate bestCandidate() { + return candidates[bestCandidateIndex()]; + } + @Override public long bitSize() { - return Arrays.stream(candidates).map(candidate -> candidate.bloomFilter.bitSize()).count(); + return Arrays.stream(candidates) + .mapToLong(c -> c.bloomFilter.bitSize()) + .sum(); } @Override @@ -68,9 +105,9 @@ public boolean put(long item) { bitChanged = candidates[i].bloomFilter.put(item); } - // Use the last candidate's put result which is most accurate + // Use the last candidate's put result which is the most accurate if (bitChanged) { - total++; + cardinality++; } return bitChanged; } @@ -78,12 +115,12 @@ public boolean put(long item) { @Override public BloomFilter merge(BloomFilter other) { AdaptiveBloomFilter otherBf = (AdaptiveBloomFilter) other; - total += otherBf.total; + cardinality += otherBf.cardinality; for (int i = 0; i < candidates.length; i++) { candidates[i].bloomFilter.merge(otherBf.candidates[i].bloomFilter); } - return null; + return this; } @Override @@ -93,31 +130,29 @@ public boolean mightContain(long item) { @Override public void writeTo(OutputStream out) throws IOException { - new DataOutputStream(out).writeInt(total); + // Serialized cardinality counter first + new DataOutputStream(out).writeInt(cardinality); + + // Serialize classic BloomFilter array for (BloomFilterCandidate candidate : candidates) { candidate.bloomFilter.writeTo(out); } } - public static BloomFilter readFrom(int numCandidates, InputStream in) { - try { - int total = new DataInputStream(in).readInt(); - BloomFilter[] candidates = new BloomFilter[numCandidates]; - for (int i = 0; i < numCandidates; i++) { - candidates[i] = ClassicBloomFilter.readFrom(in); - } - return new AdaptiveBloomFilter(total, candidates); - } catch (IOException e) { - throw new RuntimeException(e); - } - } + private BloomFilterCandidate[] initializeCandidates(int numCandidates, + Function initializer) { + BloomFilterCandidate[] candidates = new BloomFilterCandidate[numCandidates]; + int ndv = INITIAL_EXPECTED_NUM_ITEMS; - public BloomFilterCandidate bestCandidate() { - return candidates[bestCandidateIndex()]; + // Initialize candidate with NDV doubled in each iteration + for (int i = 0; i < numCandidates; i++, ndv *= 2) { + candidates[i] = new BloomFilterCandidate(ndv, initializer.apply(ndv)); + } + return candidates; } private int bestCandidateIndex() { - int index = Arrays.binarySearch(candidates, new BloomFilterCandidate(total, null)); + int index = Arrays.binarySearch(candidates, new BloomFilterCandidate(cardinality, null)); if (index < 0) { index = -(index + 1); } @@ -125,7 +160,6 @@ private int bestCandidateIndex() { } public static class BloomFilterCandidate implements Comparable { - int expectedNumItems; BloomFilter bloomFilter; @@ -166,12 +200,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AdaptiveBloomFilter that = (AdaptiveBloomFilter) o; - return total == that.total && Arrays.equals(candidates, that.candidates); + return cardinality == that.cardinality && Arrays.equals(candidates, that.candidates); } @Override public int hashCode() { - int result = Objects.hash(total); + int result = Objects.hash(cardinality); result = 31 * result + Arrays.hashCode(candidates); return result; } diff --git a/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java index 8b7b30377..e22e791b6 100644 --- a/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java +++ b/flint-core/src/test/java/org/opensearch/flint/core/field/bloomfilter/adaptive/AdaptiveBloomFilterTest.java @@ -7,7 +7,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter.*; +import static org.opensearch.flint.core.field.bloomfilter.adaptive.AdaptiveBloomFilter.readFrom; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; @@ -18,32 +18,29 @@ public class AdaptiveBloomFilterTest { - private final int numCandidates = 10; + private final int numCandidates = 5; - private final AdaptiveBloomFilter bloomFilter = new AdaptiveBloomFilter(numCandidates, 0.01); + private final AdaptiveBloomFilter bloomFilter = new AdaptiveBloomFilter(numCandidates, 0.03); @Test public void shouldChooseBestCandidateAdaptively() { - // Insert 500 items + // Insert 500 items should choose 1st candidate for (int i = 0; i < 500; i++) { bloomFilter.put(i); } - BloomFilterCandidate candidate1 = bloomFilter.bestCandidate(); - assertEquals(1024, candidate1.expectedNumItems); + assertEquals(1024, bloomFilter.bestCandidate().expectedNumItems); - // Insert 1000 (total 1500) + // Insert 1000 (total 1500) should choose 2nd candidate for (int i = 500; i < 1500; i++) { bloomFilter.put(i); } - BloomFilterCandidate candidate2 = bloomFilter.bestCandidate(); - assertEquals(2048, candidate2.expectedNumItems); + assertEquals(2048, bloomFilter.bestCandidate().expectedNumItems); - // Insert 4000 (total 5500) + // Insert 4000 (total 5500) should choose 4th candidate for (int i = 1500; i < 5500; i++) { bloomFilter.put(i); } - BloomFilterCandidate candidate3 = bloomFilter.bestCandidate(); - assertEquals(8192, candidate3.expectedNumItems); + assertEquals(8192, bloomFilter.bestCandidate().expectedNumItems); } @Test @@ -52,10 +49,35 @@ public void shouldBeTheSameAfterWriteToAndReadFrom() throws IOException { bloomFilter.put(456L); bloomFilter.put(789L); + // Serialize and deserialize and assert the equality ByteArrayOutputStream out = new ByteArrayOutputStream(); bloomFilter.writeTo(out); InputStream in = new ByteArrayInputStream(out.toByteArray()); BloomFilter newBloomFilter = readFrom(numCandidates, in); assertEquals(bloomFilter, newBloomFilter); } + + @Test + public void shouldMergeTwoFiltersCorrectly() { + AdaptiveBloomFilter bloomFilter2 = new AdaptiveBloomFilter(numCandidates, 0.03); + + // Insert items into the first filter + for (int i = 0; i < 1000; i++) { + bloomFilter.put(i); + } + + // Insert different items into the second filter + for (int i = 1000; i < 2000; i++) { + bloomFilter2.put(i); + } + + // Merge the second filter into the first one + bloomFilter.merge(bloomFilter2); + + // Check if the merged filter contains items from both filters + for (int i = 0; i < 2000; i++) { + assertTrue(bloomFilter.mightContain(i)); + } + assertEquals(2048, bloomFilter.bestCandidate().expectedNumItems); + } } \ No newline at end of file