diff --git a/pkg/pair/pair.go b/pkg/pair/pair.go index 066fb88..c5c6c12 100644 --- a/pkg/pair/pair.go +++ b/pkg/pair/pair.go @@ -5,6 +5,7 @@ import ( "crypto/sha512" "errors" "hash" + mrandv2 "math/rand/v2" "github.com/gtank/ristretto255" ) @@ -119,3 +120,13 @@ func (pk *PrivateKey) Decrypt(ciphertext []byte) ([]byte, error) { return cipher.MarshalText() } + +// Shuffle shuffles the data in place by using the Fisher-Yates algorithm. +// Note that ideally, it should be called with less than 2^32-1 (4 billion) elements. +func Shuffle(data [][]byte) { + // NOTE: since go 1.20, math.Rand seeds the global random number generator. + // V2 uses ChaCha8 generator as the global one. + mrandv2.Shuffle(len(data), func(i, j int) { + data[i], data[j] = data[j], data[i] + }) +} diff --git a/pkg/pair/pair_test.go b/pkg/pair/pair_test.go index f9f7429..1542e0d 100644 --- a/pkg/pair/pair_test.go +++ b/pkg/pair/pair_test.go @@ -1,8 +1,10 @@ package pair import ( + "bytes" "crypto/rand" "crypto/sha512" + "slices" "strings" "testing" @@ -59,3 +61,45 @@ func TestPAIR(t *testing.T) { t.Fatalf("want: %s, got: %s", string(ciphertext), string(decrypted)) } } + +func genData(n int) [][]byte { + data := make([][]byte, n) + for i := 0; i < n; i++ { + // marshaled ristretto255.Scalar is 44 bytes + data[i] = make([]byte, 44) + rand.Read(data[i]) + } + return data +} + +func TestShuffle(t *testing.T) { + data := genData(1 << 10) // 1k + orig := make([][]byte, len(data)) + copy(orig, data) + + // shuffle the data in place + Shuffle(data) + + once := make([][]byte, len(data)) + copy(once, data) + + if slices.EqualFunc(data, orig, bytes.Equal) { + t.Fatalf("data not shuffled") + } + + // shuffle again + Shuffle(data) + + if slices.EqualFunc(data, once, bytes.Equal) { + t.Fatalf("data not shuffled") + } +} + +func BenchmarkShuffleOneMillionIDs(b *testing.B) { + data := genData(1 << 20) // 1m + b.ResetTimer() + + for i := 0; i < b.N; i++ { + Shuffle(data) + } +}