diff --git a/src/crypto/tls/cfkem.go b/src/crypto/tls/cfkem.go index 8d440e4c3c5..b300ce15c57 100644 --- a/src/crypto/tls/cfkem.go +++ b/src/crypto/tls/cfkem.go @@ -22,14 +22,14 @@ import ( "fmt" "io" - "crypto/ecdh" - "github.com/cloudflare/circl/kem" "github.com/cloudflare/circl/kem/hybrid" ) // Either *ecdh.PrivateKey or *kemPrivateKey -type clientKeySharePrivate interface{} +type singleClientKeySharePrivate interface{} + +type clientKeySharePrivate map[CurveID]singleClientKeySharePrivate type kemPrivateKey struct { secretKey kem.PrivateKey @@ -44,20 +44,9 @@ var ( invalidCurveID = CurveID(0) ) -// Extract CurveID from clientKeySharePrivate -func clientKeySharePrivateCurveID(ks clientKeySharePrivate) CurveID { - switch v := ks.(type) { - case *kemPrivateKey: - return v.curveID - case *ecdh.PrivateKey: - ret, ok := curveIDForCurve(v.Curve()) - if !ok { - panic("cfkem: internal error: unknown curve") - } - return ret - default: - panic("cfkem: internal error: unknown clientKeySharePrivate") - } +func singleClientKeySharePrivateFor(ks clientKeySharePrivate, group CurveID) singleClientKeySharePrivate { + ret, _ := ks[group] + return ret } // Returns scheme by CurveID if supported by Circl diff --git a/src/crypto/tls/cfkem_test.go b/src/crypto/tls/cfkem_test.go index 85da45ede8c..821cc674aff 100644 --- a/src/crypto/tls/cfkem_test.go +++ b/src/crypto/tls/cfkem_test.go @@ -104,3 +104,48 @@ func TestHybridKEX(t *testing.T) { run(curveID, true, true, true, true) } } + +func TestClientCurveGuess(t *testing.T) { + run := func(guess, clientPrefs, serverPrefs []CurveID) { + t.Run( + fmt.Sprintf("guess=%v clientPrefs=%v serverPrefs=%v", + guess, clientPrefs, serverPrefs), + func(t *testing.T) { + testClientCurveGuess(t, guess, clientPrefs, serverPrefs) + }) + } + both := []CurveID{X25519Kyber768Draft00, X25519} + run([]CurveID{}, []CurveID{X25519}, both) + run([]CurveID{X25519}, []CurveID{X25519}, both) + run([]CurveID{X25519Kyber768Draft00}, both, []CurveID{X25519}) + run(both, both, both) + run(both, both, []CurveID{X25519}) + run(both, both, []CurveID{X25519Kyber768Draft00}) +} + +func testClientCurveGuess(t *testing.T, guess, clientPrefs, serverPrefs []CurveID) { + clientConfig := testConfig.Clone() + serverConfig := testConfig.Clone() + serverConfig.CurvePreferences = serverPrefs + clientConfig.CurvePreferences = clientPrefs + clientConfig.ClientCurveGuess = guess + + c, s := localPipe(t) + done := make(chan error) + defer c.Close() + + go func() { + defer s.Close() + done <- Server(s, serverConfig).Handshake() + }() + + cli := Client(c, clientConfig) + clientErr := cli.HandshakeContext(context.Background()) + serverErr := <-done + if clientErr != nil { + t.Errorf("client error: %v", clientErr) + } + if serverErr != nil { + t.Errorf("server error: %v", serverErr) + } +} diff --git a/src/crypto/tls/common.go b/src/crypto/tls/common.go index ebd42f506ca..bc02ab755c2 100644 --- a/src/crypto/tls/common.go +++ b/src/crypto/tls/common.go @@ -837,6 +837,18 @@ type Config struct { // which is currently TLS 1.3. MaxVersion uint16 + // ClientCurveGuess contains the "curves" for which the client will create + // a keyshare in the initial ClientHello for TLS 1.3. If the client + // guesses incorrectly, and the server does not support or does not + // prefer those keyshares, then the server will return a HelloRetryRequest + // incurring an extra roundtrip. + // + // If empty, no keyshares will be included in the ClientHello. + // + // If nil (default), will send the single most preferred keyshare + // as configurable via CurvePreferences. + ClientCurveGuess []CurveID + // CurvePreferences contains the elliptic curves that will be used in // an ECDHE handshake, in preference order. If empty, the default will // be used. The client will use the first preference as the type for @@ -974,6 +986,7 @@ func (c *Config) Clone() *Config { MinVersion: c.MinVersion, MaxVersion: c.MaxVersion, CurvePreferences: c.CurvePreferences, + ClientCurveGuess: c.ClientCurveGuess, PQSignatureSchemesEnabled: c.PQSignatureSchemesEnabled, DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, Renegotiation: c.Renegotiation, diff --git a/src/crypto/tls/handshake_client.go b/src/crypto/tls/handshake_client.go index 14eedb6248c..b31c253e32b 100644 --- a/src/crypto/tls/handshake_client.go +++ b/src/crypto/tls/handshake_client.go @@ -134,7 +134,7 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha hello.supportedSignatureAlgorithms = testingOnlyForceClientHelloSignatureAlgorithms } - var secret clientKeySharePrivate + secret := make(clientKeySharePrivate) if hello.supportedVersions[0] == VersionTLS13 { // Reset the list of ciphers when the client only supports TLS 1.3. if len(hello.supportedVersions) == 1 { @@ -146,30 +146,74 @@ func (c *Conn) makeClientHello(minVersion uint16) (*clientHelloMsg, clientKeySha hello.cipherSuites = append(hello.cipherSuites, defaultCipherSuitesTLS13NoAES...) } - curveID := config.curvePreferences()[0] - if scheme := curveIdToCirclScheme(curveID); scheme != nil { - pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) - if err != nil { - return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", - scheme.Name(), err) - } - packedPk, err := pk.MarshalBinary() - if err != nil { - return nil, nil, fmt.Errorf("pack circl public key %s: %w", - scheme.Name(), err) + curveIDs := []CurveID{config.curvePreferences()[0]} + + if config.ClientCurveGuess != nil { + curveIDs = config.ClientCurveGuess + } + + hello.keyShares = make([]keyShare, 0, len(curveIDs)) + + // Check whether ClientCurveGuess is a subsequence of CurvePreferences + // as is required by RFC8446 ยง4.2.8 + offset := 0 + curvePreferences := config.curvePreferences() + found := 0 + CurveGuessCheck: + for _, curveID := range curveIDs { + for { + if offset == len(curvePreferences) { + break CurveGuessCheck + } + + if curvePreferences[offset] == curveID { + found++ + break + } + + offset++ } - hello.keyShares = []keyShare{{group: curveID, data: packedPk}} - secret = sk - } else { - if _, ok := curveForCurveID(curveID); !ok { - return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + if found != len(curveIDs) { + return nil, nil, errors.New("tls: ClientCurveGuess not a subsequence of CurvePreferences") + } + + for _, curveID := range curveIDs { + var ( + singleSecret interface{} + singleShare []byte + ) + + if _, ok := secret[curveID]; ok { + return nil, nil, errors.New("tls: ClientCurveGuess contains duplicate") } - key, err := generateECDHEKey(config.rand(), curveID) - if err != nil { - return nil, nil, err + + if scheme := curveIdToCirclScheme(curveID); scheme != nil { + pk, sk, err := generateKemKeyPair(scheme, curveID, config.rand()) + if err != nil { + return nil, nil, fmt.Errorf("generateKemKeyPair %s: %w", + scheme.Name(), err) + } + packedPk, err := pk.MarshalBinary() + if err != nil { + return nil, nil, fmt.Errorf("pack circl public key %s: %w", + scheme.Name(), err) + } + singleShare = packedPk + singleSecret = sk + } else { + if _, ok := curveForCurveID(curveID); !ok { + return nil, nil, errors.New("tls: CurvePreferences includes unsupported curve") + } + key, err := generateECDHEKey(config.rand(), curveID) + if err != nil { + return nil, nil, err + } + singleShare = key.PublicKey().Bytes() + singleSecret = key } - hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} - secret = key + hello.keyShares = append(hello.keyShares, keyShare{group: curveID, data: singleShare}) + secret[curveID] = singleSecret } hello.delegatedCredentialSupported = config.SupportDelegatedCredential diff --git a/src/crypto/tls/handshake_client_tls13.go b/src/crypto/tls/handshake_client_tls13.go index eac73fcc685..04a272c4a91 100644 --- a/src/crypto/tls/handshake_client_tls13.go +++ b/src/crypto/tls/handshake_client_tls13.go @@ -103,7 +103,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } // Consistency check on the presence of a keyShare and its parameters. - if hs.keySharePrivate == nil || len(hs.hello.keyShares) != 1 { + if hs.keySharePrivate == nil { return c.sendAlert(alertInternalError) } @@ -379,7 +379,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } - if clientKeySharePrivateCurveID(hs.keySharePrivate) == curveID { + if singleClientKeySharePrivateFor(hs.keySharePrivate, curveID) != nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: server sent an unnecessary HelloRetryRequest key_share") } @@ -396,7 +396,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { return fmt.Errorf("HRR pack circl public key %s: %w", scheme.Name(), err) } - hs.keySharePrivate = sk + hs.keySharePrivate = clientKeySharePrivate{curveID: sk} hello.keyShares = []keyShare{{group: curveID, data: packedPk}} } else { if _, ok := curveForCurveID(curveID); !ok { @@ -408,7 +408,7 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error { c.sendAlert(alertInternalError) return err } - hs.keySharePrivate = key + hs.keySharePrivate = clientKeySharePrivate{curveID: key} hello.keyShares = []keyShare{{group: curveID, data: key.PublicKey().Bytes()}} } } @@ -558,7 +558,7 @@ func (hs *clientHandshakeStateTLS13) processServerHello() error { c.sendAlert(alertIllegalParameter) return errors.New("tls: server did not send a key share") } - if hs.serverHello.serverShare.group != clientKeySharePrivateCurveID(hs.keySharePrivate) { + if singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group) == nil { c.sendAlert(alertIllegalParameter) return errors.New("tls: server selected unsupported group") } @@ -613,12 +613,16 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error { var sharedKey []byte var err error - if key, ok := hs.keySharePrivate.(*ecdh.PrivateKey); ok { + + // We already checked that ks isn't nil in processServerHello() + ks := singleClientKeySharePrivateFor(hs.keySharePrivate, hs.serverHello.serverShare.group) + + if key, ok := ks.(*ecdh.PrivateKey); ok { peerKey, err := key.Curve().NewPublicKey(hs.serverHello.serverShare.data) if err == nil { sharedKey, _ = key.ECDH(peerKey) } - } else if key, ok := hs.keySharePrivate.(*kemPrivateKey); ok { + } else if key, ok := ks.(*kemPrivateKey); ok { sk := key.secretKey sharedKey, err = sk.Scheme().Decapsulate(sk, hs.serverHello.serverShare.data) if err != nil { diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index 40c54cca58b..b3840f6c10f 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -248,7 +248,7 @@ func (m *clientHelloMsg) marshal() ([]byte, error) { }) }) } - if len(m.keyShares) > 0 { + if m.keyShares != nil { // RFC 8446, Section 4.2.8 exts.AddUint16(extensionKeyShare) exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { diff --git a/src/crypto/tls/tls_test.go b/src/crypto/tls/tls_test.go index 946a64c7c9e..9882cdc79a3 100644 --- a/src/crypto/tls/tls_test.go +++ b/src/crypto/tls/tls_test.go @@ -865,6 +865,8 @@ func TestCloneNonFuncFields(t *testing.T) { f.Set(reflect.ValueOf([]uint16{1, 2})) case "CurvePreferences": f.Set(reflect.ValueOf([]CurveID{CurveP256})) + case "ClientCurveGuess": + f.Set(reflect.ValueOf([]CurveID{CurveP256})) case "PQSignatureSchemesEnabled": f.Set(reflect.ValueOf(true)) case "Renegotiation":