From a888ac35c2f9878dcd8d0c69e75aa6541e5b6b89 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 25 May 2023 19:00:56 +0200 Subject: [PATCH] crypto/tls: add new X25519Kyber768Draft00 code point * Point tls.X25519Kyber768Draft00 to the new 0x6399 identifier while the old 0xfe31 identifier is available as tls.X25519Kyber768Draft00Old. * Make sure that the kem.PrivateKey can always be mapped to the CurveID that was linked to it. This is needed since we now have two ID aliasing to the same scheme, and clients need to be able to detect whether the key share presented by the server actually matches the key share that the client originally sent. * Update tests, add the new identifier and remove unnecessary code. Link: https://mailarchive.ietf.org/arch/msg/tls/HAWpNpgptl--UZNSYuvsjB-Pc2k/ Link: https://datatracker.ietf.org/doc/draft-tls-westerbaan-xyber768d00/02/ --- src/crypto/tls/cfkem.go | 50 +++++++------------ src/crypto/tls/cfkem_test.go | 63 +++++++++--------------- src/crypto/tls/handshake_client.go | 2 +- src/crypto/tls/handshake_client_tls13.go | 7 ++- 4 files changed, 47 insertions(+), 75 deletions(-) diff --git a/src/crypto/tls/cfkem.go b/src/crypto/tls/cfkem.go index 083f2921b64..8d440e4c3c5 100644 --- a/src/crypto/tls/cfkem.go +++ b/src/crypto/tls/cfkem.go @@ -6,13 +6,12 @@ // To enable set CurvePreferences with the desired scheme as the first element: // // import ( -// "github.com/cloudflare/circl/kem/tls" -// "github.com/cloudflare/circl/kem/hybrid" +// "crypto/tls" // // [...] // // config.CurvePreferences = []tls.CurveID{ -// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(), +// tls.X25519Kyber768Draft00, // tls.X25519, // tls.P256, // } @@ -29,38 +28,27 @@ import ( "github.com/cloudflare/circl/kem/hybrid" ) -// Either ecdheParameters or kem.PrivateKey +// Either *ecdh.PrivateKey or *kemPrivateKey type clientKeySharePrivate interface{} +type kemPrivateKey struct { + secretKey kem.PrivateKey + curveID CurveID +} + var ( - X25519Kyber512Draft00 = CurveID(0xfe30) - X25519Kyber768Draft00 = CurveID(0xfe31) - P256Kyber768Draft00 = CurveID(0xfe32) - invalidCurveID = CurveID(0) + X25519Kyber512Draft00 = CurveID(0xfe30) + X25519Kyber768Draft00 = CurveID(0x6399) + X25519Kyber768Draft00Old = CurveID(0xfe31) + P256Kyber768Draft00 = CurveID(0xfe32) + invalidCurveID = CurveID(0) ) -func kemSchemeKeyToCurveID(s kem.Scheme) CurveID { - switch s.Name() { - case "Kyber512-X25519": - return X25519Kyber512Draft00 - case "Kyber768-X25519": - return X25519Kyber768Draft00 - case "P256Kyber768Draft00": - return P256Kyber768Draft00 - default: - return invalidCurveID - } -} - // Extract CurveID from clientKeySharePrivate func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { switch v := ks.(type) { - case kem.PrivateKey: - ret := kemSchemeKeyToCurveID(v.Scheme()) - if ret == invalidCurveID { - panic("cfkem: internal error: don't know CurveID for this KEM") - } - return ret + case *kemPrivateKey: + return v.curveID case *ecdh.PrivateKey: ret, ok := curveIDForCurve(v.Curve()) if !ok { @@ -77,7 +65,7 @@ func curveIdToCirclScheme(id CurveID) kem.Scheme { switch id { case X25519Kyber512Draft00: return hybrid.Kyber512X25519() - case X25519Kyber768Draft00: + case X25519Kyber768Draft00, X25519Kyber768Draft00Old: return hybrid.Kyber768X25519() case P256Kyber768Draft00: return hybrid.P256Kyber768Draft00() @@ -102,12 +90,12 @@ func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) ( } // Generate a new keypair using randomness from rnd. -func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) ( - kem.PublicKey, kem.PrivateKey, error) { +func generateKemKeyPair(scheme kem.Scheme, curveID CurveID, rnd io.Reader) ( + kem.PublicKey, *kemPrivateKey, error) { seed := make([]byte, scheme.SeedSize()) if _, err := io.ReadFull(rnd, seed); err != nil { return nil, nil, err } pk, sk := scheme.DeriveKeyPair(seed) - return pk, sk, nil + return pk, &kemPrivateKey{sk, curveID}, nil } diff --git a/src/crypto/tls/cfkem_test.go b/src/crypto/tls/cfkem_test.go index fb2156aa963..85da45ede8c 100644 --- a/src/crypto/tls/cfkem_test.go +++ b/src/crypto/tls/cfkem_test.go @@ -7,28 +7,16 @@ import ( "context" "fmt" "testing" - - "github.com/cloudflare/circl/kem" - "github.com/cloudflare/circl/kem/hybrid" ) -func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, +func testHybridKEX(t *testing.T, curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { var clientSelectedKEX *CurveID var retry bool - rsaCert := Certificate{ - Certificate: [][]byte{testRSACertificate}, - PrivateKey: testRSAPrivateKey, - } - serverCerts := []Certificate{rsaCert} - clientConfig := testConfig.Clone() if clientPQ { - clientConfig.CurvePreferences = []CurveID{ - kemSchemeKeyToCurveID(scheme), - X25519, - } + clientConfig.CurvePreferences = []CurveID{curveID, X25519} } clientCFEventHandler := func(ev CFEvent) { switch e := ev.(type) { @@ -44,15 +32,13 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, serverConfig := testConfig.Clone() if serverPQ { - serverConfig.CurvePreferences = []CurveID{ - kemSchemeKeyToCurveID(scheme), - X25519, - } + serverConfig.CurvePreferences = []CurveID{curveID, X25519} + } else { + serverConfig.CurvePreferences = []CurveID{X25519} } if serverTLS12 { serverConfig.MaxVersion = VersionTLS12 } - serverConfig.Certificates = serverCerts c, s := localPipe(t) done := make(chan error) @@ -78,7 +64,7 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, var expectedRetry bool if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { - expectedKEX = kemSchemeKeyToCurveID(scheme) + expectedKEX = curveID } else { expectedKEX = X25519 } @@ -86,36 +72,35 @@ func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, expectedRetry = true } + if expectedRetry != retry { + t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) + } if clientSelectedKEX == nil { t.Error("No KEX happened?") - } - - if *clientSelectedKEX != expectedKEX { + } else if *clientSelectedKEX != expectedKEX { t.Errorf("failed to negotiate: expected %d, got %d", expectedKEX, *clientSelectedKEX) } - if expectedRetry != retry { - t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) - } } func TestHybridKEX(t *testing.T) { - run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { - t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(), + run := func(curveID CurveID, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { + t.Run(fmt.Sprintf("%#04x serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", uint16(curveID), serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { - testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12) + testHybridKEX(t, curveID, clientPQ, serverPQ, clientTLS12, serverTLS12) }) } - for _, scheme := range []kem.Scheme{ - hybrid.Kyber512X25519(), - hybrid.Kyber768X25519(), - hybrid.P256Kyber768Draft00(), + for _, curveID := range []CurveID{ + X25519Kyber512Draft00, + X25519Kyber768Draft00, + X25519Kyber768Draft00Old, + P256Kyber768Draft00, } { - run(scheme, true, true, false, false) - run(scheme, true, false, false, false) - run(scheme, false, true, false, false) - run(scheme, true, true, true, false) - run(scheme, true, true, false, true) - run(scheme, true, true, true, true) + run(curveID, true, true, false, false) + run(curveID, true, false, false, false) + run(curveID, false, true, false, false) + run(curveID, true, true, true, false) + run(curveID, true, true, false, true) + run(curveID, true, true, true, true) } } diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index cfe1c25ec31..7c026f9284a 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -136,7 +136,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha curveID := config.curvePreferences()[0] if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, config.rand()) + pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) if err != nil { return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", scheme.Name(), err) diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index 6dd1627f08d..74780d5a01c 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -16,8 +16,6 @@ import ( "fmt" "hash" "time" - - circlKem "github.com/cloudflare/circl/kem" ) type clientHandshakeStateTLS13 struct { @@ -382,7 +380,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, c.config.rand()) + pk, sk, err := generateKemKeyPair(scheme, curveID, c.config.rand()) if err != nil { c.sendAlert(alertInternalError) return fmt.Errorf("HRR generateKemKeyPair %s: %w", @@ -610,7 +608,8 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { if err == nil { sharedKey, _ = key.ECDH(peerKey) } - } else if sk, ok := hs.keySharePrivate.(circlKem.PrivateKey); ok { + } else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok { + sk := key.secretKey sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) if err != nil { c.sendAlert(alertIllegalParameter)