diff --git a/README.md b/README.md index c80df33..944f719 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,35 @@ # match [![CircleCI](https://circleci.com/gh/Optable/match/tree/main.svg?style=svg)](https://circleci.com/gh/Optable/match/tree/main) +[![Go Report Card](https://goreportcard.com/badge/github.com/optable/match)](https://goreportcard.com/report/github.com/optable/match) [![GoDoc](https://godoc.org/github.com/optable/match?status.svg)](https://godoc.org/github.com/optable/match) -An open-source set intersection protocols library written in golang. +An open-source set intersection protocols library written in golang. Currently only compatible with **x86-64**. The goal of the match library is to provide production level implementations of various set intersection protocols. Protocols will typically tradeoff security for performance. For example, a private set intersection (PSI) protocol provides cryptographic guarantees to participants concerning their private and non-intersecting data records, and is suitable for scenarios where participants trust each other to be honest in adhering to the protocol, but still want to protect their private data while performing the intersection operation. -The standard match operation under consideration involves a *sender* and a *receiver*. The sender performs an intersection match with a receiver, such that the receiver learns the result of the intersection, and the sender learns nothing. Protocols such as PSI allow the sender and the receiver to protect, to varying degrees of security guarantees and without a trusted third-party, the private data records that are used as inputs in performing the intersection match. +The standard match operation under consideration involves a *sender* and a *receiver*. The sender performs an intersection match with a receiver, such that the receiver learns the result of the intersection, and the sender learns nothing. Protocols such as PSI allow the sender and receiver to protect, to varying degrees of security guarantees and without a trusted third-party, the private data records that are used as inputs in performing the intersection match. The protocols that are currently provided by the match library are listed below, along with an overview of their characteristics. ## dhpsi -Diffie-Hellman based PSI (DH-based PSI) is an implementation of private set intersection. It provides strong protections to participants regarding their non-intersecting data records. See documentation [here](pkg/dhpsi/README.md). +Diffie-Hellman based PSI (DH-based PSI) is an implementation of private set intersection. It provides strong protections to participants regarding their non-intersecting data records. Documentation located [here](pkg/dhpsi/README.md). ## npsi -The naive, [highway hash](https://github.com/google/highwayhash) based PSI: an *insecure* but fast solution for PSI. Documentation located [here](pkg/npsi/README.md). +The naive, [MetroHash](http://www.jandrewrogers.com/2015/05/27/metrohash/) based PSI: an *insecure* but fast solution for PSI. Documentation located [here](pkg/npsi/README.md). ## bpsi -The [bloomfilter](https://en.wikipedia.org/wiki/Bloom_filter) based PSI: an *insecure* but fast with lower communication overhead than [npsi](pkg/npsi/README.md) solution for PSI. Take a look [here](pkg/bpsi/README.md) to consult the documentation. +The [bloomfilter](https://en.wikipedia.org/wiki/Bloom_filter) based PSI: an *insecure* but fast with lower communication overhead than [npsi](pkg/npsi/README.md) solution for PSI. Documentation located [here](pkg/bpsi/README.md). + +## kkrtpsi + +Similar to the dhpsi protocol, the KKRT PSI, also known as the Batched-OPRF PSI, is a semi-honest secure PSI protocol that has significantly less computation cost, but requires more network communication. An extensive description of the protocol is available [here](pkg/kkrtpsi/README.md). ## logging -[logr](https://github.com/go-logr/logr) is used internally for logging, which accepts a `logr.Logger` object. See the [documentation](https://github.com/go-logr/logr#implementations-non-exhaustive) on `logr` for various concrete implementation of logging api. Example implementation of match sender and receiver uses [stdr](https://github.com/go-logr/stdr) which logs to `os.Stderr`. +[logr](https://github.com/go-logr/logr) is used internally for logging, which accepts a `logr.Logger` object. See the [documentation](https://github.com/go-logr/logr#implementations-non-exhaustive) on `logr` for various concrete implementations of logging api. Example implementation of match sender and receiver uses [stdr](https://github.com/go-logr/stdr) which logs to `os.Stderr`. ### pass logger to sender or receiver To pass a logger to a sender or a receiver, create a new context with the parent context and `logr.Logger` object as follows @@ -71,7 +76,7 @@ $go run examples/sender/main.go -proto dhpsi -v 1 # testing -A complete test suite for all PSIs is present [here](test/psi). Don't hesitate to take a look and help us improve the quality of the testing by reporting problems and observations! +A complete test suite for all PSIs is present [here](test/psi). Don't hesitate to take a look and help us improve the quality of the testing by reporting problems and observations! The PSIs have only been tested on **x86-64**. # benchmarks diff --git a/benchmark/KKRT.md b/benchmark/KKRT.md new file mode 100644 index 0000000..7940a60 --- /dev/null +++ b/benchmark/KKRT.md @@ -0,0 +1,32 @@ +# KKRT Benchmarks + +## Runtime with varying system threads +This heatmap compares runtimes when the sender and receiver have been limited to a set number of system threads (on an n2-standard-64 VM). Both sender and receiver have 100m (million) records with an intersection size of 50m. The receiver's datasets are represented row-wise while the sender's datasets are represented column-wise. + +
+ +
+ +As shown in the above performance results where the number of system threads is increased, there is an up to 15% improvement in performance from the sender’s perspective, but very little effect on the receiver. Additionally, as the number of system threads is increased beyond approximately 8, there is a slight *degradation* in performance. Since KKRT does not benefit much from multi-thread parallelism, we recommend sizing your hardware primarily according to the memory requirements (see below). + +## Memory +These heatmaps compare memory usage when sender and receiver use the same type of VM (n2-standard-64) but have differing number of records (50m, 100m, 200m, 300m, 400m and 500m). The receiver's datasets are represented row-wise while the sender's datasets are represented column-wise. All match attempts performed have an intersection size of 50m. + ++ +
+ ++ +
+ +## GC calls +These heatmaps compare number of garbage collector calls when sender and receiver use the same type of VM (n2-standard-64) but have differing number of records (50m, 100m, 200m, 300m, 400m and 500m). The receiver's datasets are represented row-wise while the sender's datasets are represented column-wise. All match attempts performed have an intersection size of 50m. + ++ +
+ ++ +
\ No newline at end of file diff --git a/benchmark/README.md b/benchmark/README.md index a41314a..1e58f7c 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,16 +1,12 @@ # Benchmarks -The following scatter plots show the results of benchmarking match attempts using different PSI algorithms on Google Cloud n2-standard-64 [general-purpose virtual machines (VMs)](https://cloud.google.com/compute/docs/general-purpose-machines#n2-standard). For each benchmark, the sender and the receiver use same type of VMs. In the first plot, the receiver has 100m records while the sender has varying datasets of 50m, 100m, 200m, 300m, 400m and 500m records where as in the second plot, the sender's dataset is of 100m records while the receiver has varying datasets of 50m, 100m, 200m, 300m, 400m and 500m records. The BPSI used for these experiments has a false positive rate fixed at 1e-6. Also, all the match attempts performed have an intersection size of 50m. +The following scatter plot shows the results of benchmarking match attempts using different PSI algorithms on Google Cloud n2-standard-64 [general-purpose virtual machines (VMs)](https://cloud.google.com/compute/docs/general-purpose-machines#n2_machines). For each benchmark, the sender and the receiver use the same type of VM. The plot shows runtime for various PSI algorithms when the sender and receiver have an equal number of records. The BPSI used for these experiments has a false positive rate fixed at 1e-6. All the match attempts performed have an intersection size of 50m (million). [Detailed benchmarks of the KKRT protocol can be found here](KKRT.md).- +
-- -
- -The results for match attempts using different PSI algorithms are provided below. Both sender and receiver used n2-standard-64 VMs with datasets containing 50m, 100m, 200m, 300m, 400m and 500m records. The receiver's datasets are represented row-wise while the sender's datasets are represented column-wise. +The runtimes for match attempts using different PSI algorithms are provided below. Both sender and receiver used n2-standard-64 VMs with datasets containing 50m, 100m, 200m, 300m, 400m and 500m records. The receiver's datasets are represented row-wise while the sender's datasets are represented column-wise.@@ -21,5 +17,9 @@ The results for match attempts using different PSI algorithms are provided below
- +
+ ++ +
\ No newline at end of file diff --git a/benchmark/heatmap_bpsi.png b/benchmark/heatmap_bpsi.png index 7b352e7..5e2f1d2 100644 Binary files a/benchmark/heatmap_bpsi.png and b/benchmark/heatmap_bpsi.png differ diff --git a/benchmark/heatmap_dhpsi.png b/benchmark/heatmap_dhpsi.png index 7df8bf6..b976ae6 100644 Binary files a/benchmark/heatmap_dhpsi.png and b/benchmark/heatmap_dhpsi.png differ diff --git a/benchmark/heatmap_kkrt.png b/benchmark/heatmap_kkrt.png new file mode 100644 index 0000000..2cc4d4c Binary files /dev/null and b/benchmark/heatmap_kkrt.png differ diff --git a/benchmark/heatmap_kkrt_procs.png b/benchmark/heatmap_kkrt_procs.png new file mode 100644 index 0000000..689fe0a Binary files /dev/null and b/benchmark/heatmap_kkrt_procs.png differ diff --git a/benchmark/heatmap_kkrt_rec_gc.png b/benchmark/heatmap_kkrt_rec_gc.png new file mode 100644 index 0000000..c9a66bd Binary files /dev/null and b/benchmark/heatmap_kkrt_rec_gc.png differ diff --git a/benchmark/heatmap_kkrt_rec_mem.png b/benchmark/heatmap_kkrt_rec_mem.png new file mode 100644 index 0000000..de78615 Binary files /dev/null and b/benchmark/heatmap_kkrt_rec_mem.png differ diff --git a/benchmark/heatmap_kkrt_sen_gc.png b/benchmark/heatmap_kkrt_sen_gc.png new file mode 100644 index 0000000..43fcda2 Binary files /dev/null and b/benchmark/heatmap_kkrt_sen_gc.png differ diff --git a/benchmark/heatmap_kkrt_sen_mem.png b/benchmark/heatmap_kkrt_sen_mem.png new file mode 100644 index 0000000..a5e05ee Binary files /dev/null and b/benchmark/heatmap_kkrt_sen_mem.png differ diff --git a/benchmark/heatmap_npsi.png b/benchmark/heatmap_npsi.png index 10c5c61..7988e81 100644 Binary files a/benchmark/heatmap_npsi.png and b/benchmark/heatmap_npsi.png differ diff --git a/benchmark/scatter_equal_sets.png b/benchmark/scatter_equal_sets.png new file mode 100644 index 0000000..1371b8c Binary files /dev/null and b/benchmark/scatter_equal_sets.png differ diff --git a/benchmark/scatter_fixed_receiver.png b/benchmark/scatter_fixed_receiver.png deleted file mode 100644 index e0574ea..0000000 Binary files a/benchmark/scatter_fixed_receiver.png and /dev/null differ diff --git a/benchmark/scatter_plot_sender_fixed.png b/benchmark/scatter_plot_sender_fixed.png deleted file mode 100644 index 253167b..0000000 Binary files a/benchmark/scatter_plot_sender_fixed.png and /dev/null differ diff --git a/examples/README.md b/examples/README.md index 508607a..dff9a83 100644 --- a/examples/README.md +++ b/examples/README.md @@ -2,12 +2,12 @@ The standard match operation involves a *sender* and a *receiver*. The sender performs an intersection match with a receiver, such that the receiver learns the result of the intersection, and the sender learns nothing. Protocols such as PSI allow the sender and receiver to protect, to varying degrees of security guarantees and without a trusted third-party, private data records that are used as inputs in performing the intersection match. -The examples support dhpsi, npsi and bpsi: the protocol can be selected with the *-proto* argument. Note that *npsi* is the default. +The examples support kkrt, dhpsi, npsi and bpsi: the protocol can be selected with the *-proto* argument. Note that *npsi* is the default. ## 1. generate some data `go run generate.go` -This will create two files, `sender-ids.txt` and `receiver-ids.txt` with 100 *IDs* in common between them. You can confirm the communality by running: +This will create two files, `sender-ids.txt` and `receiver-ids.txt` with 100 *IDs* in common between them. You can confirm the commonality by running: `comm -12 <(sort sender-ids.txt) <(sort receiver-ids.txt) | wc -l` diff --git a/examples/format/format.go b/examples/format/format.go new file mode 100644 index 0000000..d5daf92 --- /dev/null +++ b/examples/format/format.go @@ -0,0 +1,49 @@ +package format + +import ( + "math" + "os" + "runtime" + + "github.com/go-logr/logr" + "github.com/go-logr/stdr" +) + +// GetLogger returns a stdr.Logger that implements the logr.Logger interface +// and sets the verbosity of the returned logger. +// set v to 0 for info level messages, +// 1 for debug messages and 2 for trace level message. +// any other verbosity level will default to 0. +func GetLogger(v int) logr.Logger { + logger := stdr.New(nil) + // bound check + if v > 2 || v < 0 { + v = 0 + logger.Info("Invalid verbosity, setting logger to display info level messages only.") + } + stdr.SetVerbosity(v) + + return logger +} + +// ShowUsageAndExit displays the usage message to stdout and exit +func ShowUsageAndExit(usage func(), exitcode int) { + usage() + os.Exit(exitcode) +} + +// MemUsageToStdErr logs the total PSI memory usage, and garbage collector calls +func MemUsageToStdErr(logger logr.Logger) { + var m runtime.MemStats + runtime.ReadMemStats(&m) // https://cs.opensource.google/go/go/+/go1.17.1:src/runtime/mstats.go;l=107 + logger.V(1).Info("Final stats", "total memory (GiB)", math.Round(float64(m.Sys)*100/(1024*1024*1024))/100) + logger.V(1).Info("Final stats", "garbage collector calls", m.NumGC) +} + +// ExitOnErr logs the error and exit if error is not nil +func ExitOnErr(logger logr.Logger, err error, msg string) { + if err != nil { + logger.Error(err, msg) + os.Exit(1) + } +} diff --git a/examples/generate.go b/examples/generate.go index 4283eb4..9a40cf3 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -21,7 +21,7 @@ func main() { var ws sync.WaitGroup fmt.Printf("generating %d sender(s) and %d receiver(s) IDs with %d in common\r\n", senderCardinality, receiverCardinality, commonCardinality) // make the common part - common := emails.Common(commonCardinality) + common := emails.Common(commonCardinality, emails.HashLen) // do advertisers & publishers in parallel ws.Add(2) go output(senderFileName, common, senderCardinality-commonCardinality, &ws) @@ -34,7 +34,7 @@ func output(filename string, common []byte, n int, ws *sync.WaitGroup) { if f, err := os.Create(filename); err == nil { defer f.Close() // exhaust out - for matchable := range emails.Mix(common, n) { + for matchable := range emails.Mix(common, n, emails.HashLen) { // add \n out := append(matchable, "\n"...) // and write it diff --git a/examples/receiver/main.go b/examples/receiver/main.go index 5e86065..d9cc523 100644 --- a/examples/receiver/main.go +++ b/examples/receiver/main.go @@ -11,7 +11,7 @@ import ( "time" "github.com/go-logr/logr" - "github.com/go-logr/stdr" + "github.com/optable/match/examples/format" "github.com/optable/match/internal/util" "github.com/optable/match/pkg/psi" ) @@ -28,40 +28,11 @@ func usage() { flag.PrintDefaults() } -func showUsageAndExit(exitcode int) { - usage() - os.Exit(exitcode) -} - -func exitOnErr(logger logr.Logger, err error, msg string) { - if err != nil { - logger.Error(err, msg) - os.Exit(1) - } -} - -// getLogger returns a stdr.Logger that implements the logr.Logger interface -// and sets the verbosity of the returned logger. -// set v to 0 for info level messages, -// 1 for debug messages and 2 for trace level message. -// any other verbosity level will default to 0. -func getLogger(v int) logr.Logger { - logger := stdr.New(nil) - // bound check - if v > 2 || v < 0 { - v = 0 - logger.Info("Invalid verbosity, setting logger to display info level messages only.") - } - stdr.SetVerbosity(v) - - return logger -} - var out *string func main() { var wg sync.WaitGroup - var protocol = flag.String("proto", defaultProtocol, "the psi protocol (dhpsi,npsi)") + var protocol = flag.String("proto", defaultProtocol, "the psi protocol (bpsi,npsi,dhpsi,kkrt)") var port = flag.String("p", defaultPort, "The receiver port") var file = flag.String("in", defaultSenderFileName, "A list of IDs terminated with a newline") out = flag.String("out", defaultCommonFileName, "A list of IDs that intersect between the receiver and the sender") @@ -74,7 +45,7 @@ func main() { flag.Parse() if *showHelp { - showUsageAndExit(0) + format.ShowUsageAndExit(usage, 0) } // validate protocol @@ -86,48 +57,49 @@ func main() { psiType = psi.ProtocolNPSI case "dhpsi": psiType = psi.ProtocolDHPSI + case "kkrt": + psiType = psi.ProtocolKKRTPSI default: psiType = psi.ProtocolUnsupported } log.Printf("operating with protocol %s", psiType) // fetch stdr logger - mlog := getLogger(*verbose) + mlog := format.GetLogger(*verbose) // open file f, err := os.Open(*file) - exitOnErr(mlog, err, "failed to open file") + format.ExitOnErr(mlog, err, "failed to open file") defer f.Close() // count lines log.Printf("counting lines in %s", *file) t := time.Now() n, err := util.Count(f) - exitOnErr(mlog, err, "failed to count") + format.ExitOnErr(mlog, err, "failed to count") log.Printf("that took %v", time.Since(t)) log.Printf("operating on %s with %d IDs", *file, n) // get a listener l, err := net.Listen("tcp", *port) - exitOnErr(mlog, err, "failed to listen on tcp port") + format.ExitOnErr(mlog, err, "failed to listen on tcp port") log.Printf("receiver listening on %s", *port) for { if c, err := l.Accept(); err != nil { - exitOnErr(mlog, err, "failed to accept incoming connection") + format.ExitOnErr(mlog, err, "failed to accept incoming connection") } else { log.Printf("handling sender %s", c.RemoteAddr()) f, err := os.Open(*file) - exitOnErr(mlog, err, "failed to open file") + format.ExitOnErr(mlog, err, "failed to open file") // enable nagle switch v := c.(type) { - // enable nagle case *net.TCPConn: v.SetNoDelay(false) } - // make the receiver + // make the receiver receiver, err := psi.NewReceiver(psiType, c) - exitOnErr(mlog, err, "failed to create receiver") + format.ExitOnErr(mlog, err, "failed to create receiver") // and hand it off wg.Add(1) go func() { @@ -148,10 +120,12 @@ func main() { func handle(r psi.Receiver, n int64, f io.ReadCloser, ctx context.Context) { defer f.Close() ids := util.Exhaust(n, f) - logger, _ := logr.FromContext(ctx) + logger := logr.FromContextOrDiscard(ctx) if i, err := r.Intersect(ctx, n, ids); err != nil { - log.Printf("intersect failed (%d): %v", len(i), err) + format.ExitOnErr(logger, err, "intersect failed") } else { + // write memory usage to stderr + format.MemUsageToStdErr(logger) // write out to common-ids.txt log.Printf("intersected %d IDs, writing out to %s", len(i), *out) if f, err := os.Create(*out); err == nil { @@ -159,11 +133,11 @@ func handle(r psi.Receiver, n int64, f io.ReadCloser, ctx context.Context) { for _, id := range i { // and write it if _, err := f.Write(append(id, "\n"...)); err != nil { - exitOnErr(logger, err, "failed to write intersected ID to file") + format.ExitOnErr(logger, err, "failed to write intersected ID to file") } } } else { - exitOnErr(logger, err, "failed to perform PSI") + format.ExitOnErr(logger, err, "failed to perform PSI") } } } diff --git a/examples/sender/main.go b/examples/sender/main.go index 13ddbd2..8fb6d76 100644 --- a/examples/sender/main.go +++ b/examples/sender/main.go @@ -9,7 +9,7 @@ import ( "os" "github.com/go-logr/logr" - "github.com/go-logr/stdr" + "github.com/optable/match/examples/format" "github.com/optable/match/internal/util" "github.com/optable/match/pkg/psi" ) @@ -25,37 +25,8 @@ func usage() { flag.PrintDefaults() } -func showUsageAndExit(exitcode int) { - usage() - os.Exit(exitcode) -} - -func exitOnErr(logger logr.Logger, err error, msg string) { - if err != nil { - logger.Error(err, msg) - os.Exit(1) - } -} - -// getLogger returns a stdr.Logger that implements the logr.Logger interface -// and sets the verbosity of the returned logger. -// set v to 0 for info level messages, -// 1 for debug messages and 2 for trace level message. -// any other verbosity level will default to 0. -func getLogger(v int) logr.Logger { - logger := stdr.New(nil) - // bound check - if v > 2 || v < 0 { - v = 0 - logger.Info("Invalid verbosity, setting logger to display info level messages only.") - } - stdr.SetVerbosity(v) - - return logger -} - func main() { - var protocol = flag.String("proto", defaultProtocol, "the psi protocol (dhpsi,npsi)") + var protocol = flag.String("proto", defaultProtocol, "the psi protocol (bpsi,npsi,dhpsi,kkrt)") var addr = flag.String("a", defaultAddress, "The receiver address") var file = flag.String("in", defaultSenderFileName, "A list of IDs terminated with a newline") var verbose = flag.Int("v", 0, "Verbosity level, default to -v 0 for info level messages, -v 1 for debug messages, and -v 2 for trace level message.") @@ -66,7 +37,7 @@ func main() { flag.Parse() if *showHelp { - showUsageAndExit(0) + format.ShowUsageAndExit(usage, 0) } // validate protocol @@ -78,29 +49,31 @@ func main() { psiType = psi.ProtocolNPSI case "dhpsi": psiType = psi.ProtocolDHPSI + case "kkrt": + psiType = psi.ProtocolKKRTPSI default: psiType = psi.ProtocolUnsupported } log.Printf("operating with protocol %s", psiType) // fetch stdr logger - slog := getLogger(*verbose) + slog := format.GetLogger(*verbose) // open file f, err := os.Open(*file) - exitOnErr(slog, err, "failed to open file") + format.ExitOnErr(slog, err, "failed to open file") // count lines log.Printf("counting lines in %s", *file) n, err := util.Count(f) - exitOnErr(slog, err, "failed to count") + format.ExitOnErr(slog, err, "failed to count") log.Printf("operating on %s with %d IDs", *file, n) // rewind f.Seek(0, io.SeekStart) c, err := net.Dial("tcp", *addr) - exitOnErr(slog, err, "failed to dial") + format.ExitOnErr(slog, err, "failed to dial") defer c.Close() // enable nagle switch v := c.(type) { @@ -109,8 +82,9 @@ func main() { } s, err := psi.NewSender(psiType, c) - exitOnErr(slog, err, "failed to create sender") + format.ExitOnErr(slog, err, "failed to create sender") ids := util.Exhaust(n, f) err = s.Send(logr.NewContext(context.Background(), slog), n, ids) - exitOnErr(slog, err, "failed to perform PSI") + format.ExitOnErr(slog, err, "failed to perform PSI") + format.MemUsageToStdErr(slog) } diff --git a/go.mod b/go.mod index 2e020d0..54f6501 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,15 @@ module github.com/optable/match go 1.16 require ( + github.com/alecthomas/unsafeslice v0.1.0 github.com/bits-and-blooms/bloom/v3 v3.0.1 github.com/bwesterb/go-ristretto v1.2.0 - github.com/cespare/xxhash v1.1.0 - github.com/dchest/siphash v1.2.2 - github.com/dgryski/go-highway v0.0.0-20210309212254-61406496927c + github.com/dgryski/go-metro v0.0.0-20211015221634-2661b20a2446 github.com/go-logr/logr v1.2.0 github.com/go-logr/stdr v1.2.0 github.com/gtank/ristretto255 v0.1.2 - github.com/intel-go/cpuid v0.0.0-20210602155658-5747e5cec0d9 // indirect - github.com/spaolacci/murmur3 v1.1.0 + github.com/twmb/murmur3 v1.1.6 + github.com/zeebo/blake3 v0.2.0 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c + golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf // indirect ) diff --git a/go.sum b/go.sum index aee3474..f471fbf 100644 --- a/go.sum +++ b/go.sum @@ -1,25 +1,41 @@ -github.com/OneOfOne/xxhash v1.2.2 h1:KMrpdQIwFcEqXDklaen+P1axHaj9BSKzvpUUfnHldSE= -github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= +github.com/alecthomas/unsafeslice v0.1.0 h1:bZlgt0CcjLz1OxjIS//2B5MTfifcGPKmkzvQ7OarOvg= +github.com/alecthomas/unsafeslice v0.1.0/go.mod h1:H7s9N0gAbfiwu02rQEexZbN/YMxm+2l3rVRa/zE2DM8= github.com/bits-and-blooms/bitset v1.2.0 h1:Kn4yilvwNtMACtf1eYDlG8H77R07mZSPbMjLyS07ChA= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= github.com/bits-and-blooms/bloom/v3 v3.0.1 h1:Inlf0YXbgehxVjMPmCGv86iMCKMGPPrPSHtBF5yRHwA= github.com/bits-and-blooms/bloom/v3 v3.0.1/go.mod h1:MC8muvBzzPOFsrcdND/A7kU7kMhkqb9KI70JlZCP+C8= github.com/bwesterb/go-ristretto v1.2.0 h1:xxWOVbN5m8NNKiSDZXE1jtZvZnC6JSJ9cYFADiZcWtw= github.com/bwesterb/go-ristretto v1.2.0/go.mod h1:fUIoIZaG73pV5biE2Blr2xEzDoMj7NFEuV9ekS419A0= -github.com/cespare/xxhash v1.1.0 h1:a6HrQnmkObjyL+Gs60czilIUGqrzKutQD6XZog3p+ko= -github.com/cespare/xxhash v1.1.0/go.mod h1:XrSqR1VqqWfGrhpAt58auRo0WTKS1nRRg3ghfAqPWnc= -github.com/dchest/siphash v1.2.2 h1:9DFz8tQwl9pTVt5iok/9zKyzA1Q6bRGiF3HPiEEVr9I= -github.com/dchest/siphash v1.2.2/go.mod h1:q+IRvb2gOSrUnYoPqHiyHXS0FOBBOdl6tONBlVnOnt4= -github.com/dgryski/go-highway v0.0.0-20210309212254-61406496927c h1:BU5wmdaJQ45rcu5N8BLEBnLc6F8cjg25rRCQqCID6ns= -github.com/dgryski/go-highway v0.0.0-20210309212254-61406496927c/go.mod h1:k3huN2j0rapqSEtGr9Yy2/RnTtn4jTiufYz6j6Kv+wM= +github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-metro v0.0.0-20211015221634-2661b20a2446 h1:QnWGyQI3H080vbC9E4jlr6scOYEnALtvV/69oATYzOo= +github.com/dgryski/go-metro v0.0.0-20211015221634-2661b20a2446/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= github.com/go-logr/logr v1.2.0 h1:QK40JKJyMdUDz+h+xvCsru/bJhvG0UxvePV0ufL/AcE= github.com/go-logr/logr v1.2.0/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/stdr v1.2.0 h1:j4LrlVXgrbIWO83mmQUnK0Hi+YnbD+vzrE1z/EphbFE= github.com/go-logr/stdr v1.2.0/go.mod h1:YkVgnZu1ZjjL7xTxrfm/LLZBfkhTqSR1ydtm6jTKKwI= github.com/gtank/ristretto255 v0.1.2 h1:JEqUCPA1NvLq5DwYtuzigd7ss8fwbYay9fi4/5uMzcc= github.com/gtank/ristretto255 v0.1.2/go.mod h1:Ph5OpO6c7xKUGROZfWVLiJf9icMDwUeIvY4OmlYW69o= -github.com/intel-go/cpuid v0.0.0-20210602155658-5747e5cec0d9 h1:x9HFDMDCsaxTvC4X3o0ZN6mw99dT/wYnTItGwhBRmg0= -github.com/intel-go/cpuid v0.0.0-20210602155658-5747e5cec0d9/go.mod h1:RmeVYf9XrPRbRc3XIx0gLYA8qOFvNoPOfaEZduRlEp4= -github.com/spaolacci/murmur3 v0.0.0-20180118202830-f09979ecbc72/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/twmb/murmur3 v1.1.6 h1:mqrRot1BRxm+Yct+vavLMou2/iJt0tNVTTC0QoIjaZg= +github.com/twmb/murmur3 v1.1.6/go.mod h1:Qq/R7NUyOfr65zD+6Q5IHKsJLwP7exErjN6lyyq3OSQ= +github.com/zeebo/assert v1.1.0 h1:hU1L1vLTHsnO8x8c9KAR5GmM5QscxHg5RNU5z5qbUWY= +github.com/zeebo/assert v1.1.0/go.mod h1:Pq9JiuJQpG8JLJdtkwrJESF0Foym2/D9XMU5ciN/wJ0= +github.com/zeebo/blake3 v0.2.0 h1:1SGx3IvKWFUU/xl+/7kjdcjjMcvVSm+3dMo/N42afC8= +github.com/zeebo/blake3 v0.2.0/go.mod h1:G9pM4qQwjRzF1/v7+vabMj/c5mWpGZ2Wzo3Eb4z0pb4= +github.com/zeebo/pcg v1.0.0 h1:dt+dx+HvX8g7Un32rY9XWoYnd0NmKmrIzpHF7qiTDj0= +github.com/zeebo/pcg v1.0.0/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20201014080544-cc95f250f6bc/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf h1:2ucpDCmfkl8Bd/FsLtiD653Wf96cW37s+iGx93zsu4k= +golang.org/x/sys v0.0.0-20210823070655-63515b42dcdf/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/internal/crypto/cipher.go b/internal/crypto/cipher.go new file mode 100644 index 0000000..ae6a1d2 --- /dev/null +++ b/internal/crypto/cipher.go @@ -0,0 +1,68 @@ +package crypto + +import ( + "crypto/aes" + "crypto/cipher" + + "github.com/alecthomas/unsafeslice" + "github.com/optable/match/internal/util" + "github.com/twmb/murmur3" + "github.com/zeebo/blake3" +) + +// PseudorandomCode is implemented as follows: +// C(x) = AES(1||h(x)[:15]) || +// AES(2||h(x)[:15]) || +// AES(3||h(x)[:15]) || +// AES(4||h(x)[:15]) +// where h() is the Murmur3 hashing function. +// PseudorandomCode is passed the src as well as the associated hash +// index. It also requires an AES block cipher. +// The full pseudorandom code consists of four 16 byte encrypted AES +// blocks that are encoded into a slice of 64 bytes. The hash function is +// constructed with the hash index as its two seeds. It is fed the full +// ID source. It returns two uint64s which are cast to a slice of bytes. +// The output is shifted right to allow prepending of the block index. +// For each block, the prepended value is changed to indicate the block +// index (1, 2, 3, 4) before being used as the source for the AES encode. +func PseudorandomCode(aesBlock cipher.Block, src []byte, hIdx byte) []byte { + // prepare destination + dst := make([]byte, aes.BlockSize*4) + + // hash id and the hash index + lo, hi := murmur3.SeedSum128(uint64(hIdx), uint64(hIdx), src) + + // store in scratch slice + s := unsafeslice.ByteSliceFromUint64Slice([]uint64{lo, hi}) + copy(s[1:], s) // shift for prepending + + // encrypt + s[0] = 1 + aesBlock.Encrypt(dst[:aes.BlockSize], s) + s[0] = 2 + aesBlock.Encrypt(dst[aes.BlockSize:aes.BlockSize*2], s) + s[0] = 3 + aesBlock.Encrypt(dst[aes.BlockSize*2:aes.BlockSize*3], s) + s[0] = 4 + aesBlock.Encrypt(dst[aes.BlockSize*3:], s) + return dst +} + +// XorCipherWithBlake3 uses the output of Blake3 XOF as pseudorandom +// bytes to perform a XOR cipher. +func XorCipherWithBlake3(key []byte, ind byte, src []byte) []byte { + hash := make([]byte, len(src)) + getBlake3Hash(key, ind, hash) + util.ConcurrentBitOp(util.Xor, hash, src) + return hash +} + +func getBlake3Hash(key []byte, ind byte, dst []byte) { + h := blake3.New() + h.Write(key) + h.Write([]byte{ind}) + + // convert to *digest to take a snapshot of the hashstate for XOF + d := h.Digest() + d.Read(dst) +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 0000000..e1d3e0c --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,224 @@ +package crypto + +import ( + "crypto/aes" + "crypto/rand" + "encoding/hex" + "fmt" + "testing" +) + +type pseudorandomCodeTest struct { + out string // hex + in string + aesKey string // hex + hIdx byte +} + +var aesTestStrings = []pseudorandomCodeTest{ + {"7d040a001a48af28aa3a1837ed04864935f4b73a9a54ad7c0decd361f2f6ee30590c4a43c61873cc39c45be00ac71d31b0cab6d39e167971622aa3ed41c8b406", "", "6680dc641356cbdb590c370f747d4e9f", 0}, + {"7c9c0421edbb2cb9583fe0af30a82b21c9424a31af3d13c092edb0fefbd9e2a23cb6969021c2e42f927c52d479957a18cf98b6c041d2a3620e4a690ad401cc2b", "", "6680dc641356cbdb590c370f747d4e9f", 1}, + {"43757caabedfe32a7d1aad2a1966633e11fde0067d2666a96d2a65fb8e8cb62e167c2c3372778ac3c3ebea4c5c495b5bd30f3bc53e3d26c314cc04d8d18fdf26", "", "6680dc641356cbdb590c370f747d4e9f", 2}, + {"04aaa630968b3e58bd381208856009de37a1a629b99f68007a1b060f8439d2e19b4c858dd1be11cced3c9a82bfb135b5f24d96da6c0aeb141d862f164b631c25", "", "6680dc641356cbdb590c370f747d4e9f", 3}, + {"b6d67803e1154d2afcca1906b9e72f848c454ab1345d098d78d798aae8f6eedf85b2f20fc28e9f590e3ad5dce9a49f39ed5dba3f1166023c3fc25bb090527ec7", "", "5f2d0b92a398c22fb9816d0de476db2c", 0}, + {"38201a5379641fcdede51196fa2a3be8aef7ce5dbb33db95c77751414085128d260fb4c8acc6cd5f522f932dded2a025d60984445a12123274f8b137be2a74a5", "", "5f2d0b92a398c22fb9816d0de476db2c", 1}, + {"f4c9a342bc66ee04dc517d604564faa96942f5ada68b2ff8e83cd33eb362761a199f3e368a854d3a6dc7be9385a5035bc8f36c1f99276981ef6fdc20dc2bc8dd", "", "5f2d0b92a398c22fb9816d0de476db2c", 2}, + {"aa843dd67a9ef7bb501f53b0ab18f013e39da3aee28b0646016420e13a28e0e476b31355e862544b6c3ec8350f69d321628a125c6481efbb2aa91d492708b180", "", "5f2d0b92a398c22fb9816d0de476db2c", 3}, + {"98c919064c1023e35bcf4b5e326dc339dd5373848e06c37063fb5c737c9207a3fcc5f740909170a94eee137a1850b11ae32adc1c1b51074bd9c71af701fb390f", "", "6f612185e1b2d64a0657fc056e156a89", 0}, + {"b3ac585807ec5666c63071e712c169f43f207cc15cd40cf7ab0d7dcb992c81074d6a85079291daf34ddbbbe355fe634d4c3a9cf8a81fca9e637fbc30dc0299dd", "", "6f612185e1b2d64a0657fc056e156a89", 1}, + {"9ae39eb0b7ad0053f44cf6840ef7b2fa139cdc50a3efeeee2958c11c78f83527d255b01fe0d12100a299bb41b568922a96e03fa842b8c1b0bb93da8149d1b138", "", "6f612185e1b2d64a0657fc056e156a89", 2}, + {"07f4746c6ebc748052ca6f6356794c4738df720874600ebab6bb09772f8e4c2e916bbe3aeebce61b38d0c15a57a18daf49a25dd70a504548bff7ca135ed88f76", "", "6f612185e1b2d64a0657fc056e156a89", 3}, + {"473a27dc50637a128be257d46864333862c12ab5b4280a50f0523ebfd91cc00be773f3faba976a0a6dab807a4b62b848ebd680b42ff8fed6eb5bf2f15f4ce144", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6680dc641356cbdb590c370f747d4e9f", 0}, + {"3271e4ba95d28b07c19849cad13e1ec6a686038e98335b2323cc0c7d55beb031474ad0c0a2d93eb5765c319f2af5c70e948a640e8e2f612ed43beb07ffe4abc4", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6680dc641356cbdb590c370f747d4e9f", 1}, + {"ec6917ea8f475d915fbf944fce360da40597f7c5c6489e587c8dfd53c5528c775e26839a952a29805de742b36d6e0141c7d5a3bc9e6b1f927f6318932ad148f2", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6680dc641356cbdb590c370f747d4e9f", 2}, + {"20d77c710d53437a56ad0ec8541271916ea92d8458b8013bc6a2579afae606af0026fc85363e177aa243ac9cf6d1a64f4da7b24cb9243e8e5c7949af1b6a9d2e", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6680dc641356cbdb590c370f747d4e9f", 3}, + {"f8d545b70906d5ffcedec23901b8f664e2a7c4936704ec79aaad8690b83ed06c475576a0398dc34d3495e9a058f88b655c3e770e01e6154657e76f57cc6d985c", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "5f2d0b92a398c22fb9816d0de476db2c", 0}, + {"7a82c33c1f4d4e47bd232e781e97a08bf4a0742a56419d6702a4dbd31eaaaf479afb1ac8bb52342ebdd76e95b1b56dc640a8c46dd138a3137db4928b6569c90c", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "5f2d0b92a398c22fb9816d0de476db2c", 1}, + {"634f08d0ec1096f745e858f004d8aed9da7c239676a5662a6abdc63a55c5b78b85fa9879d939a37599be5c0cd85e4e630b981889b6fba80142f42a1179cd0351", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "5f2d0b92a398c22fb9816d0de476db2c", 2}, + {"a8fd8a5393a0ae669ed82eea4b1640d794166266d1aad5c3f50e9ee09373569e077e7b779d7d7f2799b643364cfa51f3105a92bed83faa8efab26cfd6a6f9e0d", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "5f2d0b92a398c22fb9816d0de476db2c", 3}, + {"912ab94cb0dba6ef52548502915de1ab467bf72ba74e9594f97d3bc0fc35ac1c716e8a2e2047b7f2ce2029d3a54355eeaf809360e8acf9461bf99cd7e6450d1b", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6f612185e1b2d64a0657fc056e156a89", 0}, + {"8b76c6654dfb5acfbe99ed7e9fa0914579b44b4409cfe15e3d5b9b1cc9943e3a1eb34a1b4b3f8eb01f5b09d667ab44e8f05cff88c18ddc42a14d25a1ad0be605", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6f612185e1b2d64a0657fc056e156a89", 1}, + {"c63b00c909dfcf2490066de4de4341c7fdeafda48513121b75ac6c32ed8ed4ebc239ebe444884431f95baa3367c2d68be35d6402dc4337f840306d6c731c05ac", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6f612185e1b2d64a0657fc056e156a89", 2}, + {"c06f380eeb2cd7e93f16f215f739650bd98253b3234d4a1b6b5878eb2c8b8bd3b117d4a78a4faad53a6fc35cb5d5701669884950bbda57338b7b027ebba32bc8", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "6f612185e1b2d64a0657fc056e156a89", 3}, + {"2b13d5c18268b7f7b432a36e8c4d255a1371ca0ccac3a712a6f8d048dbcb2306179d458987a595287be7cb97ece82cb6e3ae7420cfc8ea90572b29c8ab750634", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6680dc641356cbdb590c370f747d4e9f", 0}, + {"c8a12ca9be5bd3215ea53e78c0be09b409f161723544b29c94960782fdff206557ae579d44a380bb34f3bc3edeaf1a30050781e4dae459176ddc56f2bca7a27d", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6680dc641356cbdb590c370f747d4e9f", 1}, + {"d1370dbe6544cf2fc35d8f0fbe48c86516df3d3f97bb87f0227778823eabba19cf4be97208f88d58d9561876e9f706cb108552c7e5c049ff318ce8dc2fe9c163", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6680dc641356cbdb590c370f747d4e9f", 2}, + {"7b1dac8f699e6d1e1c8a7d490b26bc12e47f219d2cf257504ade2657286ad2e8db4fd4b9cf5470e77a3622d64718c058809fa551fc2791734419dedebc0f4c7f", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6680dc641356cbdb590c370f747d4e9f", 3}, + {"bb381f67c8ddc711d2d45d6bf46d09390e6888c36283773cad7c954585dfd2ddc572dc59397099eb6eab3d9b68a9be7f17f309e9e655b52762136a01fd871acf", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "5f2d0b92a398c22fb9816d0de476db2c", 0}, + {"3f6ac072b23dac2eed840a855734527fb4b293a3bcd79a9344e39810d39308d964615050ce9ca605fa2ed50b38812da96485cd6ad54d0f531c70dd95208de6f6", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "5f2d0b92a398c22fb9816d0de476db2c", 1}, + {"98cf45e94daa57c637e32436ab81fe8cfab8f96de55a085840fc546c5299fe0e0b4b7027a62de87a90200e790f0c9b2e2cae418a0a934dacd9ccbc634431bf25", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "5f2d0b92a398c22fb9816d0de476db2c", 2}, + {"8fb793ca77ba1b033876ce7d555ec7188278dc11b6d2e53471ff5d7171cf52ed7c36a581e2204e227f08abdc0d534a3885c585bbd289d022404900ef68efd771", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "5f2d0b92a398c22fb9816d0de476db2c", 3}, + {"762fef8c12ec980a70f0b89d11490a0016d9b385a7414d9dc24b21cc9ab0285a79cc7ef346f6562beca5c597ed40363d390c865b9cbfd78ad1096a13fa4c67f0", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6f612185e1b2d64a0657fc056e156a89", 0}, + {"a9046d7afd68be3dc45447c42e644d16114d3b09c111184a236c58d58da7b04011f9a522e6ca8df776130faf0170a8ad7a79e064779508ed9440b30e0b68d1a3", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6f612185e1b2d64a0657fc056e156a89", 1}, + {"ca17a038d904817d3fd01131557d301048ac80cd1a42627281fd159c657fa5ce737736be000a823ffc1443caf6de93d6f084314386270a851e76da38a8dd1b3d", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6f612185e1b2d64a0657fc056e156a89", 2}, + {"7a7f4a139f07d13aa12130c0d8d992f090f21e7c60a41dd632cd8e0a79e35402fef66d83168805629e737943c352ebda30a724d68e9c483f6827a0cf6cc4c087", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "6f612185e1b2d64a0657fc056e156a89", 3}, + {"400a4ad961e475d8c2e20f07784a02634a9d10f2f7d81e6add27c60f62e03b06027459b547ac5ebf68bd01838d38c9ec93a5d44a9fe0945eb5d298fe882e6582", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6680dc641356cbdb590c370f747d4e9f", 0}, + {"02818be2a3c5d4fe10bfae41112ea228755d3e93d5aeb951bb6f79b754f01c176caeae76afff7cfd13b8d10b095e40ddde2bc342c69cd05ce52fb686dd1d761e", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6680dc641356cbdb590c370f747d4e9f", 1}, + {"44361dde703bcf5291bc538a396dcacbcebad3562bb16e58711573fea60b0eae67c9bfe64ce41f157b5442cc64f06e299df522ae86966eda86845ba057b22c30", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6680dc641356cbdb590c370f747d4e9f", 2}, + {"1368587df1fbd0ae5b02586f08bfbc9a8526bc5ea59f6f5895e05d6e19bed085ae8dd10aa829d04b2067d0ec1d1a99a1dc73492dbe64bad76dc9796557dac831", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6680dc641356cbdb590c370f747d4e9f", 3}, + {"3abc9c2f7c33fce44ddb3f38fba586da04f9dbc77f1319caebdfa7bd9b4fb7cd5e20c3e84f9fc91dd1ca164ec466ba47208ef01174ff23cbd9223a3e8f39a5ea", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "5f2d0b92a398c22fb9816d0de476db2c", 0}, + {"e37f360d2a8476a9fc485cff64130f9dcb0e6aae4be970fd21c5e0022869826a31fc744523e93690470beabce06ff705019aae108b6deb6c1998ec6f2dade395", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "5f2d0b92a398c22fb9816d0de476db2c", 1}, + {"c8c31f8c69c66db28ef1108261c372ef66cd5bb1ae099acf017bbda14a4dbbc8530a16b6edde0d0c52fa9087e1511597eeb59b97c9209baeb6d3aef14f9f37bb", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "5f2d0b92a398c22fb9816d0de476db2c", 2}, + {"edd3538741f35d76f97cee03f3f14b289a4264f52192300b5f2c8744c80b61b9bda7c2e0da48ce0692c2fcd22d6361bf7b96a19b8e311995d3fab9e9cab37553", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "5f2d0b92a398c22fb9816d0de476db2c", 3}, + {"ac80c57cebd18facf12b979349029964c3cf512d080bed818cb9ec42a4c64d782af8b36e4a39ecd65fbb3a6f4e12c8799dfd928af589e83fad1e313020c76d7a", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6f612185e1b2d64a0657fc056e156a89", 0}, + {"43704b603a76c16e08e69546dbff0b57fdc8d5d4cb67588b4a52675f7fb327e1d5ee565dd0b58100f4d14f9deb7fafe9ff13fda30b80a4f196021b20d17b50fe", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6f612185e1b2d64a0657fc056e156a89", 1}, + {"b8831aa6cccefe526c1054701da0268b4f6ea8e2ae5418b2df976ca4e924c0ae75d7dc901dc64bc53cb69a8483b5c51a47ab0043f018f23e575ff470669697e3", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6f612185e1b2d64a0657fc056e156a89", 2}, + {"86156e15c1e6511d2627384d76d38f3f3e3402be0f34c901afb10ea118d5fb6b400a07c5b93e06f5a15a710731e1a694b1e21248ee830bc0774a400312881e7c", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "6f612185e1b2d64a0657fc056e156a89", 3}, +} + +type xorCipherTest struct { + cipherText string // hex + plainText string + xorKey string // hex + choice byte +} + +var xorCipherTestStrings = []xorCipherTest{ + {"32c207fb7951ac2f8edb334b120f6337279f19af323b27be976e520796a8f7499e420d67ec78c58ce0d34274cc7ef46b2d5b16d0a5d012452b", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 0}, + {"e2c4bf84b04050e41983d021c91254cd38a0cbf26abaf2a863a102a8b33b9f48c5a033df4a68b1bcaa46acd3d6abebbe113d18c664f22feabd", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 1}, + {"ea70ec5f79ce3df7c390e804f5940198fb33f64094a13354e9fd8881f85cb85047674d24010576753f7457f0ee4e0da927a129e2406c986edc", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 0}, + {"388b761ddb6c9ea235cc63862fa1f4cb4aa1e71d2d80010720067184fea081ca03c0dd231febff6d84bb564d5992f663edeba580fd09970ddb", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 1}, + {"4ef613000dc9612bead97ef1c802233dc311ecfb33ae2c9e063e9eebc939dc9740ad7adea498debe7000ebc928f2008ff9b7c86605959c0a8b", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 0}, + {"0d729e57468445fe8aa5b7344d0dc330822fab5efc27237b6295a4ebfa9cd3b9553d6ada3f6a6fa0cb0ec2903589ecb6a47ee04cdde8e86bc7", "Free! Free!/A trip/to Mars/for 900/empty jars/Burma Shave", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 1}, + {"118a5bfd6910dc6bde892505374d22747e8c50ee3c2c5de9d67f1902ccbbe745ca431567b73c8997f18f41308d32d57a3b0441c0c281440077f2c2d29e1c5749ac00", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 0}, + {"c18ce382a00120a049d1c66fec50158e61b382b364ad88ff22b049ade9288f4491a12bdf112cfda7bb1aaf9797e7caaf07624fd603a379afe12cc258e2a7ea9e0d3f", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 1}, + {"c938b059698f4db393c2fe4ad0d640dba220bf019ab64903a8ecc384a24fa85c136655245a413a6e2e2854b4af022cb831fe7ef2273dce2b808c6df2dd5835923ad1", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 0}, + {"1bc32a1bcb2deee6659e75c80ae3b58813b2ae5c23977b5061173a81a4b391c657c1c52344afb37695e7550918ded772fbb4f2909a58c148871eea941d34093bb002", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 1}, + {"6dbe4f061d88116fba8b68bfed40627e9a02a5ba3db956c9472fd5ee932acc9b14ac62deffdc92a5615ce88d69be219eefe89f7662c4ca4fd71916c06d28c0626d05", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 0}, + {"2e3ac25156c535badaf7a17a684f8273db3ce21ff230592c2384efeea08fc3b5013c72da642e23bbda52c1d474c5cda7b221b75cbab9be2e9b38439239b5446cdfa7", "e:9c1a66577adb510cf5a7763bdc5a05d17e648b16b62ccdd260497394536662d9", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 1}, + {"20d807be3e048d3c88d7661d734071652fcf55b433681eb69168180f8dfabe1e8e13026fe870c580b2dc0369d971d17f2c5304d097cc53526ea19e97cd44410fad5467fc8334b777404a59242f86a8f7ecb27ba243eb0d89537dcae3dfa701ebdeabca4a7c814763fa4b556325f490687e1435d0e380e35e4321cd82ac51bd53fc94f84f", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 0}, + {"f0debfc1f71571f71f8f8577a85d469f30f087e96be9cba065a748a0a869d61fd5f13cd74e60b1b0f849edcec3a4ceaa10350ac656ee6efdf87f9e1db1fffcd80c6b65a59651e0e7bda457b1317369b5189753b0e42a799d22d57df3d8539614b53ff0eb81f6f7149dd57b79f55b2618b962c6ea30e99469e28d51166698af1820137a8d", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "3d1fe501030418491fde1223b3cf05094996fe655139934b538095715b7c68d5", 1}, + {"f86aec1a3e9b1ce4c59cbd5294db13caf363ba5b95f20a5ceffbc289e30ef1075736422c050d76796d7b16edfb4128bd26a93be27270d97999df31b78e0023d43b857280aab29764e65e2b83bbf7a4ebcad4c9c40050ba8200bfe613588506dd2063d90bae16637d4ea237d77ae85b661f3fa7650bfecb1a6ba1d705ebe347ec3f9d6a08", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 0}, + {"2a9176589c39bfb133c036d04eeee69942f1ab062cd3380f26003b8ce5f2c89d1391d22b1be3ff61d6b417504c9dd377ece3b780cf15d61a9e4db6d14e6c1f7db156fbab84a5cc4dd648ba58582611d813a38c96f627b35a0404178c6653a198fbf696193fc2c88e582faf1370c16885a13d09ecc3a28770b08cb96f97b19aa0215ff0c7", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "4a8ec3b4f381da6b8ad2ad122ea71ea4064b90e280f65075e676feb9c40806c1", 1}, + {"5cec13454a9c4038ecd52ba7a94d316fcb41a0e032fd15960038d4e3d26b95c050fc75d6a090deb2220faad43dfd259bf8bfda663789dd1dce4a4a853e70d6246c51614bf4d45c82dbb5b5b12385a629efc4d14a249cb280c18b4e57e35c1366fc1b8db9ed1b1bbafc8c681195b82a25c074a729995693e17ff973fdfa0ebaedd4f624d6", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 0}, + {"1f689e1201d164ed8ca9e2622c42d1628a7fe745fd741a736493eee3e1ce9aee456c65d23b626fac9901838d2086c9a2a576f24ceff4a97c826b1fd76aed522adef31a71971fe9349567d1865f31a146c2829818c20a10b572782b59830ddd40b37817c22988d95965a50a52cecccf89a4ebe8e34922ea1feeb643786dee8fc579a663fe", "The fugacity of a constituent in a mixture of gases at a given temperature is proportional to its mole fraction. Lewis-Randall Rule", "2e87ae384abbb3b5d593385100bcee15a25766097128e4930353b88dc2a5e328", 1}, +} + +func TestPseudorandomCode(t *testing.T) { + for _, s := range aesTestStrings { + // instantiate aes block + aesKey, err := hex.DecodeString(s.aesKey) + if err != nil { + t.Fatal(err) + } + aesBlock, err := aes.NewCipher(aesKey) + if err != nil { + t.Fatal(err) + } + + // validate output of PseudorandomCode + enc := PseudorandomCode(aesBlock, []byte(s.in), s.hIdx) + if s.out != fmt.Sprintf("%x", enc) { + t.Errorf("AES block encoding did not return expected result with hash index %v for\ninput: %v\nreturned: %x\nexpected: %v", s.hIdx, s.in, string(enc), s.out) + } + } +} + +func TestEncryptionWithXorCipherWithBlake3(t *testing.T) { + for _, s := range xorCipherTestStrings { + xorKey, err := hex.DecodeString(s.xorKey) + if err != nil { + t.Fatal(err) + } + + cipherText := XorCipherWithBlake3(xorKey, s.choice, []byte(s.plainText)) + if s.cipherText != fmt.Sprintf("%x", cipherText) { + t.Fatalf("Encryption via XOR cipher with Blake 3 did not return expected result with choice bit %v for\ninput: %v\nreturned: %x\nexpected: %v", s.choice, s.plainText, string(cipherText), s.cipherText) + } + } +} + +func TestDecryptionWithXorCipherWithBlake3(t *testing.T) { + for _, s := range xorCipherTestStrings { + xorKey, err := hex.DecodeString(s.xorKey) + if err != nil { + t.Fatal(err) + } + + cipherText, err := hex.DecodeString(s.cipherText) + if err != nil { + t.Fatal(err) + } + plainText := XorCipherWithBlake3(xorKey, s.choice, cipherText) + if s.plainText != string(plainText) { + t.Fatalf("Decryption via XOR cipher with Blake 3 did not return expected result with choice bit %v for\ninput: %v\nreturned: %x\nexpected: %v", s.choice, s.cipherText, string(plainText), s.plainText) + } + } +} + +func TestEncryptDecrypt(t *testing.T) { + for _, s := range xorCipherTestStrings { + xorKey, err := hex.DecodeString(s.xorKey) + if err != nil { + t.Fatal(err) + } + + startText, err := hex.DecodeString(s.cipherText) + if err != nil { + t.Fatal(err) + } + + encryptText := XorCipherWithBlake3(xorKey, s.choice, startText) + decryptText := XorCipherWithBlake3(xorKey, s.choice, encryptText) + + if fmt.Sprintf("%x", startText) != fmt.Sprintf("%x", decryptText) { + t.Fatalf("Encryption followed by decryption via XOR cipher with Blake 3 did not return the original with choice bit %v for\noriginal: %v\nreturned: %x", s.cipherText, s.plainText, decryptText) + } + } +} + +func BenchmarkPseudorandomCode(b *testing.B) { + // the normal input is a 64 byte digest with a byte indicating + // which hash function is used to compute the cuckoo hash + in := make([]byte, 64) + aesKey := make([]byte, 16) + if _, err := rand.Read(in); err != nil { + b.Fatal(err) + } + if _, err := rand.Read(aesKey); err != nil { + b.Fatal(err) + } + + aesBlock, err := aes.NewCipher(aesKey) + if err != nil { + b.Fatal(err) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = PseudorandomCode(aesBlock, in, 0) + } +} + +func BenchmarkEncryptionWithXorCipherWithBlake3(b *testing.B) { + xorKey := make([]byte, 32) + p := make([]byte, 64) + if _, err := rand.Read(xorKey); err != nil { + b.Fatal(err) + } + if _, err := rand.Read(p); err != nil { + b.Fatal(err) + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + XorCipherWithBlake3(xorKey, 0, p) + } +} + +func BenchmarkDecryptionWithXorCipherWithBlake3(b *testing.B) { + xorKey := make([]byte, 32) + p := make([]byte, 64) + if _, err := rand.Read(xorKey); err != nil { + b.Fatal(err) + } + if _, err := rand.Read(p); err != nil { + b.Fatal(err) + } + + c := XorCipherWithBlake3(xorKey, 0, p) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + XorCipherWithBlake3(xorKey, 0, c) + } +} diff --git a/internal/crypto/point.go b/internal/crypto/point.go new file mode 100644 index 0000000..1fd76a4 --- /dev/null +++ b/internal/crypto/point.go @@ -0,0 +1,131 @@ +package crypto + +import ( + "crypto/elliptic" + "crypto/rand" + "fmt" + "io" + "math/big" + + "github.com/zeebo/blake3" +) + +/* +High level api for operating on P256 elliptic curve Points. +*/ + +var ( + curve = elliptic.P256() + encodeLen = encodeLenWithCurve(curve) +) + +// encodeLenWithCurve returns the number of bytes needed to encode a point +func encodeLenWithCurve(curve elliptic.Curve) int { + return len(elliptic.MarshalCompressed(curve, curve.Params().Gx, curve.Params().Gy)) +} + +// Point represents a point on the P256 elliptic curve +type Point struct { + x *big.Int + y *big.Int +} + +// NewPoint returns a Point +func NewPoint() *Point { + return &Point{x: new(big.Int), y: new(big.Int)} +} + +// Marshal converts a Point to a byte slice representation +func (p *Point) Marshal() []byte { + return elliptic.MarshalCompressed(curve, p.x, p.y) +} + +// Unmarshal takes in a marshaledPoint byte slice and extracts the Point object +func (p *Point) Unmarshal(marshaledPoint []byte) error { + x, y := elliptic.UnmarshalCompressed(curve, marshaledPoint) + + // on error of Unmarshal, x is nil + if x == nil { + return fmt.Errorf("error unmarshalling elliptic curve point") + } + + p.x.Set(x) + p.y.Set(y) + return nil +} + +// Add adds two points +func (p *Point) Add(q *Point) *Point { + x, y := curve.Add(p.x, p.y, q.x, q.y) + return &Point{x: x, y: y} +} + +// ScalarMult multiplies a point with a scalar +func (p *Point) ScalarMult(scalar []byte) *Point { + x, y := curve.ScalarMult(p.x, p.y, scalar) + return &Point{x: x, y: y} +} + +// Sub substracts point p from q +func (p *Point) Sub(q *Point) *Point { + // p - q = p.x + q.x, p.y - q.y + x, y := curve.Add(p.x, p.y, q.x, new(big.Int).Neg(q.y)) + return &Point{x: x, y: y} +} + +// DeriveKeyFromECPoint returns a key of 32 byte +func (p *Point) DeriveKeyFromECPoint() []byte { + key := blake3.Sum256(p.x.Bytes()) + return key[:] +} + +// GenerateKey returns a secret and public key pair +func GenerateKey() ([]byte, *Point, error) { + secret, x, y, err := elliptic.GenerateKey(curve, rand.Reader) + if err != nil { + return nil, nil, err + } + + return secret, &Point{x: x, y: y}, nil +} + +// pointWriter for elliptic curve points +type pointWriter struct { + w io.Writer +} + +// pointReader for elliptic curve points +type pointReader struct { + r io.Reader +} + +// NewECPointWriter returns an elliptic curve point writer +func NewECPointWriter(w io.Writer) *pointWriter { + return &pointWriter{w: w} +} + +// NewECPointReader returns an elliptic curve point reader +func NewECPointReader(r io.Reader) *pointReader { + return &pointReader{r: r} +} + +// Write writes the marshalled elliptic curve point to writer +func (w *pointWriter) Write(p *Point) (err error) { + _, err = w.w.Write(p.Marshal()) + return err +} + +// Read reads a marshalled elliptic curve point from reader and stores it in point +func (r *pointReader) Read(p *Point) (err error) { + pt := make([]byte, encodeLen) + if _, err = io.ReadFull(r.r, pt); err != nil { + return err + } + + return p.Unmarshal(pt) +} + +// Equal returns true when 2 points are equal +func (p *Point) equal(q *Point) bool { + return p.x.Cmp(q.x) == 0 && p.y.Cmp(q.y) == 0 +} diff --git a/internal/crypto/point_test.go b/internal/crypto/point_test.go new file mode 100644 index 0000000..d01e262 --- /dev/null +++ b/internal/crypto/point_test.go @@ -0,0 +1,182 @@ +package crypto + +import ( + "fmt" + "math/big" + "testing" +) + +type addTest struct { + xLeft, yLeft string + xRight, yRight string + xOut, yOut string +} + +var addTests = []addTest{ + { + "48439561293906451759052585252797914202762949526041747995844080717082404635286", // base point X + "36134250956749795798585127919587881956611106672985015071877198253568414405109", // base point Y + "48439561293906451759052585252797914202762949526041747995844080717082404635286", + "36134250956749795798585127919587881956611106672985015071877198253568414405109", + "56515219790691171413109057904011688695424810155802929973526481321309856242040", // 2x + "3377031843712258259223711451491452598088675519751548567112458094635497583569", // 2y + }, + { + "48439561293906451759052585252797914202762949526041747995844080717082404635286", // base point X + "36134250956749795798585127919587881956611106672985015071877198253568414405109", // base point Y + "102369864249653057322725350723741461599905180004905897298779971437827381725266", // 4x + "101744491111635190512325668403432589740384530506764148840112137220732283181254", // 4y + "36794669340896883012101473439538929759152396476648692591795318194054580155373", // 5x + "101659946828913883886577915207667153874746613498030835602133042203824767462820", // 5y + }, +} + +func TestAdd(t *testing.T) { + for i, e := range addTests { + xL, _ := new(big.Int).SetString(e.xLeft, 10) + yL, _ := new(big.Int).SetString(e.yLeft, 10) + xR, _ := new(big.Int).SetString(e.xRight, 10) + yR, _ := new(big.Int).SetString(e.yRight, 10) + pointL := &Point{x: xL, y: yL} + pointR := &Point{x: xR, y: yR} + expectedX, _ := new(big.Int).SetString(e.xOut, 10) + expectedY, _ := new(big.Int).SetString(e.yOut, 10) + sum1 := pointL.Add(pointR) + if !sum1.equal(&Point{x: expectedX, y: expectedY}) { + t.Errorf("#%d: got (%s, %s), want (%s, %s)", i, sum1.x.String(), sum1.y.String(), expectedX.String(), expectedY.String()) + } + + sum2 := pointL.Add(pointR) + if !sum2.equal(&Point{x: expectedX, y: expectedY}) { + t.Errorf("#%d: got (%s, %s), want (%s, %s)", i, sum2.x.String(), sum2.y.String(), expectedX.String(), expectedY.String()) + } + } +} + +type scalarMultTest struct { + k string + xIn, yIn string + xOut, yOut string +} + +var scalarMultTests = []scalarMultTest{ + { + "2a265f8bcbdcaf94d58519141e578124cb40d64a501fba9c11847b28965bc737", + "023819813ac969847059028ea88a1f30dfbcde03fc791d3a252c6b41211882ea", + "f93e4ae433cc12cf2a43fc0ef26400c0e125508224cdb649380f25479148a4ad", + "4d4de80f1534850d261075997e3049321a0864082d24a917863366c0724f5ae3", + "a22d2b7f7818a3563e0f7a76c9bf0921ac55e06e2e4d11795b233824b1db8cc0", + }, + { + "313f72ff9fe811bf573176231b286a3bdb6f1b14e05c40146590727a71c3bccd", + "cc11887b2d66cbae8f4d306627192522932146b42f01d3c6f92bd5c8ba739b06", + "a2f08a029cd06b46183085bae9248b0ed15b70280c7ef13a457f5af382426031", + "831c3f6b5f762d2f461901577af41354ac5f228c2591f84f8a6e51e2e3f17991", + "93f90934cd0ef2c698cc471c60a93524e87ab31ca2412252337f364513e43684", + }, +} + +func TestScalarMult(t *testing.T) { + for i, e := range scalarMultTests { + x, _ := new(big.Int).SetString(e.xIn, 16) + y, _ := new(big.Int).SetString(e.yIn, 16) + k, _ := new(big.Int).SetString(e.k, 16) + point := &Point{x: x, y: y} + expectedX, _ := new(big.Int).SetString(e.xOut, 16) + expectedY, _ := new(big.Int).SetString(e.yOut, 16) + + kPoint := point.ScalarMult(k.Bytes()) + if !kPoint.equal(&Point{x: expectedX, y: expectedY}) { + t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, kPoint.x, kPoint.y, expectedX, expectedY) + } + } +} + +func TestSub(t *testing.T) { + for i, e := range addTests { + expectedX, _ := new(big.Int).SetString(e.xLeft, 10) + expectedY, _ := new(big.Int).SetString(e.yLeft, 10) + x, _ := new(big.Int).SetString(e.xRight, 10) + y, _ := new(big.Int).SetString(e.yRight, 10) + xSum, _ := new(big.Int).SetString(e.xOut, 10) + ySum, _ := new(big.Int).SetString(e.yOut, 10) + point := &Point{x: x, y: y} + sum := &Point{x: xSum, y: ySum} + + diff := sum.Sub(point) + if !diff.equal(&Point{x: expectedX, y: expectedY}) { + t.Errorf("#%d: got (%s, %s), want (%s, %s)", i, diff.x.String(), diff.y.String(), expectedX.String(), expectedY.String()) + } + } +} + +type keyMarshalTest struct { + x, y string + key string + marshal string +} + +var keyMarshalTests = []keyMarshalTest{ + { + "48439561293906451759052585252797914202762949526041747995844080717082404635286", // base point X + "36134250956749795798585127919587881956611106672985015071877198253568414405109", // base point Y + "f7f366d0495aeb2267bc93104770f2ccc28929f610575a0ccf43b5a8fa53febb", + "036b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", + }, + { + "102369864249653057322725350723741461599905180004905897298779971437827381725266", // 4x + "101744491111635190512325668403432589740384530506764148840112137220732283181254", // 4y + "4354ebf0fc87975a838e3d8af6f5bb074e5cba896843ec4f0ddf670f343c0e8a", + "02e2534a3532d08fbba02dde659ee62bd0031fe2db785596ef509302446b030852", + }, +} + +func TestDeriveKeyPoint(t *testing.T) { + for i, e := range keyMarshalTests { + x, _ := new(big.Int).SetString(e.x, 10) + y, _ := new(big.Int).SetString(e.y, 10) + point := &Point{x: x, y: y} + key := point.DeriveKeyFromECPoint() + + if fmt.Sprintf("%x", key) != e.key { + t.Errorf("#%d: got %x, want %v", i, key, e.key) + } + } +} + +func TestMarshalUnmarshal(t *testing.T) { + for i, e := range keyMarshalTests { + x, _ := new(big.Int).SetString(e.x, 10) + y, _ := new(big.Int).SetString(e.y, 10) + point := &Point{x: x, y: y} + + marshaled := point.Marshal() + if fmt.Sprintf("%x", marshaled) != e.marshal { + t.Errorf("#%d: got %x, want %v", i, marshaled, e.marshal) + } + + unmarshalPoint := NewPoint() + unmarshalPoint.Unmarshal(marshaled) + if !point.equal(unmarshalPoint) { + t.Errorf("#%d: got (%x, %x), want (%x, %x)", i, unmarshalPoint.x, unmarshalPoint.y, point.x, point.y) + } + } +} + +func BenchmarkDeriveKey(b *testing.B) { + p := &Point{x: big.NewInt(1), y: big.NewInt(2)} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + p.DeriveKeyFromECPoint() + } +} + +func BenchmarkSub(b *testing.B) { + p := &Point{x: big.NewInt(1), y: big.NewInt(2)} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + p.Sub(p) + } +} diff --git a/internal/crypto/prg.go b/internal/crypto/prg.go new file mode 100644 index 0000000..d383839 --- /dev/null +++ b/internal/crypto/prg.go @@ -0,0 +1,29 @@ +package crypto + +import ( + "github.com/zeebo/blake3" +) + +// PseudorandomGenerate is a pseudorandom generator (PRG) +// using a deterministic random bit generator (DRBG) as +// specified by NIST - Special Publication 800-90A Revision +// 1. Blake3 is not normally used as a DRBG but we've applied +// it here for performance reasons. +func PseudorandomGenerate(dst []byte, seed []byte, h *blake3.Hasher) error { + if len(dst) < len(seed) { + copy(dst, seed) + return nil + } + + // reset internal state + h.Reset() + if _, err := h.Write(seed); err != nil { + return err + } + + drbg := h.Digest() + + _, err := drbg.Read(dst) + + return err +} diff --git a/internal/cuckoo/README.md b/internal/cuckoo/README.md new file mode 100644 index 0000000..81da5b5 --- /dev/null +++ b/internal/cuckoo/README.md @@ -0,0 +1,21 @@ +# Cuckoo hash tables + +## Description +Cuckoo hash tables [1] is an optimized hash table data structure with O(1) look up times in worst case scenario, and O(1) insertion time with amortized costs. We implement a variant of cuckoo hash tables that uses *3* hash functions to limit the probability of hashing faillure to _2σ_, where _σ_ is a security parameter that is set to _40_. + +## Benchmark +``` +go test -bench=. -benchmem ./internal/cuckoo/... +goos: darwin +goarch: amd64 +pkg: github.com/optable/match/internal/cuckoo +cpu: Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz +BenchmarkCuckooInsert-12 2627492 498.0 ns/op 0 B/op 0 allocs/op +BenchmarkCuckooExists-12 7582648 203.0 ns/op 0 B/op 0 allocs/op +PASS +ok github.com/optable/match/internal/cuckoo 7.896s +``` + +## References + +[1] Pagh, R., and Rodler, F. F. Cuckoo hashing. J. Algorithms 51, 2 (2004), 122–144. diff --git a/internal/cuckoo/cuckoo.go b/internal/cuckoo/cuckoo.go new file mode 100644 index 0000000..7b80330 --- /dev/null +++ b/internal/cuckoo/cuckoo.go @@ -0,0 +1,236 @@ +package cuckoo + +import ( + "bytes" + "fmt" + "math/rand" + + "github.com/optable/match/internal/hash" +) + +const ( + // Nhash is the number of hash function used for cuckoo hash + Nhash = 3 + // ReInsertLimit is the maximum number of reinsertions. + // Each reinsertion kicks off 1 egg (item) and replace it + // with the item being reinserted, and then reinserts the + // kicked off egg + ReInsertLimit = 200 + // Factor is the multiplicative factor of items to be + // inserted which represents the capacity overhead of + // the hash table to reduce risk of failure on insertion. + Factor = 1.4 +) + +// CuckooHasher is the building block of a Cuckoo hash table. It only holds +// the bucket size and the hashers. +type CuckooHasher struct { + // Total bucket count, len(bucket) + bucketSize uint64 + // 3 hash functions h_0, h_1, h_2 + hashers [Nhash]hash.Hasher +} + +// NewCuckooHasher instantiates a CuckooHasher struct. +func NewCuckooHasher(size uint64, seeds [Nhash][]byte) *CuckooHasher { + bSize := max(1, uint64(Factor*float64(size))) + var hashers [Nhash]hash.Hasher + var err error + for i, s := range seeds { + if hashers[i], err = hash.NewMetroHasher(s); err != nil { + panic(err) + } + } + + return &CuckooHasher{ + bucketSize: bSize, + hashers: hashers, + } +} + +// GetHasher returns the first seeded hash function from a CuckooHasher struct. +func (h *CuckooHasher) GetHasher() hash.Hasher { + return h.hashers[0] +} + +// BucketIndices returns the 3 possible bucket indices of an item +func (h *CuckooHasher) BucketIndices(item []byte) (idxs [Nhash]uint64) { + for i := range idxs { + idxs[i] = h.hashers[i].Hash64(item) % h.bucketSize + } + + return idxs +} + +// Cuckoo represents a 3-way Cuckoo hash table data structure +// that contains the items, bucket indices of each item and the 3 +// hash functions. The bucket lookup is a lookup table on items which +// tells us which item should be in the bucket at that index. Upon +// construction the items slice has an additional nil value prepended +// so the index of the Cuckoo.items slice is +1 compared to the index +// of the input slice you use. The number of inserted items is also +// tracked. +type Cuckoo struct { + items [][]byte + inserted uint64 + hashIndices []byte + bucketLookup []uint64 + *CuckooHasher +} + +// NewCuckoo instantiates a Cuckoo struct with a bucket of size Factor * size, +// and a CuckooHasher for the 3-way cuckoo hashing. +func NewCuckoo(size uint64, seeds [Nhash][]byte) *Cuckoo { + cuckooHasher := NewCuckooHasher(size, seeds) + + return &Cuckoo{ + // extra element is "keeper" to which the bucketLookup can be directed + // when there is no element present in the bucket. + make([][]byte, size+1), + 0, + make([]byte, size+1), + make([]uint64, cuckooHasher.bucketSize), + cuckooHasher, + } +} + +// GetBucket returns the index in a given bucket which represents the value in +// the list of identifiers to which it points. +func (c *Cuckoo) GetBucket(bIdx uint64) uint64 { + if bIdx > c.bucketSize { + panic(fmt.Errorf("failed to retrieve item in bucket #%v", bIdx)) + } + return c.bucketLookup[bIdx] +} + +// GetItemWithHash returns the item at a given index along with its +// hash index. Panic if the index is greater than the number of items. +func (c *Cuckoo) GetItemWithHash(idx uint64) (item []byte, hIdx uint8) { + if idx > uint64(len(c.items)-1) { + panic(fmt.Errorf("index greater than number of items")) + } + + return c.items[idx], c.hashIndices[idx] +} + +// Exists returns true if an item is inserted in cuckoo, false otherwise +func (c *Cuckoo) Exists(item []byte) (bool, byte) { + bucketIndices := c.BucketIndices(item) + + for hIdx, bIdx := range bucketIndices { + if bytes.Equal(c.items[c.bucketLookup[bIdx]], item) { + return true, byte(hIdx) + } + } + return false, 0 +} + +// Insert tries to insert a given item at the next index to the bucket +// in available slots, otherwise, it evicts a random occupied slot, +// and reinserts evicted item. +// Returns an error msg if all failed. +func (c *Cuckoo) Insert(item []byte) error { + if int(c.inserted) == len(c.items) { + return fmt.Errorf("%v of %v items have already been inserted into the cuckoo hash table. Cannot insert again", c.inserted, len(c.items)) + } + c.items[c.inserted+1] = item + bucketIndices := c.BucketIndices(item) + + // check if item has already been inserted + if found, _ := c.Exists(item); found { + return nil + } + + // add to free slots + if c.tryAdd(c.inserted+1, bucketIndices, false, 0) { + c.inserted++ + return nil + } + + // force insert by cuckoo (eviction) + homelessIdx, added := c.tryGreedyAdd(c.inserted+1, bucketIndices) + if added { + c.inserted++ + return nil + } + + return fmt.Errorf("failed to Insert item %v, results in homeless item #%v", item, homelessIdx) +} + +// tryAdd finds a free slot and inserts the item (at index, idx) +// if ignore is true, it will not insert into exceptBIdx +func (c *Cuckoo) tryAdd(idx uint64, bucketIndices [Nhash]uint64, ignore bool, exceptBIdx uint64) (added bool) { + for hIdx, bIdx := range bucketIndices { + if ignore && exceptBIdx == bIdx { + continue + } + + if c.isEmpty(bIdx) { + // this is a free slot + c.bucketLookup[bIdx] = idx + c.hashIndices[idx] = uint8(hIdx) + return true + } + } + return false +} + +// tryGreedyAdd evicts a random occupied slot, inserts the item to the evicted slot +// and reinserts the evicted item. If reinsertions fail after ReInsertLimit tries +// return false and the last evicted item. +func (c *Cuckoo) tryGreedyAdd(idx uint64, bucketIndices [Nhash]uint64) (homeLessItem uint64, added bool) { + for i := 1; i < ReInsertLimit; i++ { + // select a random slot to be evicted + evictedHIdx := rand.Int31n(Nhash) + evictedBIdx := bucketIndices[evictedHIdx] + evictedIdx := c.bucketLookup[evictedBIdx] + // insert the item in the evicted slot + c.bucketLookup[evictedBIdx] = idx + c.hashIndices[idx] = byte(evictedHIdx) + + evictedBucketIndices := c.BucketIndices(c.items[evictedIdx]) + // try to reinsert the evicted items + // ignore the evictedBIdx since we just inserted there + if c.tryAdd(evictedIdx, evictedBucketIndices, true, evictedBIdx) { + return 0, true + } + + // insertion of evicted item unsuccessful, recurse + idx = evictedIdx + bucketIndices = evictedBucketIndices + } + + return idx, false +} + +// LoadFactor returns the ratio of occupied buckets with the overall bucketSize +func (c *Cuckoo) LoadFactor() (factor float64) { + occupation := 0 + for _, v := range c.bucketLookup { + if v != 0 { + occupation += 1 + } + } + + return float64(occupation) / float64(c.bucketSize) +} + +// Len returns the total size of the cuckoo struct which is equal +// to bucketSize +func (c *Cuckoo) Len() uint64 { + return c.bucketSize +} + +// isEmpty returns true if bucket at bidx does not contain the index +// of an identifier +func (c *Cuckoo) isEmpty(bidx uint64) bool { + return c.bucketLookup[bidx] == 0 +} + +func max(a, b uint64) uint64 { + if a > b { + return a + } + + return b +} diff --git a/internal/cuckoo/cuckoo_test.go b/internal/cuckoo/cuckoo_test.go new file mode 100644 index 0000000..b632a15 --- /dev/null +++ b/internal/cuckoo/cuckoo_test.go @@ -0,0 +1,111 @@ +package cuckoo + +import ( + "bytes" + "crypto/rand" + "math" + "testing" + "time" +) + +var testN = uint64(1e6) // 1 Million + +func makeSeeds() [Nhash][]byte { + var seeds [Nhash][]byte + + for i := range seeds { + seeds[i] = make([]byte, 32) + rand.Read(seeds[i]) + } + + return seeds +} + +func TestNewCuckoo(t *testing.T) { + cuckooTests := []struct { + size uint64 + bSize uint64 //bucketSize + }{ + {uint64(0), uint64(1)}, + {uint64(math.Pow(2, 4)), uint64(Factor * math.Pow(2, 4))}, + {uint64(math.Pow(2, 8)), uint64(Factor * math.Pow(2, 8))}, + {uint64(math.Pow(2, 16)), uint64(Factor * math.Pow(2, 16))}, + } + + seeds := makeSeeds() + + for _, tt := range cuckooTests { + c := NewCuckoo(tt.size, seeds) + if c.CuckooHasher.bucketSize != tt.bSize { + t.Errorf("cuckoo bucketsize: want: %d, got: %d", tt.bSize, c.CuckooHasher.bucketSize) + } + } +} + +func TestInsertAndGetHashIdx(t *testing.T) { + cuckoo := NewCuckoo(testN, makeSeeds()) + errCount := 0 + testData := genBytes(int(testN)) + + insertTime := time.Now() + for _, item := range testData { + if err := cuckoo.Insert(item); err != nil { + errCount += 1 + } + } + + t.Logf("To be inserted: %d, bucketSize: %d, load factor: %f, failure insertion: %d, taken %v", + testN, cuckoo.bucketSize, cuckoo.LoadFactor(), errCount, time.Since(insertTime)) + + //test GetHashIdx + for i, item := range testData { + bIndices := cuckoo.BucketIndices(item) + found, hIdx := cuckoo.Exists(item) + if !found { + t.Fatalf("Cuckoo GetHashIdx, %dth item: %v not inserted.", i+1, item) + } + + checkIndex := cuckoo.GetBucket(bIndices[hIdx]) + checkItem, _ := cuckoo.GetItemWithHash(checkIndex) + if !bytes.Equal(checkItem, item) { + t.Fatalf("Cuckoo GetHashIdx, hashIdx not correct for item: %v, with hIdx: %d, item : %v", item, hIdx, checkItem) + } + } +} + +func BenchmarkCuckooInsert(b *testing.B) { + seeds := makeSeeds() + benchCuckoo := NewCuckoo(uint64(b.N), seeds) + benchData := genBytes(int(b.N)) + b.ResetTimer() + + for i := 1; i < b.N; i++ { + idx := uint64(i % int(b.N)) + if err := benchCuckoo.Insert(benchData[idx]); err != nil { + b.Fatal(err) + } + } +} + +// Benchmark finding hash index and checking existance +func BenchmarkCuckooExists(b *testing.B) { + seeds := makeSeeds() + benchCuckoo := NewCuckoo(uint64(b.N), seeds) + benchData := genBytes(int(b.N)) + b.ResetTimer() + + for i := 1; i < b.N; i++ { + idx := uint64(i % int(b.N)) + benchCuckoo.Exists(benchData[idx]) + } +} + +func genBytes(n int) [][]byte { + data := make([][]byte, n) + for i := 0; i < n; i++ { + data[i] = make([]byte, 64) + rand.Read(data[i]) + } + + return data +} diff --git a/internal/hash/hash.go b/internal/hash/hash.go index 65283db..a28b62d 100644 --- a/internal/hash/hash.go +++ b/internal/hash/hash.go @@ -5,25 +5,18 @@ import ( "fmt" "log" - "github.com/cespare/xxhash" - "github.com/dchest/siphash" - "github.com/dgryski/go-highway" - "github.com/spaolacci/murmur3" + metro "github.com/dgryski/go-metro" + "github.com/optable/match/internal/util" + "github.com/twmb/murmur3" ) -const ( - SaltLength = 32 - - SIP = iota - Murmur3 - XX - Highway -) +// SaltLength is the number of bytes which should be used as salt in +// hashing functions +const SaltLength = 32 -var ( - ErrUnknownHash = fmt.Errorf("cannot create a hasher of unknown hash type") - ErrSaltLengthMismatch = fmt.Errorf("provided salt is not %d length", SaltLength) -) +// ErrSaltLengthMismath is used when the provided salt does not match +// the expected SaltLength +var ErrSaltLengthMismatch = fmt.Errorf("provided salt is not %d length", SaltLength) func init() { if SaltLength != 32 { @@ -31,72 +24,19 @@ func init() { } } -// extractSalt a length SaltLength (32 fixed tho) slice of bytes into 4 uint64 -// -func extractKeys(salt []byte) (keys []uint64) { - for i := 0; i < 4; i++ { - var key = binary.BigEndian.Uint64(salt[i*8 : i*8+8]) - keys = append(keys, key) - } - return -} - -// Hasher implements different non cryptographic -// hashing functions +// Hasher implements different non cryptographic hashing functions type Hasher interface { Hash64([]byte) uint64 } -// New creates a hasher of type t -func New(t int, salt []byte) (Hasher, error) { - switch t { - case SIP: - return NewSIPHasher(salt) - case Murmur3: - return NewMurmur3Hasher(salt) - case XX: - return NewXXHasher(salt) - case Highway: - return NewHighwayHasher(salt) - default: - return nil, ErrUnknownHash - } -} - -// sipHash implementation of Hasher -type siphash64 struct { - key0, key1 uint64 -} - -// NewSIPHasher returns a SIP hasher -// that uses the salt as a key -func NewSIPHasher(salt []byte) (siphash64, error) { - if len(salt) != SaltLength { - return siphash64{}, ErrSaltLengthMismatch - } - // extract the keys - keys := extractKeys(salt) - // xor key0 and key1 into key0, key2 and key3 into key1 - key0 := keys[0] ^ keys[1] - key1 := keys[2] ^ keys[3] - - return siphash64{key0: key0, key1: key1}, nil -} - -func (s siphash64) Hash64(p []byte) uint64 { - // hash using key0, key1 in s - return siphash.Hash(s.key0, s.key1, p) -} - -// murmur3 implementation of Hasher +// Murmur3 implementation of Hasher type murmur64 struct { salt []byte } -// NewMurmur3Hasher returns a Murmur3 hasher -// that uses salt as a prefix to the +// NewMurmur3Hasher returns a Murmur3 hasher that uses salt as a prefix to the // bytes being summed -func NewMurmur3Hasher(salt []byte) (murmur64, error) { +func NewMurmur3Hasher(salt []byte) (Hasher, error) { if len(salt) != SaltLength { return murmur64{}, ErrSaltLengthMismatch } @@ -104,56 +44,33 @@ func NewMurmur3Hasher(salt []byte) (murmur64, error) { return murmur64{salt: salt}, nil } -func (m murmur64) Hash64(p []byte) uint64 { +func (t murmur64) Hash64(p []byte) uint64 { // prepend the salt in m and then Sum - return murmur3.Sum64(append(m.salt, p...)) + return murmur3.Sum64(append(t.salt, p...)) } -// xxHash implementation of Hasher -type xxHash struct { - salt []byte +// Metro Hash implementation of Hasher +type metro64 struct { + seed uint64 } -// NewXXHasher returns a xxHash hasher that uses salt -// as a prefix to the bytes being summed -func NewXXHasher(salt []byte) (xxHash, error) { +// NewMetroHasher returns a metro hasher that uses salt as a +// prefix to the bytes being summed +func NewMetroHasher(salt []byte) (Hasher, error) { if len(salt) != SaltLength { - return xxHash{}, ErrSaltLengthMismatch + return metro64{}, ErrSaltLengthMismatch } - return xxHash{salt: salt}, nil -} - -func (x xxHash) Hash64(p []byte) uint64 { - // prepend the salt in x and then Sum - return xxhash.Sum64(append(x.salt, p...)) -} + // condense 32 byte salt to a uint64 + seed := make([]byte, 8) + copy(seed, salt) + util.Xor(seed, salt[8:16]) + util.Xor(seed, salt[16:24]) + util.Xor(seed, salt[24:]) -// highway hash implementation of Hasher -type hw struct { - key highway.Lanes + return metro64{seed: binary.LittleEndian.Uint64(seed)}, nil } -// NewHighwayHasher returns a highwayHash hasher that uses salt -// as the 4 lanes for the hashing -func NewHighwayHasher(salt []byte) (hw, error) { - if len(salt) != SaltLength { - return hw{}, ErrSaltLengthMismatch - } - - // extract the keys - keys := extractKeys(salt) - // turn into lanes - var key highway.Lanes - key[0] = keys[0] - key[1] = keys[1] - key[2] = keys[2] - key[3] = keys[3] - - return hw{key: key}, nil -} - -func (h hw) Hash64(p []byte) uint64 { - // prepend the salt in m and then Sum - return highway.Hash(h.key, p) +func (m metro64) Hash64(p []byte) uint64 { + return metro.Hash64(p, m.seed) } diff --git a/internal/hash/hash_test.go b/internal/hash/hash_test.go index 82f29de..fc28f56 100644 --- a/internal/hash/hash_test.go +++ b/internal/hash/hash_test.go @@ -4,6 +4,9 @@ import ( "crypto/rand" "fmt" "testing" + + "github.com/alecthomas/unsafeslice" + "github.com/twmb/murmur3" ) var xxx = []byte("e:0e1f461bbefa6e07cc2ef06b9ee1ed25101e24d4345af266ed2f5a58bcd26c5e") @@ -20,90 +23,29 @@ func makeSalt() ([]byte, error) { } } -func BenchmarkSipHash(b *testing.B) { - s, _ := makeSalt() - h, _ := New(SIP, s) - for i := 0; i < b.N; i++ { - h.Hash64(xxx) - } -} - func BenchmarkMurmur3(b *testing.B) { s, _ := makeSalt() - h, _ := New(Murmur3, s) + h, _ := NewMurmur3Hasher(s) + b.ResetTimer() for i := 0; i < b.N; i++ { h.Hash64(xxx) } } -func BenchmarkXXHasher(b *testing.B) { +func BenchmarkMetro(b *testing.B) { s, _ := makeSalt() - h, _ := New(XX, s) + h, _ := NewMetroHasher(s) + b.ResetTimer() for i := 0; i < b.N; i++ { h.Hash64(xxx) } } -func BenchmarkHighwayHash(b *testing.B) { - s, _ := makeSalt() - h, _ := New(Highway, s) +func BenchmarkMurmur316Unsafe(b *testing.B) { + src := make([]byte, 66) + b.ResetTimer() for i := 0; i < b.N; i++ { - h.Hash64(xxx) - } -} - -func TestUnknownHasher(t *testing.T) { - s, _ := makeSalt() - h, err := New(666, s) - if err != ErrUnknownHash { - t.Fatalf("requested impossible hasher and got %v", h) - } -} - -func TestGetSIP(t *testing.T) { - s, _ := makeSalt() - h, err := New(SIP, s) - if err != nil { - t.Fatalf("got error %v while requesting SIP hash", err) - } - - if _, ok := h.(siphash64); !ok { - t.Fatalf("expected type siphash64 and got %T", h) - } -} - -func TestGetMurmur3(t *testing.T) { - s, _ := makeSalt() - h, err := New(Murmur3, s) - if err != nil { - t.Fatalf("got error %v while requesting murmur3 hash", err) - } - - if _, ok := h.(murmur64); !ok { - t.Fatalf("expected type murmur64 and got %T", h) - } -} - -func TestGetxxHash(t *testing.T) { - s, _ := makeSalt() - h, err := New(XX, s) - if err != nil { - t.Fatalf("got error %v while requesting xxHash hash", err) - } - - if _, ok := h.(xxHash); !ok { - t.Fatalf("expected type xxHash and got %T", h) - } -} - -func TestGetHighwayHash(t *testing.T) { - s, _ := makeSalt() - h, err := New(Highway, s) - if err != nil { - t.Fatalf("got error %v while requesting highway hash", err) - } - - if _, ok := h.(hw); !ok { - t.Fatalf("expected type hw and got %T", h) + hi, lo := murmur3.SeedSum128(0, 2, src) + unsafeslice.ByteSliceFromUint64Slice([]uint64{hi, lo}) } } diff --git a/internal/oprf/README.md b/internal/oprf/README.md new file mode 100644 index 0000000..0e2fc5f --- /dev/null +++ b/internal/oprf/README.md @@ -0,0 +1,16 @@ +# Oblivious Pseudorandom Function (OPRF) + +# Introduction +An Oblivious Pseudorandom Function (OPRF) is a two-party protocol for computing the output of a pseudorandom function (PRF). A PRF F(k, x) is an efficiently computable function taking a secret key k and an input x that produces a pseudorandom output. This function is pseudorandom if the keyed function is indistinguishable from a randomly sampled function acting on the same domain and range. In the KKRT OPRF [1], one party (the sender) holds the PRF secret key, and the other (the receiver) holds the PRF output evaluated using the secret key on his inputs. The sender can later on use the same secret key to evaluate the OPRF output on any input. The 'obliviousness' property ensures that the sender does not learn anything about the receiver's input during the evaluation. The receiver should also not learn anything about the sender's secret PRF key. This can be efficiently implemented by slightly modifying the KKRT 1 out of n OT extension protocol. + +## Implementation +We have implemented an OPRF that is inspired by [1], [2] and [3] that uses Naor-Pinkas as its underlying [baseOT](../ot/README.md). + +## References + +[1] V. Kolesnikov, R. Kumaresan, M. Rosulek, N.Trieu. Efficient Batched Oblivious PRF with Applications to Private Set Intersection. Source: https://eprint.iacr.org/2016/799.pdf, and ACM version at https://dl.acm.org/doi/pdf/10.1145/2976749.2978381. + +[2] Y. Ishai, J. Kilian, K. Nissim, E. Petrank. "Extending oblivious transfers efficiently." In Annual International Cryptology Conference (pp. 145-161). Springer, Berlin, Heidelberg, 2003. Source: https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf + +[3] G. Asharov, Y. Lindell, T. Schneider, M. Zohner. "More Efficient Oblivious Transfer Extensions". Source: https://dl.acm.org/doi/10.1007/s00145-016-9236-6 + diff --git a/internal/oprf/oprf.go b/internal/oprf/oprf.go new file mode 100644 index 0000000..0b95083 --- /dev/null +++ b/internal/oprf/oprf.go @@ -0,0 +1,229 @@ +package oprf + +/* +Improved oblivious pseudorandom function (OPRF) +based on KKRT 1 out of 2 OT extension +from the paper: "Efficient Batched Oblivious PRF with Applications to Private Set Intersection" +by Vladimir Kolesnikov, Ranjit Kumaresan, Mike Rosulek, and Ni Treu in 2016, and +the paper "More Efficient Oblivious Transfer Extensions" +by Gilad Asharov, Yehuda Lindell, Thomas Schneider, and Michael Zohner +and the paper "Extending oblivious transfers efficiently" +by Yuval Ishai, Joe Kilian, Kobbi Nissim, and Erez Petrank for ot-extension using +short secrets. + +References: +- http://dx.doi.org/10.1145/2976749.2978381 (KKRT) +- https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf (IKNP) +- https://dl.acm.org/doi/10.1007/s00145-016-9236-6 (ALSZ) + +*/ + +import ( + "crypto/aes" + "crypto/rand" + "io" + "runtime" + + "github.com/optable/match/internal/crypto" + "github.com/optable/match/internal/cuckoo" + "github.com/optable/match/internal/ot" + "github.com/optable/match/internal/util" + "github.com/zeebo/blake3" +) + +const ( + // width of base OT binary matrix as well as the output + // length of PseudorandomCode (in bits) + baseOTCount = aes.BlockSize * 4 * 8 + baseOTCountBitmapWidth = aes.BlockSize * 4 +) + +// Key contains the relaxed OPRF keys: (C, s), (j, q_j) +// oprfKeys is the received OT extension matrix oprfKeys +// chosen with choice bytes secret. +type Key struct { + secret []byte // secret choice bits + oprfKeys [][]byte // m x k bit matrice +} + +// OPRF implements the oprf struct containing the base OT +// as well as the number of message tuples. +type OPRF struct { + baseOT ot.OT // base OT under the hood + m int // number of message tuples +} + +// NewOPRF returns an OPRF where m specifies the number +// of message tuples being exchanged. +func NewOPRF(m int) *OPRF { + // send k columns of messages of length k/8 (64 bytes) + baseMsgLens := make([]int, baseOTCount) + for i := range baseMsgLens { + baseMsgLens[i] = baseOTCountBitmapWidth // 64 bytes + } + + return &OPRF{baseOT: ot.NewNaorPinkas(baseMsgLens), m: m} +} + +// Send returns the OPRF keys +func (ext *OPRF) Send(rw io.ReadWriter) (*Key, error) { + // sample choice bits for baseOT + choices := make([]byte, baseOTCountBitmapWidth) + if _, err := rand.Read(choices); err != nil { + return nil, err + } + + // act as receiver in baseOT to receive k x k seeds for the pseudorandom generator + seeds := make([][]byte, baseOTCount) + if err := ext.baseOT.Receive(choices, seeds, rw); err != nil { + return nil, err + } + + // receive masked columns oprfMask + paddedLen := util.PadBitMap(ext.m, baseOTCount) + oprfMask := make([]byte, paddedLen) + oprfKeys := make([][]byte, baseOTCount) + prg := blake3.New() + for col := range oprfKeys { + if _, err := io.ReadFull(rw, oprfMask); err != nil { + return nil, err + } + + oprfKeys[col] = make([]byte, paddedLen) + if err := crypto.PseudorandomGenerate(oprfKeys[col], seeds[col], prg); err != nil { + return nil, err + } + + // Binary AND of each byte in oprfMask with the test bit + // if bit is 1, we get whole row oprfMask to XOR with + // oprfKeys[row] if bit is 0, we get a row of 0s which when + // XORed with oprfKeys[row] just returns the same row, so + // no need to do an operation + if util.IsBitSet(choices, col) { + util.ConcurrentBitOp(util.Xor, oprfKeys[col], oprfMask) + } + } + runtime.GC() + oprfKeys = util.ConcurrentTransposeWide(oprfKeys)[:ext.m] + + // store oprf keys + return &Key{secret: choices, oprfKeys: oprfKeys}, nil +} + +// Receive returns the hashes of OPRF encodings of choice strings embedded +// in the cuckoo hash table using OPRF keys +func (ext *OPRF) Receive(choices *cuckoo.Cuckoo, secretKey []byte, rw io.ReadWriter) ([]map[uint64]uint64, error) { + if int(choices.Len()) != ext.m { + return nil, ot.ErrBaseCountMissMatch + } + + // compute code word using PseudorandomCode on choice strings in a separate thread + aesBlock, err := aes.NewCipher(secretKey) + if err != nil { + return nil, err + } + var pseudorandomChan = make(chan [][]byte) + go func() { + defer close(pseudorandomChan) + bitMapLen := util.Pad(ext.m, baseOTCount) + pseudorandomEncoding := make([][]byte, bitMapLen) + i := 0 + for ; i < ext.m; i++ { + idx := choices.GetBucket(uint64(i)) + item, hIdx := choices.GetItemWithHash(idx) + pseudorandomEncoding[i] = crypto.PseudorandomCode(aesBlock, item, hIdx) + } + // pad matrix to ensure the number of rows is divisible by baseOTCount for transposition + for ; i < len(pseudorandomEncoding); i++ { + pseudorandomEncoding[i] = make([]byte, baseOTCountBitmapWidth) + } + pseudorandomChan <- util.ConcurrentTransposeTall(pseudorandomEncoding) + }() + + // sample random OT messages + baseMsgs, err := sampleRandomOTMessages() + if err != nil { + return nil, err + } + + // act as sender in baseOT to send k columns + if err = ext.baseOT.Send(baseMsgs, rw); err != nil { + return nil, err + } + + // read pseudorandomEncodings + pseudorandomEncoding := <-pseudorandomChan + + oprfEncodings := make([][]byte, baseOTCount) + paddedLen := util.PadBitMap(ext.m, baseOTCount) + oprfMask := make([]byte, paddedLen) + // oprfMask = G(seeds[1]) + // oprfEncoding = G(seeds[0]) ^ oprfMask ^ pseudorandomEncoding + prg := blake3.New() + for col := range pseudorandomEncoding { + oprfEncodings[col] = make([]byte, paddedLen) + err = crypto.PseudorandomGenerate(oprfEncodings[col], baseMsgs[col][0], prg) + if err != nil { + return nil, err + } + + err = crypto.PseudorandomGenerate(oprfMask, baseMsgs[col][1], prg) + if err != nil { + return nil, err + } + + util.ConcurrentDoubleBitOp(util.DoubleXor, oprfMask, oprfEncodings[col], pseudorandomEncoding[col]) + + // send oprfMask + if _, err = rw.Write(oprfMask); err != nil { + return nil, err + } + } + + runtime.GC() + oprfEncodings = util.ConcurrentTransposeWide(oprfEncodings)[:ext.m] + + // Hash and index all local encodings + // the hash value of the oprfEncodings is the key + // the index of the corresponding ID in the cuckoo hash table is the value + encodings := make([]map[uint64]uint64, cuckoo.Nhash) + for i := range encodings { + encodings[i] = make(map[uint64]uint64, ext.m) + } + hasher := choices.GetHasher() + // hash local oprf output + for bIdx := uint64(0); bIdx < uint64(len(oprfEncodings)); bIdx++ { + // check if it was an empty input + if idx := choices.GetBucket(bIdx); idx != 0 { + // insert into proper map + _, hIdx := choices.GetItemWithHash(idx) + encodings[hIdx][hasher.Hash64(oprfEncodings[bIdx])] = idx + } + } + + return encodings, nil +} + +// Encode computes and returns the OPRF encoding of a byte slice using an OPRF Key +func (k Key) Encode(rowIdx uint64, pseudorandomEncoding []byte) { + util.ConcurrentDoubleBitOp(util.AndXor, pseudorandomEncoding, k.secret, k.oprfKeys[rowIdx]) +} + +// sampleRandomOTMessage allocates a slice of OTMessage, each OTMessage contains a pair of messages. +// Extra elements are added to each column to be a multiple of 512. Every slice is filled with pseudorandom bytes +// values from a rand reader. +func sampleRandomOTMessages() ([]ot.OTMessage, error) { + // instantiate matrix + matrix := make([]ot.OTMessage, baseOTCount) + for row := range matrix { + for col := range matrix[row] { + matrix[row][col] = make([]byte, baseOTCountBitmapWidth) + // fill + if _, err := rand.Read(matrix[row][col]); err != nil { + return nil, err + } + } + } + + return matrix, nil +} diff --git a/internal/oprf/oprf_test.go b/internal/oprf/oprf_test.go new file mode 100644 index 0000000..ba369fd --- /dev/null +++ b/internal/oprf/oprf_test.go @@ -0,0 +1,176 @@ +package oprf + +import ( + "bytes" + "crypto/aes" + "crypto/rand" + "fmt" + "net" + "testing" + "time" + + "github.com/optable/match/internal/crypto" + "github.com/optable/match/internal/cuckoo" + "github.com/optable/match/internal/hash" +) + +const msgCount = 1 << 16 + +func genChoiceString() [][]byte { + choices := make([][]byte, msgCount) + for i := range choices { + choices[i] = make([]byte, 66) + rand.Read(choices[i]) + } + return choices +} + +func makeCuckoo(choices [][]byte, seeds [cuckoo.Nhash][]byte) (*cuckoo.Cuckoo, error) { + c := cuckoo.NewCuckoo(uint64(msgCount), seeds) + for _, id := range choices { + if err := c.Insert(id); err != nil { + return nil, err + } + } + return c, nil +} + +func testEncodings(encodedHashMap []map[uint64]uint64, key *Key, sk []byte, seeds [cuckoo.Nhash][]byte, choicesCuckoo *cuckoo.Cuckoo, choices [][]byte) error { + senderCuckoo := cuckoo.NewCuckooHasher(uint64(msgCount), seeds) + hasher := senderCuckoo.GetHasher() + var hashes [cuckoo.Nhash]uint64 + + aesBlock, err := aes.NewCipher(sk) + if err != nil { + return err + } + for i, id := range choices { + // compute encoding and hash + for hIdx, bIdx := range senderCuckoo.BucketIndices(id) { + pseudorandId := crypto.PseudorandomCode(aesBlock, id, byte(hIdx)) + key.Encode(bIdx, pseudorandId) + hashes[hIdx] = hasher.Hash64(pseudorandId) + } + + // test hashes + var found bool + for hIdx, hashed := range hashes { + if idx, ok := encodedHashMap[hIdx][hashed]; ok { + found = true + id, _ := choicesCuckoo.GetItemWithHash(idx) + if id == nil { + return fmt.Errorf("failed to retrieve item #%v", idx) + } + + if !bytes.Equal(id, choices[i]) { + return fmt.Errorf("oprf failed, got: %v, want %v", id, choices[i]) + } + } + } + + if !found { + return fmt.Errorf("failed to find proper encoding.") + } + } + + return nil +} + +func TestOPRF(t *testing.T) { + outBus := make(chan []map[uint64]uint64, cuckoo.Nhash) + keyBus := make(chan *Key) + errs := make(chan error, 1) + sk := make([]byte, 16) + choices := genChoiceString() + + // start timer + start := time.Now() + // sample seeds + var seeds [cuckoo.Nhash][]byte + for i := range seeds { + seeds[i] = make([]byte, hash.SaltLength) + rand.Read(seeds[i]) + } + + // generate oprf Input + choicesCuckoo, err := makeCuckoo(choices, seeds) + if err != nil { + t.Fatal(err) + } + oprfInputSize := int(choicesCuckoo.Len()) + + // generate AES secret key (16-byte) + if _, err := rand.Read(sk); err != nil { + t.Fatal(err) + } + + // create client, server connections + senderConn, receiverConn := net.Pipe() + + // sender + go func() { + defer close(errs) + defer close(keyBus) + keys, err := NewOPRF(oprfInputSize).Send(senderConn) + if err != nil { + errs <- fmt.Errorf("Send encountered error: %s", err) + close(outBus) + } + + keyBus <- keys + }() + + // receiver + go func() { + defer close(outBus) + out, err := NewOPRF(oprfInputSize).Receive(choicesCuckoo, sk, receiverConn) + if err != nil { + errs <- err + } + outBus <- out + }() + + // any errors? + select { + case err := <-errs: + t.Fatal(err) + default: + } + + // Receive keys + keys := <-keyBus + + // Receive msg + encodedHashMap := <-outBus + + // stop timer + end := time.Now() + t.Logf("Time taken for %d OPRF is: %v\n", msgCount, end.Sub(start)) + + // Testing encodings + err = testEncodings(encodedHashMap, keys, sk, seeds, choicesCuckoo, choices) + if err != nil { + t.Fatal(err) + } +} + +func BenchmarkEncode(b *testing.B) { + sk := make([]byte, 16) + s := make([]byte, 64) + q := make([][]byte, 1) + q[0] = make([]byte, 64) + rand.Read(sk) + rand.Read(s) + rand.Read(q[0]) + aesBlock, err := aes.NewCipher(sk) + if err != nil { + b.Fatal(err) + } + key := Key{secret: s, oprfKeys: q} + bytes := crypto.PseudorandomCode(aesBlock, s, 0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + key.Encode(0, bytes) + } +} diff --git a/internal/ot/README.md b/internal/ot/README.md new file mode 100644 index 0000000..4619095 --- /dev/null +++ b/internal/ot/README.md @@ -0,0 +1,11 @@ +# Oblivious Transfer (OT) +## Introduction +Oblivious transfer is a cryptographic primitive crucial to building secure multiparty computation (MPC) protocols. A secure OT protocol allows for two untrusted parties, a sender and a receiver, to perform data exchange in the following way. A sender has as input two messages _M0_, _M1_, and a receiver has a selection bit _b_. After the OT protocol, the receiver will learn only the message _Mb_ and not _M1-b_, while the sender does not learn the selection bit _b_. This way the receiver does not learn the unintended message (protect against malicious receiver), and the sender cannot forge messages, since he does not know which message will be learnt by the receiver (protect against malicious sender). +After 40 years since its invention, two notable base OT protocols are the Naor-Pinkas OT[1] and the Simplest Protocol for OT[2]. +The Naor-Pinkas[1] OT protocol using `crypto/elliptic` is implemented here. + +## References + +[1] M. Naor, B. Pinkas. "Efficient oblivious transfer protocols." In SODA (Vol. 1, pp. 448-457), 2001. Paper available here: https://link.springer.com/content/pdf/10.1007/978-3-662-46800-5_26.pdf + +[2] T. Chou, O. Claudio. "The simplest protocol for oblivious transfer." In International Conference on Cryptology and Information Security in Latin America (pp. 40-58). Springer, Cham, 2015. Paper available here: https://eprint.iacr.org/2015/267.pdf diff --git a/internal/ot/naor_pinkas.go b/internal/ot/naor_pinkas.go new file mode 100644 index 0000000..41d8252 --- /dev/null +++ b/internal/ot/naor_pinkas.go @@ -0,0 +1,159 @@ +package ot + +import ( + "fmt" + "io" + + "github.com/optable/match/internal/crypto" + "github.com/optable/match/internal/util" +) + +/* +1 out of 2 base OT +from the paper: Efficient Oblivious Transfer Protocol +by Moni Naor and Benny Pinkas in 2001. +reference: https://dl.acm.org/doi/abs/10.5555/365411.365502 +*/ + +type naorPinkas struct { + // msgLen holds the length of each pair of OT message + // it serves to inform the receiver, how many bytes it is + // expected to read + msgLens []int +} + +func NewNaorPinkas(msgLens []int) OT { + return naorPinkas{msgLens: msgLens} +} + +func (n naorPinkas) Send(otMessages []OTMessage, rw io.ReadWriter) (err error) { + if len(n.msgLens) != len(otMessages) { + return ErrBaseCountMissMatch + } + + // instantiate reader, writer + reader := crypto.NewECPointReader(rw) + writer := crypto.NewECPointWriter(rw) + + // generate sender point A w/o secret, since a is never used. + _, pointA, err := crypto.GenerateKey() + if err != nil { + return fmt.Errorf("error generating keys: %w", err) + } + + // generate sender secret public key pairs used for encryption. + secretR, pointR, err := crypto.GenerateKey() + if err != nil { + return fmt.Errorf("error generating keys: %w", err) + } + + // send point A to receiver + if err := writer.Write(pointA); err != nil { + return fmt.Errorf("error writing point: %w", err) + } + + // send point R to receiver + if err := writer.Write(pointR); err != nil { + return fmt.Errorf("error writing point: %w", err) + } + + // precompute A = rA + pointA = pointA.ScalarMult(secretR) + + // encrypt plaintext messages and send them. + for i := range otMessages { + keyMaterial := crypto.NewPoint() + // read keyMaterials to derive keys + if err := reader.Read(keyMaterial); err != nil { + return fmt.Errorf("error reading point: %w", err) + } + + // compute and derive key for first OT message + var keys [2]*crypto.Point + // K0 = rK0 + keys[0] = keyMaterial.ScalarMult(secretR) + // compute and derive key for second OT message + // K1 = rA - rK0 + keys[1] = pointA.Sub(keys[0]) + + // encrypt plaintext message with keys + for choice, plaintext := range otMessages[i] { + // encryption + ciphertext := crypto.XorCipherWithBlake3(keys[choice].DeriveKeyFromECPoint(), uint8(choice), plaintext) + + // send ciphertext + if _, err = rw.Write(ciphertext); err != nil { + return fmt.Errorf("error writing bytes: %w", err) + } + } + } + + return +} + +func (n naorPinkas) Receive(choices []uint8, messages [][]byte, rw io.ReadWriter) (err error) { + if len(choices)*8 != len(messages) || len(choices)*8 != len(n.msgLens) { + return ErrBaseCountMissMatch + } + + // instantiate Reader, Writer + reader := crypto.NewECPointReader(rw) + writer := crypto.NewECPointWriter(rw) + + // receive point A from sender + pointA := crypto.NewPoint() + if err := reader.Read(pointA); err != nil { + return fmt.Errorf("error reading point: %w", err) + } + // recieve point R from sender + pointR := crypto.NewPoint() + if err := reader.Read(pointR); err != nil { + return fmt.Errorf("error reading point: %w", err) + } + + for i := range messages { + // generate receiver priv/pub key pairs going to take a long time. + secretB, pointB, err := crypto.GenerateKey() + if err != nil { + return fmt.Errorf("error generating keys: %w", err) + } + + // for each choice bit, compute the key material corresponding to + // the choice bit and sent it. + if !util.IsBitSet(choices, i) { + // K0 = Kc = B + if err := writer.Write(pointB); err != nil { + return fmt.Errorf("error writing point: %w", err) + } + } else { + // K1 = Kc = B + // K0 = K1-c = A - B + if err := writer.Write(pointA.Sub(pointB)); err != nil { + return fmt.Errorf("error writing point: %w", err) + } + } + + // receive encrypted messages, and decrypt it. + var encryptedOTMessages OTMessage + // read both msg + encryptedOTMessages[0] = make([]byte, n.msgLens[i]) + if _, err := io.ReadFull(rw, encryptedOTMessages[0]); err != nil { + return fmt.Errorf("error reading bytes: %w", err) + } + + encryptedOTMessages[1] = make([]byte, n.msgLens[i]) + if _, err := io.ReadFull(rw, encryptedOTMessages[1]); err != nil { + return fmt.Errorf("error writing point: %w", err) + } + + // build keys for decryption + // K = bR + pointK := pointR.ScalarMult(secretB) + + // decrypt the message indexed by choice bit + choiceBit := util.BitExtract(choices, i) + messages[i] = crypto.XorCipherWithBlake3(pointK.DeriveKeyFromECPoint(), choiceBit, encryptedOTMessages[choiceBit]) + } + + return +} diff --git a/internal/ot/ot.go b/internal/ot/ot.go new file mode 100644 index 0000000..023cd5b --- /dev/null +++ b/internal/ot/ot.go @@ -0,0 +1,28 @@ +package ot + +import ( + "errors" + "io" +) + +/* +OT interface +*/ + +var ( + ErrBaseCountMissMatch = errors.New("provided slices is not the same length as the number of base OT") + ErrEmptyMessage = errors.New("attempt to perform OT on empty messages") +) + +// OT implements a BaseOT +type OT interface { + Send(messages []OTMessage, rw io.ReadWriter) error + Receive(choices []uint8, messages [][]byte, rw io.ReadWriter) error +} + +// OTMessage represent a pair of messages +// where an OT receiver with choice bit 0 will +// correctly decode the first message +// and an OT receiver with choice bit 1 will +// correctly decode the second message +type OTMessage [2][]byte diff --git a/internal/ot/ot_test.go b/internal/ot/ot_test.go new file mode 100644 index 0000000..91ba9b6 --- /dev/null +++ b/internal/ot/ot_test.go @@ -0,0 +1,106 @@ +package ot + +import ( + "bytes" + "crypto/rand" + "fmt" + "net" + "testing" + "time" + + "github.com/optable/match/internal/util" +) + +const ( + baseCount = 512 + otExtensionCount = 1400 +) + +func genMsg(n, t int) []OTMessage { + data := make([]OTMessage, n) + for i := 0; i < n; i++ { + for j := range data[i] { + data[i][j] = make([]byte, otExtensionCount) + rand.Read(data[i][j]) + } + } + + return data +} + +func genChoiceBits(n int) []uint8 { + choices := make([]uint8, n) + rand.Read(choices) + return choices +} + +func TestNaorPinkas(t *testing.T) { + messages := genMsg(baseCount, 2) + msgLen := make([]int, len(messages)) + choices := genChoiceBits(baseCount / 8) + + for i, m := range messages { + msgLen[i] = len(m[0]) + } + + msgBus := make(chan []byte) + errs := make(chan error, 5) + + // start timer + start := time.Now() + + // create client, server connections + senderConn, receiverConn := net.Pipe() + + // sender + go func() { + senderOT := NewNaorPinkas(msgLen) + if err := senderOT.Send(messages, senderConn); err != nil { + errs <- fmt.Errorf("Send encountered error: %s", err) + close(msgBus) + } + }() + + // receiver + go func() { + defer close(msgBus) + receiverOT := NewNaorPinkas(msgLen) + + msg := make([][]byte, baseCount) + if err := receiverOT.Receive(choices, msg, receiverConn); err != nil { + errs <- err + } + + for _, m := range msg { + msgBus <- m + } + }() + + // Receive msg + var msg [][]byte + for m := range msgBus { + msg = append(msg, m) + } + + select { + case err := <-errs: + t.Fatal(err) + default: + } + + // stop timer + end := time.Now() + t.Logf("Time taken for NaorPinkas OT of %d OTs is: %v\n", baseCount, end.Sub(start)) + + // verify if the received msgs are correct: + if len(msg) == 0 { + t.Fatal("OT failed, did not receive any messages") + } + + for i, m := range msg { + bit := util.BitExtract(choices, i) + if !bytes.Equal(m, messages[i][bit]) { + t.Fatalf("OT failed got: %s, want %s", m, messages[i][bit]) + } + } +} diff --git a/internal/permutations/kensler.go b/internal/permutations/kensler.go index 7d2a89a..9d14485 100644 --- a/internal/permutations/kensler.go +++ b/internal/permutations/kensler.go @@ -43,7 +43,7 @@ func NewKensler(l int64) (kensler, error) { // since this only works on uint32 size. // // As long as the number of items being matched is not >4b -// its not an issue. +// it is not an issue. func (k kensler) Shuffle(n int64) int64 { var l = k.l var p = k.p diff --git a/internal/permutations/nil.go b/internal/permutations/nil.go index 952b0b9..f703a7e 100644 --- a/internal/permutations/nil.go +++ b/internal/permutations/nil.go @@ -2,10 +2,6 @@ package permutations type null int -func NewNil(n int64) (null, error) { - return 0, nil -} - // Shuffle using the nil method // just return the same value func (k null) Shuffle(n int64) int64 { diff --git a/internal/permutations/types.go b/internal/permutations/types.go index 75bd090..3044067 100644 --- a/internal/permutations/types.go +++ b/internal/permutations/types.go @@ -1,5 +1,7 @@ package permutations +// Permutations is an interface satisfied by anything with a proper +// Shuffle method type Permutations interface { Shuffle(n int64) int64 } diff --git a/internal/util/bits.go b/internal/util/bits.go new file mode 100644 index 0000000..15d478a --- /dev/null +++ b/internal/util/bits.go @@ -0,0 +1,237 @@ +package util + +import ( + "crypto/rand" + "fmt" + "runtime" + "sync" + + "github.com/alecthomas/unsafeslice" +) + +var ErrByteLengthMissMatch = fmt.Errorf("provided bytes do not have the same length for bit operations") + +// Xor casts the first part of the byte slices (length divisible +// by 8) into uint64 and then performs XOR on the slices of uint64. +// The excess elements that could not be cast are XORed conventionally. +// The whole operation is performed in place. Panic if a and dst do +// not have the same length. +// Only tested on x86-64. +func Xor(dst, a []byte) { + if len(dst) != len(a) { + panic(ErrByteLengthMissMatch) + } + + castDst := unsafeslice.Uint64SliceFromByteSlice(dst) + castA := unsafeslice.Uint64SliceFromByteSlice(a) + + for i := range castDst { + castDst[i] ^= castA[i] + } + + // deal with excess bytes which could not be cast to uint64 + // in the conventional manner + for j := 0; j < len(dst)%8; j++ { + dst[len(dst)-j-1] ^= a[len(a)-j-1] + } +} + +// And casts the first part of the byte slices (length divisible +// by 8) into uint64 and then performs AND on the slices of uint64. +// The excess elements that could not be cast are ANDed conventionally. +// The whole operation is performed in place. Panic if a and dst do +// not have the same length. +// Only tested on x86-64. +func And(dst, a []byte) { + if len(dst) != len(a) { + panic(ErrByteLengthMissMatch) + } + + castDst := unsafeslice.Uint64SliceFromByteSlice(dst) + castA := unsafeslice.Uint64SliceFromByteSlice(a) + + for i := range castDst { + castDst[i] &= castA[i] + } + + // deal with excess bytes which could not be cast to uint64 + // in the conventional manner + for j := 0; j < len(dst)%8; j++ { + dst[len(dst)-j-1] &= a[len(a)-j-1] + } +} + +// DoubleXor casts the first part of the byte slices (length divisible +// by 8) into uint64 and then performs XOR on the slices of uint64 +// (first with a and then with b). The excess elements that could not +// be cast are XORed conventionally. The whole operation is performed +// in place. Panic if a, b and dst do not have the same length. +// Only tested on x86-64. +func DoubleXor(dst, a, b []byte) { + if len(dst) != len(a) || len(dst) != len(b) { + panic(ErrByteLengthMissMatch) + } + + castDst := unsafeslice.Uint64SliceFromByteSlice(dst) + castA := unsafeslice.Uint64SliceFromByteSlice(a) + castB := unsafeslice.Uint64SliceFromByteSlice(b) + + for i := range castDst { + castDst[i] ^= castA[i] + castDst[i] ^= castB[i] + } + + // deal with excess bytes which could not be cast to uint64 + // in the conventional manner + for j := 0; j < len(dst)%8; j++ { + dst[len(dst)-j-1] ^= a[len(a)-j-1] + dst[len(dst)-j-1] ^= b[len(b)-j-1] + } +} + +// AndXor casts the first part of the byte slices (length divisible +// by 8) into uint64 and then performs AND on the slices of uint64 +// (with a) and then performs XOR (with b). The excess elements +// that could not be cast are operated on conventionally. The whole +// operation is performed in place. Panic if a, b and dst do not +// have the same length. +// Only tested on x86-64. +func AndXor(dst, a, b []byte) { + if len(dst) != len(a) || len(dst) != len(b) { + panic(ErrByteLengthMissMatch) + } + + castDst := unsafeslice.Uint64SliceFromByteSlice(dst) + castA := unsafeslice.Uint64SliceFromByteSlice(a) + castB := unsafeslice.Uint64SliceFromByteSlice(b) + + for i := range castDst { + castDst[i] &= castA[i] + castDst[i] ^= castB[i] + } + + // deal with excess bytes which could not be cast to uint64 + // in the conventional manner + for j := 0; j < len(dst)%8; j++ { + dst[len(dst)-j-1] &= a[len(a)-j-1] + dst[len(dst)-j-1] ^= b[len(b)-j-1] + } +} + +// ConcurrentBitOp performs an in-place bitwise operation, f, on each +// byte from a with dst if they are both the same length. +func ConcurrentBitOp(f func([]byte, []byte), dst, a []byte) { + nworkers := runtime.GOMAXPROCS(0) + + // no need to split into goroutines + if len(dst) < nworkers*16384 { + f(dst, a) + } else { + + // determine number of blocks to split original matrix + blockSize := len(dst) / nworkers + + // Run a worker pool + var wg sync.WaitGroup + wg.Add(nworkers) + for w := 0; w < nworkers; w++ { + w := w + go func() { + defer wg.Done() + step := blockSize * w + if w == nworkers-1 { // last block + f(dst[step:], a[step:]) + } else { + f(dst[step:step+blockSize], a[step:step+blockSize]) + } + }() + } + + wg.Wait() + } +} + +// ConcurrentDoubleBitOp performs an in-place double bitwise operation, f, +// on each byte from a with dst if they are both the same length +func ConcurrentDoubleBitOp(f func([]byte, []byte, []byte), dst, a, b []byte) { + nworkers := runtime.GOMAXPROCS(0) + + // no need to split into goroutines + if len(dst) < nworkers*16384 { + f(dst, a, b) + } else { + + // determine number of blocks to split original matrix + blockSize := len(dst) / nworkers + + // Run a worker pool + var wg sync.WaitGroup + wg.Add(nworkers) + for w := 0; w < nworkers; w++ { + w := w + go func() { + defer wg.Done() + step := blockSize * w + if w == nworkers-1 { // last block + f(dst[step:], a[step:], b[step:]) + } else { + f(dst[step:step+blockSize], a[step:step+blockSize], b[step:step+blockSize]) + } + }() + } + + wg.Wait() + } +} + +// IsBitSet returns true if bit i is set in a byte slice. +// It extracts bits from the least significant bit (i = 0) to the +// most significant bit (i = 7). +func IsBitSet(b []byte, i int) bool { + return b[i/8]&(1<<(i%8)) > 0 +} + +// BitExtract returns the ith bit in b +func BitExtract(b []byte, i int) byte { + if IsBitSet(b, i) { + return 1 + } + + return 0 +} + +// SampleRandomBitMatrix allocates a 2D byte matrix of dimension row x col, +// and adds extra rows of 0s to have the number of rows be a multiple of 512, +// fills each entry in the byte matrix with pseudorandom byte values from a rand reader. +func SampleRandomBitMatrix(row, col int) ([][]uint8, error) { + // instantiate matrix + matrix := make([][]uint8, row) + for row := range matrix { + matrix[row] = make([]uint8, (col+Pad(col, 512))/8) + } + // fill matrix + for row := range matrix { + if _, err := rand.Read(matrix[row]); err != nil { + return nil, err + } + } + + return matrix, nil +} + +// Pad returns the total padded length such that n is padded to a multiple of +// multiple. +func Pad(n, multiple int) int { + p := n % multiple + if p == 0 { + return n + } + + return n + (multiple - p) +} + +// PadBitMap returns the total padded length such that n is padded to a multiple of +// multiple bytes to fit in a bitmap. +func PadBitMap(n, multiple int) int { + return Pad(n, multiple) / 8 +} diff --git a/internal/util/bits_test.go b/internal/util/bits_test.go new file mode 100644 index 0000000..b92fc8a --- /dev/null +++ b/internal/util/bits_test.go @@ -0,0 +1,438 @@ +package util + +import ( + "math/rand" + "reflect" + "testing" + "testing/quick" +) + +const benchmarkBytes = 1000000 + +func genBytes(size int) []byte { + bytes := make([]byte, size) + if _, err := rand.Read(bytes); err != nil { + panic("error generating random bytes") + } + + return bytes +} + +type bitSets struct { + Scratch []byte + A []byte + B []byte + C []byte +} + +// Generate creates a bitSets struct with three byte slices of +// equal length +func (bitSets) Generate(r *rand.Rand, size int) reflect.Value { + var sets bitSets + sets.Scratch = make([]byte, size) + sets.A = genBytes(size) + sets.B = genBytes(size) + sets.C = genBytes(size) + return reflect.ValueOf(sets) +} + +func TestXor(t *testing.T) { + fast := func(b bitSets) []byte { + copy(b.Scratch, b.A) + Xor(b.Scratch, b.B) + return b.Scratch + } + + naive := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] ^= b.B[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(fast, naive, nil); err != nil { + t.Errorf("fast XOR != naive XOR: %v", err) + } + + commutative := func(b bitSets) bool { + copy(b.Scratch, b.A) + Xor(b.Scratch, b.B) + Xor(b.B, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != b.B[i] { + return false + } + } + return true + } + + if err := quick.Check(commutative, nil); err != nil { + t.Errorf("A ^ B != B ^ A (commutative): %v", err) + } + + associative := func(b bitSets) bool { + copy(b.Scratch, b.B) + Xor(b.Scratch, b.C) + Xor(b.Scratch, b.A) + + // check + Xor(b.A, b.B) + Xor(b.A, b.C) + + for i := range b.Scratch { + if b.Scratch[i] != b.A[i] { + return false + } + } + return true + } + + if err := quick.Check(associative, nil); err != nil { + t.Errorf("A ^ (B ^ C) != (A ^ B) ^ C (associative): %v", err) + } + + identityElement := func(b bitSets) bool { + Xor(b.Scratch, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != b.A[i] { + return false + } + } + return true + } + + if err := quick.Check(identityElement, nil); err != nil { + t.Errorf("A ^ 0 != A (identity): %v", err) + } + + selfInverse := func(b bitSets) bool { + Xor(b.A, b.A) + // check + for i := range b.A { + if b.A[i] != 0 { + return false + } + } + return true + } + + if err := quick.Check(selfInverse, nil); err != nil { + t.Errorf("A ^ A != 0 (self-inverse): %v", err) + } +} + +func TestAnd(t *testing.T) { + fast := func(b bitSets) []byte { + copy(b.Scratch, b.A) + And(b.Scratch, b.B) + return b.Scratch + } + + naive := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] &= b.B[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(fast, naive, nil); err != nil { + t.Errorf("fast AND != naive AND: %v", err) + } + + annulment := func(b bitSets) bool { + And(b.Scratch, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != 0 { + return false + } + } + return true + } + + if err := quick.Check(annulment, nil); err != nil { + t.Errorf("A & 0 != 0 (annulment): %v", err) + } + + commutative := func(b bitSets) bool { + copy(b.Scratch, b.A) + And(b.Scratch, b.B) + And(b.B, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != b.B[i] { + return false + } + } + return true + } + + if err := quick.Check(commutative, nil); err != nil { + t.Errorf("A & B != B & A (commutative): %v", err) + } + + associative := func(b bitSets) bool { + copy(b.Scratch, b.B) + And(b.Scratch, b.C) + And(b.Scratch, b.A) + + // check + And(b.A, b.B) + And(b.A, b.C) + + for i := range b.Scratch { + if b.Scratch[i] != b.A[i] { + return false + } + } + return true + } + + if err := quick.Check(associative, nil); err != nil { + t.Errorf("A & (B & C) != (A & B) & C (associative): %v", err) + } + + identityElement := func(b bitSets) bool { + for i := range b.Scratch { + b.Scratch[i] = 255 + } + And(b.Scratch, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != b.A[i] { + return false + } + } + return true + } + + if err := quick.Check(identityElement, nil); err != nil { + t.Errorf("A & 1 != A (identity): %v", err) + } + + idempotent := func(b bitSets) bool { + copy(b.Scratch, b.A) + And(b.Scratch, b.A) + // check + for i := range b.Scratch { + if b.Scratch[i] != b.A[i] { + return false + } + } + return true + } + + if err := quick.Check(idempotent, nil); err != nil { + t.Errorf("A & A != A (idempotent): %v", err) + } +} + +func TestDoubleXor(t *testing.T) { + fast := func(b bitSets) []byte { + copy(b.Scratch, b.A) + DoubleXor(b.Scratch, b.B, b.C) + return b.Scratch + } + + naive := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] &= b.B[i] + b.Scratch[i] &= b.C[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(fast, naive, nil); err != nil { + t.Errorf("fast double XOR != naive double XOR: %v", err) + } +} + +func TestAndXor(t *testing.T) { + fast := func(b bitSets) []byte { + copy(b.Scratch, b.A) + AndXor(b.Scratch, b.B, b.C) + return b.Scratch + } + + naive := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] &= b.B[i] + b.Scratch[i] ^= b.C[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(fast, naive, nil); err != nil { + t.Errorf("fast AND followed by XOR != naive AND followed by XOR: %v", err) + } +} + +func TestConcurrentBitOp(t *testing.T) { + concXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + ConcurrentBitOp(Xor, b.Scratch, b.B) + return b.Scratch + } + + naiveXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] ^= b.B[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(concXor, naiveXor, nil); err != nil { + t.Errorf("concurrent fast XOR != naive XOR: %v", err) + } + + concAnd := func(b bitSets) []byte { + copy(b.Scratch, b.A) + ConcurrentBitOp(And, b.Scratch, b.B) + return b.Scratch + } + + naiveAnd := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] &= b.B[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(concAnd, naiveAnd, nil); err != nil { + t.Errorf("concurrent fast AND != naive AND: %v", err) + } +} + +func TestConcurrentDoubleBitOp(t *testing.T) { + concDoubleXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + ConcurrentDoubleBitOp(DoubleXor, b.Scratch, b.B, b.C) + return b.Scratch + } + + naiveDoubleXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] ^= b.B[i] + b.Scratch[i] ^= b.C[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(concDoubleXor, naiveDoubleXor, nil); err != nil { + t.Errorf("concurrent fast double XOR != naive double XOR: %v", err) + } + + concAndXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + ConcurrentDoubleBitOp(AndXor, b.Scratch, b.B, b.C) + return b.Scratch + } + + naiveAndXor := func(b bitSets) []byte { + copy(b.Scratch, b.A) + for i := range b.Scratch { + b.Scratch[i] &= b.B[i] + b.Scratch[i] ^= b.C[i] + } + return b.Scratch + } + + if err := quick.CheckEqual(concAndXor, naiveAndXor, nil); err != nil { + t.Errorf("concurrent fast AND followed by XOR != naive AND followed by XOR: %v", err) + } +} + +func TestTestBitSetInByte(t *testing.T) { + b := []byte{1} + + for i := 0; i < 8; i++ { + if i == 0 { + if !IsBitSet(b, i) { + t.Fatal("bit extraction failed") + } + } else { + if IsBitSet(b, i) { + t.Fatal("bit extraction failed") + } + } + } + + b = []byte{161} + for i := 0; i < 8; i++ { + if i == 0 || i == 7 || i == 5 { + if !IsBitSet(b, i) { + t.Fatal("bit extraction failed") + } + } else { + if IsBitSet(b, i) { + t.Fatal("bit extraction failed") + } + } + } +} + +func BenchmarkXor(b *testing.B) { + src := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + Xor(dst, src) + } +} + +func BenchmarkAnd(b *testing.B) { + src := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + And(dst, src) + } +} + +func BenchmarkDoubleXor(b *testing.B) { + src := genBytes(benchmarkBytes) + src2 := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + DoubleXor(dst, src, src2) + } +} + +func BenchmarkAndXor(b *testing.B) { + src := genBytes(benchmarkBytes) + src2 := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + AndXor(dst, src, src2) + } +} + +func BenchmarkConcurrentBitOp(b *testing.B) { + src := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ConcurrentBitOp(Xor, dst, src) + } +} + +func BenchmarkConcurrentDoubleBitOp(b *testing.B) { + src := genBytes(benchmarkBytes) + src2 := genBytes(benchmarkBytes) + dst := genBytes(benchmarkBytes) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ConcurrentDoubleBitOp(AndXor, dst, src, src2) + } +} diff --git a/internal/util/bitvect.go b/internal/util/bitvect.go new file mode 100644 index 0000000..b915a0b --- /dev/null +++ b/internal/util/bitvect.go @@ -0,0 +1,572 @@ +package util + +import ( + "runtime" + "sync" + + "github.com/alecthomas/unsafeslice" +) + +// A BitVect is a matrix of 512 by 512 bits encoded into a contiguous slice of +// uint64 elements. +type BitVect struct { + set [512 * 8]uint64 +} + +// unravelTall populates a BitVect from a 2D matrix of bytes. The matrix +// must have 64 columns and a multiple of 512 rows. idx is the block target. +// Only tested on x86-64. +func (b *BitVect) unravelTall(matrix [][]byte, idx int) { + for i := 0; i < 512; i++ { + copy(b.set[(i)*8:(i+1)*8], unsafeslice.Uint64SliceFromByteSlice(matrix[(512*idx)+i])) + } +} + +// unravelWide populates a BitVect from a 2D matrix of bytes. The matrix +// must have a multiple of 64 columns and 512 rows. idx is the block target. +// Only tested on x86-64. +func (b *BitVect) unravelWide(matrix [][]byte, idx int) { + for i := 0; i < 512; i++ { + copy(b.set[i*8:(i+1)*8], unsafeslice.Uint64SliceFromByteSlice(matrix[i][idx*64:(64*idx)+64])) + } +} + +// ravelToTall reconstructs a subsection of a tall (mx64) matrix from a BitVect. +// Only tested on x86-64. +func (b *BitVect) ravelToTall(matrix [][]byte, idx int) { + for i := 0; i < 512; i++ { + copy(matrix[(idx*512)+i][:], unsafeslice.ByteSliceFromUint64Slice(b.set[i*8:(i+1)*8])) + } +} + +// ravelToWide reconstructs a subsection of a wide (512xn) matrix from a BitVect. +// Only tested on x86-64. +func (b *BitVect) ravelToWide(matrix [][]byte, idx int) { + for i := 0; i < 512; i++ { + copy(matrix[i][idx*64:(idx+1)*64], unsafeslice.ByteSliceFromUint64Slice(b.set[(i*8):(i+1)*8])) + } +} + +// ConcurrentTransposeTall tranposes a tall (64 column) matrix. If the input +// matrix does not have a multiple of 512 rows (tall), panic. First it +// determines how many 512x512 bit blocks are necessary to contain the matrix. +// The blocks are divided among the number of workers. If there are fewer blocks +// than workers, this function operates as though it were single-threaded. For +// those small sets, performance could be improved by limiting the number of +// workers to the number of blocks but this incurs a performance penalty and it +// is much more likely that there will be more blocks than workers/cpu cores. +// Each goroutine, iterates over the blocks for which it is responsible. For +// each block it generates a BitVect from the matrix at the appropriate index, +// performs a cache-oblivious, in-place, contiguous transpose on the BitVect, +// and finally writes the result to a shared final output matrix. The last +// worker is responsible for any excess blocks which were not evenly divisible +// into the number of workers. +func ConcurrentTransposeTall(matrix [][]byte) [][]byte { + if len(matrix)%512 != 0 { + panic("rows of input matrix not a multiple of 512") + } + + nworkers := runtime.GOMAXPROCS(0) + + // number of blocks to split original matrix + nblks := len(matrix) / 512 + + // how many blocks each worker is responsible for + workerResp := nblks / nworkers + + // build output matrix + trans := make([][]byte, 512) + for r := range trans { + trans[r] = make([]byte, len(matrix)/8) + } + + // Run a worker pool + var wg sync.WaitGroup + wg.Add(nworkers) + for w := 0; w < nworkers; w++ { + w := w + go func() { + defer wg.Done() + step := workerResp * w + var b BitVect + if w == nworkers-1 { // last worker has extra work + for i := step; i < nblks; i++ { + b.unravelTall(matrix, i) + b.transpose() + b.ravelToWide(trans, i) + } + } else { + for i := step; i < step+workerResp; i++ { + b.unravelTall(matrix, i) + b.transpose() + b.ravelToWide(trans, i) + } + } + }() + } + + wg.Wait() + + return trans +} + +// ConcurrentTransposeWide tranposes a wide (512 row) matrix. If the input +// matrix does not have a multiple of 64 columns (wide), panic. First it +// determines how many 512x512 bit blocks are necessary to contain the matrix. +// The blocks are divided among the number of workers. If there are fewer blocks +// than workers, this function operates as though it were single-threaded. For +// those small sets, performance could be improved by limiting the number of +// workers to the number of blocks but this incurs a performance penalty and it +// is much more likely that there will be more blocks than workers/cpu cores. +// Each goroutine iterates over the blocks for which it is responsible. For +// each block it generates a BitVect from the matrix at the appropriate index, +// performs a cache-oblivious, in-place, contiguous transpose on the BitVect, +// and finally writes the result to a shared final output matrix. The last +// worker is responsible for any excess blocks which were not evenly divisible +// into the number of workers. +func ConcurrentTransposeWide(matrix [][]byte) [][]byte { + if len(matrix[0])%64 != 0 { + panic("columns of input matrix not a multiple of 64") + } + + nworkers := runtime.GOMAXPROCS(0) + + // determine number of blocks to split original matrix + nblks := len(matrix[0]) / 64 + + // how many blocks each worker is responsible for + workerResp := nblks / nworkers + + // build output matrix + trans := make([][]byte, len(matrix[0])*8) + for r := range trans { + trans[r] = make([]byte, 64) + } + + // Run a worker pool + var wg sync.WaitGroup + wg.Add(nworkers) + for w := 0; w < nworkers; w++ { + w := w + go func() { + defer wg.Done() + step := workerResp * w + var b BitVect + if w == nworkers-1 { // last worker has extra work + for i := step; i < nblks; i++ { + b.unravelWide(matrix, i) + b.transpose() + b.ravelToTall(trans, i) + } + } else { + for i := step; i < step+workerResp; i++ { + b.unravelWide(matrix, i) + b.transpose() + b.ravelToTall(trans, i) + } + } + }() + } + + wg.Wait() + + return trans +} + +// transpose performs a cache-oblivious, in-place, contiguous transpose. +// Since a BitVect represents a 512 by 512 square bit matrix, transposition will +// be performed blockwise starting with blocks of 256 x 4, swapped about the +// principle diagonal. Then block size will decrease by half until it is only +// 64 x 1. The remaining transposition steps are performed using bit masks and +// shifts. Operations are performed on blocks of bits of size 32, 16, 8, 4, 2, +// and 1. Since the input is square, the transposition can be performed in place. +func (b *BitVect) transpose() { + tmp := make([]uint64, 4) + // Transpose 4 x 256 blocks + var jmp int + for i := 0; i < 256; i++ { + jmp = i * 8 + copy(tmp, b.set[jmp+4:jmp+8]) + copy(b.set[jmp+4:jmp+8], b.set[(256*8)+jmp:(256*8)+jmp+4]) + copy(b.set[(256*8)+jmp:(256*8)+jmp+4], tmp) + } + + // Transpose 2 x 128 blocks + for j := 0; j < 128; j++ { + jmp = j * 8 + copy(tmp, b.set[jmp+2:jmp+4]) + copy(b.set[jmp+2:jmp+4], b.set[(128*8)+jmp:(128*8)+jmp+2]) + copy(b.set[(128*8)+jmp:(128*8)+jmp+2], tmp[:2]) + + copy(tmp, b.set[jmp+6:jmp+8]) + copy(b.set[jmp+6:jmp+8], b.set[(128*8)+jmp+4:(128*8)+jmp+6]) + copy(b.set[(128*8)+jmp+4:(128*8)+jmp+6], tmp[:2]) + + copy(tmp, b.set[(256*8)+jmp+2:(256*8)+jmp+4]) + copy(b.set[(256*8)+jmp+2:(256*8)+jmp+4], b.set[(384*8)+jmp:(384*8)+jmp+2]) + copy(b.set[(384*8)+jmp:(384*8)+jmp+2], tmp[:2]) + + copy(tmp, b.set[(256*8)+jmp+6:(256*8)+jmp+8]) + copy(b.set[(256*8)+jmp+6:(256*8)+jmp+8], b.set[(384*8)+jmp+4:(384*8)+jmp+6]) + copy(b.set[(384*8)+jmp+4:(384*8)+jmp+6], tmp[:2]) + } + + // Transpose 1 x 64 blocks + for k := 0; k < 64; k++ { + jmp = k * 8 + copy(tmp, b.set[jmp+1:jmp+2]) + copy(b.set[jmp+1:jmp+2], b.set[(64*8)+jmp:(64*8)+jmp+1]) + copy(b.set[(64*8)+jmp:(64*8)+jmp+1], tmp[:1]) + + copy(tmp, b.set[jmp+3:jmp+4]) + copy(b.set[jmp+3:jmp+4], b.set[(64*8)+jmp+2:(64*8)+jmp+3]) + copy(b.set[(64*8)+jmp+2:(64*8)+jmp+3], tmp[:1]) + + copy(tmp, b.set[jmp+5:jmp+6]) + copy(b.set[jmp+5:jmp+6], b.set[(64*8)+jmp+4:(64*8)+jmp+5]) + copy(b.set[(64*8)+jmp+4:(64*8)+jmp+5], tmp[:1]) + + copy(tmp, b.set[jmp+7:jmp+8]) + copy(b.set[jmp+7:jmp+8], b.set[(64*8)+jmp+6:(64*8)+jmp+7]) + copy(b.set[(64*8)+jmp+6:(64*8)+jmp+7], tmp[:1]) + + copy(tmp, b.set[(128*8)+jmp+1:(128*8)+jmp+2]) + copy(b.set[(128*8)+jmp+1:(128*8)+jmp+2], b.set[(192*8)+jmp:(192*8)+jmp+1]) + copy(b.set[(192*8)+jmp:(192*8)+jmp+1], tmp[:1]) + + copy(tmp, b.set[(128*8)+jmp+3:(128*8)+jmp+4]) + copy(b.set[(128*8)+jmp+3:(128*8)+jmp+4], b.set[(192*8)+jmp+2:(192*8)+jmp+3]) + copy(b.set[(192*8)+jmp+2:(192*8)+jmp+3], tmp[:1]) + + copy(tmp, b.set[(128*8)+jmp+5:(128*8)+jmp+6]) + copy(b.set[(128*8)+jmp+5:(128*8)+jmp+6], b.set[(192*8)+jmp+4:(192*8)+jmp+5]) + copy(b.set[(192*8)+jmp+4:(192*8)+jmp+5], tmp[:1]) + + copy(tmp, b.set[(128*8)+jmp+7:(128*8)+jmp+8]) + copy(b.set[(128*8)+jmp+7:(128*8)+jmp+8], b.set[(192*8)+jmp+6:(192*8)+jmp+7]) + copy(b.set[(192*8)+jmp+6:(192*8)+jmp+7], tmp[:1]) + + copy(tmp, b.set[(256*8)+jmp+1:(256*8)+jmp+2]) + copy(b.set[(256*8)+jmp+1:(256*8)+jmp+2], b.set[(320*8)+jmp:(320*8)+jmp+1]) + copy(b.set[(320*8)+jmp:(320*8)+jmp+1], tmp[:1]) + + copy(tmp, b.set[(256*8)+jmp+3:(256*8)+jmp+4]) + copy(b.set[(256*8)+jmp+3:(256*8)+jmp+4], b.set[(320*8)+jmp+2:(320*8)+jmp+3]) + copy(b.set[(320*8)+jmp+2:(320*8)+jmp+3], tmp[:1]) + + copy(tmp, b.set[(256*8)+jmp+5:(256*8)+jmp+6]) + copy(b.set[(256*8)+jmp+5:(256*8)+jmp+6], b.set[(320*8)+jmp+4:(320*8)+jmp+5]) + copy(b.set[(320*8)+jmp+4:(320*8)+jmp+5], tmp[:1]) + + copy(tmp, b.set[(256*8)+jmp+7:(256*8)+jmp+8]) + copy(b.set[(256*8)+jmp+7:(256*8)+jmp+8], b.set[(320*8)+jmp+6:(320*8)+jmp+7]) + copy(b.set[(320*8)+jmp+6:(320*8)+jmp+7], tmp[:1]) + + copy(tmp, b.set[(384*8)+jmp+1:(384*8)+jmp+2]) + copy(b.set[(384*8)+jmp+1:(384*8)+jmp+2], b.set[(448*8)+jmp:(448*8)+jmp+1]) + copy(b.set[(448*8)+jmp:(448*8)+jmp+1], tmp[:1]) + + copy(tmp, b.set[(384*8)+jmp+3:(384*8)+jmp+4]) + copy(b.set[(384*8)+jmp+3:(384*8)+jmp+4], b.set[(448*8)+jmp+2:(448*8)+jmp+3]) + copy(b.set[(448*8)+jmp+2:(448*8)+jmp+3], tmp[:1]) + + copy(tmp, b.set[(384*8)+jmp+5:(384*8)+jmp+6]) + copy(b.set[(384*8)+jmp+5:(384*8)+jmp+6], b.set[(448*8)+jmp+4:(448*8)+jmp+5]) + copy(b.set[(448*8)+jmp+4:(448*8)+jmp+5], tmp[:1]) + + copy(tmp, b.set[(384*8)+jmp+7:(384*8)+jmp+8]) + copy(b.set[(384*8)+jmp+7:(384*8)+jmp+8], b.set[(448*8)+jmp+6:(448*8)+jmp+7]) + copy(b.set[(448*8)+jmp+6:(448*8)+jmp+7], tmp[:1]) + + } + + // Bitwise transposition + for blk := 0; blk < 8; blk++ { + for col := 0; col < 8; col++ { + transpose64(b, blk, col) + } + } +} + +// swap swaps two rows of masked binary elements in a 64x64 bit matrix which is +// held as a contiguous uint64 array in a BitVect. +func swap(a, b int, vect *BitVect, mask uint64, width int) { + t := (vect.set[a] ^ (vect.set[b] << width)) & mask + vect.set[a] = vect.set[a] ^ t + vect.set[b] = vect.set[b] ^ (t >> width) +} + +// transpose64 performs a bitwise transpose on a 64x64 bit matrix which +// is held as a contiguous uint64 array in a BitVect. Instead of looping and +// generating the mask with each loop, the unrolled version is fully declared +// which boosts performance at the expense of verbosity. +func transpose64(vect *BitVect, vblock, col int) { + jmp := vblock*(64*8) + col + // 32x32 swap + var mask uint64 = 0xFFFFFFFF00000000 + var width int = 32 + swap(jmp+(8*0), jmp+(8*32), vect, mask, width) // 0 and 32 + swap(jmp+(8*1), jmp+(8*33), vect, mask, width) // 1 and 33 + swap(jmp+(8*2), jmp+(8*34), vect, mask, width) // 2 and 34 + swap(jmp+(8*3), jmp+(8*35), vect, mask, width) // 3 and 35 + swap(jmp+(8*4), jmp+(8*36), vect, mask, width) // 4 and 36 + swap(jmp+(8*5), jmp+(8*37), vect, mask, width) // 5 and 37 + swap(jmp+(8*6), jmp+(8*38), vect, mask, width) // 6 and 38 + swap(jmp+(8*7), jmp+(8*39), vect, mask, width) // 7 and 39 + swap(jmp+(8*8), jmp+(8*40), vect, mask, width) // 8 and 40 + swap(jmp+(8*9), jmp+(8*41), vect, mask, width) // 9 and 41 + swap(jmp+(8*10), jmp+(8*42), vect, mask, width) // 10 and 42 + swap(jmp+(8*11), jmp+(8*43), vect, mask, width) // 11 and 43 + swap(jmp+(8*12), jmp+(8*44), vect, mask, width) // 12 and 44 + swap(jmp+(8*13), jmp+(8*45), vect, mask, width) // 13 and 45 + swap(jmp+(8*14), jmp+(8*46), vect, mask, width) // 14 and 46 + swap(jmp+(8*15), jmp+(8*47), vect, mask, width) // 15 and 47 + swap(jmp+(8*16), jmp+(8*48), vect, mask, width) // 16 and 48 + swap(jmp+(8*17), jmp+(8*49), vect, mask, width) // 17 and 49 + swap(jmp+(8*18), jmp+(8*50), vect, mask, width) // 18 and 50 + swap(jmp+(8*19), jmp+(8*51), vect, mask, width) // 19 and 51 + swap(jmp+(8*20), jmp+(8*52), vect, mask, width) // 20 and 52 + swap(jmp+(8*21), jmp+(8*53), vect, mask, width) // 21 and 53 + swap(jmp+(8*22), jmp+(8*54), vect, mask, width) // 22 and 54 + swap(jmp+(8*23), jmp+(8*55), vect, mask, width) // 23 and 55 + swap(jmp+(8*24), jmp+(8*56), vect, mask, width) // 24 and 56 + swap(jmp+(8*25), jmp+(8*57), vect, mask, width) // 25 and 57 + swap(jmp+(8*26), jmp+(8*58), vect, mask, width) // 26 and 58 + swap(jmp+(8*27), jmp+(8*59), vect, mask, width) // 27 and 29 + swap(jmp+(8*28), jmp+(8*60), vect, mask, width) // 28 and 60 + swap(jmp+(8*29), jmp+(8*61), vect, mask, width) // 29 and 61 + swap(jmp+(8*30), jmp+(8*62), vect, mask, width) // 30 and 62 + swap(jmp+(8*31), jmp+(8*63), vect, mask, width) // 31 and 63 + // 16x16 swap + mask = 0xFFFF0000FFFF0000 + width = 16 + swap(jmp+(8*0), jmp+(8*16), vect, mask, width) // 0 and 16 + swap(jmp+(8*1), jmp+(8*17), vect, mask, width) // 1 and 17 + swap(jmp+(8*2), jmp+(8*18), vect, mask, width) // 2 and 18 + swap(jmp+(8*3), jmp+(8*19), vect, mask, width) // 3 and 19 + swap(jmp+(8*4), jmp+(8*20), vect, mask, width) // 4 and 20 + swap(jmp+(8*5), jmp+(8*21), vect, mask, width) // 5 and 21 + swap(jmp+(8*6), jmp+(8*22), vect, mask, width) // 6 and 22 + swap(jmp+(8*7), jmp+(8*23), vect, mask, width) // 7 and 23 + swap(jmp+(8*8), jmp+(8*24), vect, mask, width) // 8 and 24 + swap(jmp+(8*9), jmp+(8*25), vect, mask, width) // 9 and 25 + swap(jmp+(8*10), jmp+(8*26), vect, mask, width) // 10 and 26 + swap(jmp+(8*11), jmp+(8*27), vect, mask, width) // 11 and 27 + swap(jmp+(8*12), jmp+(8*28), vect, mask, width) // 12 and 28 + swap(jmp+(8*13), jmp+(8*29), vect, mask, width) // 13 and 29 + swap(jmp+(8*14), jmp+(8*30), vect, mask, width) // 14 and 30 + swap(jmp+(8*15), jmp+(8*31), vect, mask, width) // 15 and 31 + + swap(jmp+(8*32), jmp+(8*48), vect, mask, width) // 32 and 48 + swap(jmp+(8*33), jmp+(8*49), vect, mask, width) // 33 and 49 + swap(jmp+(8*34), jmp+(8*50), vect, mask, width) // 34 and 50 + swap(jmp+(8*35), jmp+(8*51), vect, mask, width) // 35 and 51 + swap(jmp+(8*36), jmp+(8*52), vect, mask, width) // 36 and 52 + swap(jmp+(8*37), jmp+(8*53), vect, mask, width) // 37 and 53 + swap(jmp+(8*38), jmp+(8*54), vect, mask, width) // 38 and 54 + swap(jmp+(8*39), jmp+(8*55), vect, mask, width) // 39 and 55 + swap(jmp+(8*40), jmp+(8*56), vect, mask, width) // 40 and 56 + swap(jmp+(8*41), jmp+(8*57), vect, mask, width) // 41 and 57 + swap(jmp+(8*42), jmp+(8*58), vect, mask, width) // 42 and 58 + swap(jmp+(8*43), jmp+(8*59), vect, mask, width) // 43 and 59 + swap(jmp+(8*44), jmp+(8*60), vect, mask, width) // 44 and 60 + swap(jmp+(8*45), jmp+(8*61), vect, mask, width) // 45 and 61 + swap(jmp+(8*46), jmp+(8*62), vect, mask, width) // 46 and 62 + swap(jmp+(8*47), jmp+(8*63), vect, mask, width) // 47 and 63 + // 8x8 swap + mask = 0xFF00FF00FF00FF00 + width = 8 + swap(jmp+(8*0), jmp+(8*8), vect, mask, width) // 0 and 8 + swap(jmp+(8*1), jmp+(8*9), vect, mask, width) // 1 and 9 + swap(jmp+(8*2), jmp+(8*10), vect, mask, width) // 2 and 10 + swap(jmp+(8*3), jmp+(8*11), vect, mask, width) // 3 and 11 + swap(jmp+(8*4), jmp+(8*12), vect, mask, width) // 4 and 12 + swap(jmp+(8*5), jmp+(8*13), vect, mask, width) // 5 and 13 + swap(jmp+(8*6), jmp+(8*14), vect, mask, width) // 6 and 14 + swap(jmp+(8*7), jmp+(8*15), vect, mask, width) // 7 and 15 + + swap(jmp+(8*16), jmp+(8*24), vect, mask, width) // 16 and 24 + swap(jmp+(8*17), jmp+(8*25), vect, mask, width) // 17 and 25 + swap(jmp+(8*18), jmp+(8*26), vect, mask, width) // 18 and 26 + swap(jmp+(8*19), jmp+(8*27), vect, mask, width) // 19 and 27 + swap(jmp+(8*20), jmp+(8*28), vect, mask, width) // 20 and 28 + swap(jmp+(8*21), jmp+(8*29), vect, mask, width) // 21 and 29 + swap(jmp+(8*22), jmp+(8*30), vect, mask, width) // 22 and 30 + swap(jmp+(8*23), jmp+(8*31), vect, mask, width) // 23 and 31 + + swap(jmp+(8*32), jmp+(8*40), vect, mask, width) // 32 and 40 + swap(jmp+(8*33), jmp+(8*41), vect, mask, width) // 33 and 41 + swap(jmp+(8*34), jmp+(8*42), vect, mask, width) // 34 and 42 + swap(jmp+(8*35), jmp+(8*43), vect, mask, width) // 35 and 43 + swap(jmp+(8*36), jmp+(8*44), vect, mask, width) // 36 and 44 + swap(jmp+(8*37), jmp+(8*45), vect, mask, width) // 37 and 45 + swap(jmp+(8*38), jmp+(8*46), vect, mask, width) // 38 and 46 + swap(jmp+(8*39), jmp+(8*47), vect, mask, width) // 39 and 47 + + swap(jmp+(8*48), jmp+(8*56), vect, mask, width) // 48 and 56 + swap(jmp+(8*49), jmp+(8*57), vect, mask, width) // 49 and 57 + swap(jmp+(8*50), jmp+(8*58), vect, mask, width) // 50 and 58 + swap(jmp+(8*51), jmp+(8*59), vect, mask, width) // 51 and 59 + swap(jmp+(8*52), jmp+(8*60), vect, mask, width) // 52 and 60 + swap(jmp+(8*53), jmp+(8*61), vect, mask, width) // 53 and 61 + swap(jmp+(8*54), jmp+(8*62), vect, mask, width) // 54 and 62 + swap(jmp+(8*55), jmp+(8*63), vect, mask, width) // 55 and 63 + // 4x4 swap + mask = 0xF0F0F0F0F0F0F0F0 + width = 4 + swap(jmp+(8*0), jmp+(8*4), vect, mask, width) // 0 and 4 + swap(jmp+(8*1), jmp+(8*5), vect, mask, width) // 1 and 5 + swap(jmp+(8*2), jmp+(8*6), vect, mask, width) // 2 and 6 + swap(jmp+(8*3), jmp+(8*7), vect, mask, width) // 3 and 7 + + swap(jmp+(8*8), jmp+(8*12), vect, mask, width) // 8 and 12 + swap(jmp+(8*9), jmp+(8*13), vect, mask, width) // 9 and 13 + swap(jmp+(8*10), jmp+(8*14), vect, mask, width) // 10 and 14 + swap(jmp+(8*11), jmp+(8*15), vect, mask, width) // 11 and 15 + + swap(jmp+(8*16), jmp+(8*20), vect, mask, width) // 16 and 20 + swap(jmp+(8*17), jmp+(8*21), vect, mask, width) // 17 and 21 + swap(jmp+(8*18), jmp+(8*22), vect, mask, width) // 18 and 22 + swap(jmp+(8*19), jmp+(8*23), vect, mask, width) // 19 and 23 + + swap(jmp+(8*24), jmp+(8*28), vect, mask, width) // 24 and 28 + swap(jmp+(8*25), jmp+(8*29), vect, mask, width) // 25 and 29 + swap(jmp+(8*26), jmp+(8*30), vect, mask, width) // 26 and 30 + swap(jmp+(8*27), jmp+(8*31), vect, mask, width) // 27 and 31 + + swap(jmp+(8*32), jmp+(8*36), vect, mask, width) // 32 and 36 + swap(jmp+(8*33), jmp+(8*37), vect, mask, width) // 33 and 37 + swap(jmp+(8*34), jmp+(8*38), vect, mask, width) // 34 and 38 + swap(jmp+(8*35), jmp+(8*39), vect, mask, width) // 35 and 39 + + swap(jmp+(8*40), jmp+(8*44), vect, mask, width) // 40 and 44 + swap(jmp+(8*41), jmp+(8*45), vect, mask, width) // 41 and 45 + swap(jmp+(8*42), jmp+(8*46), vect, mask, width) // 42 and 46 + swap(jmp+(8*43), jmp+(8*47), vect, mask, width) // 43 and 47 + + swap(jmp+(8*48), jmp+(8*52), vect, mask, width) // 48 and 52 + swap(jmp+(8*49), jmp+(8*53), vect, mask, width) // 49 and 53 + swap(jmp+(8*50), jmp+(8*54), vect, mask, width) // 50 and 54 + swap(jmp+(8*51), jmp+(8*55), vect, mask, width) // 51 and 55 + + swap(jmp+(8*56), jmp+(8*60), vect, mask, width) // 56 and 60 + swap(jmp+(8*57), jmp+(8*61), vect, mask, width) // 57 and 61 + swap(jmp+(8*58), jmp+(8*62), vect, mask, width) // 58 and 62 + swap(jmp+(8*59), jmp+(8*63), vect, mask, width) // 59 and 63 + // 2x2 swap + mask = 0xcccccccccccccccc + width = 2 + swap(jmp+(8*0), jmp+(8*2), vect, mask, width) // 0 and 2 + swap(jmp+(8*1), jmp+(8*3), vect, mask, width) // 1 and 3 + + swap(jmp+(8*4), jmp+(8*6), vect, mask, width) // 4 and 6 + swap(jmp+(8*5), jmp+(8*7), vect, mask, width) // 5 and 7 + + swap(jmp+(8*8), jmp+(8*10), vect, mask, width) // 8 and 10 + swap(jmp+(8*9), jmp+(8*11), vect, mask, width) // 9 and 11 + + swap(jmp+(8*12), jmp+(8*14), vect, mask, width) // 12 and 14 + swap(jmp+(8*13), jmp+(8*15), vect, mask, width) // 13 and 15 + + swap(jmp+(8*16), jmp+(8*18), vect, mask, width) // 16 and 18 + swap(jmp+(8*17), jmp+(8*19), vect, mask, width) // 17 and 19 + + swap(jmp+(8*20), jmp+(8*22), vect, mask, width) // 20 and 22 + swap(jmp+(8*21), jmp+(8*23), vect, mask, width) // 21 and 23 + + swap(jmp+(8*24), jmp+(8*26), vect, mask, width) // 24 and 26 + swap(jmp+(8*25), jmp+(8*27), vect, mask, width) // 25 and 27 + + swap(jmp+(8*28), jmp+(8*30), vect, mask, width) // 28 and 30 + swap(jmp+(8*29), jmp+(8*31), vect, mask, width) // 29 and 31 + + swap(jmp+(8*32), jmp+(8*34), vect, mask, width) // 32 and 34 + swap(jmp+(8*33), jmp+(8*35), vect, mask, width) // 33 and 35 + + swap(jmp+(8*36), jmp+(8*38), vect, mask, width) // 36 and 38 + swap(jmp+(8*37), jmp+(8*39), vect, mask, width) // 37 and 39 + + swap(jmp+(8*40), jmp+(8*42), vect, mask, width) // 40 and 42 + swap(jmp+(8*41), jmp+(8*43), vect, mask, width) // 41 and 43 + + swap(jmp+(8*44), jmp+(8*46), vect, mask, width) // 44 and 46 + swap(jmp+(8*45), jmp+(8*47), vect, mask, width) // 45 and 47 + + swap(jmp+(8*48), jmp+(8*50), vect, mask, width) // 48 and 50 + swap(jmp+(8*49), jmp+(8*51), vect, mask, width) // 49 and 51 + + swap(jmp+(8*52), jmp+(8*54), vect, mask, width) // 52 and 54 + swap(jmp+(8*53), jmp+(8*55), vect, mask, width) // 53 and 55 + + swap(jmp+(8*56), jmp+(8*58), vect, mask, width) // 56 and 58 + swap(jmp+(8*57), jmp+(8*59), vect, mask, width) // 57 and 59 + + swap(jmp+(8*60), jmp+(8*62), vect, mask, width) // 60 and 62 + swap(jmp+(8*61), jmp+(8*63), vect, mask, width) // 61 and 63 + // 1x1 swap + mask = 0xaaaaaaaaaaaaaaaa + width = 1 + swap(jmp+(8*0), jmp+(8*1), vect, mask, width) // 0 and 1 + + swap(jmp+(8*2), jmp+(8*3), vect, mask, width) // 2 and 3 + + swap(jmp+(8*4), jmp+(8*5), vect, mask, width) // 4 and 5 + + swap(jmp+(8*6), jmp+(8*7), vect, mask, width) // 6 and 7 + + swap(jmp+(8*8), jmp+(8*9), vect, mask, width) // 8 and 9 + + swap(jmp+(8*10), jmp+(8*11), vect, mask, width) // 10 and 11 + + swap(jmp+(8*12), jmp+(8*13), vect, mask, width) // 12 and 13 + + swap(jmp+(8*14), jmp+(8*15), vect, mask, width) // 14 and 15 + + swap(jmp+(8*16), jmp+(8*17), vect, mask, width) // 16 and 17 + + swap(jmp+(8*18), jmp+(8*19), vect, mask, width) // 18 and 19 + + swap(jmp+(8*20), jmp+(8*21), vect, mask, width) // 20 and 21 + + swap(jmp+(8*22), jmp+(8*23), vect, mask, width) // 22 and 23 + + swap(jmp+(8*24), jmp+(8*25), vect, mask, width) // 24 and 25 + + swap(jmp+(8*26), jmp+(8*27), vect, mask, width) // 26 and 27 + + swap(jmp+(8*28), jmp+(8*29), vect, mask, width) // 28 and 29 + + swap(jmp+(8*30), jmp+(8*31), vect, mask, width) // 30 and 31 + + swap(jmp+(8*32), jmp+(8*33), vect, mask, width) // 32 and 33 + + swap(jmp+(8*34), jmp+(8*35), vect, mask, width) // 34 and 35 + + swap(jmp+(8*36), jmp+(8*37), vect, mask, width) // 36 and 37 + + swap(jmp+(8*38), jmp+(8*39), vect, mask, width) // 38 and 39 + + swap(jmp+(8*40), jmp+(8*41), vect, mask, width) // 40 and 41 + + swap(jmp+(8*42), jmp+(8*43), vect, mask, width) // 42 and 43 + + swap(jmp+(8*44), jmp+(8*45), vect, mask, width) // 44 and 45 + + swap(jmp+(8*46), jmp+(8*47), vect, mask, width) // 46 and 47 + + swap(jmp+(8*48), jmp+(8*49), vect, mask, width) // 48 and 49 + + swap(jmp+(8*50), jmp+(8*51), vect, mask, width) // 50 and 51 + + swap(jmp+(8*52), jmp+(8*53), vect, mask, width) // 52 and 53 + + swap(jmp+(8*54), jmp+(8*55), vect, mask, width) // 54 and 55 + + swap(jmp+(8*56), jmp+(8*57), vect, mask, width) // 56 and 57 + + swap(jmp+(8*58), jmp+(8*59), vect, mask, width) // 58 and 59 + + swap(jmp+(8*60), jmp+(8*61), vect, mask, width) // 60 and 61 + + swap(jmp+(8*62), jmp+(8*63), vect, mask, width) // 62 and 63 +} diff --git a/internal/util/bitvect_test.go b/internal/util/bitvect_test.go new file mode 100644 index 0000000..b048f99 --- /dev/null +++ b/internal/util/bitvect_test.go @@ -0,0 +1,207 @@ +package util + +import ( + "crypto/rand" + "testing" +) + +var ( + nmsg = 1 << 20 +) + +// genZebraBlock creates a 512x512 bit block where every bit position +// alternates between 0 and 1. When transposed, this block should +// consists of rows of all 0s alternating with rows of all 1s. +func genZebraBlock() BitVect { + zebraBlock2D := make([][]byte, 512) + var b BitVect + for row := range zebraBlock2D { + zebraBlock2D[row] = make([]byte, 64) + for c := 0; c < 64; c++ { + zebraBlock2D[row][c] = 0b01010101 + } + } + b.unravelTall(zebraBlock2D, 0) + return b +} + +// sampleRandomTall fills an m by 64 byte matrix (512 bits wide) with +// pseudorandom bytes. +func sampleRandomTall(m int) [][]byte { + // instantiate matrix + matrix := make([][]byte, m) + + for row := range matrix { + matrix[row] = make([]byte, 64) + rand.Read(matrix[row]) + } + + return matrix +} + +// sampleRandomWide fills a 512 by n byte matrix (512 bits tall) with +// pseudorandom bytes. +func sampleRandomWide(n int) [][]byte { + // instantiate matrix + matrix := make([][]byte, 512) + + for row := range matrix { + matrix[row] = make([]byte, n) + rand.Read(matrix[row]) + } + + return matrix +} + +func TestUnReRavelingTall(t *testing.T) { + trange := []int{512, 512 * 2, 512 * 3, 512 * 4} + var b BitVect + for _, a := range trange { + matrix := sampleRandomTall(a) + // determine number of blocks to split original matrix (m x 64) + nblks := len(matrix) / 512 + + rerav := make([][]byte, len(matrix)) + for r := range rerav { + rerav[r] = make([]byte, len(matrix[0])) + } + + for id := 0; id < nblks; id++ { + b.unravelTall(matrix, id) + b.ravelToTall(rerav, id) + } + + // check + for k := range rerav { + for l := range rerav[k] { + if rerav[k][l] != matrix[k][l] { + t.Fatal("Unraveled and reraveled tall (", a, ") matrix did not match with original at row", k, ".") + } + } + } + } +} + +func TestUnReRavelingWide(t *testing.T) { + trange := []int{64, 128, 512} + var b BitVect + for _, a := range trange { + matrix := sampleRandomWide(a) + // determine number of blocks to split original matrix (512 x n) + nblks := len(matrix[0]) / 64 + + trans := make([][]byte, len(matrix)) + for r := range trans { + trans[r] = make([]byte, len(matrix[0])) + } + + for id := 0; id < nblks; id++ { + b.unravelWide(matrix, id) + b.ravelToWide(trans, id) + } + + // check + for k := range trans { + for l := range trans[k] { + if trans[k][l] != matrix[k][l] { + t.Fatal("Unraveled and reraveled wide (", a, ") matrix did not match with original at row", k, ".") + } + } + } + } +} + +// Test single block transposition +func TestTranspose512x512(t *testing.T) { + var tr BitVect + tr.unravelTall(sampleRandomTall(nmsg), 0) + orig := BitVect{tr.set} // copy to check after + + tr.transpose() + tr.transpose() + // check if transpose is correct + if tr != orig { + t.Fatalf("Block incorrectly transposed.") + } +} + +func TestIfLittleEndianTranspose(t *testing.T) { + tr := genZebraBlock() + // 0101.... + // 0101.... + // 0101.... + tr.transpose() + // If Little Endian, we expect the resulting rows to be + // 1111.... + // 0000.... + // 1111.... + + // check if Little Endian + for i := 0; i < 512; i++ { + if i%2 == 1 { // odd + if tr.set[i*8] != 0 { + t.Fatalf("transpose appears to be Big Endian") + } + } else { + if tr.set[i*8] != 0xFFFFFFFFFFFFFFFF { + t.Fatalf("transpose appears to be Big Endian") + } + } + } +} + +func TestConcurrentTransposeTall(t *testing.T) { + trange := []int{512, 512 * 2, 512 * 3, 512 * 4} + for _, m := range trange { + orig := sampleRandomTall(m) + tr := ConcurrentTransposeTall(orig) + dtr := ConcurrentTransposeWide(tr) + // test + for k := range orig { + for l := range orig[k] { + if orig[k][l] != dtr[k][l] { + t.Fatal("Doubly-transposed tall (", m, ") matrix did not match with original at row", k, ".") + } + } + } + } +} + +func TestConcurrentTransposeWide(t *testing.T) { + trange := []int{64, 64 * 2, 64 * 3, 64 * 4} + for _, m := range trange { + orig := sampleRandomWide(m) + tr := ConcurrentTransposeWide(orig) + dtr := ConcurrentTransposeTall(tr) + //test + for k := range dtr { + for l := range dtr[k] { + if dtr[k][l] != orig[k][l] { + t.Fatal("Doubly-transposed wide (", m, ") matrix did not match with original at row", k, ".") + } + } + } + } +} + +// BenchmarkTranspose512x512 benchmarks transposing a single +// BitVect block. +func BenchmarkTranspose512x512(b *testing.B) { + var tr BitVect + tr.unravelTall(sampleRandomTall(nmsg), 0) + b.ResetTimer() + for i := 0; i < b.N; i++ { + tr.transpose() + } +} + +// BenchmarkConcurrentTranspose tests the BitVect transpose with the +// overhead of having to pull the blocks out of a larger matrix and +// write to a new transposed matrix. +func BenchmarkConcurrentTranspose(b *testing.B) { + byteBlock := sampleRandomTall(nmsg) + b.ResetTimer() + for i := 0; i < b.N; i++ { + ConcurrentTransposeTall(byteBlock) + } +} diff --git a/internal/util/framing.go b/internal/util/framing.go index c73fa89..2c6337d 100644 --- a/internal/util/framing.go +++ b/internal/util/framing.go @@ -44,24 +44,3 @@ func Exhaust(n int64, r io.Reader) <-chan []byte { return identifiers } - -// Exhaust2 scans for identifiers in r, -// It expects that each indentifier is line separated with \n -// at the end of each line. -func Exhaust2(n int64, r io.Reader) <-chan []byte { - // make the output channel - var identifiers = make(chan []byte) - // wrap r in a bufio reader - src := bufio.NewScanner(r) - go func() { - defer close(identifiers) - for src.Scan() { - identifiers <- src.Bytes() - } - if err := src.Err(); err != nil { - log.Printf("error reading identifiers: %v", err) - } - }() - - return identifiers -} diff --git a/internal/util/select.go b/internal/util/select.go index 67875f2..d5ed022 100644 --- a/internal/util/select.go +++ b/internal/util/select.go @@ -4,6 +4,7 @@ import ( "context" ) +// Sel runs a single stage for protocol func Sel(ctx context.Context, f func() error) error { var d = make(chan error) go func() { @@ -18,6 +19,7 @@ func Sel(ctx context.Context, f func() error) error { } } +// Sels runs multiple stages for a protocol func Sels(fs ...func() error) chan error { var d = make(chan error, len(fs)) for _, f := range fs { diff --git a/internal/util/select_test.go b/internal/util/select_test.go index 7548622..35c877a 100644 --- a/internal/util/select_test.go +++ b/internal/util/select_test.go @@ -21,7 +21,7 @@ func TestSel(t *testing.T) { t.Errorf("expected %v, got %v", err1, err) } - // check context cancelled + // check context canceled cancel() if err := Sel(ctx, f1); err != context.Canceled { t.Errorf("expected context.Canceled, got %v", err) diff --git a/pkg/bpsi/README.md b/pkg/bpsi/README.md index e16daf6..c95c2a3 100644 --- a/pkg/bpsi/README.md +++ b/pkg/bpsi/README.md @@ -2,7 +2,7 @@ ## protocol -The bloomfilter private set intersection (BPSI) is another naive and insecure protocol, but it is highly efficient and has lower communication cost than [NPSI](../npsi/README.md). It is based on [bloomfilter](https://en.wikipedia.org/wiki/Bloom_filter) [1], a probablistic data structure that uses _k_ independent hash functions to compactly represent a set of _n_ elements with only _m_ bits. It supports _O(1)_ set insertion and provides _O(1)_set membership queries at the cost of a small and tunable false positive rate. This means that we can know for certain an element is not in the bloomfilter, and we know an element is in the bloomfilter except with a small false positive probability. +The bloomfilter private set intersection (BPSI) is another naive and insecure protocol, but it is highly efficient and has lower communication cost than [NPSI](../npsi/README.md). It is based on [bloomfilter](https://en.wikipedia.org/wiki/Bloom_filter) [1], a probablistic data structure that uses _k_ independent hash functions to compactly represent a set of _n_ elements with only _m_ bits. It supports _O(1)_ set insertion and provides _O(1)_ set membership queries at the cost of a small and tunable false positive rate. This means that we can know for certain an element is not in the bloomfilter, and we know an element is in the bloomfilter except with a small false positive probability. In the protocol, the sender _P1_ inserts all its elements _X_ into a bloomfilter, and sends it to the receiver _P2_. To compute the intersection, _P2_ needs to simply check the set membership of each of his elements _Y_ with the received bloomfilter. @@ -20,4 +20,4 @@ BF(X): Bloomfilter bit set of inputs X # References -[1] Bloom, Burton H. "Space/time trade-offs in hash coding with allowable errors." Communications of the ACM 13.7 (1970): 422-426. \ No newline at end of file +[1] Bloom, Burton H. "Space/time trade-offs in hash coding with allowable errors." Communications of the ACM 13.7 (1970): 422-426. diff --git a/pkg/bpsi/receiver.go b/pkg/bpsi/receiver.go index 5f4857c..c10751a 100644 --- a/pkg/bpsi/receiver.go +++ b/pkg/bpsi/receiver.go @@ -31,12 +31,11 @@ func NewReceiver(rw io.ReadWriter) *Receiver { // returning the matching intersection, using the NPSI protocol. // The format of an indentifier is // string -func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan []byte) ([][]byte, error) { +func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan []byte) (intersection [][]byte, err error) { // fetch and set up logger logger := logr.FromContextOrDiscard(ctx) logger = logger.WithValues("protocol", "bpsi") var bf bloomfilter - var intersected [][]byte // stage 1: read the bloomfilter from the remote side stage1 := func() error { @@ -58,7 +57,7 @@ func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan [] logger.V(1).Info("Starting stage 2") for identifier := range identifiers { if bf.Check(identifier) { - intersected = append(intersected, identifier) + intersection = append(intersection, identifier) } } @@ -68,14 +67,14 @@ func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan [] // run stage1 if err := util.Sel(ctx, stage1); err != nil { - return intersected, err + return intersection, err } // run stage2 if err := util.Sel(ctx, stage2); err != nil { - return intersected, err + return intersection, err } - logger.V(1).Info("receiver finished", "intersected", len(intersected)) - return intersected, nil + logger.V(1).Info("receiver finished", "intersected", len(intersection)) + return intersection, nil } diff --git a/pkg/bpsi/sender.go b/pkg/bpsi/sender.go index 0481c4f..9ca8f36 100644 --- a/pkg/bpsi/sender.go +++ b/pkg/bpsi/sender.go @@ -66,6 +66,6 @@ func (s *Sender) Send(ctx context.Context, n int64, identifiers <-chan []byte) e return err } - logger.V(1).Info("sender finished.") + logger.V(1).Info("sender finished") return nil } diff --git a/pkg/dhpsi/README.md b/pkg/dhpsi/README.md index d9fd739..7d69c21 100644 --- a/pkg/dhpsi/README.md +++ b/pkg/dhpsi/README.md @@ -2,7 +2,7 @@ ## protocol -The Diffie-Hellman private set intersection (DHPSI) [1] is one of the first PSI protocol and is communication efficient, but requires expensive computations from both parties: a sender and a receiver. We implement DHPSI using elliptic curve (specifically `ristretto255` [2]) instead of finite field exponentiation for performance reasons. The point operation of _kP_ is the multiplication of a ristretto point _P_ with a scalar _k_ over an ellipic curve (Curve25519). +The Diffie-Hellman private set intersection (DHPSI) [1] is one of the first PSI protocols and is communication efficient, but requires expensive computations from both parties: a sender and a receiver. We implement DHPSI using elliptic curve (specifically `ristretto255` [2]) instead of finite field exponentiation for performance reasons. The point operation of _kP_ is the multiplication of a ristretto point _P_ with a scalar _k_ over an ellipic curve (Curve25519). 1. the receiver and the sender agree on a preset elliptic curve _E_ (Curve25519). 1. the sender generates his private key (*scalar*) _a_, and hashes each identifier from his input audience list to obtains points _xi ∈ X_ on _E_. (*Derive*) diff --git a/pkg/dhpsi/dhpsi.go b/pkg/dhpsi/dhpsi.go index a929dfd..98702b1 100644 --- a/pkg/dhpsi/dhpsi.go +++ b/pkg/dhpsi/dhpsi.go @@ -1,11 +1,9 @@ package dhpsi import ( - "crypto/rand" "encoding/binary" "fmt" "io" - "math/big" "github.com/optable/match/internal/permutations" ) @@ -26,7 +24,7 @@ var ( // // DeriveMultiplyShuffler contains the necessary -// machineries to derive identifiers into ristretto point +// machineries to derive identifiers into ristretto point, // multiply them with secret key and permute them. type DeriveMultiplyShuffler struct { w io.Writer @@ -162,8 +160,8 @@ func NewMultiplyReader(r io.Reader, gr Ristretto) (*MultiplyReader, error) { return &MultiplyReader{r: rr, gr: gr}, nil } -// Read reads a point from the underlying reader, multiply it with ristretto -// and write it into point. Returns io.EOF when +// Read reads a point from the underlying reader, multiplies it with ristretto +// and writes it into point. Returns io.EOF when // the sequence has been completely read. func (r *MultiplyReader) Read(point *[EncodedLen]byte) (err error) { var b [EncodedLen]byte @@ -192,7 +190,7 @@ func NewReader(r io.Reader) (*Reader, error) { } // Read reads a point from the underlying reader and -// write it into p. Returns io.EOF when +// writes it into p. Returns io.EOF when // the sequence has been completely read. func (r *Reader) Read(point *[EncodedLen]byte) (err error) { // ignore any read past the max size @@ -213,29 +211,3 @@ func (r *Reader) Read(point *[EncodedLen]byte) (err error) { func (r *Reader) Max() int64 { return r.max } - -// init the permutations slice matrix -func initP(n int64) []int64 { - var p = make([]int64, n) - var max = big.NewInt(n - 1) - // Chooses a uniform random int64 - choose := func() int64 { - i, err := rand.Int(rand.Reader, max) - if err != nil { - return 0 - } - return i.Int64() - } - // Initialize a trivial permutation - for i := int64(0); i < n; i++ { - p[i] = i - } - // and then shuffle it by random swaps - for i := int64(0); i < n; i++ { - if j := choose(); j != i { - p[j], p[i] = p[i], p[j] - } - } - - return p -} diff --git a/pkg/dhpsi/dhpsi_parallel.go b/pkg/dhpsi/dhpsi_parallel.go index 139c818..e66578b 100644 --- a/pkg/dhpsi/dhpsi_parallel.go +++ b/pkg/dhpsi/dhpsi_parallel.go @@ -272,8 +272,8 @@ func copyOut(b mBatch, c chan [EncodedLen]byte, done chan bool) (n int64) { return } -// Read reads a point from the underlying reader, multiply it with ristretto -// and write it into point. Returns io.EOF when +// Read reads a point from the underlying reader, multiplies it with ristretto +// and writes it into point. Returns io.EOF when // the sequence has been completely read. func (dec *MultiplyParallelReader) Read(point *[EncodedLen]byte) (err error) { // ignore any read past the max size diff --git a/pkg/dhpsi/dhpsi_parallel_test.go b/pkg/dhpsi/dhpsi_parallel_test.go index fbb3a7c..cddc0fa 100644 --- a/pkg/dhpsi/dhpsi_parallel_test.go +++ b/pkg/dhpsi/dhpsi_parallel_test.go @@ -18,8 +18,8 @@ func TestDeriveMultiplyParallelShuffler(t *testing.T) { // get an io pipe to read results rcv, snd := io.Pipe() // setup a matchables generator - common := emails.Common(DHPSITestCommonLen) - matchables := emails.Mix(common, DHPSITestBodyLen) + common := emails.Common(DHPSITestCommonLen, emails.HashLen) + matchables := emails.Mix(common, DHPSITestBodyLen, emails.HashLen) // save test matchables var sent [][]byte @@ -89,8 +89,8 @@ func TestMultiplyParallelReader(t *testing.T) { // get an io pipe to read results rcv, snd := io.Pipe() // setup a matchables generator - common := emails.Common(DHPSITestCommonLen) - matchables := emails.Mix(common, DHPSITestBodyLen) + common := emails.Common(DHPSITestCommonLen, emails.HashLen) + matchables := emails.Mix(common, DHPSITestBodyLen, emails.HashLen) // save sent encoded ristretto points var sent [][EncodedLen]byte diff --git a/pkg/dhpsi/dhpsi_test.go b/pkg/dhpsi/dhpsi_test.go index 3738497..5d39a3a 100644 --- a/pkg/dhpsi/dhpsi_test.go +++ b/pkg/dhpsi/dhpsi_test.go @@ -38,9 +38,8 @@ func compare(b1 [EncodedLen]byte, b2 []byte) bool { func sender(e ShufflerEncoder, r Ristretto, matchables <-chan []byte) ([][]byte, permutations.Permutations, error) { // save test matchables var sent [][]byte - var encoder ShufflerEncoder // setup stage 1 - encoder = e + var encoder = e // save the permutations p := encoder.Permutations() for matchable := range matchables { @@ -97,8 +96,8 @@ func TestDeriveMultiplyShuffler(t *testing.T) { // get an io pipe to read results rcv, snd := io.Pipe() // setup a matchables generator - common := emails.Common(DHPSITestCommonLen) - matchables := emails.Mix(common, DHPSITestBodyLen) + common := emails.Common(DHPSITestCommonLen, emails.HashLen) + matchables := emails.Mix(common, DHPSITestBodyLen, emails.HashLen) // save test matchables var sent [][]byte diff --git a/pkg/dhpsi/receiver.go b/pkg/dhpsi/receiver.go index 66f12ba..5758799 100644 --- a/pkg/dhpsi/receiver.go +++ b/pkg/dhpsi/receiver.go @@ -65,7 +65,7 @@ func (s *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan [] // pick a ristretto implementation gr, _ := NewRistretto(RistrettoTypeR255) - // step1 : reads the identifiers from the sender, encrypt them and index the encoded ristretto point in a map + // step1 : reads the identifiers from the sender, encrypts them and indexes the encoded ristretto point in a map stage1 := func() error { logger.V(1).Info("Starting stage 1") diff --git a/pkg/dhpsi/sender.go b/pkg/dhpsi/sender.go index b62d17c..f6cc8e4 100644 --- a/pkg/dhpsi/sender.go +++ b/pkg/dhpsi/sender.go @@ -70,7 +70,7 @@ func (s *Sender) Send(ctx context.Context, n int64, identifiers <-chan []byte) e return nil } - // stage2 : reads the identifiers from the receiver, encrypt them and send them back + // stage2 : reads the identifiers from the receiver, encrypts them and sends them back stage2 := func() error { logger.V(1).Info("Starting stage 2") diff --git a/pkg/kkrtpsi/README.md b/pkg/kkrtpsi/README.md new file mode 100644 index 0000000..d823a6f --- /dev/null +++ b/pkg/kkrtpsi/README.md @@ -0,0 +1,49 @@ +# kkrtpsi implementation + +# protocol +The KKRT PSI (Batched-OPRF PSI) [1] is one of the most efficient OT-extension ([oblivious transfer](https://en.wikipedia.org/wiki/Oblivious_transfer)) based PSI protocol that boasts more than 100 times speed up in single core performance (300s for DHPSI vs 3s for KKRT for a match between two 1 million records dataset). It is secure against semi-honest adversaries, a malicious party that adheres to the protocol honestly but wants to learn/extract the other party's private information from the data being exchanged. + +1. the sender generates [CuckooHash](https://en.wikipedia.org/wiki/Cuckoo_hashing) parameters and exchange with the receiver. +2. the receiver inserts his input set _Y_ to the Cuckoo Hash Table. +3. the receiver acts as the sender in the OPRF protocol and samples two matrices _T_ and _U_ such that the matrix _T_ is a uniformly random bit matrix, and the matrix _U_ is the Pseudorandom Code (linear correcting code) _C_ on cuckoohashed inputs. The receiver outputs the matrix _T_ as the OPRF evaluation of his inputs _Y_. +4. the sender acts as the receiver in the OPRF protocol with input secret choice bits _s_, and receives matrix _Q_, with columns of _Q_ correspond to either matrix _T_ or _U_ depending on the value of _s_, and outputs the matrix _Q_. Each row of the column _Q_ along with the secret choice bit _s_ serves as the OPRF keys to encode his own input _X_. +5. the sender uses the key _k_ to encode his own input _X_, and sends it to the receiver. +6. the receiver receives the OPRF evaluation of _X_, and compares with his own OPRF evaluation of _Y_, and outputs the intersection. + + +## data flow +``` + Sender Receiver + X Y + + + +Stage 1 Cuckoo Hash Stage 1 + ───────────────────CukooHashParam───────────────► cuckoo.Insert(Y) + + + +Stage 2.1 Stage 2.1 + oprf.Receive() ◄─────────────────────T, U─────────────────────── oprf.Send() + + + K = Q ────────────────────────────────────────────────► OPRF(K, Y) = T + + + +Stage 3 OPRF(K, X) ────────────────────OPRF(K, X)──────────────────► Stage 3 + + +K: OPRF keys +OPRF(K, Y): OPRF evaluation of input Y with key K +``` + +## References + +[1] V. Kolesnikov, R. Kumaresan, M. Rosulek, N.Trieu. "Efficient Batched Oblivious PRF with Applications to Private Set Intersection." In Proceedings of the 2016 ACM SIGSAC Conference on Computer and Communications Security (pp. 818-829),2016. Paper available here: https://dl.acm.org/doi/pdf/10.1145/2976749.2978381. + +[2] M. Naor, B. Pinkas. "Efficient oblivious transfer protocols." In SODA (Vol. 1, pp. 448-457), 2001. Paper available here: https://link.springer.com/content/pdf/10.1007/978-3-662-46800-5_26.pdf + +[3] T. Chou, O. Claudio. "The simplest protocol for oblivious transfer." In International Conference on Cryptology and Information Security in Latin America (pp. 40-58). Springer, Cham, 2015. Paper available here: https://eprint.iacr.org/2015/267.pdf + +[4] Y. Ishai and J. Kilian and K. Nissim and E. Petrank, Extending Oblivious Transfers Efficiently. https://www.iacr.org/archive/crypto2003/27290145/27290145.pdf diff --git a/pkg/kkrtpsi/kkrtpsi.go b/pkg/kkrtpsi/kkrtpsi.go new file mode 100644 index 0000000..385df6e --- /dev/null +++ b/pkg/kkrtpsi/kkrtpsi.go @@ -0,0 +1,44 @@ +package kkrtpsi + +import ( + "encoding/binary" + "io" + "math" + "runtime" + "time" + + "github.com/go-logr/logr" + "github.com/optable/match/internal/cuckoo" + "github.com/optable/match/internal/hash" + "github.com/optable/match/internal/oprf" +) + +// HashRead reads one hash +func EncodingsRead(r io.Reader, u *[cuckoo.Nhash]uint64) error { + return binary.Read(r, binary.BigEndian, u) +} + +// HashWrite writes one hash out +func EncodingsWrite(w io.Writer, u [cuckoo.Nhash]uint64) error { + return binary.Write(w, binary.BigEndian, u) +} + +func (input *inputToOprfEncode) encodeAndHash(oprfKeys *oprf.Key, hasher hash.Hasher) (hashes [cuckoo.Nhash]uint64) { + // oprfInput is instantiated at the required size + for hIdx, bucketIdx := range input.bucketIdx { + oprfKeys.Encode(bucketIdx, input.prcEncoded[hIdx]) + hashes[hIdx] = hasher.Hash64(input.prcEncoded[hIdx]) + } + + return hashes +} + +func printStageStats(log logr.Logger, stage int, prevTime, startTime time.Time, prevMem uint64) (time.Time, uint64) { + endTime := time.Now() + log.V(2).Info("stats", "stage", stage, "time", time.Since(prevTime).String(), "cumulative time", time.Since(startTime).String()) + var m runtime.MemStats + runtime.ReadMemStats(&m) // https://cs.opensource.google/go/go/+/go1.17.1:src/runtime/mstats.go;l=107 + log.V(2).Info("stats", "stage", stage, "total memory from OS (MiB)", math.Round(float64(m.Sys-prevMem)*100/(1024*1024))/100) + log.V(2).Info("stats", "stage", stage, "cumulative GC calls", m.NumGC) + return endTime, m.Sys +} diff --git a/pkg/kkrtpsi/receiver.go b/pkg/kkrtpsi/receiver.go new file mode 100644 index 0000000..fdf6219 --- /dev/null +++ b/pkg/kkrtpsi/receiver.go @@ -0,0 +1,161 @@ +package kkrtpsi + +import ( + "bufio" + "context" + "encoding/binary" + "fmt" + "io" + "time" + + "github.com/go-logr/logr" + "github.com/optable/match/internal/cuckoo" + "github.com/optable/match/internal/hash" + "github.com/optable/match/internal/oprf" + "github.com/optable/match/internal/util" +) + +// stage 1: read hash seeds for cuckoo hash, read local IDs until exhaustion +// and insert them all into a cuckoo hash table +// stage 2: OPRF Receive +// stage 3: receive sender's OPRF encodings and intersect + +// Receiver side of the KKRTPSI protocol +type Receiver struct { + rw io.ReadWriter +} + +// NewReceiver returns a KKRT receiver initialized to +// use rw as the communication layer +func NewReceiver(rw io.ReadWriter) *Receiver { + return &Receiver{rw: rw} +} + +// Intersect on matchables read from the identifiers channel, +// returning the matching intersection, using the KKRTPSI protocol. +// The format of an indentifier is string +// example: +// 0e1f461bbefa6e07cc2ef06b9ee1ed25101e24d4345af266ed2f5a58bcd26c5e +func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan []byte) (intersection [][]byte, err error) { + // fetch and set up logger + logger := logr.FromContextOrDiscard(ctx) + logger = logger.WithValues("protocol", "kkrtpsi") + + // start timer: + start := time.Now() + timer := time.Now() + var mem uint64 + + var seeds [cuckoo.Nhash][]byte + var oprfOutput = make([]map[uint64]uint64, cuckoo.Nhash) + var cuckooHashTable *cuckoo.Cuckoo + var secretKey []byte + + // stage 1: read the hash seeds from the remote side + // initiate a cuckoo hash table and insert all local + // IDs into the cuckoo hash table. + stage1 := func() error { + logger.V(1).Info("Starting stage 1") + for i := range seeds { + seeds[i] = make([]byte, hash.SaltLength) + if _, err := io.ReadFull(r.rw, seeds[i]); err != nil { + return fmt.Errorf("stage1: %v", err) + } + } + + // send size + if err := binary.Write(r.rw, binary.BigEndian, &n); err != nil { + return err + } + + // instantiate cuckoo hash table + cuckooHashTable = cuckoo.NewCuckoo(uint64(n), seeds) + for id := range identifiers { + if err = cuckooHashTable.Insert(id); err != nil { + return err + } + } + + // receive secret key for AES-128 (16 byte) + secretKey = make([]byte, 16) + if _, err := io.ReadFull(r.rw, secretKey); err != nil { + return fmt.Errorf("stage1: %v", err) + } + + // end stage1 + timer, mem = printStageStats(logger, 1, start, start, 0) + logger.V(1).Info("Finished stage 1") + return nil + } + + // stage 2: prepare OPRF receive input and run Receive to get local OPRF encodings + stage2 := func() error { + logger.V(1).Info("Starting stage 2") + oprfInputSize := int(cuckooHashTable.Len()) + oprfOutput, err = oprf.NewOPRF(oprfInputSize).Receive(cuckooHashTable, secretKey, r.rw) + if err != nil { + return err + } + + // end stage2 + timer, mem = printStageStats(logger, 2, timer, start, mem) + logger.V(1).Info("Finished stage 2") + return nil + } + + // stage 3: read remote encoded identifiers and compare + // to produce intersections + stage3 := func() error { + logger.V(1).Info("Starting stage 3") + // read number of remote IDs + var remoteN int64 + if err := binary.Read(r.rw, binary.BigEndian, &remoteN); err != nil { + return err + } + + // Add a buffer of 64k to amortize syscalls cost + var bufferedReader = bufio.NewReaderSize(r.rw, 1024*64) + + // read remote encodings and intersect + for i := int64(0); i < remoteN; i++ { + // read cuckoo.Nhash possible encodings + var remoteEncoding [cuckoo.Nhash]uint64 + if err := EncodingsRead(bufferedReader, &remoteEncoding); err != nil { + return err + } + // intersect + for hashIdx, remoteHash := range remoteEncoding { + if idx, ok := oprfOutput[hashIdx][remoteHash]; ok { + id, _ := cuckooHashTable.GetItemWithHash(idx) + if id == nil { + return fmt.Errorf("failed to retrieve item #%v", idx) + } + intersection = append(intersection, id) + // dedup + delete(oprfOutput[hashIdx], remoteHash) + } + } + } + // end stage3 + _, _ = printStageStats(logger, 3, timer, start, mem) + logger.V(1).Info("Finished stage 3") + return nil + } + + // run stage1 + if err := util.Sel(ctx, stage1); err != nil { + return intersection, err + } + + // run stage2 + if err := util.Sel(ctx, stage2); err != nil { + return intersection, err + } + + // run stage3 + if err := util.Sel(ctx, stage3); err != nil { + return intersection, err + } + + return intersection, nil +} diff --git a/pkg/kkrtpsi/sender.go b/pkg/kkrtpsi/sender.go new file mode 100644 index 0000000..271e2d0 --- /dev/null +++ b/pkg/kkrtpsi/sender.go @@ -0,0 +1,284 @@ +package kkrtpsi + +import ( + "bufio" + "context" + "crypto/aes" + "crypto/rand" + "encoding/binary" + "fmt" + "io" + "runtime" + "time" + + "github.com/go-logr/logr" + "github.com/optable/match/internal/crypto" + "github.com/optable/match/internal/cuckoo" + "github.com/optable/match/internal/hash" + "github.com/optable/match/internal/oprf" + "github.com/optable/match/internal/util" + "golang.org/x/sync/errgroup" +) + +// stage 1: samples cuckoo.Nhash hash seeds and sends them to receiver for cuckoo hash +// stage 2: act as sender in OPRF, and receive OPRF keys +// stage 3: compute OPRF(k, id) and send them to receiver for intersection. + +// Sender side of the KKRTPSI protocol +type Sender struct { + rw io.ReadWriter +} + +// inputToOprfEncode stores the possible bucket +// indexes in the receiver cuckoo hash table +type inputToOprfEncode struct { + prcEncoded [cuckoo.Nhash][]byte // PseudoRandom Code + bucketIdx [cuckoo.Nhash]uint64 +} + +// stage1Result is used to pass the OPRF encoded +// inputs along with the hasher from stage 1 to stage 3 +type stage1Result struct { + inputs []inputToOprfEncode + hasher hash.Hasher +} + +// NewSender returns a KKRTPSI sender initialized to +// use rw as the communication layer +func NewSender(rw io.ReadWriter) *Sender { + return &Sender{rw: rw} +} + +// Send initiates a KKRTPSI exchange +// that reads local IDs from identifiers, until identifiers closes. +// The format of an indentifier is string +// example: +// 0e1f461bbefa6e07cc2ef06b9ee1ed25101e24d4345af266ed2f5a58bcd26c5e +func (s *Sender) Send(ctx context.Context, n int64, identifiers <-chan []byte) (err error) { + // fetch and set up logger + logger := logr.FromContextOrDiscard(ctx) + logger = logger.WithValues("protocol", "kkrtpsi") + + // statistics + start := time.Now() + timer := time.Now() + var mem uint64 + + var seeds [cuckoo.Nhash][]byte + var remoteN int64 // receiver size + var oprfInputSize int // nb of OPRF keys + + var oprfKey *oprf.Key + var encodedInputChan = make(chan stage1Result) + + // stage 1: sample hash seeds and write them to receiver + // for cuckoo hashing parameters agreement. + // read local ids and store the potential bucket indexes for each id. + stage1 := func() error { + logger.V(1).Info("Starting stage 1") + + // sample cuckoo.Nhash hash seeds + for i := range seeds { + seeds[i] = make([]byte, hash.SaltLength) + if _, err := rand.Read(seeds[i]); err != nil { + return err + } + // write it into rw + if _, err := s.rw.Write(seeds[i]); err != nil { + return err + } + } + + // read remote input size + if err := binary.Read(s.rw, binary.BigEndian, &remoteN); err != nil { + return err + } + + // sample random 16 byte secret key for AES-128 and send to the receiver + secretKey := make([]byte, aes.BlockSize) + if _, err = rand.Read(secretKey); err != nil { + return err + } + + // send the secret key + if _, err := s.rw.Write(secretKey); err != nil { + return err + } + + // calculate number of OPRF from the receiver based on + // number of buckets in cuckooHashTable + oprfInputSize = int(cuckoo.Factor * float64(remoteN)) + if 1 > oprfInputSize { + oprfInputSize = 1 + } + + // instantiate an AES block + aesBlock, err := aes.NewCipher(secretKey) + if err != nil { + return err + } + + // exhaust local ids, and precompute all potential + // hashes and store them using the same + // cuckoo hash table parameters as the receiver. + go func() { + cuckooHasher := cuckoo.NewCuckooHasher(uint64(remoteN), seeds) + + // prepare struct to send inputs and hasher to stage 3 + var result stage1Result + result.inputs = make([]inputToOprfEncode, n) + + var i int + for id := range identifiers { + // hash and calculate pseudorandom code given each possible hash index + var bytes [cuckoo.Nhash][]byte + for hIdx := 0; hIdx < cuckoo.Nhash; hIdx++ { + bytes[hIdx] = crypto.PseudorandomCode(aesBlock, id, byte(hIdx)) + } + result.inputs[i] = inputToOprfEncode{prcEncoded: bytes, bucketIdx: cuckooHasher.BucketIndices(id)} + i++ + } + + result.hasher = cuckooHasher.GetHasher() + encodedInputChan <- result + }() + + // end stage1 + timer, mem = printStageStats(logger, 1, start, start, 0) + logger.V(1).Info("Finished stage 1") + return nil + } + + // stage 2: act as sender in OPRF, and receive OPRF keys + stage2 := func() error { + logger.V(1).Info("Starting stage 2") + + // instantiate OPRF sender with agreed parameters + oprfKey, err = oprf.NewOPRF(oprfInputSize).Send(s.rw) + if err != nil { + return err + } + + // end stage2 + timer, mem = printStageStats(logger, 2, timer, start, mem) + logger.V(1).Info("Finished stage 2") + return nil + } + + // stage 3: compute all possible OPRF output using keys obtained from stage2 + stage3 := func() error { + logger.V(1).Info("Starting stage 3") + + // inform the receiver the number of local ID + if err := binary.Write(s.rw, binary.BigEndian, &n); err != nil { + return err + } + + message := <-encodedInputChan + nWorkers := runtime.GOMAXPROCS(0) + var localEncodings = make(chan [][cuckoo.Nhash]uint64, nWorkers*2) + + batchSize := 2048 + nBatches := len(message.inputs) / batchSize + workerResp := nBatches / nWorkers + + g, ctx := errgroup.WithContext(ctx) + + // each worker is responsible to encode and hash workerResp batches and send them out + for w := 0; w < nWorkers; w++ { + w := w + g.Go(func() error { + for batchNumber := 0; batchNumber < workerResp; batchNumber++ { + batch := make([][cuckoo.Nhash]uint64, batchSize) + step := (w*workerResp + batchNumber) * batchSize + for bIdx := 0; bIdx < batchSize; bIdx++ { + batch[bIdx] = message.inputs[step+bIdx].encodeAndHash(oprfKey, message.hasher) + } + + select { + case <-ctx.Done(): + return ctx.Err() + // batch is filled; send it out + case localEncodings <- batch: + } + } + return nil + }) + } + + // Extra worker deals with the remaining inputs + // In the case that nBatches < nWorkers, this worker will be responsible + // for the entire set of inputs. In addition, since it aggregates its + // inputs into a single large batch, this process is not pipelined. + g.Go(func() error { + workLeft := len(message.inputs) - (workerResp * nWorkers * batchSize) + if workLeft == 0 { + return nil + } + + lastBatch := make([][cuckoo.Nhash]uint64, workLeft) + for bIdx := 0; bIdx < workLeft; bIdx++ { + lastBatch[bIdx] = message.inputs[len(message.inputs)-workLeft+bIdx].encodeAndHash(oprfKey, message.hasher) + } + + select { + case <-ctx.Done(): + return ctx.Err() + // final batch filled; send it out + case localEncodings <- lastBatch: + return nil + } + }) + + g.Go(func() error { + // Add a buffer of 64k to amortize syscalls cost + var bufferedWriter = bufio.NewWriterSize(s.rw, 1024*64) + defer bufferedWriter.Flush() + var sent int + // no message + if len(message.inputs) == 0 { + close(localEncodings) + } + + for batch := range localEncodings { + for _, hashedEncodings := range batch { + // send all encodings of an ID at once + if err := EncodingsWrite(bufferedWriter, hashedEncodings); err != nil { + return fmt.Errorf("stage3: %v", err) + } + sent++ + } + if sent == len(message.inputs) { + close(localEncodings) + } + } + return nil + }) + + if err := g.Wait(); err != nil { + return err + } + + // end stage3 + _, _ = printStageStats(logger, 3, timer, start, mem) + logger.V(1).Info("Finished stage 3") + return nil + } + + // run stage1 + if err := util.Sel(ctx, stage1); err != nil { + return err + } + + // run stage2 + if err := util.Sel(ctx, stage2); err != nil { + return err + } + + // run stage3 + if err := util.Sel(ctx, stage3); err != nil { + return err + } + + return nil +} diff --git a/pkg/npsi/README.md b/pkg/npsi/README.md index 7044114..857079e 100644 --- a/pkg/npsi/README.md +++ b/pkg/npsi/README.md @@ -2,9 +2,9 @@ ## protocol -In the naive private set intersection (NPSI) [1], both parties agree on a non-cryptographic hash function, apply it to their inputs and then compare the resulting hashes. It is the most commonly used protocol due to its efficiency and ease for implementation, but it is *insecure*. The protocol has a major security flaw if the elements are taken from a small domain or a domain that does not have high entroy. In that case, _P2_ (the receiver) can recover all elements in the set of _P1_ (the sender) by running a brute force attack. +In the naive private set intersection (NPSI) [1], both parties agree on a non-cryptographic hash function, apply it to their inputs and then compare the resulting hashes. It is the most commonly used protocol due to its efficiency and ease for implementation, but it is *insecure*. The protocol has a major security flaw if the elements are taken from a small domain or a domain that does not have high entropy. In that case, _P2_ (the receiver) can recover all elements in the set of _P1_ (the sender) by running a brute force attack. -In the protocol, _P2_ samples a random 32 bytes salt _K_ and sends it to _P1_. Both parties then use a non-cryptographic hash function ([Highway Hash](github.com/dgryski/go-highway)) to hash their input identifiers seeded with _K_. _P1_ sends the hash values _Hx_ to _P2_, who computes the intersection of both hashed identifiers. +In the protocol, _P2_ samples a random 32 bytes salt _K_ and sends it to _P1_. Both parties then use a non-cryptographic hash function ([MetroHash](http://www.jandrewrogers.com/2015/05/27/metrohash/)) to hash their input identifiers seeded with _K_. _P1_ sends the hash values _Hx_ to _P2_, who computes the intersection of both hashed identifiers. ## data flow @@ -14,9 +14,9 @@ X Y receive K <------------------------------ generate K (32 bytes) -hwh(K,X) -> H_X ------------------------------> intersect(H_X, hwh(K,Y) -> H_Y)) +mh(K,X) -> H_X ------------------------------> intersect(H_X, mh(K,Y) -> H_Y)) -hwh(K,I): Highway hash of input I seeded with K +mh(K,I): Metro hash of input I seeded with K ``` # References diff --git a/pkg/npsi/npsi_parallel.go b/pkg/npsi/npsi_parallel.go index 90eeff8..cb9ade4 100644 --- a/pkg/npsi/npsi_parallel.go +++ b/pkg/npsi/npsi_parallel.go @@ -7,10 +7,6 @@ import ( "github.com/optable/match/internal/hash" ) -// -// parallel hashing engine -// - const ( batchSize = 512 ) diff --git a/pkg/npsi/receiver.go b/pkg/npsi/receiver.go index 21edefe..9d62e4d 100644 --- a/pkg/npsi/receiver.go +++ b/pkg/npsi/receiver.go @@ -61,7 +61,7 @@ func (r *Receiver) Intersect(ctx context.Context, n int64, identifiers <-chan [] var localIDs = make(map[uint64][]byte) var remoteIDs = make(map[uint64]bool) // get a hasher - h, err := hash.New(hash.Highway, k) + h, err := hash.NewMetroHasher(k) if err != nil { return err } diff --git a/pkg/npsi/sender.go b/pkg/npsi/sender.go index 5d11bd5..eefbb25 100644 --- a/pkg/npsi/sender.go +++ b/pkg/npsi/sender.go @@ -54,7 +54,7 @@ func (s *Sender) Send(ctx context.Context, n int64, identifiers <-chan []byte) e stage2 := func() error { logger.V(1).Info("Starting stage 2") // get a hasher - h, err := hash.New(hash.Highway, k) + h, err := hash.NewMetroHasher(k) if err != nil { return err } diff --git a/pkg/psi/psi.go b/pkg/psi/psi.go index 79dbbca..8678017 100644 --- a/pkg/psi/psi.go +++ b/pkg/psi/psi.go @@ -7,6 +7,7 @@ import ( "github.com/optable/match/pkg/bpsi" "github.com/optable/match/pkg/dhpsi" + "github.com/optable/match/pkg/kkrtpsi" "github.com/optable/match/pkg/npsi" ) @@ -18,6 +19,7 @@ const ( ProtocolDHPSI ProtocolNPSI ProtocolBPSI + ProtocolKKRTPSI ) var ErrUnsupportedPSIProtocol = errors.New("unsupported PSI protocol") @@ -40,6 +42,8 @@ func NewSender(protocol Protocol, rw io.ReadWriter) (Sender, error) { return npsi.NewSender(rw), nil case ProtocolBPSI: return bpsi.NewSender(rw), nil + case ProtocolKKRTPSI: + return kkrtpsi.NewSender(rw), nil case ProtocolUnsupported: fallthrough default: @@ -55,6 +59,8 @@ func NewReceiver(protocol Protocol, rw io.ReadWriter) (Receiver, error) { return npsi.NewReceiver(rw), nil case ProtocolBPSI: return bpsi.NewReceiver(rw), nil + case ProtocolKKRTPSI: + return kkrtpsi.NewReceiver(rw), nil case ProtocolUnsupported: fallthrough default: @@ -70,6 +76,8 @@ func (p Protocol) String() string { return "npsi" case ProtocolBPSI: return "bpsi" + case ProtocolKKRTPSI: + return "kkrtpsi" case ProtocolUnsupported: fallthrough default: diff --git a/test/emails/generate.go b/test/emails/generate.go index e90f838..600787d 100644 --- a/test/emails/generate.go +++ b/test/emails/generate.go @@ -8,13 +8,15 @@ import ( ) const ( - Prefix = "e:" + // Prefix is value to be prepended to each generated email + Prefix = "e:" + // HashLen is the number of bytes to generate HashLen = 32 ) // Common generates the common matchable identifiers -func Common(n int) (common []byte) { - common = make([]byte, n*HashLen) +func Common(n, hashLen int) (common []byte) { + common = make([]byte, n*hashLen) if _, err := rand.Read(common); err != nil { log.Fatalf("could not generate %d hashes for the common portion", n) } @@ -22,20 +24,20 @@ func Common(n int) (common []byte) { } // Mix mixes identifiers from common and n new fresh matchables -func Mix(common []byte, n int) <-chan []byte { +func Mix(common []byte, n, hashLen int) <-chan []byte { // setup the streams - c1 := commons(common) - c2 := freshes(n) + c1 := commons(common, hashLen) + c2 := freshes(n, hashLen) return mixes(c1, c2) } // commons will write HashLen chunks from b to a channel and then close it -func commons(b []byte) <-chan []byte { +func commons(b []byte, hashLen int) <-chan []byte { out := make(chan []byte) go func() { defer close(out) - for i := 0; i < len(b)/HashLen; i++ { - hash := b[i*HashLen : i*HashLen+HashLen] + for i := 0; i < len(b)/hashLen; i++ { + hash := b[i*hashLen : i*hashLen+hashLen] out <- hash } }() @@ -43,12 +45,12 @@ func commons(b []byte) <-chan []byte { } // freshes will write a total number of fresh hashes to a channel and then close it -func freshes(total int) <-chan []byte { +func freshes(total, hashLen int) <-chan []byte { out := make(chan []byte) go func() { defer close(out) for i := 0; i < total; i++ { - b := make([]byte, HashLen) + b := make([]byte, hashLen) if _, err := rand.Read(b); err == nil { out <- b } diff --git a/test/emails/generate_test.go b/test/emails/generate_test.go index 7fc09a5..415bcf9 100644 --- a/test/emails/generate_test.go +++ b/test/emails/generate_test.go @@ -17,7 +17,7 @@ func initDataSource(common []byte) *bufio.Reader { i, o := io.Pipe() b := bufio.NewReader(i) go func() { - matchables := Mix(common, Cardinality-CommonCardinality) + matchables := Mix(common, Cardinality-CommonCardinality, HashLen) for matchable := range matchables { out := append(matchable, "\n"...) if _, err := o.Write(out); err != nil { @@ -30,7 +30,7 @@ func initDataSource(common []byte) *bufio.Reader { func TestGenerate(t *testing.T) { // generate common data - common := Common(CommonCardinality) + common := Common(CommonCardinality, HashLen) r := initDataSource(common) // read N matchables from r diff --git a/test/psi/receiver_test.go b/test/psi/receiver_test.go index 5a52191..6c62921 100644 --- a/test/psi/receiver_test.go +++ b/test/psi/receiver_test.go @@ -14,7 +14,7 @@ import ( ) // test receiver and return the addr string -func r_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen int, intersectionsBus chan<- []byte, errs chan<- error) (addr string, err error) { +func r_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen, hashLen int, intersectionsBus chan<- []byte, errs chan<- error) (addr string, err error) { ln, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return "", err @@ -24,16 +24,17 @@ func r_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen conn, err := ln.Accept() if err != nil { // handle error + errs <- err } - go r_receiverHandle(protocol, common, commonLen, receiverLen, conn, intersectionsBus, errs) + go r_receiverHandle(protocol, common, commonLen, receiverLen, hashLen, conn, intersectionsBus, errs) } }() return ln.Addr().String(), nil } -func r_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen int, conn net.Conn, intersectionsBus chan<- []byte, errs chan<- error) { +func r_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen, hashLen int, conn net.Conn, intersectionsBus chan<- []byte, errs chan<- error) { defer close(intersectionsBus) - r := initTestDataSource(common, receiverLen-commonLen) + r := initTestDataSource(common, receiverLen-commonLen, hashLen) rec, _ := psi.NewReceiver(protocol, conn) ii, err := rec.Intersect(context.Background(), int64(receiverLen), r) @@ -41,21 +42,20 @@ func r_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverL intersectionsBus <- intersection } if err != nil { - // hmm - send this to the main thread with a channel errs <- err } } // take the common chunk from the emails generator // and turn it into prefixed sha512 hashes -func parseCommon(b []byte) (out []string) { - for i := 0; i < len(b)/emails.HashLen; i++ { +func parseCommon(b []byte, hashLen int) (out []string) { + for i := 0; i < len(b)/hashLen; i++ { // make one - one := make([]byte, len(emails.Prefix)+hex.EncodedLen(len(b[i*emails.HashLen:i*emails.HashLen+emails.HashLen]))) + one := make([]byte, len(emails.Prefix)+hex.EncodedLen(len(b[i*hashLen:i*hashLen+hashLen]))) // copy the prefix first and then the // hex string copy(one, emails.Prefix) - hex.Encode(one[len(emails.Prefix):], b[i*emails.HashLen:i*emails.HashLen+emails.HashLen]) + hex.Encode(one[len(emails.Prefix):], b[i*hashLen:i*hashLen+hashLen]) out = append(out, string(one)) } return @@ -65,14 +65,14 @@ func testReceiver(protocol psi.Protocol, common []byte, s test_size, determinist // setup channels var intersectionsBus = make(chan []byte) var errs = make(chan error, 2) - addr, err := r_receiverInit(protocol, common, s.commonLen, s.receiverLen, intersectionsBus, errs) + addr, err := r_receiverInit(protocol, common, s.commonLen, s.receiverLen, s.hashLen, intersectionsBus, errs) if err != nil { return err } // send operation go func() { - r := initTestDataSource(common, s.senderLen-s.commonLen) + r := initTestDataSource(common, s.senderLen-s.commonLen, s.hashLen) conn, err := net.Dial("tcp", addr) if err != nil { errs <- fmt.Errorf("sender: %v", err) @@ -98,7 +98,7 @@ func testReceiver(protocol psi.Protocol, common []byte, s test_size, determinist // turn the common chunk into a slice of // string IDs - var c = parseCommon(common) + var c = parseCommon(common, s.hashLen) // is this a deterministic PSI? if not remove all false positives first if !deterministic { // filter out intersections to @@ -107,8 +107,8 @@ func testReceiver(protocol psi.Protocol, common []byte, s test_size, determinist } // right amount? - if len(common)/emails.HashLen != len(intersections) { - return fmt.Errorf("expected %d intersections and got %d", len(common)/emails.HashLen, len(intersections)) + if len(common)/s.hashLen != len(intersections) { + return fmt.Errorf("expected %d intersections and got %d", len(common)/s.hashLen, len(intersections)) } // sort intersections sort.Slice(intersections, func(i, j int) bool { @@ -152,34 +152,94 @@ func TestDHPSIReceiver(t *testing.T) { for _, s := range test_sizes { t.Logf("testing scenario %s", s.scenario) // generate common data - common := emails.Common(s.commonLen) + common := emails.Common(s.commonLen, s.hashLen) // test if err := testReceiver(psi.ProtocolDHPSI, common, s, true); err != nil { t.Fatalf("%s: %v", s.scenario, err) } } + + for _, hashLen := range hashLenSizes { + hashLenTest := test_size{"same size with hash digest length", 100, 100, 200, hashLen} + scenario := hashLenTest.scenario + " with hash digest length: " + fmt.Sprint(hashDigestLen(hashLen)) + t.Logf("testing scenario %s", scenario) + // generate common data + common := emails.Common(hashLenTest.commonLen, hashLen) + // test + if err := testReceiver(psi.ProtocolDHPSI, common, hashLenTest, true); err != nil { + t.Fatalf("%s: %v", hashLenTest.scenario, err) + } + } } func TestNPSIReceiver(t *testing.T) { for _, s := range test_sizes { t.Logf("testing scenario %s", s.scenario) // generate common data - common := emails.Common(s.commonLen) + common := emails.Common(s.commonLen, s.hashLen) // test if err := testReceiver(psi.ProtocolNPSI, common, s, true); err != nil { t.Fatalf("%s: %v", s.scenario, err) } } + + for _, hashLen := range hashLenSizes { + hashLenTest := test_size{"same size with hash digest length", 100, 100, 200, hashLen} + scenario := hashLenTest.scenario + " with hash digest length: " + fmt.Sprint(hashDigestLen(hashLen)) + t.Logf("testing scenario %s", scenario) + // generate common data + common := emails.Common(hashLenTest.commonLen, hashLen) + // test + if err := testReceiver(psi.ProtocolNPSI, common, hashLenTest, true); err != nil { + t.Fatalf("%s: %v", hashLenTest.scenario, err) + } + } } func TestBPSIReceiver(t *testing.T) { for _, s := range test_sizes { t.Logf("testing scenario %s", s.scenario) // generate common data - common := emails.Common(s.commonLen) + common := emails.Common(s.commonLen, s.hashLen) // test if err := testReceiver(psi.ProtocolBPSI, common, s, false); err != nil { t.Fatalf("%s: %v", s.scenario, err) } } + + for _, hashLen := range hashLenSizes { + hashLenTest := test_size{"same size with hash digest length", 100, 100, 200, hashLen} + scenario := hashLenTest.scenario + " with hash digest length: " + fmt.Sprint(hashDigestLen(hashLen)) + t.Logf("testing scenario %s", scenario) + // generate common data + common := emails.Common(hashLenTest.commonLen, hashLen) + // test + if err := testReceiver(psi.ProtocolBPSI, common, hashLenTest, false); err != nil { + t.Fatalf("%s: %v", hashLenTest.scenario, err) + } + } +} + +func TestKKRTReceiver(t *testing.T) { + for _, s := range test_sizes { + t.Logf("testing scenario %s", s.scenario) + // generate common data + common := emails.Common(s.commonLen, s.hashLen) + // test + if err := testReceiver(psi.ProtocolKKRTPSI, common, s, true); err != nil { + t.Fatalf("%s: %v", s.scenario, err) + } + } + + for _, hashLen := range hashLenSizes { + hashLenTest := test_size{"same size with hash length", 100, 100, 200, hashLen} + scenario := hashLenTest.scenario + " with hash length: " + fmt.Sprint(hashDigestLen(hashLen)) + t.Logf("testing scenario %s", scenario) + // generate common data + common := emails.Common(hashLenTest.commonLen, hashLen) + // test + if err := testReceiver(psi.ProtocolKKRTPSI, common, hashLenTest, true); err != nil { + t.Fatalf("%s: %v", hashLenTest.scenario, err) + } + } } diff --git a/test/psi/sender_test.go b/test/psi/sender_test.go index 885341f..fa1d551 100644 --- a/test/psi/sender_test.go +++ b/test/psi/sender_test.go @@ -3,7 +3,7 @@ package psi_test import ( "context" - "log" + "fmt" "net" "testing" @@ -12,12 +12,12 @@ import ( ) // will output len(common)+bodyLen identifiers -func initTestDataSource(common []byte, bodyLen int) <-chan []byte { - return emails.Mix(common, bodyLen) +func initTestDataSource(common []byte, bodyLen, hashLen int) <-chan []byte { + return emails.Mix(common, bodyLen, hashLen) } // test receiver and return the addr string -func s_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen int) (addr string, err error) { +func s_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen, hashLen int, errs chan<- error) (addr string, err error) { ln, err := net.Listen("tcp", "127.0.0.1:") if err != nil { return "", err @@ -27,27 +27,27 @@ func s_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen conn, err := ln.Accept() if err != nil { // handle error + errs <- err } - go s_receiverHandle(protocol, common, commonLen, receiverLen, conn) + go s_receiverHandle(protocol, common, commonLen, receiverLen, hashLen, conn, errs) } }() return ln.Addr().String(), nil } -func s_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen int, conn net.Conn) { - r := initTestDataSource(common, receiverLen-commonLen) +func s_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen, hashLen int, conn net.Conn, errs chan<- error) { + r := initTestDataSource(common, receiverLen-commonLen, hashLen) // do a nil receive, ignore the results rec, _ := psi.NewReceiver(protocol, conn) _, err := rec.Intersect(context.Background(), int64(receiverLen), r) if err != nil { - // hmm - send this to the main thread with a channel - log.Print(err) + errs <- err } } -func testSender(protocol psi.Protocol, addr string, common []byte, commonLen, senderLen int) error { +func testSender(protocol psi.Protocol, addr string, common []byte, commonLen, senderLen, hashLen int) error { // test sender - r := initTestDataSource(common, senderLen-commonLen) + r := initTestDataSource(common, senderLen-commonLen, hashLen) conn, err := net.Dial("tcp", addr) if err != nil { return err @@ -61,21 +61,57 @@ func testSender(protocol psi.Protocol, addr string, common []byte, commonLen, se } func testSenderByProtocol(p psi.Protocol, t *testing.T) { + var errs = make(chan error, 2) + defer close(errs) for _, s := range test_sizes { t.Logf("testing scenario %s", s.scenario) // generate common data - common := emails.Common(s.commonLen) + common := emails.Common(s.commonLen, s.hashLen) // init a test receiver server - addr, err := s_receiverInit(p, common, s.commonLen, s.receiverLen) + addr, err := s_receiverInit(p, common, s.commonLen, s.receiverLen, s.hashLen, errs) if err != nil { t.Fatalf("%s: %v", s.scenario, err) } + + // errors? + select { + case err := <-errs: + t.Fatalf("%s: %v", s.scenario, err) + default: + } + // test sender - err = testSender(p, addr, common, s.commonLen, s.senderLen) + err = testSender(p, addr, common, s.commonLen, s.senderLen, 32) if err != nil { t.Fatalf("%s: %v", s.scenario, err) } } + + for _, hashLen := range hashLenSizes { + hashLenTest := test_size{"same size with hash length", 100, 100, 200, hashLen} + scenario := hashLenTest.scenario + " with hash length: " + fmt.Sprint(hashDigestLen(hashLen)) + t.Logf("testing scenario %s", scenario) + // generate common data + common := emails.Common(hashLenTest.commonLen, hashLen) + // init a test receiver server + addr, err := s_receiverInit(p, common, hashLenTest.commonLen, hashLenTest.receiverLen, hashLenTest.hashLen, errs) + if err != nil { + t.Fatalf("%s: %v", scenario, err) + } + + // errors? + select { + case err := <-errs: + t.Fatalf("%s: %v", scenario, err) + default: + } + + // test sender + err = testSender(p, addr, common, hashLenTest.commonLen, hashLenTest.senderLen, hashLenTest.hashLen) + if err != nil { + t.Fatalf("%s: %v", hashLenTest.scenario, err) + } + } } func TestDHPSISender(t *testing.T) { @@ -89,3 +125,7 @@ func TestNPSISender(t *testing.T) { func TestBPSISender(t *testing.T) { testSenderByProtocol(psi.ProtocolBPSI, t) } + +func TestKKRTPSISender(t *testing.T) { + testSenderByProtocol(psi.ProtocolKKRTPSI, t) +} diff --git a/test/psi/types_test.go b/test/psi/types_test.go index be8a1af..0829262 100644 --- a/test/psi/types_test.go +++ b/test/psi/types_test.go @@ -1,9 +1,11 @@ // black box testing of all PSIs package psi_test +import "github.com/optable/match/test/emails" + type test_size struct { - scenario string - commonLen, senderLen, receiverLen int + scenario string + commonLen, senderLen, receiverLen, hashLen int } // test scenarios @@ -14,11 +16,18 @@ type test_size struct { // composed of the common part // var test_sizes = []test_size{ - {"sender100receiver200", 100, 100, 200}, - {"emptySenderSize", 0, 0, 1000}, - {"emptyReceiverSize", 0, 1000, 0}, - {"sameSize", 100, 100, 100}, - {"smallSize", 100, 10000, 1000}, - {"mediumSize", 1000, 100000, 10000}, - {"bigSize", 10000, 100000, 10000}, + {"sender100receiver200", 100, 100, 200, emails.HashLen}, + {"emptySenderSize", 0, 0, 1000, emails.HashLen}, + {"emptyReceiverSize", 0, 1000, 0, emails.HashLen}, + {"sameSize", 100, 100, 100, emails.HashLen}, + {"smallSize", 100, 10000, 1000, emails.HashLen}, + {"mediumSize", 1000, 100000, 10000, emails.HashLen}, + {"bigSize", 10000, 100000, 100000, emails.HashLen}, +} + +var hashLenSizes = []int{4, 8, 16, 32, 64} + +func hashDigestLen(hashLen int) int { + // hex encode + 2 byte for prefix + return 2*hashLen + 2 }