From 23b9650075da79fd992e23bb1c00a4f6f4ef2098 Mon Sep 17 00:00:00 2001 From: Stephen Buttolph Date: Sat, 23 Dec 2023 14:17:30 -0500 Subject: [PATCH] Unexport fields from gossip.BloomFilter --- network/p2p/gossip/bloom.go | 34 +++++++++++++++++++------------ network/p2p/gossip/bloom_test.go | 15 ++++++++++++-- network/p2p/gossip/handler.go | 6 +++--- network/p2p/gossip/test_gossip.go | 3 +-- vms/avm/network/gossip.go | 3 +-- 5 files changed, 39 insertions(+), 22 deletions(-) diff --git a/network/p2p/gossip/bloom.go b/network/p2p/gossip/bloom.go index c113396bf0b3..bb588b23bb4f 100644 --- a/network/p2p/gossip/bloom.go +++ b/network/p2p/gossip/bloom.go @@ -32,35 +32,43 @@ func NewBloomFilter( salt, err := randomSalt() return &BloomFilter{ - Bloom: bloom, - Salt: salt, + bloom: bloom, + salt: salt, }, err } type BloomFilter struct { - Bloom *bloomfilter.Filter - // Salt is provided to eventually unblock collisions in Bloom. It's possible + bloom *bloomfilter.Filter + // salt is provided to eventually unblock collisions in Bloom. It's possible // that conflicting Gossipable items collide in the bloom filter, so a salt // is generated to eventually resolve collisions. - Salt ids.ID + salt ids.ID } func (b *BloomFilter) Add(gossipable Gossipable) { h := gossipable.GossipID() salted := &hasher{ hash: h[:], - salt: b.Salt, + salt: b.salt, } - b.Bloom.Add(salted) + b.bloom.Add(salted) } func (b *BloomFilter) Has(gossipable Gossipable) bool { h := gossipable.GossipID() salted := &hasher{ hash: h[:], - salt: b.Salt, + salt: b.salt, } - return b.Bloom.Contains(salted) + return b.bloom.Contains(salted) +} + +func (b *BloomFilter) Marshal() ([]byte, []byte, error) { + bloomBytes, err := b.bloom.MarshalBinary() + // salt must be copied here to ensure the bytes aren't overwritten if salt + // is later modified. + salt := b.salt + return bloomBytes, salt[:], err } // ResetBloomFilterIfNeeded resets a bloom filter if it breaches a target false @@ -69,11 +77,11 @@ func ResetBloomFilterIfNeeded( bloomFilter *BloomFilter, falsePositiveProbability float64, ) (bool, error) { - if bloomFilter.Bloom.FalsePosititveProbability() < falsePositiveProbability { + if bloomFilter.bloom.FalsePosititveProbability() < falsePositiveProbability { return false, nil } - newBloom, err := bloomfilter.New(bloomFilter.Bloom.M(), bloomFilter.Bloom.K()) + newBloom, err := bloomfilter.New(bloomFilter.bloom.M(), bloomFilter.bloom.K()) if err != nil { return false, err } @@ -82,8 +90,8 @@ func ResetBloomFilterIfNeeded( return false, err } - bloomFilter.Bloom = newBloom - bloomFilter.Salt = salt + bloomFilter.bloom = newBloom + bloomFilter.salt = salt return true, nil } diff --git a/network/p2p/gossip/bloom_test.go b/network/p2p/gossip/bloom_test.go index 860d2d5e936e..1a05a7eb9bd5 100644 --- a/network/p2p/gossip/bloom_test.go +++ b/network/p2p/gossip/bloom_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/require" + "golang.org/x/exp/slices" + "github.com/ava-labs/avalanchego/ids" ) @@ -49,16 +51,25 @@ func TestBloomFilterRefresh(t *testing.T) { b, err := bloomfilter.New(10, 1) require.NoError(err) bloom := BloomFilter{ - Bloom: b, + bloom: b, } for _, item := range tt.add { + bloomBytes, saltBytes, err := bloom.Marshal() + require.NoError(err) + + initialBloomBytes := slices.Clone(bloomBytes) + initialSaltBytes := slices.Clone(saltBytes) + _, err = ResetBloomFilterIfNeeded(&bloom, tt.falsePositiveProbability) require.NoError(err) bloom.Add(item) + + require.Equal(initialBloomBytes, bloomBytes) + require.Equal(initialSaltBytes, saltBytes) } - require.Equal(uint64(len(tt.expected)), bloom.Bloom.N()) + require.Equal(uint64(len(tt.expected)), bloom.bloom.N()) for _, expected := range tt.expected { require.True(bloom.Has(expected)) diff --git a/network/p2p/gossip/handler.go b/network/p2p/gossip/handler.go index de74d78169cf..0cea0c98ab71 100644 --- a/network/p2p/gossip/handler.go +++ b/network/p2p/gossip/handler.go @@ -71,10 +71,10 @@ func (h Handler[T]) AppRequest(_ context.Context, _ ids.NodeID, _ time.Time, req } filter := &BloomFilter{ - Bloom: &bloomfilter.Filter{}, - Salt: salt, + bloom: &bloomfilter.Filter{}, + salt: salt, } - if err := filter.Bloom.UnmarshalBinary(request.Filter); err != nil { + if err := filter.bloom.UnmarshalBinary(request.Filter); err != nil { return nil, err } diff --git a/network/p2p/gossip/test_gossip.go b/network/p2p/gossip/test_gossip.go index 83021730d444..4603333ba28f 100644 --- a/network/p2p/gossip/test_gossip.go +++ b/network/p2p/gossip/test_gossip.go @@ -65,6 +65,5 @@ func (t *testSet) Iterate(f func(gossipable *testTx) bool) { } func (t *testSet) GetFilter() ([]byte, []byte, error) { - bloom, err := t.bloom.Bloom.MarshalBinary() - return bloom, t.bloom.Salt[:], err + return t.bloom.Marshal() } diff --git a/vms/avm/network/gossip.go b/vms/avm/network/gossip.go index e4e145d830eb..ad5d1589d25d 100644 --- a/vms/avm/network/gossip.go +++ b/vms/avm/network/gossip.go @@ -152,6 +152,5 @@ func (g *gossipMempool) GetFilter() (bloom []byte, salt []byte, err error) { g.lock.RLock() defer g.lock.RUnlock() - bloomBytes, err := g.bloom.Bloom.MarshalBinary() - return bloomBytes, g.bloom.Salt[:], err + return g.bloom.Marshal() }