diff --git a/src/circl/kem/hybrid/hybrid.go b/src/circl/kem/hybrid/hybrid.go index 44e312747d7..49ffcad2fb6 100644 --- a/src/circl/kem/hybrid/hybrid.go +++ b/src/circl/kem/hybrid/hybrid.go @@ -33,7 +33,6 @@ package hybrid import ( "errors" - "circl/hpke" "circl/internal/sha3" "circl/kem" "circl/kem/kyber/kyber1024" @@ -46,28 +45,37 @@ var ErrUninitialized = errors.New("public or private key not initialized") // Returns the hybrid KEM of Kyber512 and X25519. func Kyber512X25519() kem.Scheme { return kyber512X } +// Returns the hybrid KEM of Kyber768 and X25519. +func Kyber768X25519() kem.Scheme { return kyber768X } + // Returns the hybrid KEM of Kyber768 and X448. -func Kyber768X448() kem.Scheme { return kyber768X } +func Kyber768X448() kem.Scheme { return kyber768X4 } // Returns the hybrid KEM of Kyber1024 and X448. func Kyber1024X448() kem.Scheme { return kyber1024X } var kyber512X kem.Scheme = &scheme{ "Kyber512-X25519", + x25519Kem, kyber512.Scheme(), - hpke.KEM_X25519_HKDF_SHA256.Scheme(), } var kyber768X kem.Scheme = &scheme{ + "Kyber768-X25519", + x25519Kem, + kyber768.Scheme(), +} + +var kyber768X4 kem.Scheme = &scheme{ "Kyber768-X448", + x448Kem, kyber768.Scheme(), - hpke.KEM_X448_HKDF_SHA512.Scheme(), } var kyber1024X kem.Scheme = &scheme{ "Kyber1024-X448", + x448Kem, kyber1024.Scheme(), - hpke.KEM_X448_HKDF_SHA512.Scheme(), } // Public key of a hybrid KEM. diff --git a/src/circl/kem/hybrid/xkem.go b/src/circl/kem/hybrid/xkem.go new file mode 100644 index 00000000000..427f11300f9 --- /dev/null +++ b/src/circl/kem/hybrid/xkem.go @@ -0,0 +1,208 @@ +package hybrid + +import ( + "bytes" + cryptoRand "crypto/rand" + "crypto/subtle" + + "circl/dh/x25519" + "circl/dh/x448" + "circl/internal/sha3" + "circl/kem" +) + +type xPublicKey struct { + scheme *xScheme + key []byte +} +type xPrivateKey struct { + scheme *xScheme + key []byte +} +type xScheme struct { + size int +} + +var ( + x25519Kem = &xScheme{x25519.Size} + x448Kem = &xScheme{x448.Size} +) + +func (sch *xScheme) Name() string { + switch sch.size { + case x25519.Size: + return "X25519" + case x448.Size: + return "X448" + } + panic(kem.ErrTypeMismatch) +} + +func (sch *xScheme) PublicKeySize() int { return sch.size } +func (sch *xScheme) PrivateKeySize() int { return sch.size } +func (sch *xScheme) SeedSize() int { return sch.size } +func (sch *xScheme) SharedKeySize() int { return sch.size } +func (sch *xScheme) CiphertextSize() int { return sch.size } +func (sch *xScheme) EncapsulationSeedSize() int { return sch.size } + +func (sk *xPrivateKey) Scheme() kem.Scheme { return sk.scheme } +func (pk *xPublicKey) Scheme() kem.Scheme { return pk.scheme } + +func (sk *xPrivateKey) MarshalBinary() ([]byte, error) { + ret := make([]byte, len(sk.key)) + copy(ret, sk.key) + return ret, nil +} + +func (sk *xPrivateKey) Equal(other kem.PrivateKey) bool { + oth, ok := other.(*xPrivateKey) + if !ok { + return false + } + if oth.scheme != sk.scheme { + return false + } + return subtle.ConstantTimeCompare(oth.key, sk.key) == 1 +} + +func (sk *xPrivateKey) Public() kem.PublicKey { + pk := xPublicKey{sk.scheme, make([]byte, sk.scheme.size)} + switch sk.scheme.size { + case x25519.Size: + var sk2, pk2 x25519.Key + copy(sk2[:], sk.key) + x25519.KeyGen(&pk2, &sk2) + copy(pk.key, pk2[:]) + case x448.Size: + var sk2, pk2 x448.Key + copy(sk2[:], sk.key) + x448.KeyGen(&pk2, &sk2) + copy(pk.key, pk2[:]) + } + return &pk +} + +func (pk *xPublicKey) Equal(other kem.PublicKey) bool { + oth, ok := other.(*xPublicKey) + if !ok { + return false + } + if oth.scheme != pk.scheme { + return false + } + return bytes.Equal(oth.key, pk.key) +} + +func (pk *xPublicKey) MarshalBinary() ([]byte, error) { + ret := make([]byte, pk.scheme.size) + copy(ret, pk.key) + return ret, nil +} + +func (sch *xScheme) GenerateKeyPair() (kem.PublicKey, kem.PrivateKey, error) { + seed := make([]byte, sch.SeedSize()) + _, err := cryptoRand.Read(seed) + if err != nil { + return nil, nil, err + } + pk, sk := sch.DeriveKeyPair(seed) + return pk, sk, nil +} + +func (sch *xScheme) DeriveKeyPair(seed []byte) (kem.PublicKey, kem.PrivateKey) { + if len(seed) != sch.SeedSize() { + panic(kem.ErrSeedSize) + } + sk := xPrivateKey{scheme: sch, key: make([]byte, sch.size)} + + h := sha3.NewShake256() + _, _ = h.Write(seed) + _, _ = h.Read(sk.key) + + return sk.Public(), &sk +} + +func (sch *xScheme) Encapsulate(pk kem.PublicKey) (ct, ss []byte, err error) { + seed := make([]byte, sch.EncapsulationSeedSize()) + _, err = cryptoRand.Read(seed) + if err != nil { + return + } + return sch.EncapsulateDeterministically(pk, seed) +} + +func (pk *xPublicKey) X(sk *xPrivateKey) []byte { + if pk.scheme != sk.scheme { + panic(kem.ErrTypeMismatch) + } + + switch pk.scheme.size { + case x25519.Size: + var ss2, pk2, sk2 x25519.Key + copy(pk2[:], pk.key) + copy(sk2[:], sk.key) + x25519.Shared(&ss2, &sk2, &pk2) + return ss2[:] + case x448.Size: + var ss2, pk2, sk2 x448.Key + copy(pk2[:], pk.key) + copy(sk2[:], sk.key) + x448.Shared(&ss2, &sk2, &pk2) + return ss2[:] + } + panic(kem.ErrTypeMismatch) +} + +func (sch *xScheme) EncapsulateDeterministically( + pk kem.PublicKey, seed []byte, +) (ct, ss []byte, err error) { + if len(seed) != sch.EncapsulationSeedSize() { + return nil, nil, kem.ErrSeedSize + } + pub, ok := pk.(*xPublicKey) + if !ok || pub.scheme != sch { + return nil, nil, kem.ErrTypeMismatch + } + + pk2, sk2 := sch.DeriveKeyPair(seed) + ss = pub.X(sk2.(*xPrivateKey)) + ct, _ = pk2.MarshalBinary() + return +} + +func (sch *xScheme) Decapsulate(sk kem.PrivateKey, ct []byte) ([]byte, error) { + if len(ct) != sch.CiphertextSize() { + return nil, kem.ErrCiphertextSize + } + + priv, ok := sk.(*xPrivateKey) + if !ok || priv.scheme != sch { + return nil, kem.ErrTypeMismatch + } + + pk, err := sch.UnmarshalBinaryPublicKey(ct) + if err != nil { + return nil, err + } + + ss := pk.(*xPublicKey).X(priv) + return ss, nil +} + +func (sch *xScheme) UnmarshalBinaryPublicKey(buf []byte) (kem.PublicKey, error) { + if len(buf) != sch.PublicKeySize() { + return nil, kem.ErrPubKeySize + } + ret := xPublicKey{sch, make([]byte, sch.size)} + copy(ret.key, buf) + return &ret, nil +} + +func (sch *xScheme) UnmarshalBinaryPrivateKey(buf []byte) (kem.PrivateKey, error) { + if len(buf) != sch.PrivateKeySize() { + return nil, kem.ErrPrivKeySize + } + ret := xPrivateKey{sch, make([]byte, sch.size)} + copy(ret.key, buf) + return &ret, nil +} diff --git a/src/circl/kem/schemes/schemes.go b/src/circl/kem/schemes/schemes.go index 59b5d8906fa..6872d9c7137 100644 --- a/src/circl/kem/schemes/schemes.go +++ b/src/circl/kem/schemes/schemes.go @@ -41,6 +41,7 @@ var allSchemes = [...]kem.Scheme{ sikep503.Scheme(), sikep751.Scheme(), hybrid.Kyber512X25519(), + hybrid.Kyber768X25519(), hybrid.Kyber768X448(), hybrid.Kyber1024X448(), } diff --git a/src/circl/kem/schemes/schemes_test.go b/src/circl/kem/schemes/schemes_test.go index 78210549e12..36333a96f64 100644 --- a/src/circl/kem/schemes/schemes_test.go +++ b/src/circl/kem/schemes/schemes_test.go @@ -159,6 +159,7 @@ func Example_schemes() { // SIKEp503 // SIKEp751 // Kyber512-X25519 + // Kyber768-X25519 // Kyber768-X448 // Kyber1024-X448 } diff --git a/src/circl/pke/kyber/internal/common/ntt.go b/src/circl/pke/kyber/internal/common/ntt.go index b6c1f7405fd..94df2e1f00a 100644 --- a/src/circl/pke/kyber/internal/common/ntt.go +++ b/src/circl/pke/kyber/internal/common/ntt.go @@ -59,7 +59,7 @@ var InvNTTReductions = [...]int{ // their proper order by calling Detangle(). func (p *Poly) nttGeneric() { // Note that ℤ_q does not have a primitive 512ᵗʰ root of unity (as 512 - // does not divide into q) and so we cannot do a regular NTT. ℤ_q + // does not divide into q-1) and so we cannot do a regular NTT. ℤ_q // does have a primitive 256ᵗʰ root of unity, the smallest of which // is ζ := 17. // @@ -73,12 +73,12 @@ func (p *Poly) nttGeneric() { // ⋮ // = (x² - ζ)(x² + ζ)(x² - ζ⁶⁵)(x² + ζ⁶⁵) … (x² + ζ¹²⁷) // - // Note that the powers of ζ that appear (from th second line down) are + // Note that the powers of ζ that appear (from the second line down) are // in binary // - // 010000 110000 - // 001000 101000 011000 111000 - // 000100 100100 010100 110100 001100 101100 011100 111100 + // 0100000 1100000 + // 0010000 1010000 0110000 1110000 + // 0001000 1001000 0101000 1101000 0011000 1011000 0111000 1111000 // … // // That is: brv(2), brv(3), brv(4), …, where brv(x) denotes the 7-bit @@ -89,7 +89,7 @@ func (p *Poly) nttGeneric() { // // ℤ_q[x]/(x²⁵⁶+1) → ℤ_q[x]/(x²-ζ) x … x ℤ_q[x]/(x²+ζ¹²⁷) // - // given by a ↦ ( a mod x²-z, …, a mod x²+z¹²⁷ ) + // given by a ↦ ( a mod x²-ζ, …, a mod x²+ζ¹²⁷ ) // is an isomorphism, which is the "NTT". It can be efficiently computed by // // @@ -105,7 +105,7 @@ func (p *Poly) nttGeneric() { // // Each cross is a Cooley-Tukey butterfly: it's the map // - // (a, b) ↦ (a + ζ, a - ζ) + // (a, b) ↦ (a + ζb, a - ζb) // // for the appropriate power ζ for that column and row group. diff --git a/src/crypto/tls/cfkem.go b/src/crypto/tls/cfkem.go new file mode 100644 index 00000000000..afe66f91dfe --- /dev/null +++ b/src/crypto/tls/cfkem.go @@ -0,0 +1,102 @@ +// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code +// is governed by a BSD-style license that can be found in the LICENSE file. +// +// Glue to add Circl's (post-quantum) hybrid KEMs. +// +// To enable set CurvePreferences with the desired scheme as the first element: +// +// import ( +// "github.com/cloudflare/circl/kem/tls" +// "github.com/cloudflare/circl/kem/hybrid" +// +// [...] +// +// config.CurvePreferences = []tls.CurveID{ +// hybrid.X25519Kyber512Draft00().(tls.TLSScheme).TLSCurveID(), +// tls.X25519, +// tls.P256, +// } + +package tls + +import ( + "fmt" + "io" + + "circl/kem" + "circl/kem/hybrid" +) + +// Either ecdheParameters or kem.PrivateKey +type clientKeySharePrivate interface{} + +var ( + X25519Kyber512Draft00 = CurveID(0xfe30) + X25519Kyber768Draft00 = CurveID(0xfe31) + invalidCurveID = CurveID(0) +) + +func kemSchemeKeyToCurveID(s kem.Scheme) CurveID { + switch s.Name() { + case "Kyber512-X25519": + return X25519Kyber512Draft00 + case "Kyber768-X25519": + return X25519Kyber768Draft00 + default: + return invalidCurveID + } +} + +// Extract CurveID from clientKeySharePrivate +func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { + switch v := ks.(type) { + case kem.PrivateKey: + ret := kemSchemeKeyToCurveID(v.Scheme()) + if ret == invalidCurveID { + panic("cfkem: internal error: don't know CurveID for this KEM") + } + return ret + case ecdheParameters: + return v.CurveID() + default: + panic("cfkem: internal error: unknown clientKeySharePrivate") + } +} + +// Returns scheme by CurveID if supported by Circl +func curveIdToCirclScheme(id CurveID) kem.Scheme { + switch id { + case X25519Kyber512Draft00: + return hybrid.Kyber512X25519() + case X25519Kyber768Draft00: + return hybrid.Kyber768X25519() + } + return nil +} + +// Generate a new shared secret and encapsulates it for the packed +// public key in ppk using randomness from rnd. +func encapsulateForKem(scheme kem.Scheme, rnd io.Reader, ppk []byte) ( + ct, ss []byte, alert alert, err error) { + pk, err := scheme.UnmarshalBinaryPublicKey(ppk) + if err != nil { + return nil, nil, alertIllegalParameter, fmt.Errorf("unpack pk: %w", err) + } + seed := make([]byte, scheme.EncapsulationSeedSize()) + if _, err := io.ReadFull(rnd, seed); err != nil { + return nil, nil, alertInternalError, fmt.Errorf("random: %w", err) + } + ct, ss, err = scheme.EncapsulateDeterministically(pk, seed) + return ct, ss, alertIllegalParameter, err +} + +// Generate a new keypair using randomness from rnd. +func generateKemKeyPair(scheme kem.Scheme, rnd io.Reader) ( + kem.PublicKey, kem.PrivateKey, error) { + seed := make([]byte, scheme.SeedSize()) + if _, err := io.ReadFull(rnd, seed); err != nil { + return nil, nil, err + } + pk, sk := scheme.DeriveKeyPair(seed) + return pk, sk, nil +} diff --git a/src/crypto/tls/cfkem_test.go b/src/crypto/tls/cfkem_test.go new file mode 100644 index 00000000000..033acaca55f --- /dev/null +++ b/src/crypto/tls/cfkem_test.go @@ -0,0 +1,118 @@ +// Copyright 2022 Cloudflare, Inc. All rights reserved. Use of this source code +// is governed by a BSD-style license that can be found in the LICENSE file. + +package tls + +import ( + "fmt" + "testing" + + "circl/kem" + "circl/kem/hybrid" +) + +func testHybridKEX(t *testing.T, scheme kem.Scheme, clientPQ, serverPQ, + clientTLS12, serverTLS12 bool) { + var clientSelectedKEX *CurveID + var retry bool + + rsaCert := Certificate{ + Certificate: [][]byte{testRSACertificate}, + PrivateKey: testRSAPrivateKey, + } + serverCerts := []Certificate{rsaCert} + + clientConfig := testConfig.Clone() + if clientPQ { + clientConfig.CurvePreferences = []CurveID{ + kemSchemeKeyToCurveID(scheme), + X25519, + } + } + clientConfig.CFEventHandler = func(ev CFEvent) { + switch e := ev.(type) { + case CFEventTLSNegotiatedNamedKEX: + clientSelectedKEX = &e.KEX + case CFEventTLS13HRR: + retry = true + } + } + if clientTLS12 { + clientConfig.MaxVersion = VersionTLS12 + } + + serverConfig := testConfig.Clone() + if serverPQ { + serverConfig.CurvePreferences = []CurveID{ + kemSchemeKeyToCurveID(scheme), + X25519, + } + } + if serverTLS12 { + serverConfig.MaxVersion = VersionTLS12 + } + serverConfig.Certificates = serverCerts + + c, s := localPipe(t) + done := make(chan error) + defer c.Close() + + go func() { + defer s.Close() + done <- Server(s, serverConfig).Handshake() + }() + + cli := Client(c, clientConfig) + clientErr := cli.Handshake() + serverErr := <-done + if clientErr != nil { + t.Errorf("client error: %s", clientErr) + } + if serverErr != nil { + t.Errorf("server error: %s", serverErr) + } + + var expectedKEX CurveID + var expectedRetry bool + + if clientPQ && serverPQ && !clientTLS12 && !serverTLS12 { + expectedKEX = kemSchemeKeyToCurveID(scheme) + } else { + expectedKEX = X25519 + } + if !clientTLS12 && clientPQ && !serverPQ { + expectedRetry = true + } + + if clientSelectedKEX == nil { + t.Error("No KEX happened?") + } + + if *clientSelectedKEX != expectedKEX { + t.Errorf("failed to negotiate: expected %d, got %d", + expectedKEX, *clientSelectedKEX) + } + if expectedRetry != retry { + t.Errorf("Expected retry=%v, got retry=%v", expectedRetry, retry) + } +} + +func TestHybridKEX(t *testing.T) { + run := func(scheme kem.Scheme, clientPQ, serverPQ, clientTLS12, serverTLS12 bool) { + t.Run(fmt.Sprintf("%s serverPQ:%v clientPQ:%v serverTLS12:%v clientTLS12:%v", scheme.Name(), + serverPQ, clientPQ, serverTLS12, clientTLS12), func(t *testing.T) { + testHybridKEX(t, scheme, clientPQ, serverPQ, clientTLS12, serverTLS12) + }) + } + for _, scheme := range []kem.Scheme{ + hybrid.Kyber512X25519(), + hybrid.Kyber768X25519(), + } { + run(scheme, true, true, false, false) + run(scheme, true, false, false, false) + run(scheme, false, true, false, false) + run(scheme, true, true, true, false) + run(scheme, true, true, false, true) + run(scheme, true, true, true, true) + } +} diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index fa75eefa194..fc4022edb45 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -36,7 +36,7 @@ type clientHandshakeState struct { session *ClientSessionState } -func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, ecdheParameters, error) { +func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySharePrivate, error) { config := c.config if len(config.ServerName) == 0 && !config.InsecureSkipVerify { return nil, nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") @@ -122,7 +122,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, ecdheParamet hello.supportedSignatureAlgorithms = config.supportedSignatureAlgorithms() } - var params ecdheParameters + var secret clientKeySharePrivate if hello.supportedVersions[0] == VersionTLS13 { if hasAESGCMHardwareSupport { hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13...) @@ -131,19 +131,36 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, ecdheParamet } curveID := config.curvePreferences()[0] - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err = generateECDHEParameters(config.rand(), curveID) - if err != nil { - return nil, nil, err + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, config.rand()) + if err != nil { + return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", + scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + return nil, nil, fmt.Errorf("pack circl public key %s: %w", + scheme.Name(), err) + } + hello.keyShares = []keyShare{{group: curveID, data: packedPk}} + secret = sk + } else { + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + secret = params } - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} + hello.delegatedCredentialSupported = config.SupportDelegatedCredential hello.supportedSignatureAlgorithmsDC = supportedSignatureAlgorithmsDC } - return hello, params, nil + return hello, secret, nil } func (c *Conn) clientHandshake(ctx context.Context) (err error) { @@ -230,16 +247,16 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { if c.vers == VersionTLS13 { hs := &clientHandshakeStateTLS13{ - c: c, - ctx: ctx, - serverHello: serverHello, - hello: hello, - helloInner: helloInner, - ecdheParams: ecdheParams, - session: session, - earlySecret: earlySecret, - binderKey: binderKey, - hsTimings: hsTimings, + c: c, + ctx: ctx, + serverHello: serverHello, + hello: hello, + helloInner: helloInner, + keySharePrivate: ecdheParams, + session: session, + earlySecret: earlySecret, + binderKey: binderKey, + hsTimings: hsTimings, } // In TLS 1.3, session tickets are delivered after the handshake. @@ -565,6 +582,12 @@ func (hs *clientHandshakeState) doFullHandshake() error { return err } + if eccKex, ok := keyAgreement.(*ecdheKeyAgreement); ok { + c.handleCFEvent(CFEventTLSNegotiatedNamedKEX{ + KEX: eccKex.params.CurveID(), + }) + } + msg, err = c.readHandshake() if err != nil { return err diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index b5ae137f206..d35076c0b0d 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -16,19 +16,22 @@ import ( "hash" "sync/atomic" "time" + + circlKem "circl/kem" ) type clientHandshakeStateTLS13 struct { - c *Conn - ctx context.Context - serverHello *serverHelloMsg - hello *clientHelloMsg - helloInner *clientHelloMsg - ecdheParams ecdheParameters - - session *ClientSessionState - earlySecret []byte - binderKey []byte + c *Conn + ctx context.Context + serverHello *serverHelloMsg + hello *clientHelloMsg + helloInner *clientHelloMsg + keySharePrivate clientKeySharePrivate + + session *ClientSessionState + earlySecret []byte + binderKey []byte + selectedGroup CurveID certReq *certificateRequestMsgTLS13 usingPSK bool @@ -98,7 +101,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // Consistency check on the presence of a keyShare and its parameters. - if hs.ecdheParams == nil || len(hs.hello.keyShares) != 1 { + if hs.keySharePrivate == nil || len(hs.hello.keyShares) != 1 { return c.sendAlert(alertInternalError) } @@ -270,6 +273,8 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error { func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c := hs.c + c.handleCFEvent(CFEventTLS13HRR{}) + // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. (The idea is that the server might offload transcript // storage to the client in the cookie.) See RFC 8446, Section 4.4.1. @@ -351,21 +356,38 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if hs.ecdheParams.CurveID() == curveID { + if clientKeySharePrivateCurveID(hs.keySharePrivate) == curveID { c.sendAlert(alertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } - if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { - c.sendAlert(alertInternalError) - return errors.New("tls: CurvePreferences includes unsupported curve") - } - params, err := generateECDHEParameters(c.config.rand(), curveID) - if err != nil { - c.sendAlert(alertInternalError) - return err + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, c.config.rand()) + if err != nil { + c.sendAlert(alertInternalError) + return fmt.Errorf("HRR generateKemKeyPair %s: %w", + scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + c.sendAlert(alertInternalError) + return fmt.Errorf("HRR pack circl public key %s: %w", + scheme.Name(), err) + } + hs.keySharePrivate = sk + hello.keyShares = []keyShare{{group: curveID, data: packedPk}} + } else { + if _, ok := curveForCurveID(curveID); curveID != X25519 && !ok { + c.sendAlert(alertInternalError) + return errors.New("tls: CurvePreferences includes unsupported curve") + } + params, err := generateECDHEParameters(c.config.rand(), curveID) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.keySharePrivate = params + hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} } - hs.ecdheParams = params - hello.keyShares = []keyShare{{group: curveID, data: params.PublicKey()}} } hello.raw = nil @@ -491,11 +513,15 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if hs.serverHello.serverShare.group != hs.ecdheParams.CurveID() { + if hs.serverHello.serverShare.group != clientKeySharePrivateCurveID(hs.keySharePrivate) { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } + c.handleCFEvent(CFEventTLSNegotiatedNamedKEX{ + KEX: hs.serverHello.serverShare.group, + }) + if !hs.serverHello.selectedIdentityPresent { return nil } @@ -539,10 +565,21 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { c := hs.c - sharedKey := hs.ecdheParams.SharedKey(hs.serverHello.serverShare.data) + var sharedKey []byte + if params, ok := hs.keySharePrivate.(ecdheParameters); ok { + sharedKey = params.SharedKey(hs.serverHello.serverShare.data) + } else if sk, ok := hs.keySharePrivate.(circlKem.PrivateKey); ok { + var err error + sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) + if err != nil { + c.sendAlert(alertIllegalParameter) + return fmt.Errorf("%s decaps: %w", sk.Scheme().Name(), err) + } + } + if sharedKey == nil { c.sendAlert(alertIllegalParameter) - return errors.New("tls: invalid server key share") + return fmt.Errorf("tls: invalid server key share") } earlySecret := hs.earlySecret diff --git a/src/crypto/tls/handshake_server.go b/src/crypto/tls/handshake_server.go index e4f75b1cb92..7ccd5ab039a 100644 --- a/src/crypto/tls/handshake_server.go +++ b/src/crypto/tls/handshake_server.go @@ -638,6 +638,11 @@ func (hs *serverHandshakeState) doFullHandshake() error { c.sendAlert(alertHandshakeFailure) return err } + if eccKex, ok := keyAgreement.(*ecdheKeyAgreement); ok { + c.handleCFEvent(CFEventTLSNegotiatedNamedKEX{ + KEX: eccKex.params.CurveID(), + }) + } hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) if err := c.config.writeKeyLog(keyLogLabelTLS12, hs.clientHello.random, hs.masterSecret); err != nil { c.sendAlert(alertInternalError) diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index 378bb17c9a0..904aa61c460 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -34,6 +34,7 @@ type serverHandshakeStateTLS13 struct { suite *cipherSuiteTLS13 cert *Certificate sigAlg SignatureScheme + selectedGroup CurveID earlySecret []byte sharedKey []byte handshakeSecret []byte @@ -280,23 +281,36 @@ GroupSelection: clientKeyShare = &hs.clientHello.keyShares[0] } - if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && !ok { + if _, ok := curveForCurveID(selectedGroup); selectedGroup != X25519 && curveIdToCirclScheme(selectedGroup) == nil && !ok { c.sendAlert(alertInternalError) return errors.New("tls: CurvePreferences includes unsupported curve") } - params, err := generateECDHEParameters(c.config.rand(), selectedGroup) - if err != nil { - c.sendAlert(alertInternalError) - return err + if kem := curveIdToCirclScheme(selectedGroup); kem != nil { + ct, ss, alert, err := encapsulateForKem(kem, c.config.rand(), clientKeyShare.data) + if err != nil { + c.sendAlert(alert) + return fmt.Errorf("%s encap: %w", kem.Name(), err) + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: ct} + hs.sharedKey = ss + } else { + params, err := generateECDHEParameters(c.config.rand(), selectedGroup) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} + hs.sharedKey = params.SharedKey(clientKeyShare.data) } - hs.hello.serverShare = keyShare{group: selectedGroup, data: params.PublicKey()} - hs.sharedKey = params.SharedKey(clientKeyShare.data) if hs.sharedKey == nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: invalid client key share") } c.serverName = hs.clientHello.serverName + c.handleCFEvent(CFEventTLSNegotiatedNamedKEX{ + KEX: selectedGroup, + }) hs.hsTimings.ProcessClientHello = hs.hsTimings.elapsedTime() @@ -524,6 +538,8 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error { func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error { c := hs.c + c.handleCFEvent(CFEventTLS13HRR{}) + // The first ClientHello gets double-hashed into the transcript upon a // HelloRetryRequest. See RFC 8446, Section 4.4.1. hs.transcript.Write(hs.clientHello.marshal()) diff --git a/src/crypto/tls/key_agreement.go b/src/crypto/tls/key_agreement.go index 630e3df5ae0..85789c9736a 100644 --- a/src/crypto/tls/key_agreement.go +++ b/src/crypto/tls/key_agreement.go @@ -168,7 +168,7 @@ type ecdheKeyAgreement struct { func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { var curveID CurveID for _, c := range clientHello.supportedCurves { - if config.supportsCurve(c) { + if config.supportsCurve(c) && curveIdToCirclScheme(c) == nil { curveID = c break } diff --git a/src/crypto/tls/tls_cf.go b/src/crypto/tls/tls_cf.go index b7e36dc8e1e..c60ee2862d7 100644 --- a/src/crypto/tls/tls_cf.go +++ b/src/crypto/tls/tls_cf.go @@ -218,3 +218,24 @@ type CFEventECHPublicNameMismatch struct{} func (e CFEventECHPublicNameMismatch) Name() string { return "ech public name does not match outer sni" } + +// For backwards compatibility. +type CFEventTLS13NegotiatedKEX = CFEventTLSNegotiatedNamedKEX + +// CFEventTLSNegotiatedNamedKEX is emitted when a key agreement mechanism has been +// established that uses a named group. This includes all key agreements +// in TLSv1.3, but excludes RSA and DH in TLS 1.2 and earlier. +type CFEventTLSNegotiatedNamedKEX struct { + KEX CurveID +} + +func (e CFEventTLSNegotiatedNamedKEX) Name() string { + return "CFEventTLSNegotiatedNamedKEX" +} + +// CFEventTLS13HRR is emitted when a HRR is sent or received +type CFEventTLS13HRR struct{} + +func (e CFEventTLS13HRR) Name() string { + return "CFEventTLS13HRR" +}