]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: replay test recordings without network
authorFilippo Valsorda <filippo@golang.org>
Sun, 23 Jun 2024 12:10:14 +0000 (14:10 +0200)
committerGopher Robot <gobot@golang.org>
Mon, 24 Jun 2024 16:40:26 +0000 (16:40 +0000)
There is no reason to go across a pipe when replaying a conn recording.
This avoids the complexity of using localPipe and goroutines, and makes
handshake benchmarks more accurate, as we don't measure network
overhead.

Also note how it removes the need for -fast: operating locally we know
when the flow is over and can error out immediately, without waiting for
a read from the feeder on the other side of the pipe to timeout.

Avoids some noise in #67979, but doesn't fix the two root causes:
localPipe flakes and testing.B races.

Updates #67979

Change-Id: I153d3fa5a24847f3947823e8c3a7bc639f89bc1d
Reviewed-on: https://go-review.googlesource.com/c/go/+/594255
Auto-Submit: Filippo Valsorda <filippo@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Joedian Reid <joedian@google.com>
src/crypto/tls/handshake_client_test.go
src/crypto/tls/handshake_server_test.go
src/crypto/tls/handshake_test.go

index 4570f5b05edfddba5fc1c427bc7b897c47a915d7..501f9c6755f9e3fea103fd5b8844f620f9bb1bac 100644 (file)
@@ -283,7 +283,7 @@ func (test *clientTest) loadData() (flows [][]byte, err error) {
 }
 
 func (test *clientTest) run(t *testing.T, write bool) {
-       var clientConn, serverConn net.Conn
+       var clientConn net.Conn
        var recordingConn *recordingConn
        var childProcess *exec.Cmd
        var stdin opensslInput
@@ -302,178 +302,138 @@ func (test *clientTest) run(t *testing.T, write bool) {
                        }
                }()
        } else {
-               clientConn, serverConn = localPipe(t)
+               flows, err := test.loadData()
+               if err != nil {
+                       t.Fatalf("failed to load data from %s: %v", test.dataPath(), err)
+               }
+               clientConn = &replayingConn{t: t, flows: flows, reading: false}
        }
 
-       doneChan := make(chan bool)
-       defer func() {
-               clientConn.Close()
-               <-doneChan
-       }()
-       go func() {
-               defer close(doneChan)
+       config := test.config
+       if config == nil {
+               config = testConfig
+       }
+       client := Client(clientConn, config)
+       defer client.Close()
 
-               config := test.config
-               if config == nil {
-                       config = testConfig
-               }
-               client := Client(clientConn, config)
-               defer client.Close()
+       if _, err := client.Write([]byte("hello\n")); err != nil {
+               t.Errorf("Client.Write failed: %s", err)
+               return
+       }
 
-               if _, err := client.Write([]byte("hello\n")); err != nil {
-                       t.Errorf("Client.Write failed: %s", err)
-                       return
+       for i := 1; i <= test.numRenegotiations; i++ {
+               // The initial handshake will generate a
+               // handshakeComplete signal which needs to be quashed.
+               if i == 1 && write {
+                       <-stdout.handshakeComplete
                }
 
-               for i := 1; i <= test.numRenegotiations; i++ {
-                       // The initial handshake will generate a
-                       // handshakeComplete signal which needs to be quashed.
-                       if i == 1 && write {
-                               <-stdout.handshakeComplete
-                       }
-
-                       // OpenSSL will try to interleave application data and
-                       // a renegotiation if we send both concurrently.
-                       // Therefore: ask OpensSSL to start a renegotiation, run
-                       // a goroutine to call client.Read and thus process the
-                       // renegotiation request, watch for OpenSSL's stdout to
-                       // indicate that the handshake is complete and,
-                       // finally, have OpenSSL write something to cause
-                       // client.Read to complete.
-                       if write {
-                               stdin <- opensslRenegotiate
-                       }
-
-                       signalChan := make(chan struct{})
+               // OpenSSL will try to interleave application data and
+               // a renegotiation if we send both concurrently.
+               // Therefore: ask OpensSSL to start a renegotiation, run
+               // a goroutine to call client.Read and thus process the
+               // renegotiation request, watch for OpenSSL's stdout to
+               // indicate that the handshake is complete and,
+               // finally, have OpenSSL write something to cause
+               // client.Read to complete.
+               if write {
+                       stdin <- opensslRenegotiate
+               }
 
-                       go func() {
-                               defer close(signalChan)
+               signalChan := make(chan struct{})
 
-                               buf := make([]byte, 256)
-                               n, err := client.Read(buf)
+               go func() {
+                       defer close(signalChan)
 
-                               if test.checkRenegotiationError != nil {
-                                       newErr := test.checkRenegotiationError(i, err)
-                                       if err != nil && newErr == nil {
-                                               return
-                                       }
-                                       err = newErr
-                               }
+                       buf := make([]byte, 256)
+                       n, err := client.Read(buf)
 
-                               if err != nil {
-                                       t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
+                       if test.checkRenegotiationError != nil {
+                               newErr := test.checkRenegotiationError(i, err)
+                               if err != nil && newErr == nil {
                                        return
                                }
+                               err = newErr
+                       }
 
-                               buf = buf[:n]
-                               if !bytes.Equal([]byte(opensslSentinel), buf) {
-                                       t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
-                               }
-
-                               if expected := i + 1; client.handshakes != expected {
-                                       t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
-                               }
-                       }()
-
-                       if write && test.renegotiationExpectedToFail != i {
-                               <-stdout.handshakeComplete
-                               stdin <- opensslSendSentinel
+                       if err != nil {
+                               t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
+                               return
                        }
-                       <-signalChan
-               }
 
-               if test.sendKeyUpdate {
-                       if write {
-                               <-stdout.handshakeComplete
-                               stdin <- opensslKeyUpdate
+                       buf = buf[:n]
+                       if !bytes.Equal([]byte(opensslSentinel), buf) {
+                               t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
                        }
 
-                       doneRead := make(chan struct{})
+                       if expected := i + 1; client.handshakes != expected {
+                               t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
+                       }
+               }()
 
-                       go func() {
-                               defer close(doneRead)
+               if write && test.renegotiationExpectedToFail != i {
+                       <-stdout.handshakeComplete
+                       stdin <- opensslSendSentinel
+               }
+               <-signalChan
+       }
 
-                               buf := make([]byte, 256)
-                               n, err := client.Read(buf)
+       if test.sendKeyUpdate {
+               if write {
+                       <-stdout.handshakeComplete
+                       stdin <- opensslKeyUpdate
+               }
 
-                               if err != nil {
-                                       t.Errorf("Client.Read failed after KeyUpdate: %s", err)
-                                       return
-                               }
+               doneRead := make(chan struct{})
 
-                               buf = buf[:n]
-                               if !bytes.Equal([]byte(opensslSentinel), buf) {
-                                       t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
-                               }
-                       }()
+               go func() {
+                       defer close(doneRead)
 
-                       if write {
-                               // There's no real reason to wait for the client KeyUpdate to
-                               // send data with the new server keys, except that s_server
-                               // drops writes if they are sent at the wrong time.
-                               <-stdout.readKeyUpdate
-                               stdin <- opensslSendSentinel
-                       }
-                       <-doneRead
+                       buf := make([]byte, 256)
+                       n, err := client.Read(buf)
 
-                       if _, err := client.Write([]byte("hello again\n")); err != nil {
-                               t.Errorf("Client.Write failed: %s", err)
+                       if err != nil {
+                               t.Errorf("Client.Read failed after KeyUpdate: %s", err)
                                return
                        }
-               }
 
-               if test.validate != nil {
-                       if err := test.validate(client.ConnectionState()); err != nil {
-                               t.Errorf("validate callback returned error: %s", err)
+                       buf = buf[:n]
+                       if !bytes.Equal([]byte(opensslSentinel), buf) {
+                               t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
                        }
-               }
+               }()
 
-               // If the server sent us an alert after our last flight, give it a
-               // chance to arrive.
-               if write && test.renegotiationExpectedToFail == 0 {
-                       if err := peekError(client); err != nil {
-                               t.Errorf("final Read returned an error: %s", err)
-                       }
+               if write {
+                       // There's no real reason to wait for the client KeyUpdate to
+                       // send data with the new server keys, except that s_server
+                       // drops writes if they are sent at the wrong time.
+                       <-stdout.readKeyUpdate
+                       stdin <- opensslSendSentinel
                }
-       }()
+               <-doneRead
 
-       if !write {
-               flows, err := test.loadData()
-               if err != nil {
-                       t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
+               if _, err := client.Write([]byte("hello again\n")); err != nil {
+                       t.Errorf("Client.Write failed: %s", err)
+                       return
                }
-               for i, b := range flows {
-                       if i%2 == 1 {
-                               if *fast {
-                                       serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
-                               } else {
-                                       serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
-                               }
-                               serverConn.Write(b)
-                               continue
-                       }
-                       bb := make([]byte, len(b))
-                       if *fast {
-                               serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
-                       } else {
-                               serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
-                       }
-                       _, err := io.ReadFull(serverConn, bb)
-                       if err != nil {
-                               t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
-                       }
-                       if !bytes.Equal(b, bb) {
-                               t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
-                       }
+       }
+
+       if test.validate != nil {
+               if err := test.validate(client.ConnectionState()); err != nil {
+                       t.Errorf("validate callback returned error: %s", err)
                }
        }
 
-       <-doneChan
-       if !write {
-               serverConn.Close()
+       // If the server sent us an alert after our last flight, give it a
+       // chance to arrive.
+       if write && test.renegotiationExpectedToFail == 0 {
+               if err := peekError(client); err != nil {
+                       t.Errorf("final Read returned an error: %s", err)
+               }
        }
 
        if write {
+               clientConn.Close()
                path := test.dataPath()
                out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
                if err != nil {
index 44bc8f1bb7442a5ddc5bb4980620a147052352d2..94d3d0f6dc87bc6f48ce9e592ce1e273cd30182c 100644 (file)
@@ -21,6 +21,7 @@ import (
        "os/exec"
        "path/filepath"
        "runtime"
+       "slices"
        "strings"
        "testing"
        "time"
@@ -659,7 +660,7 @@ func (test *serverTest) loadData() (flows [][]byte, err error) {
 }
 
 func (test *serverTest) run(t *testing.T, write bool) {
-       var clientConn, serverConn net.Conn
+       var serverConn net.Conn
        var recordingConn *recordingConn
        var childProcess *exec.Cmd
 
@@ -676,65 +677,33 @@ func (test *serverTest) run(t *testing.T, write bool) {
                        }
                }()
        } else {
-               clientConn, serverConn = localPipe(t)
+               flows, err := test.loadData()
+               if err != nil {
+                       t.Fatalf("Failed to load data from %s", test.dataPath())
+               }
+               serverConn = &replayingConn{t: t, flows: flows, reading: true}
        }
        config := test.config
        if config == nil {
                config = testConfig
        }
        server := Server(serverConn, config)
-       connStateChan := make(chan ConnectionState, 1)
-       go func() {
-               _, err := server.Write([]byte("hello, world\n"))
-               if len(test.expectHandshakeErrorIncluding) > 0 {
-                       if err == nil {
-                               t.Errorf("Error expected, but no error returned")
-                       } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) {
-                               t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s)
-                       }
-               } else {
-                       if err != nil {
-                               t.Logf("Error from Server.Write: '%s'", err)
-                       }
-               }
-               server.Close()
-               serverConn.Close()
-               connStateChan <- server.ConnectionState()
-       }()
 
-       if !write {
-               flows, err := test.loadData()
-               if err != nil {
-                       t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+       _, err := server.Write([]byte("hello, world\n"))
+       if len(test.expectHandshakeErrorIncluding) > 0 {
+               if err == nil {
+                       t.Errorf("Error expected, but no error returned")
+               } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) {
+                       t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s)
                }
-               for i, b := range flows {
-                       if i%2 == 0 {
-                               if *fast {
-                                       clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
-                               } else {
-                                       clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
-                               }
-                               clientConn.Write(b)
-                               continue
-                       }
-                       bb := make([]byte, len(b))
-                       if *fast {
-                               clientConn.SetReadDeadline(time.Now().Add(1 * time.Second))
-                       } else {
-                               clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
-                       }
-                       n, err := io.ReadFull(clientConn, bb)
-                       if err != nil {
-                               t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
-                       }
-                       if !bytes.Equal(b, bb) {
-                               t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
-                       }
+       } else {
+               if err != nil {
+                       t.Logf("Error from Server.Write: '%s'", err)
                }
-               clientConn.Close()
        }
+       server.Close()
 
-       connState := <-connStateChan
+       connState := server.ConnectionState()
        peerCerts := connState.PeerCertificates
        if len(peerCerts) == len(test.expectedPeerCerts) {
                for i, peerCert := range peerCerts {
@@ -754,6 +723,7 @@ func (test *serverTest) run(t *testing.T, write bool) {
        }
 
        if write {
+               serverConn.Close()
                path := test.dataPath()
                out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
                if err != nil {
@@ -1330,37 +1300,14 @@ func benchmarkHandshakeServer(b *testing.B, version uint16, cipherSuite uint16,
        serverConn.Close()
        flows := serverConn.(*recordingConn).flows
 
-       feeder := make(chan struct{})
-       clientConn, serverConn = localPipe(b)
-
-       go func() {
-               for range feeder {
-                       for i, f := range flows {
-                               if i%2 == 0 {
-                                       clientConn.Write(f)
-                                       continue
-                               }
-                               ff := make([]byte, len(f))
-                               n, err := io.ReadFull(clientConn, ff)
-                               if err != nil {
-                                       b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f)
-                               }
-                               if !bytes.Equal(f, ff) {
-                                       b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f)
-                               }
-                       }
-               }
-       }()
-
        b.ResetTimer()
        for i := 0; i < b.N; i++ {
-               feeder <- struct{}{}
-               server := Server(serverConn, config)
+               replay := &replayingConn{t: b, flows: slices.Clone(flows), reading: true}
+               server := Server(replay, config)
                if err := server.Handshake(); err != nil {
                        b.Fatalf("handshake failed: %v", err)
                }
        }
-       close(feeder)
 }
 
 func BenchmarkHandshakeServer(b *testing.B) {
index 57fc761dbb81088388aeb53b78939a9fd28e6d14..bc3d23d5adc24efcb1afdca9a4674c55bf7bb98f 100644 (file)
@@ -6,6 +6,7 @@ package tls
 
 import (
        "bufio"
+       "bytes"
        "crypto/ed25519"
        "crypto/x509"
        "encoding/hex"
@@ -42,7 +43,6 @@ import (
 
 var (
        update       = flag.Bool("update", false, "update golden files on failure")
-       fast         = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests")
        keyFile      = flag.String("keylog", "", "destination file for KeyLogWriter")
        bogoMode     = flag.Bool("bogo-mode", false, "Enabled bogo shim mode, ignore everything else")
        bogoFilter   = flag.String("bogo-filter", "", "BoGo test filter")
@@ -223,6 +223,76 @@ func parseTestData(r io.Reader) (flows [][]byte, err error) {
        return flows, nil
 }
 
+// replayingConn is a net.Conn that replays flows recorded by recordingConn.
+type replayingConn struct {
+       t testing.TB
+       sync.Mutex
+       flows   [][]byte
+       reading bool
+}
+
+var _ net.Conn = (*replayingConn)(nil)
+
+func (r *replayingConn) Read(b []byte) (n int, err error) {
+       r.Lock()
+       defer r.Unlock()
+
+       if !r.reading {
+               r.t.Errorf("expected write, got read")
+               return 0, fmt.Errorf("recording expected write, got read")
+       }
+
+       n = copy(b, r.flows[0])
+       r.flows[0] = r.flows[0][n:]
+       if len(r.flows[0]) == 0 {
+               r.flows = r.flows[1:]
+               if len(r.flows) == 0 {
+                       return n, io.EOF
+               } else {
+                       r.reading = false
+               }
+       }
+       return n, nil
+}
+
+func (r *replayingConn) Write(b []byte) (n int, err error) {
+       r.Lock()
+       defer r.Unlock()
+
+       if r.reading {
+               r.t.Errorf("expected read, got write")
+               return 0, fmt.Errorf("recording expected read, got write")
+       }
+
+       if !bytes.HasPrefix(r.flows[0], b) {
+               r.t.Errorf("write mismatch: expected %x, got %x", r.flows[0], b)
+               return 0, fmt.Errorf("write mismatch")
+       }
+       r.flows[0] = r.flows[0][len(b):]
+       if len(r.flows[0]) == 0 {
+               r.flows = r.flows[1:]
+               r.reading = true
+       }
+       return len(b), nil
+}
+
+func (r *replayingConn) Close() error {
+       r.Lock()
+       defer r.Unlock()
+
+       if len(r.flows) > 0 {
+               r.t.Errorf("closed with unfinished flows")
+               return fmt.Errorf("unexpected close")
+       }
+       return nil
+}
+
+func (r *replayingConn) LocalAddr() net.Addr                { return nil }
+func (r *replayingConn) RemoteAddr() net.Addr               { return nil }
+func (r *replayingConn) SetDeadline(t time.Time) error      { return nil }
+func (r *replayingConn) SetReadDeadline(t time.Time) error  { return nil }
+func (r *replayingConn) SetWriteDeadline(t time.Time) error { return nil }
+
 // tempFile creates a temp file containing contents and returns its path.
 func tempFile(contents string) string {
        file, err := os.CreateTemp("", "go-tls-test")