Skip to content

Commit

Permalink
expose KeyPair.Public() key
Browse files Browse the repository at this point in the history
  • Loading branch information
ekoby committed Feb 14, 2020
1 parent 3ba9bbc commit 8f7d94a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 21 deletions.
28 changes: 17 additions & 11 deletions kx/kx.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ var cryptoError = errors.New("crypto error")
var notImplemented = errors.New("not implemented")

type KeyPair struct {
pk []byte
sk []byte
pk [SessionKeyBytes]byte
sk [PublicKeyBytes]byte
}

func NewKeyPair() (*KeyPair, error) {
Expand All @@ -39,24 +39,30 @@ func newKeyPairFromSeed(seed []byte) (*KeyPair, error) {

hash, _ := blake2b.New(SecretKeyBytes, nil)
hash.Write(seed)
kp.sk = hash.Sum(nil)

if len(kp.sk) != SecretKeyBytes {
sk := hash.Sum(nil)
if len(sk) != SecretKeyBytes {
return nil, cryptoError
}
copy(kp.sk[:], sk)

kp.pk, err = curve25519.X25519(kp.sk, curve25519.Basepoint)
pk, err := curve25519.X25519(kp.sk[:], curve25519.Basepoint)
if err != nil {
return nil, err
}
if len(kp.pk) != PublicKeyBytes {
if len(pk) != PublicKeyBytes {
return nil, cryptoError
}
copy(kp.pk[:], pk)

return kp, nil
}

func (pair *KeyPair) Public() []byte {
return pair.pk[:]
}

func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte, err error) {
q, err := curve25519.X25519(pair.sk, server_pk)
q, err := curve25519.X25519(pair.sk[:], server_pk)
if err != nil {
return nil, nil, err
}
Expand All @@ -66,7 +72,7 @@ func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte,
return nil, nil, err
}

for _, b := range [][]byte{q, pair.pk, server_pk} {
for _, b := range [][]byte{q, pair.Public(), server_pk} {
if _, err = h.Write(b); err != nil {
return nil, nil, err
}
Expand All @@ -80,7 +86,7 @@ func (pair *KeyPair) ClientSessionKeys(server_pk []byte) (rx []byte, tx []byte,

func (pair *KeyPair) ServerSessionKeys(client_pk []byte) (rx []byte, tx []byte, err error) {

q, err := curve25519.X25519(pair.sk, client_pk)
q, err := curve25519.X25519(pair.sk[:], client_pk)
if err != nil {
return nil, nil, err
}
Expand All @@ -90,7 +96,7 @@ func (pair *KeyPair) ServerSessionKeys(client_pk []byte) (rx []byte, tx []byte,
return nil, nil, err
}

for _, b := range [][]byte{q, client_pk, pair.pk} {
for _, b := range [][]byte{q, client_pk, pair.Public()} {
if _, err = h.Write(b); err != nil {
return nil, nil, err
}
Expand Down
20 changes: 10 additions & 10 deletions kx/kx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ func seedIncrement(s []byte) []byte {
func TestNewKeyPair(t *testing.T) {
pk, _ := hex.DecodeString("0e0216223f147143d32615a91189c288c1728cba3cc5f9f621b1026e03d83129")
sk, _ := hex.DecodeString("cb2f5160fc1f7e05a55ef49d340b48da2e5a78099d53393351cd579dd42503d6")
kp := &KeyPair{}
copy(kp.pk[:], pk)
copy(kp.sk[:], sk)

type args struct {
seed []byte
Expand All @@ -45,12 +48,9 @@ func TestNewKeyPair(t *testing.T) {
wantErr bool
}{
{
name: "pre-seeded key",
args: args{seed: seed},
want: &KeyPair{
pk: pk,
sk: sk,
},
name: "pre-seeded key",
args: args{seed: seed},
want: kp,
wantErr: false,
},
}
Expand Down Expand Up @@ -84,7 +84,7 @@ func TestKeyExchange_Seeded(t *testing.T) {
clt_rx, _ := hex.DecodeString("749519c68059bce69f7cfcc7b387a3de1a1e8237d110991323bf62870115731a")
clt_tx, _ := hex.DecodeString("62c8f4fa81800abd0577d99918d129b65deb789af8c8351f391feb0cbf238604")

client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.pk)
client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.Public())
if err != nil {
t.Errorf("ClientSessionKeys: error = %v", err)
return
Expand All @@ -97,7 +97,7 @@ func TestKeyExchange_Seeded(t *testing.T) {
t.Errorf("ClientSessionKeys(): TX got = %v, want %v", client_tx, clt_tx)
}

server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.pk)
server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.Public())
if err != nil {
t.Errorf("ServerSessionKeys: error = %v", err)
return
Expand All @@ -124,13 +124,13 @@ func TestKeyExchange(t *testing.T) {
return
}

client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.pk)
client_rx, client_tx, err := client_pair.ClientSessionKeys(server_pair.Public())
if err != nil {
t.Errorf("ClientSessionKeys: error = %v", err)
return
}

server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.pk)
server_rx, server_tx, err := server_pair.ServerSessionKeys(client_pair.Public())
if err != nil {
t.Errorf("ServerSessionKeys: error = %v", err)
return
Expand Down

0 comments on commit 8f7d94a

Please sign in to comment.