Skip to content

Commit

Permalink
add unsupported psi type as default value (#55)
Browse files Browse the repository at this point in the history
* add unsupported psi type as default value

* change Protocol type from int to a byte

* remove PSIProtocol varoables

* rm comments
  • Loading branch information
juanli16 authored Dec 6, 2021
1 parent a2d92dd commit 545c3de
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 48 deletions.
14 changes: 7 additions & 7 deletions examples/receiver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,16 @@ func main() {
var psiType psi.Protocol
switch *protocol {
case "bpsi":
psiType = psi.BPSI
psiType = psi.ProtocolBPSI
case "npsi":
psiType = psi.NPSI
psiType = psi.ProtocolNPSI
case "dhpsi":
psiType = psi.DHPSI
psiType = psi.ProtocolDHPSI
default:
log.Printf("unsupported protocol %s", *protocol)
showUsageAndExit(0)
psiType = psi.ProtocolUnsupported
}

log.Printf("operating with protocol %s", *protocol)
log.Printf("operating with protocol %s", psiType)
// fetch stdr logger
mlog := getLogger(*verbose)

Expand Down Expand Up @@ -127,7 +126,8 @@ func main() {
}
// make the receiver

receiver, _ := psi.NewReceiver(psiType, c)
receiver, err := psi.NewReceiver(psiType, c)
exitOnErr(mlog, err, "failed to create receiver")
// and hand it off
wg.Add(1)
go func() {
Expand Down
14 changes: 7 additions & 7 deletions examples/sender/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,16 @@ func main() {
var psiType psi.Protocol
switch *protocol {
case "bpsi":
psiType = psi.BPSI
psiType = psi.ProtocolBPSI
case "npsi":
psiType = psi.NPSI
psiType = psi.ProtocolNPSI
case "dhpsi":
psiType = psi.DHPSI
psiType = psi.ProtocolDHPSI
default:
log.Printf("unsupported protocol %s", *protocol)
showUsageAndExit(0)
psiType = psi.ProtocolUnsupported
}

log.Printf("operating with protocol %s", *protocol)
log.Printf("operating with protocol %s", psiType)
// fetch stdr logger
slog := getLogger(*verbose)

Expand All @@ -109,7 +108,8 @@ func main() {
v.SetNoDelay(false)
}

s, _ := psi.NewSender(psiType, c)
s, err := psi.NewSender(psiType, c)
exitOnErr(slog, err, "failed to create sender")
ids := util.Exhaust(n, f)
err = s.Send(logr.NewContext(context.Background(), slog), n, ids)
exitOnErr(slog, err, "failed to perform PSI")
Expand Down
35 changes: 18 additions & 17 deletions pkg/psi/psi.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,26 @@ package psi

import (
"context"
"fmt"
"errors"
"io"

"github.com/optable/match/pkg/bpsi"
"github.com/optable/match/pkg/dhpsi"
"github.com/optable/match/pkg/npsi"
)

const (
DHPSI = iota
NPSI
BPSI
)

// Protocol is the matching protocol enumeration
type Protocol int
type Protocol byte

var (
ProtocolDHPSI Protocol = DHPSI
ProtocolNPSI Protocol = NPSI
ProtocolBPSI Protocol = BPSI
const (
ProtocolUnsupported Protocol = iota
ProtocolDHPSI
ProtocolNPSI
ProtocolBPSI
)

var ErrUnsupportedPSIProtocol = errors.New("unsupported PSI protocol")

// Sender is the sender side of the PSI operation
type Sender interface {
Send(ctx context.Context, n int64, identifiers <-chan []byte) error
Expand All @@ -43,9 +40,10 @@ func NewSender(protocol Protocol, rw io.ReadWriter) (Sender, error) {
return npsi.NewSender(rw), nil
case ProtocolBPSI:
return bpsi.NewSender(rw), nil

case ProtocolUnsupported:
fallthrough
default:
return nil, fmt.Errorf("PSI sender protocol %d not supported", protocol)
return nil, ErrUnsupportedPSIProtocol
}
}

Expand All @@ -57,9 +55,10 @@ func NewReceiver(protocol Protocol, rw io.ReadWriter) (Receiver, error) {
return npsi.NewReceiver(rw), nil
case ProtocolBPSI:
return bpsi.NewReceiver(rw), nil

case ProtocolUnsupported:
fallthrough
default:
return nil, fmt.Errorf("PSI receiver protocol %d not supported", protocol)
return nil, ErrUnsupportedPSIProtocol
}
}

Expand All @@ -71,7 +70,9 @@ func (p Protocol) String() string {
return "npsi"
case ProtocolBPSI:
return "bpsi"
case ProtocolUnsupported:
fallthrough
default:
return "undefined"
return "unsupported"
}
}
16 changes: 8 additions & 8 deletions test/psi/receiver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

// test receiver and return the addr string
func r_receiverInit(protocol int, common []byte, commonLen, receiverLen int, intersectionsBus chan<- []byte, errs chan<- error) (addr string, err error) {
func r_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen int, intersectionsBus chan<- []byte, errs chan<- error) (addr string, err error) {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return "", err
Expand All @@ -31,11 +31,11 @@ func r_receiverInit(protocol int, common []byte, commonLen, receiverLen int, int
return ln.Addr().String(), nil
}

func r_receiverHandle(protocol int, common []byte, commonLen, receiverLen int, conn net.Conn, intersectionsBus chan<- []byte, errs chan<- error) {
func r_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen int, conn net.Conn, intersectionsBus chan<- []byte, errs chan<- error) {
defer close(intersectionsBus)
r := initTestDataSource(common, receiverLen-commonLen)

rec, _ := psi.NewReceiver(psi.Protocol(protocol), conn)
rec, _ := psi.NewReceiver(protocol, conn)
ii, err := rec.Intersect(context.Background(), int64(receiverLen), r)
for _, intersection := range ii {
intersectionsBus <- intersection
Expand All @@ -61,7 +61,7 @@ func parseCommon(b []byte) (out []string) {
return
}

func testReceiver(protocol int, common []byte, s test_size, deterministic bool) error {
func testReceiver(protocol psi.Protocol, common []byte, s test_size, deterministic bool) error {
// setup channels
var intersectionsBus = make(chan []byte)
var errs = make(chan error, 2)
Expand All @@ -77,7 +77,7 @@ func testReceiver(protocol int, common []byte, s test_size, deterministic bool)
if err != nil {
errs <- fmt.Errorf("sender: %v", err)
}
snd, _ := psi.NewSender(psi.Protocol(protocol), conn)
snd, _ := psi.NewSender(protocol, conn)
err = snd.Send(context.Background(), int64(s.senderLen), r)
if err != nil {
errs <- fmt.Errorf("sender: %v", err)
Expand Down Expand Up @@ -154,7 +154,7 @@ func TestDHPSIReceiver(t *testing.T) {
// generate common data
common := emails.Common(s.commonLen)
// test
if err := testReceiver(psi.DHPSI, common, s, true); err != nil {
if err := testReceiver(psi.ProtocolDHPSI, common, s, true); err != nil {
t.Fatalf("%s: %v", s.scenario, err)
}
}
Expand All @@ -166,7 +166,7 @@ func TestNPSIReceiver(t *testing.T) {
// generate common data
common := emails.Common(s.commonLen)
// test
if err := testReceiver(psi.NPSI, common, s, true); err != nil {
if err := testReceiver(psi.ProtocolNPSI, common, s, true); err != nil {
t.Fatalf("%s: %v", s.scenario, err)
}
}
Expand All @@ -178,7 +178,7 @@ func TestBPSIReceiver(t *testing.T) {
// generate common data
common := emails.Common(s.commonLen)
// test
if err := testReceiver(psi.BPSI, common, s, false); err != nil {
if err := testReceiver(psi.ProtocolBPSI, common, s, false); err != nil {
t.Fatalf("%s: %v", s.scenario, err)
}
}
Expand Down
18 changes: 9 additions & 9 deletions test/psi/sender_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ func initTestDataSource(common []byte, bodyLen int) <-chan []byte {
}

// test receiver and return the addr string
func s_receiverInit(protocol int, common []byte, commonLen, receiverLen int) (addr string, err error) {
func s_receiverInit(protocol psi.Protocol, common []byte, commonLen, receiverLen int) (addr string, err error) {
ln, err := net.Listen("tcp", "127.0.0.1:")
if err != nil {
return "", err
Expand All @@ -34,33 +34,33 @@ func s_receiverInit(protocol int, common []byte, commonLen, receiverLen int) (ad
return ln.Addr().String(), nil
}

func s_receiverHandle(protocol int, common []byte, commonLen, receiverLen int, conn net.Conn) {
func s_receiverHandle(protocol psi.Protocol, common []byte, commonLen, receiverLen int, conn net.Conn) {
r := initTestDataSource(common, receiverLen-commonLen)
// do a nil receive, ignore the results
rec, _ := psi.NewReceiver(psi.Protocol(protocol), conn)
rec, _ := psi.NewReceiver(protocol, conn)
_, err := rec.Intersect(context.Background(), int64(receiverLen), r)
if err != nil {
// hmm - send this to the main thread with a channel
log.Print(err)
}
}

func testSender(protocol int, addr string, common []byte, commonLen, senderLen int) error {
func testSender(protocol psi.Protocol, addr string, common []byte, commonLen, senderLen int) error {
// test sender
r := initTestDataSource(common, senderLen-commonLen)
conn, err := net.Dial("tcp", addr)
if err != nil {
return err
}
snd, _ := psi.NewSender(psi.Protocol(protocol), conn)
snd, _ := psi.NewSender(protocol, conn)
err = snd.Send(context.Background(), int64(senderLen), r)
if err != nil {
return err
}
return nil
}

func testSenderByProtocol(p int, t *testing.T) {
func testSenderByProtocol(p psi.Protocol, t *testing.T) {
for _, s := range test_sizes {
t.Logf("testing scenario %s", s.scenario)
// generate common data
Expand All @@ -79,13 +79,13 @@ func testSenderByProtocol(p int, t *testing.T) {
}

func TestDHPSISender(t *testing.T) {
testSenderByProtocol(psi.DHPSI, t)
testSenderByProtocol(psi.ProtocolDHPSI, t)
}

func TestNPSISender(t *testing.T) {
testSenderByProtocol(psi.NPSI, t)
testSenderByProtocol(psi.ProtocolNPSI, t)
}

func TestBPSISender(t *testing.T) {
testSenderByProtocol(psi.BPSI, t)
testSenderByProtocol(psi.ProtocolBPSI, t)
}

0 comments on commit 545c3de

Please sign in to comment.