Skip to content
This repository has been archived by the owner on Jan 31, 2024. It is now read-only.

Commit

Permalink
make it possible to call ConnectionState during the handshake
Browse files Browse the repository at this point in the history
  • Loading branch information
marten-seemann committed Mar 18, 2023
1 parent 9b58331 commit 8fcfb95
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 28 deletions.
27 changes: 18 additions & 9 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ type Conn struct {
used0RTT bool

tmp [16]byte

connStateMutex sync.Mutex
connState ConnectionStateWith0RTT
}

// Access to net.Conn methods.
Expand Down Expand Up @@ -1566,19 +1569,16 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) {

// ConnectionState returns basic TLS details about the connection.
func (c *Conn) ConnectionState() ConnectionState {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return c.connectionStateLocked()
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
return c.connState.ConnectionState
}

// ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection.
func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT {
c.handshakeMutex.Lock()
defer c.handshakeMutex.Unlock()
return ConnectionStateWith0RTT{
ConnectionState: c.connectionStateLocked(),
Used0RTT: c.used0RTT,
}
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
return c.connState
}

func (c *Conn) connectionStateLocked() ConnectionState {
Expand Down Expand Up @@ -1609,6 +1609,15 @@ func (c *Conn) connectionStateLocked() ConnectionState {
return toConnectionState(state)
}

func (c *Conn) updateConnectionState() {
c.connStateMutex.Lock()
defer c.connStateMutex.Unlock()
c.connState = ConnectionStateWith0RTT{
Used0RTT: c.used0RTT,
ConnectionState: c.connectionStateLocked(),
}
}

// OCSPResponse returns the stapled OCSP response from the TLS server, if
// any. (Only valid for client connections.)
func (c *Conn) OCSPResponse() []byte {
Expand Down
1 change: 1 addition & 0 deletions handshake_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session))
}

c.updateConnectionState()
return nil
}

Expand Down
4 changes: 3 additions & 1 deletion handshake_client_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.processServerHello(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.sendDummyChangeCipherSpec(); err != nil {
return err
}
Expand All @@ -99,6 +100,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
if err := hs.readServerCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readServerFinished(); err != nil {
return err
}
Expand All @@ -113,7 +115,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}

c.isHandshakeComplete.Store(true)

c.updateConnectionState()
return nil
}

Expand Down
1 change: 1 addition & 0 deletions handshake_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func (hs *serverHandshakeState) handshake() error {
c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
c.isHandshakeComplete.Store(true)

c.updateConnectionState()
return nil
}

Expand Down
1 change: 1 addition & 0 deletions handshake_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1507,6 +1507,7 @@ func TestSNIGivenOnFailure(t *testing.T) {
t.Error("No error reported from server")
}

hs.c.updateConnectionState()
cs := hs.c.ConnectionState()
if cs.HandshakeComplete {
t.Error("Handshake registered as complete")
Expand Down
4 changes: 3 additions & 1 deletion handshake_server_tls13.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
if err := hs.checkForResumption(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.pickCertificate(); err != nil {
return err
}
Expand All @@ -78,12 +79,13 @@ func (hs *serverHandshakeStateTLS13) handshake() error {
if err := hs.readClientCertificate(); err != nil {
return err
}
c.updateConnectionState()
if err := hs.readClientFinished(); err != nil {
return err
}

c.isHandshakeComplete.Store(true)

c.updateConnectionState()
return nil
}

Expand Down
70 changes: 53 additions & 17 deletions record_layer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,26 @@ func TestAlternativeRecordLayer(t *testing.T) {
cOut := make(chan interface{}, 10)
defer close(cOut)

serverKeyChan := make(chan *exportedKey, 4) // see server loop for the order in which keys are provided
testConfig := testConfig.Clone()
testConfig.NextProtos = []string{"alpn"}

// server side
errChan := make(chan error)
serverConn := Server(
&unusedConn{},
testConfig,
&ExtraConfig{AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut}},
)
go func() {
defer serverConn.Close()
err := serverConn.Handshake()
connState := serverConn.ConnectionState()
if !connState.HandshakeComplete {
t.Fatal("expected the handshake to have completed")
}
errChan <- err
}()
serverKeyChan := make(chan *exportedKey, 4) // see server loop for the order in which keys are provided
go func() {
var counter int
for {
Expand All @@ -88,6 +106,16 @@ func TestAlternativeRecordLayer(t *testing.T) {
if c.([]byte)[0] != typeServerHello {
t.Errorf("expected ServerHello")
}
connState := serverConn.ConnectionState()
if connState.HandshakeComplete {
t.Error("didn't expect the handshake to be complete yet")
}
if connState.Version != VersionTLS13 {
t.Errorf("expected TLS 1.3, got %x", connState.Version)
}
if connState.NegotiatedProtocol == "" {
t.Error("expected ALPN to be negotiated")
}
case 1:
keyEv := c.(*exportedKey)
if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake {
Expand Down Expand Up @@ -139,6 +167,12 @@ func TestAlternativeRecordLayer(t *testing.T) {
}()

// client side
clientConn := Client(
&unusedConn{},
testConfig,
&ExtraConfig{AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut}},
)
defer clientConn.Close()
go func() {
var counter int
for {
Expand All @@ -151,6 +185,13 @@ func TestAlternativeRecordLayer(t *testing.T) {
if c.([]byte)[0] != typeClientHello {
t.Errorf("expected ClientHello")
}
connState := clientConn.ConnectionState()
if connState.HandshakeComplete {
t.Error("didn't expect the handshake to be complete yet")
}
if len(connState.PeerCertificates) != 0 {
t.Error("didn't expect a certificate yet")
}
case 1:
keyEv := c.(*exportedKey)
if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake {
Expand Down Expand Up @@ -189,24 +230,19 @@ func TestAlternativeRecordLayer(t *testing.T) {
}
}()

errChan := make(chan error)
go func() {
extraConf := &ExtraConfig{
AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut},
}
tlsConn := Server(&unusedConn{}, testConfig, extraConf)
defer tlsConn.Close()
errChan <- tlsConn.Handshake()
}()

extraConf := &ExtraConfig{
AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut},
}
tlsConn := Client(&unusedConn{}, testConfig, extraConf)
defer tlsConn.Close()
if err := tlsConn.Handshake(); err != nil {
if err := clientConn.Handshake(); err != nil {
t.Fatalf("Handshake failed: %s", err)
}
connState := clientConn.ConnectionState()
if !connState.HandshakeComplete {
t.Fatal("expected the handshake to have completed")
}
if connState.Version != VersionTLS13 {
t.Errorf("expected TLS 1.3, got %x", connState.Version)
}
if len(connState.PeerCertificates) == 0 {
t.Fatal("expected the certificate to be set")
}

select {
case <-time.After(500 * time.Millisecond):
Expand Down

0 comments on commit 8fcfb95

Please sign in to comment.