Skip to content

Commit

Permalink
PAIR: Add Shuffle (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
juanli16 authored Aug 30, 2024
1 parent 2955e3c commit f600d47
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
11 changes: 11 additions & 0 deletions pkg/pair/pair.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha512"
"errors"
"hash"
mrandv2 "math/rand/v2"

"github.com/gtank/ristretto255"
)
Expand Down Expand Up @@ -119,3 +120,13 @@ func (pk *PrivateKey) Decrypt(ciphertext []byte) ([]byte, error) {

return cipher.MarshalText()
}

// Shuffle shuffles the data in place by using the Fisher-Yates algorithm.
// Note that ideally, it should be called with less than 2^32-1 (4 billion) elements.
func Shuffle(data [][]byte) {
// NOTE: since go 1.20, math.Rand seeds the global random number generator.
// V2 uses ChaCha8 generator as the global one.
mrandv2.Shuffle(len(data), func(i, j int) {
data[i], data[j] = data[j], data[i]
})
}
44 changes: 44 additions & 0 deletions pkg/pair/pair_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package pair

import (
"bytes"
"crypto/rand"
"crypto/sha512"
"slices"
"strings"
"testing"

Expand Down Expand Up @@ -59,3 +61,45 @@ func TestPAIR(t *testing.T) {
t.Fatalf("want: %s, got: %s", string(ciphertext), string(decrypted))
}
}

func genData(n int) [][]byte {
data := make([][]byte, n)
for i := 0; i < n; i++ {
// marshaled ristretto255.Scalar is 44 bytes
data[i] = make([]byte, 44)
rand.Read(data[i])
}
return data
}

func TestShuffle(t *testing.T) {
data := genData(1 << 10) // 1k
orig := make([][]byte, len(data))
copy(orig, data)

// shuffle the data in place
Shuffle(data)

once := make([][]byte, len(data))
copy(once, data)

if slices.EqualFunc(data, orig, bytes.Equal) {
t.Fatalf("data not shuffled")
}

// shuffle again
Shuffle(data)

if slices.EqualFunc(data, once, bytes.Equal) {
t.Fatalf("data not shuffled")
}
}

func BenchmarkShuffleOneMillionIDs(b *testing.B) {
data := genData(1 << 20) // 1m
b.ResetTimer()

for i := 0; i < b.N; i++ {
Shuffle(data)
}
}

0 comments on commit f600d47

Please sign in to comment.