Skip to content

Commit

Permalink
update to new mp spdz version (#15)
Browse files Browse the repository at this point in the history
Signed-off-by: Johannes Graf <[email protected]>
Signed-off-by: Petra Scherer <[email protected]>
Signed-off-by: Timo Klenk <[email protected]>
  • Loading branch information
grafjo authored May 6, 2022
1 parent 7c3f0e0 commit c43d17c
Show file tree
Hide file tree
Showing 14 changed files with 154 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .ko.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/ephemeral-spdz-base-image:cleared-20210827
defaultBaseImage: ghcr.io/carbynestack/ubuntu:20.04-20210827-nonroot
baseImageOverrides:
github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/spdz:20210827
github.com/carbynestack/ephemeral/cmd/ephemeral: ghcr.io/carbynestack/spdz:642d11f
5 changes: 3 additions & 2 deletions cmd/discovery/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package main

import (
"context"
"errors"
"fmt"
"io/ioutil"
Expand Down Expand Up @@ -57,7 +58,7 @@ var _ = Describe("Main", func() {
})
Context("all required parameters are specified", func() {
AfterEach(func() {
_, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./")
_, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./")
Expect(err).NotTo(HaveOccurred())
})
Context("parameters are plausible", func() {
Expand Down Expand Up @@ -100,7 +101,7 @@ var _ = Describe("Main", func() {
Context("one of the required parameters is missing", func() {
Context("when no frontendURL is defined", func() {
AfterEach(func() {
_, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./")
_, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./")
Expect(err).NotTo(HaveOccurred())
})
It("returns an error", func() {
Expand Down
3 changes: 2 additions & 1 deletion cmd/ephemeral/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package main_test

import (
"context"
"fmt"
"io/ioutil"
"math/rand"
Expand Down Expand Up @@ -43,7 +44,7 @@ var _ = Describe("Main", func() {
path = fmt.Sprintf("/tmp/test-%d", random)
})
AfterEach(func() {
_, _, err := cmder.CallCMD([]string{fmt.Sprintf("rm %s", path)}, "./")
_, _, err := cmder.CallCMD(context.TODO(), []string{fmt.Sprintf("rm %s", path)}, "./")
Expect(err).NotTo(HaveOccurred())
})
Context("when it succeeds", func() {
Expand Down
5 changes: 3 additions & 2 deletions pkg/ephemeral/fake_spdz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package ephemeral

import (
"context"
"errors"
"github.com/carbynestack/ephemeral/pkg/discovery/fsm"
pb "github.com/carbynestack/ephemeral/pkg/discovery/transport/proto"
Expand Down Expand Up @@ -93,14 +94,14 @@ func (f *FakePlayer) PublishEvent(name, topic string, event *pb.Event) {
type FakeExecutor struct {
}

func (f *FakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) {
func (f *FakeExecutor) CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) {
return []byte{}, []byte{}, nil
}

type BrokenFakeExecutor struct {
}

func (f *BrokenFakeExecutor) CallCMD(cmd []string, dir string) ([]byte, []byte, error) {
func (f *BrokenFakeExecutor) CallCMD(ctx context.Context, cmd []string, dir string) ([]byte, []byte, error) {
return []byte{}, []byte{}, errors.New("some error")
}

Expand Down
58 changes: 55 additions & 3 deletions pkg/ephemeral/io/carrier.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ package io

import (
"context"
"encoding/binary"
"errors"
"fmt"
"github.com/carbynestack/ephemeral/pkg/amphora"
"io"
"io/ioutil"
"net"
)
Expand All @@ -21,7 +24,7 @@ type Result struct {

// AbstractCarrier is the carriers interface.
type AbstractCarrier interface {
Connect(context.Context, string, string) error
Connect(context.Context, int32, string, string) error
Close() error
Send([]amphora.SecretShare) error
Read(ResponseConverter, bool) (*Result, error)
Expand All @@ -42,16 +45,54 @@ type Config struct {
}

// Connect establishes a TCP connection to a socket on a given host and port.
func (c *Carrier) Connect(ctx context.Context, host, port string) error {
func (c *Carrier) Connect(ctx context.Context, playerID int32, host string, port string) error {
conn, err := c.Dialer(ctx, host, port)
c.Conn = conn
if err != nil {
return err
}
c.Conn = conn
_, err = conn.Write(c.buildHeader(playerID))
if err != nil {
return err
}
if playerID == 0 {
err = c.readPrime()
if err != nil {
return err
}
}
c.connected = true
return nil
}

// readPrime reads the file header from the MP-SPDZ connection
// In MP-SPDZ connection, this will only be used when player0 connects as client to MP-SPDZ
//
// For the header composition, check:
// https://github.com/data61/MP-SPDZ/issues/418#issuecomment-975424591
//
// It is made up as follows:
// - Careful: The other header parts are not part of this communication, they are only used when reading tuple files
// - length of the prime as 4-byte number little-endian (e.g. 16),
// - prime in big-endian (e.g. 170141183460469231731687303715885907969)
func (c Carrier) readPrime() error {
const size = 4
readBytes := make([]byte, size)
_, err := io.LimitReader(c.Conn, size).Read(readBytes)
if err != nil {
return err
}

sizeOfHeader := binary.LittleEndian.Uint32(readBytes)
readBytes = make([]byte, sizeOfHeader)
_, err = io.LimitReader(c.Conn, int64(sizeOfHeader)).Read(readBytes)
if err != nil {
return err
}
//ToDo, compare read PRIME with prime number from config?
return nil
}

// Close closes the underlying TCP connection.
func (c *Carrier) Close() error {
if c.connected {
Expand All @@ -78,6 +119,17 @@ func (c *Carrier) Send(secret []amphora.SecretShare) error {
return nil
}

// Returns a new Slice with the header appended
// The header consists of the clientId as string:
// - 1 Long (4 Byte) that contains the length of the string in bytes
// - Then come X Bytes for the String
func (c *Carrier) buildHeader(playerID int32) []byte {
playerIDString := []byte(fmt.Sprintf("%d", playerID))
lengthOfString := make([]byte, 4)
binary.LittleEndian.PutUint32(lengthOfString, uint32(len(playerIDString)))
return append(lengthOfString, playerIDString...)
}

// Read reads the response from the TCP connection and unmarshals it.
func (c *Carrier) Read(conv ResponseConverter, bulkObjects bool) (*Result, error) {
resp := []byte{}
Expand Down
81 changes: 64 additions & 17 deletions pkg/ephemeral/io/carrier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ package io_test
import (
"context"
"fmt"
"net"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"

"github.com/carbynestack/ephemeral/pkg/amphora"
. "github.com/carbynestack/ephemeral/pkg/ephemeral/io"
. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
"net"
"sync"
)

var _ = Describe("Carrier", func() {
var ctx = context.TODO()
var playerID = int32(1) // PlayerID 1, since PlayerID==0 contains another check when connecting

It("connects to a socket", func() {
var connected bool
conn := FakeNetConnection{}
Expand All @@ -30,7 +31,7 @@ var _ = Describe("Carrier", func() {
carrier := Carrier{
Dialer: fakeDialer,
}
err := carrier.Connect(context.TODO(), "", "")
err := carrier.Connect(context.TODO(), playerID, "", "")
Expect(connected).To(BeTrue())
Expect(err).NotTo(HaveOccurred())
})
Expand All @@ -42,24 +43,26 @@ var _ = Describe("Carrier", func() {
carrier := Carrier{
Dialer: fakeDialer,
}
err := carrier.Connect(context.TODO(), "", "")
err := carrier.Connect(context.TODO(), playerID, "", "")
Expect(err).NotTo(HaveOccurred())
err = carrier.Close()
Expect(err).NotTo(HaveOccurred())
Expect(conn.Closed).To(BeTrue())
})

var (
secret []amphora.SecretShare
output []byte
client, server net.Conn
dialer func(ctx context.Context, addr, port string) (net.Conn, error)
secret []amphora.SecretShare
output []byte
connectionOutput []byte //Will contain (length 4 byte, playerID 1 byte)
client, server net.Conn
dialer func(ctx context.Context, addr, port string) (net.Conn, error)
)
BeforeEach(func() {
secret = []amphora.SecretShare{
amphora.SecretShare{},
}
output = make([]byte, 1)
connectionOutput = make([]byte, 5)
client, server = net.Pipe()
dialer = func(ctx context.Context, addr, port string) (net.Conn, error) {
return client, nil
Expand All @@ -75,20 +78,23 @@ var _ = Describe("Carrier", func() {
Dialer: dialer,
Packer: packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
go server.Read(output)
err := carrier.Send(secret)
carrier.Close()
Expect(err).NotTo(HaveOccurred())
Expect(output[0]).To(Equal(byte(1)))
Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerID)[0]}))
})
It("returns an error when it fails to marshal the object", func() {
packer := &FakeBrokenPacker{}
carrier := Carrier{
Dialer: dialer,
Packer: packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
go server.Read(output)
err := carrier.Send(secret)
carrier.Close()
Expand All @@ -103,7 +109,8 @@ var _ = Describe("Carrier", func() {
Dialer: dialer,
Packer: packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
// Closing the connection to trigger a failure due to writing into the closed socket.
server.Close()
err := carrier.Send(secret)
Expand All @@ -123,7 +130,8 @@ var _ = Describe("Carrier", func() {
Dialer: dialer,
Packer: &packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
go func() {
server.Write(serverResponse)
server.Close()
Expand All @@ -143,7 +151,8 @@ var _ = Describe("Carrier", func() {
Dialer: dialer,
Packer: &packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
server.Close()
anyConverter := &PlaintextConverter{}
_, err := carrier.Read(anyConverter, false)
Expand All @@ -156,7 +165,8 @@ var _ = Describe("Carrier", func() {
Dialer: dialer,
Packer: packer,
}
carrier.Connect(ctx, "", "")
go server.Read(connectionOutput)
carrier.Connect(ctx, playerID, "", "")
go func() {
server.Write(serverResponse)
server.Close()
Expand All @@ -166,4 +176,41 @@ var _ = Describe("Carrier", func() {
Expect(err).To(HaveOccurred())
})
})

Context("when connecting as Player0", func() {
playerID := int32(0)
It("will receive and handle the server's fileHeader", func() {
// Arrange
// ToDo: Better Response for real-life scenario?
serverResponse := []byte{1, 0, 0, 0, 1} // 4 byte length + header, in this case "1". In real case Descriptor + Prime
packer := &FakeBrokenPacker{}
carrier := Carrier{
Dialer: dialer,
Packer: packer,
}
waitGroup := sync.WaitGroup{}
waitGroup.Add(1)
go server.Read(connectionOutput)

// Act
var errConnecting error
go func() {
errConnecting = carrier.Connect(ctx, playerID, "", "")
waitGroup.Done()
}()

numberOfBytesWritten, errWrite := server.Write(serverResponse)
errClose := server.Close()

// Make sure we wait until the Connect and Write are done
waitGroup.Wait()

// Assert
Expect(connectionOutput).To(Equal([]byte{1, 0, 0, 0, fmt.Sprintf("%d", playerID)[0]}))
Expect(errConnecting).NotTo(HaveOccurred())
Expect(errWrite).NotTo(HaveOccurred())
Expect(numberOfBytesWritten).To(Equal(len(serverResponse)))
Expect(errClose).NotTo(HaveOccurred())
})
})
})
2 changes: 1 addition & 1 deletion pkg/ephemeral/io/feeder.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ func (f *AmphoraFeeder) feedAndRead(params []string, port string, ctx *CtxConfig
default:
return nil, fmt.Errorf("no output config is given, either %s, %s or %s must be defined", PlainText, SecretShare, AmphoraSecret)
}
err := f.carrier.Connect(ctx.Context, "localhost", port)
err := f.carrier.Connect(ctx.Context, ctx.Spdz.PlayerID, "localhost", port)
defer f.carrier.Close()
if err != nil {
return nil, err
Expand Down
7 changes: 4 additions & 3 deletions pkg/ephemeral/io/feeder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var _ = Describe("Feeder", func() {
conf = &CtxConfig{
Act: act,
Context: context.TODO(),
Spdz: &SPDZEngineTypedConfig{PlayerCount: 2},
}
})

Expand Down Expand Up @@ -211,7 +212,7 @@ type FakeCarrier struct {
isBulk bool
}

func (f *FakeCarrier) Connect(context.Context, string, string) error {
func (f *FakeCarrier) Connect(context.Context, int32, string, string) error {
return nil
}

Expand All @@ -232,7 +233,7 @@ type BrokenConnectFakeCarrier struct {
isBulk bool
}

func (f *BrokenConnectFakeCarrier) Connect(context.Context, string, string) error {
func (f *BrokenConnectFakeCarrier) Connect(context.Context, int32, string, string) error {
return errors.New("carrier connect error")
}

Expand All @@ -253,7 +254,7 @@ type BrokenSendFakeCarrier struct {
isBulk bool
}

func (f *BrokenSendFakeCarrier) Connect(context.Context, string, string) error {
func (f *BrokenSendFakeCarrier) Connect(context.Context, int32, string, string) error {
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/ephemeral/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,5 @@ func (c *Callbacker) sendEvent(name, topic string, e interface{}) {
},
}
c.pb.PublishWithBody(name, topic, event, c.playerParams.GameID)
c.logger.Debugf("Sending event %v to topic %s\n", event.Name, topic)
c.logger.Debugw("Sending event", "event", event, "topic", topic)
}
Loading

0 comments on commit c43d17c

Please sign in to comment.