Skip to content

Commit

Permalink
crypto/tls: add new X25519Kyber768Draft00 code point
Browse files Browse the repository at this point in the history
* Point tls.X25519Kyber768Draft00 to the new 0x6399 identifier while the
  old 0xfe31 identifier is available as tls.X25519Kyber768Draft00Old.
* Make sure that the kem.PrivateKey can always be mapped to the CurveID
  that was linked to it. This is needed since we now have two ID
  aliasing to the same scheme, and clients need to be able to detect
  whether the key share presented by the server actually matches the key
  share that the client originally sent.
* Update tests, add the new identifier and remove unnecessary code.

Link: https://mailarchive.ietf.org/arch/msg/tls/HAWpNpgptl--UZNSYuvsjB-Pc2k/
Link: https://datatracker.ietf.org/doc/draft-tls-westerbaan-xyber768d00/02/
  • Loading branch information
Lekensteyn authored and bwesterb committed Jul 12, 2023
1 parent 4c13101 commit a888ac3
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 75 deletions.
50 changes: 19 additions & 31 deletions src/crypto/tls/cfkem.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@
// To enable set CurvePreferences with the desired scheme as the first element:
//
// import (
// "github.com/cloudflare/circl/kem/tls"
// "github.com/cloudflare/circl/kem/hybrid"
// "crypto/tls"
//
// [...]
//
// config.CurvePreferences = []tls.CurveID{
// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(),
// tls.X25519Kyber768Draft00,
// tls.X25519,
// tls.P256,
// }
Expand All @@ -29,38 +28,27 @@ import (
"github.com/cloudflare/circl/kem/hybrid"
)

// Either ecdheParameters or kem.PrivateKey
// Either *ecdh.PrivateKey or *kemPrivateKey
type clientKeySharePrivate interface{}

type kemPrivateKey struct {
secretKey kem.PrivateKey
curveID CurveID
}

var (
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
X25519Kyber512Draft00 = CurveID(0xfe30)
X25519Kyber768Draft00 = CurveID(0x6399)
X25519Kyber768Draft00Old = CurveID(0xfe31)
P256Kyber768Draft00 = CurveID(0xfe32)
invalidCurveID = CurveID(0)
)

func kemSchemeKeyToCurveID(s kem.Scheme) CurveID {
switch s.Name() {
case "Kyber512-X25519":
return X25519Kyber512Draft00
case "Kyber768-X25519":
return X25519Kyber768Draft00
case "P256Kyber768Draft00":
return P256Kyber768Draft00
default:
return invalidCurveID
}
}

// Extract CurveID from clientKeySharePrivate
func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID {
switch v := ks.(type) {
case kem.PrivateKey:
ret := kemSchemeKeyToCurveID(v.Scheme())
if ret == invalidCurveID {
panic("cfkem: internal error: don't know CurveID for this KEM")
}
return ret
case *kemPrivateKey:
return v.curveID
case *ecdh.PrivateKey:
ret, ok := curveIDForCurve(v.Curve())
if !ok {
Expand All @@ -77,7 +65,7 @@ func curveIdToCirclScheme(id CurveID) kem.Scheme {
switch id {
case X25519Kyber512Draft00:
return hybrid.Kyber512X25519()
case X25519Kyber768Draft00:
case X25519Kyber768Draft00, X25519Kyber768Draft00Old:
return hybrid.Kyber768X25519()
case P256Kyber768Draft00:
return hybrid.P256Kyber768Draft00()
Expand All @@ -102,12 +90,12 @@ func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) (
}

// Generate a new keypair using randomness from rnd.
func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) (
kem.PublicKey, kem.PrivateKey, error) {
func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) (
kem.PublicKey, *kemPrivateKey, error) {
seed := make([]byte, scheme.SeedSize())
if _, err := io.ReadFull(rnd, seed); err != nil {
return nil, nil, err
}
pk, sk := scheme.DeriveKeyPair(seed)
return pk, sk, nil
return pk, &kemPrivateKey{sk, curveID}, nil
}
63 changes: 24 additions & 39 deletions src/crypto/tls/cfkem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,16 @@ import (
"context"
"fmt"
"testing"

"github.com/cloudflare/circl/kem"
"github.com/cloudflare/circl/kem/hybrid"
)

func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ,
clientTLS12, serverTLS12 bool) {
var clientSelectedKEX *CurveID
var retry bool

rsaCert := Certificate{
Certificate: [][]byte{testRSACertificate},
PrivateKey: testRSAPrivateKey,
}
serverCerts := []Certificate{rsaCert}

clientConfig := testConfig.Clone()
if clientPQ {
clientConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
clientConfig.CurvePreferences = []CurveID{curveID, X25519}
}
clientCFEventHandler := func(ev CFEvent) {
switch e := ev.(type) {
Expand All @@ -44,15 +32,13 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,

serverConfig := testConfig.Clone()
if serverPQ {
serverConfig.CurvePreferences = []CurveID{
kemSchemeKeyToCurveID(scheme),
X25519,
}
serverConfig.CurvePreferences = []CurveID{curveID, X25519}
} else {
serverConfig.CurvePreferences = []CurveID{X25519}
}
if serverTLS12 {
serverConfig.MaxVersion = VersionTLS12
}
serverConfig.Certificates = serverCerts

c, s := localPipe(t)
done := make(chan error)
Expand All @@ -78,44 +64,43 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ,
var expectedRetry bool

if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 {
expectedKEX = kemSchemeKeyToCurveID(scheme)
expectedKEX = curveID
} else {
expectedKEX = X25519
}
if !clientTLS12 && clientPQ && !serverPQ {
expectedRetry = true
}

if expectedRetry != retry {
t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
}
if clientSelectedKEX == nil {
t.Error("No KEX happened?")
}

if *clientSelectedKEX != expectedKEX {
} else if *clientSelectedKEX != expectedKEX {
t.Errorf("failed to negotiate: expected %d, got %d",
expectedKEX, *clientSelectedKEX)
}
if expectedRetry != retry {
t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry)
}
}

func TestHybridKEX(t *testing.T) {
run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(),
run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) {
t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID),
serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) {
testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12)
testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12)
})
}
for _, scheme := range []kem.Scheme{
hybrid.Kyber512X25519(),
hybrid.Kyber768X25519(),
hybrid.P256Kyber768Draft00(),
for _, curveID := range []CurveID{
X25519Kyber512Draft00,
X25519Kyber768Draft00,
X25519Kyber768Draft00Old,
P256Kyber768Draft00,
} {
run(scheme, true, true, false, false)
run(scheme, true, false, false, false)
run(scheme, false, true, false, false)
run(scheme, true, true, true, false)
run(scheme, true, true, false, true)
run(scheme, true, true, true, true)
run(curveID, true, true, false, false)
run(curveID, true, false, false, false)
run(curveID, false, true, false, false)
run(curveID, true, true, true, false)
run(curveID, true, true, false, true)
run(curveID, true, true, true, true)
}
}
2 changes: 1 addition & 1 deletion src/crypto/tls/handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha

curveID := config.curvePreferences()[0]
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, config.rand())
pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand())
if err != nil {
return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w",
scheme.Name(), err)
Expand Down
7 changes: 3 additions & 4 deletions src/crypto/tls/handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ import (
"fmt"
"hash"
"time"

circlKem "github.com/cloudflare/circl/kem"
)

type clientHandshakeStateTLS13 struct {
Expand Down Expand Up @@ -382,7 +380,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share")
}
if scheme := curveIdToCirclScheme(curveID); scheme != nil {
pk, sk, err := generateKemKeyPair(scheme, c.config.rand())
pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand())
if err != nil {
c.sendAlert(alertInternalError)
return fmt.Errorf("HRR generateKemKeyPair %s: %w",
Expand Down Expand Up @@ -610,7 +608,8 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
if err == nil {
sharedKey, _ = key.ECDH(peerKey)
}
} else if sk, ok := hs.keySharePrivate.(circlKem.PrivateKey); ok {
} else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok {
sk := key.secretKey
sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data)
if err != nil {
c.sendAlert(alertIllegalParameter)
Expand Down

0 comments on commit a888ac3

Please sign in to comment.