From cdd4200a5686ff9fe88e28711a6533b6b6fb1746 Mon Sep 17 00:00:00 2001 From: Marten Seemann Date: Thu, 22 Dec 2022 22:12:31 +1300 Subject: [PATCH] make it possible to call ConnectionState during the handshake --- conn.go | 27 ++++++++++----- handshake_client.go | 1 + handshake_client_tls13.go | 4 ++- handshake_server.go | 1 + handshake_server_test.go | 1 + handshake_server_tls13.go | 4 ++- record_layer_test.go | 70 +++++++++++++++++++++++++++++---------- 7 files changed, 80 insertions(+), 28 deletions(-) diff --git a/conn.go b/conn.go index 49c85f7..656c83c 100644 --- a/conn.go +++ b/conn.go @@ -126,6 +126,9 @@ type Conn struct { used0RTT bool tmp [16]byte + + connStateMutex sync.Mutex + connState ConnectionStateWith0RTT } // Access to net.Conn methods. @@ -1566,19 +1569,16 @@ func (c *Conn) handshakeContext(ctx context.Context) (ret error) { // ConnectionState returns basic TLS details about the connection. func (c *Conn) ConnectionState() ConnectionState { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return c.connectionStateLocked() + c.connStateMutex.Lock() + defer c.connStateMutex.Unlock() + return c.connState.ConnectionState } // ConnectionStateWith0RTT returns basic TLS details (incl. 0-RTT status) about the connection. func (c *Conn) ConnectionStateWith0RTT() ConnectionStateWith0RTT { - c.handshakeMutex.Lock() - defer c.handshakeMutex.Unlock() - return ConnectionStateWith0RTT{ - ConnectionState: c.connectionStateLocked(), - Used0RTT: c.used0RTT, - } + c.connStateMutex.Lock() + defer c.connStateMutex.Unlock() + return c.connState } func (c *Conn) connectionStateLocked() ConnectionState { @@ -1609,6 +1609,15 @@ func (c *Conn) connectionStateLocked() ConnectionState { return toConnectionState(state) } +func (c *Conn) updateConnectionState() { + c.connStateMutex.Lock() + defer c.connStateMutex.Unlock() + c.connState = ConnectionStateWith0RTT{ + Used0RTT: c.used0RTT, + ConnectionState: c.connectionStateLocked(), + } +} + // OCSPResponse returns the stapled OCSP response from the TLS server, if // any. (Only valid for client connections.) func (c *Conn) OCSPResponse() []byte { diff --git a/handshake_client.go b/handshake_client.go index d7baa8b..5b15759 100644 --- a/handshake_client.go +++ b/handshake_client.go @@ -302,6 +302,7 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) { c.config.ClientSessionCache.Put(cacheKey, toClientSessionState(hs.session)) } + c.updateConnectionState() return nil } diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go index a6a0283..60ae299 100644 --- a/handshake_client_tls13.go +++ b/handshake_client_tls13.go @@ -87,6 +87,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { if err := hs.processServerHello(); err != nil { return err } + c.updateConnectionState() if err := hs.sendDummyChangeCipherSpec(); err != nil { return err } @@ -99,6 +100,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { if err := hs.readServerCertificate(); err != nil { return err } + c.updateConnectionState() if err := hs.readServerFinished(); err != nil { return err } @@ -113,7 +115,7 @@ func (hs *clientHandshakeStateTLS13) handshake() error { } c.isHandshakeComplete.Store(true) - + c.updateConnectionState() return nil } diff --git a/handshake_server.go b/handshake_server.go index ea4136b..3443e85 100644 --- a/handshake_server.go +++ b/handshake_server.go @@ -131,6 +131,7 @@ func (hs *serverHandshakeState) handshake() error { c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random) c.isHandshakeComplete.Store(true) + c.updateConnectionState() return nil } diff --git a/handshake_server_test.go b/handshake_server_test.go index dfb8d0b..4ad5646 100644 --- a/handshake_server_test.go +++ b/handshake_server_test.go @@ -1579,6 +1579,7 @@ func TestSNIGivenOnFailure(t *testing.T) { t.Error("No error reported from server") } + hs.c.updateConnectionState() cs := hs.c.ConnectionState() if cs.HandshakeComplete { t.Error("Handshake registered as complete") diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go index 5ce0a80..069cc06 100644 --- a/handshake_server_tls13.go +++ b/handshake_server_tls13.go @@ -56,6 +56,7 @@ func (hs *serverHandshakeStateTLS13) handshake() error { if err := hs.checkForResumption(); err != nil { return err } + c.updateConnectionState() if err := hs.pickCertificate(); err != nil { return err } @@ -78,12 +79,13 @@ func (hs *serverHandshakeStateTLS13) handshake() error { if err := hs.readClientCertificate(); err != nil { return err } + c.updateConnectionState() if err := hs.readClientFinished(); err != nil { return err } c.isHandshakeComplete.Store(true) - + c.updateConnectionState() return nil } diff --git a/record_layer_test.go b/record_layer_test.go index 1879bd2..4f3d28b 100644 --- a/record_layer_test.go +++ b/record_layer_test.go @@ -74,8 +74,26 @@ func TestAlternativeRecordLayer(t *testing.T) { cOut := make(chan interface{}, 10) defer close(cOut) - serverKeyChan := make(chan *exportedKey, 4) // see server loop for the order in which keys are provided + testConfig := testConfig.Clone() + testConfig.NextProtos = []string{"alpn"} + // server side + errChan := make(chan error) + serverConn := Server( + &unusedConn{}, + testConfig, + &ExtraConfig{AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut}}, + ) + go func() { + defer serverConn.Close() + err := serverConn.Handshake() + connState := serverConn.ConnectionState() + if !connState.HandshakeComplete { + t.Fatal("expected the handshake to have completed") + } + errChan <- err + }() + serverKeyChan := make(chan *exportedKey, 4) // see server loop for the order in which keys are provided go func() { var counter int for { @@ -88,6 +106,16 @@ func TestAlternativeRecordLayer(t *testing.T) { if c.([]byte)[0] != typeServerHello { t.Errorf("expected ServerHello") } + connState := serverConn.ConnectionState() + if connState.HandshakeComplete { + t.Error("didn't expect the handshake to be complete yet") + } + if connState.Version != VersionTLS13 { + t.Errorf("expected TLS 1.3, got %x", connState.Version) + } + if connState.NegotiatedProtocol == "" { + t.Error("expected ALPN to be negotiated") + } case 1: keyEv := c.(*exportedKey) if keyEv.typ != "read" || keyEv.encLevel != EncryptionHandshake { @@ -139,6 +167,12 @@ func TestAlternativeRecordLayer(t *testing.T) { }() // client side + clientConn := Client( + &unusedConn{}, + testConfig, + &ExtraConfig{AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut}}, + ) + defer clientConn.Close() go func() { var counter int for { @@ -151,6 +185,13 @@ func TestAlternativeRecordLayer(t *testing.T) { if c.([]byte)[0] != typeClientHello { t.Errorf("expected ClientHello") } + connState := clientConn.ConnectionState() + if connState.HandshakeComplete { + t.Error("didn't expect the handshake to be complete yet") + } + if len(connState.PeerCertificates) != 0 { + t.Error("didn't expect a certificate yet") + } case 1: keyEv := c.(*exportedKey) if keyEv.typ != "write" || keyEv.encLevel != EncryptionHandshake { @@ -189,24 +230,19 @@ func TestAlternativeRecordLayer(t *testing.T) { } }() - errChan := make(chan error) - go func() { - extraConf := &ExtraConfig{ - AlternativeRecordLayer: &recordLayerWithKeys{in: sIn, out: sOut}, - } - tlsConn := Server(&unusedConn{}, testConfig, extraConf) - defer tlsConn.Close() - errChan <- tlsConn.Handshake() - }() - - extraConf := &ExtraConfig{ - AlternativeRecordLayer: &recordLayerWithKeys{in: cIn, out: cOut}, - } - tlsConn := Client(&unusedConn{}, testConfig, extraConf) - defer tlsConn.Close() - if err := tlsConn.Handshake(); err != nil { + if err := clientConn.Handshake(); err != nil { t.Fatalf("Handshake failed: %s", err) } + connState := clientConn.ConnectionState() + if !connState.HandshakeComplete { + t.Fatal("expected the handshake to have completed") + } + if connState.Version != VersionTLS13 { + t.Errorf("expected TLS 1.3, got %x", connState.Version) + } + if len(connState.PeerCertificates) == 0 { + t.Fatal("expected the certificate to be set") + } select { case <-time.After(500 * time.Millisecond):