From fd0be4bc2ef611dc4dc0512b72f398d5b66163f9 Mon Sep 17 00:00:00 2001 From: Jonathan Hoyland Date: Thu, 10 Aug 2023 17:47:47 +0100 Subject: [PATCH] Add mTLS flag support --- src/crypto/tls/handshake_messages.go | 12 +++++++++--- src/crypto/tls/handshake_server_tls13.go | 18 +++++++++++++++++- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/crypto/tls/handshake_messages.go b/src/crypto/tls/handshake_messages.go index e5d70b5ef6b..8d17b955213 100644 --- a/src/crypto/tls/handshake_messages.go +++ b/src/crypto/tls/handshake_messages.go @@ -236,7 +236,9 @@ func (m *clientHelloMsg) marshal() ([]byte, error) { if len(m.tlsFlags) > 0 { exts.AddUint16(extensionTLSFlags) exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) { - exts.AddBytes(m.tlsFlags) + exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) { + exts.AddBytes(m.tlsFlags) + }) }) } if len(m.cookie) > 0 { @@ -568,9 +570,13 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool { m.supportedVersions = append(m.supportedVersions, vers) } case extensionTLSFlags: - for !extData.Empty() { + var flagsList cryptobyte.String + if !extData.ReadUint8LengthPrefixed(&flagsList) || flagsList.Empty() { + return false + } + for !flagsList.Empty() { var flagByte uint8 - if !extData.ReadUint8(&flagByte) { + if !flagsList.ReadUint8(&flagByte) { return false } m.tlsFlags = append(m.tlsFlags, flagByte) diff --git a/src/crypto/tls/handshake_server_tls13.go b/src/crypto/tls/handshake_server_tls13.go index d27f55a0af0..9390f90756f 100644 --- a/src/crypto/tls/handshake_server_tls13.go +++ b/src/crypto/tls/handshake_server_tls13.go @@ -314,7 +314,18 @@ GroupSelection: return errors.New("tls: invalid client key share") } if len(hs.clientHello.tlsFlags) != 0 { - tlsFlags, err := decodeFlags(hs.clientHello.tlsFlags) + supportedFlags, err := encodeFlags(hs.c.config.TLSFlagsSupported) + if err != nil { + return errors.New("tls: invalid server flags") + } + var mutuallySupportedFlags []byte + for i, sFB := range supportedFlags { + if i >= len(hs.clientHello.tlsFlags) { + break + } + mutuallySupportedFlags = append(mutuallySupportedFlags, hs.clientHello.tlsFlags[i]&sFB) + } + tlsFlags, err := decodeFlags(mutuallySupportedFlags) if err == nil { hs.tlsFlags = tlsFlags } @@ -838,6 +849,11 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error { } func (hs *serverHandshakeStateTLS13) requestClientCert() bool { + for _, flag := range hs.tlsFlags { + if flag == FlagSupportMTLS { + return true + } + } return hs.c.config.ClientAuth >= RequestClientCert && !hs.usingPSK }