}
func (test *clientTest) run(t *testing.T, write bool) {
- var clientConn, serverConn net.Conn
+ var clientConn net.Conn
var recordingConn *recordingConn
var childProcess *exec.Cmd
var stdin opensslInput
}
}()
} else {
- clientConn, serverConn = localPipe(t)
+ flows, err := test.loadData()
+ if err != nil {
+ t.Fatalf("failed to load data from %s: %v", test.dataPath(), err)
+ }
+ clientConn = &replayingConn{t: t, flows: flows, reading: false}
}
- doneChan := make(chan bool)
- defer func() {
- clientConn.Close()
- <-doneChan
- }()
- go func() {
- defer close(doneChan)
+ config := test.config
+ if config == nil {
+ config = testConfig
+ }
+ client := Client(clientConn, config)
+ defer client.Close()
- config := test.config
- if config == nil {
- config = testConfig
- }
- client := Client(clientConn, config)
- defer client.Close()
+ if _, err := client.Write([]byte("hello\n")); err != nil {
+ t.Errorf("Client.Write failed: %s", err)
+ return
+ }
- if _, err := client.Write([]byte("hello\n")); err != nil {
- t.Errorf("Client.Write failed: %s", err)
- return
+ for i := 1; i <= test.numRenegotiations; i++ {
+ // The initial handshake will generate a
+ // handshakeComplete signal which needs to be quashed.
+ if i == 1 && write {
+ <-stdout.handshakeComplete
}
- for i := 1; i <= test.numRenegotiations; i++ {
- // The initial handshake will generate a
- // handshakeComplete signal which needs to be quashed.
- if i == 1 && write {
- <-stdout.handshakeComplete
- }
-
- // OpenSSL will try to interleave application data and
- // a renegotiation if we send both concurrently.
- // Therefore: ask OpensSSL to start a renegotiation, run
- // a goroutine to call client.Read and thus process the
- // renegotiation request, watch for OpenSSL's stdout to
- // indicate that the handshake is complete and,
- // finally, have OpenSSL write something to cause
- // client.Read to complete.
- if write {
- stdin <- opensslRenegotiate
- }
-
- signalChan := make(chan struct{})
+ // OpenSSL will try to interleave application data and
+ // a renegotiation if we send both concurrently.
+ // Therefore: ask OpensSSL to start a renegotiation, run
+ // a goroutine to call client.Read and thus process the
+ // renegotiation request, watch for OpenSSL's stdout to
+ // indicate that the handshake is complete and,
+ // finally, have OpenSSL write something to cause
+ // client.Read to complete.
+ if write {
+ stdin <- opensslRenegotiate
+ }
- go func() {
- defer close(signalChan)
+ signalChan := make(chan struct{})
- buf := make([]byte, 256)
- n, err := client.Read(buf)
+ go func() {
+ defer close(signalChan)
- if test.checkRenegotiationError != nil {
- newErr := test.checkRenegotiationError(i, err)
- if err != nil && newErr == nil {
- return
- }
- err = newErr
- }
+ buf := make([]byte, 256)
+ n, err := client.Read(buf)
- if err != nil {
- t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
+ if test.checkRenegotiationError != nil {
+ newErr := test.checkRenegotiationError(i, err)
+ if err != nil && newErr == nil {
return
}
+ err = newErr
+ }
- buf = buf[:n]
- if !bytes.Equal([]byte(opensslSentinel), buf) {
- t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
- }
-
- if expected := i + 1; client.handshakes != expected {
- t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
- }
- }()
-
- if write && test.renegotiationExpectedToFail != i {
- <-stdout.handshakeComplete
- stdin <- opensslSendSentinel
+ if err != nil {
+ t.Errorf("Client.Read failed after renegotiation #%d: %s", i, err)
+ return
}
- <-signalChan
- }
- if test.sendKeyUpdate {
- if write {
- <-stdout.handshakeComplete
- stdin <- opensslKeyUpdate
+ buf = buf[:n]
+ if !bytes.Equal([]byte(opensslSentinel), buf) {
+ t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
- doneRead := make(chan struct{})
+ if expected := i + 1; client.handshakes != expected {
+ t.Errorf("client should have recorded %d handshakes, but believes that %d have occurred", expected, client.handshakes)
+ }
+ }()
- go func() {
- defer close(doneRead)
+ if write && test.renegotiationExpectedToFail != i {
+ <-stdout.handshakeComplete
+ stdin <- opensslSendSentinel
+ }
+ <-signalChan
+ }
- buf := make([]byte, 256)
- n, err := client.Read(buf)
+ if test.sendKeyUpdate {
+ if write {
+ <-stdout.handshakeComplete
+ stdin <- opensslKeyUpdate
+ }
- if err != nil {
- t.Errorf("Client.Read failed after KeyUpdate: %s", err)
- return
- }
+ doneRead := make(chan struct{})
- buf = buf[:n]
- if !bytes.Equal([]byte(opensslSentinel), buf) {
- t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
- }
- }()
+ go func() {
+ defer close(doneRead)
- if write {
- // There's no real reason to wait for the client KeyUpdate to
- // send data with the new server keys, except that s_server
- // drops writes if they are sent at the wrong time.
- <-stdout.readKeyUpdate
- stdin <- opensslSendSentinel
- }
- <-doneRead
+ buf := make([]byte, 256)
+ n, err := client.Read(buf)
- if _, err := client.Write([]byte("hello again\n")); err != nil {
- t.Errorf("Client.Write failed: %s", err)
+ if err != nil {
+ t.Errorf("Client.Read failed after KeyUpdate: %s", err)
return
}
- }
- if test.validate != nil {
- if err := test.validate(client.ConnectionState()); err != nil {
- t.Errorf("validate callback returned error: %s", err)
+ buf = buf[:n]
+ if !bytes.Equal([]byte(opensslSentinel), buf) {
+ t.Errorf("Client.Read returned %q, but wanted %q", string(buf), opensslSentinel)
}
- }
+ }()
- // If the server sent us an alert after our last flight, give it a
- // chance to arrive.
- if write && test.renegotiationExpectedToFail == 0 {
- if err := peekError(client); err != nil {
- t.Errorf("final Read returned an error: %s", err)
- }
+ if write {
+ // There's no real reason to wait for the client KeyUpdate to
+ // send data with the new server keys, except that s_server
+ // drops writes if they are sent at the wrong time.
+ <-stdout.readKeyUpdate
+ stdin <- opensslSendSentinel
}
- }()
+ <-doneRead
- if !write {
- flows, err := test.loadData()
- if err != nil {
- t.Fatalf("%s: failed to load data from %s: %v", test.name, test.dataPath(), err)
+ if _, err := client.Write([]byte("hello again\n")); err != nil {
+ t.Errorf("Client.Write failed: %s", err)
+ return
}
- for i, b := range flows {
- if i%2 == 1 {
- if *fast {
- serverConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
- } else {
- serverConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
- }
- serverConn.Write(b)
- continue
- }
- bb := make([]byte, len(b))
- if *fast {
- serverConn.SetReadDeadline(time.Now().Add(1 * time.Second))
- } else {
- serverConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
- }
- _, err := io.ReadFull(serverConn, bb)
- if err != nil {
- t.Fatalf("%s, flow %d: %s", test.name, i+1, err)
- }
- if !bytes.Equal(b, bb) {
- t.Fatalf("%s, flow %d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
- }
+ }
+
+ if test.validate != nil {
+ if err := test.validate(client.ConnectionState()); err != nil {
+ t.Errorf("validate callback returned error: %s", err)
}
}
- <-doneChan
- if !write {
- serverConn.Close()
+ // If the server sent us an alert after our last flight, give it a
+ // chance to arrive.
+ if write && test.renegotiationExpectedToFail == 0 {
+ if err := peekError(client); err != nil {
+ t.Errorf("final Read returned an error: %s", err)
+ }
}
if write {
+ clientConn.Close()
path := test.dataPath()
out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
"os/exec"
"path/filepath"
"runtime"
+ "slices"
"strings"
"testing"
"time"
}
func (test *serverTest) run(t *testing.T, write bool) {
- var clientConn, serverConn net.Conn
+ var serverConn net.Conn
var recordingConn *recordingConn
var childProcess *exec.Cmd
}
}()
} else {
- clientConn, serverConn = localPipe(t)
+ flows, err := test.loadData()
+ if err != nil {
+ t.Fatalf("Failed to load data from %s", test.dataPath())
+ }
+ serverConn = &replayingConn{t: t, flows: flows, reading: true}
}
config := test.config
if config == nil {
config = testConfig
}
server := Server(serverConn, config)
- connStateChan := make(chan ConnectionState, 1)
- go func() {
- _, err := server.Write([]byte("hello, world\n"))
- if len(test.expectHandshakeErrorIncluding) > 0 {
- if err == nil {
- t.Errorf("Error expected, but no error returned")
- } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) {
- t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s)
- }
- } else {
- if err != nil {
- t.Logf("Error from Server.Write: '%s'", err)
- }
- }
- server.Close()
- serverConn.Close()
- connStateChan <- server.ConnectionState()
- }()
- if !write {
- flows, err := test.loadData()
- if err != nil {
- t.Fatalf("%s: failed to load data from %s", test.name, test.dataPath())
+ _, err := server.Write([]byte("hello, world\n"))
+ if len(test.expectHandshakeErrorIncluding) > 0 {
+ if err == nil {
+ t.Errorf("Error expected, but no error returned")
+ } else if s := err.Error(); !strings.Contains(s, test.expectHandshakeErrorIncluding) {
+ t.Errorf("Error expected containing '%s' but got '%s'", test.expectHandshakeErrorIncluding, s)
}
- for i, b := range flows {
- if i%2 == 0 {
- if *fast {
- clientConn.SetWriteDeadline(time.Now().Add(1 * time.Second))
- } else {
- clientConn.SetWriteDeadline(time.Now().Add(1 * time.Minute))
- }
- clientConn.Write(b)
- continue
- }
- bb := make([]byte, len(b))
- if *fast {
- clientConn.SetReadDeadline(time.Now().Add(1 * time.Second))
- } else {
- clientConn.SetReadDeadline(time.Now().Add(1 * time.Minute))
- }
- n, err := io.ReadFull(clientConn, bb)
- if err != nil {
- t.Fatalf("%s #%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", test.name, i+1, err, n, len(bb), bb[:n], b)
- }
- if !bytes.Equal(b, bb) {
- t.Fatalf("%s #%d: mismatch on read: got:%x want:%x", test.name, i+1, bb, b)
- }
+ } else {
+ if err != nil {
+ t.Logf("Error from Server.Write: '%s'", err)
}
- clientConn.Close()
}
+ server.Close()
- connState := <-connStateChan
+ connState := server.ConnectionState()
peerCerts := connState.PeerCertificates
if len(peerCerts) == len(test.expectedPeerCerts) {
for i, peerCert := range peerCerts {
}
if write {
+ serverConn.Close()
path := test.dataPath()
out, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
serverConn.Close()
flows := serverConn.(*recordingConn).flows
- feeder := make(chan struct{})
- clientConn, serverConn = localPipe(b)
-
- go func() {
- for range feeder {
- for i, f := range flows {
- if i%2 == 0 {
- clientConn.Write(f)
- continue
- }
- ff := make([]byte, len(f))
- n, err := io.ReadFull(clientConn, ff)
- if err != nil {
- b.Errorf("#%d: %s\nRead %d, wanted %d, got %x, wanted %x\n", i+1, err, n, len(ff), ff[:n], f)
- }
- if !bytes.Equal(f, ff) {
- b.Errorf("#%d: mismatch on read: got:%x want:%x", i+1, ff, f)
- }
- }
- }
- }()
-
b.ResetTimer()
for i := 0; i < b.N; i++ {
- feeder <- struct{}{}
- server := Server(serverConn, config)
+ replay := &replayingConn{t: b, flows: slices.Clone(flows), reading: true}
+ server := Server(replay, config)
if err := server.Handshake(); err != nil {
b.Fatalf("handshake failed: %v", err)
}
}
- close(feeder)
}
func BenchmarkHandshakeServer(b *testing.B) {
import (
"bufio"
+ "bytes"
"crypto/ed25519"
"crypto/x509"
"encoding/hex"
var (
update = flag.Bool("update", false, "update golden files on failure")
- fast = flag.Bool("fast", false, "impose a quick, possibly flaky timeout on recorded tests")
keyFile = flag.String("keylog", "", "destination file for KeyLogWriter")
bogoMode = flag.Bool("bogo-mode", false, "Enabled bogo shim mode, ignore everything else")
bogoFilter = flag.String("bogo-filter", "", "BoGo test filter")
return flows, nil
}
+// replayingConn is a net.Conn that replays flows recorded by recordingConn.
+type replayingConn struct {
+ t testing.TB
+ sync.Mutex
+ flows [][]byte
+ reading bool
+}
+
+var _ net.Conn = (*replayingConn)(nil)
+
+func (r *replayingConn) Read(b []byte) (n int, err error) {
+ r.Lock()
+ defer r.Unlock()
+
+ if !r.reading {
+ r.t.Errorf("expected write, got read")
+ return 0, fmt.Errorf("recording expected write, got read")
+ }
+
+ n = copy(b, r.flows[0])
+ r.flows[0] = r.flows[0][n:]
+ if len(r.flows[0]) == 0 {
+ r.flows = r.flows[1:]
+ if len(r.flows) == 0 {
+ return n, io.EOF
+ } else {
+ r.reading = false
+ }
+ }
+ return n, nil
+}
+
+func (r *replayingConn) Write(b []byte) (n int, err error) {
+ r.Lock()
+ defer r.Unlock()
+
+ if r.reading {
+ r.t.Errorf("expected read, got write")
+ return 0, fmt.Errorf("recording expected read, got write")
+ }
+
+ if !bytes.HasPrefix(r.flows[0], b) {
+ r.t.Errorf("write mismatch: expected %x, got %x", r.flows[0], b)
+ return 0, fmt.Errorf("write mismatch")
+ }
+ r.flows[0] = r.flows[0][len(b):]
+ if len(r.flows[0]) == 0 {
+ r.flows = r.flows[1:]
+ r.reading = true
+ }
+ return len(b), nil
+}
+
+func (r *replayingConn) Close() error {
+ r.Lock()
+ defer r.Unlock()
+
+ if len(r.flows) > 0 {
+ r.t.Errorf("closed with unfinished flows")
+ return fmt.Errorf("unexpected close")
+ }
+ return nil
+}
+
+func (r *replayingConn) LocalAddr() net.Addr { return nil }
+func (r *replayingConn) RemoteAddr() net.Addr { return nil }
+func (r *replayingConn) SetDeadline(t time.Time) error { return nil }
+func (r *replayingConn) SetReadDeadline(t time.Time) error { return nil }
+func (r *replayingConn) SetWriteDeadline(t time.Time) error { return nil }
+
// tempFile creates a temp file containing contents and returns its path.
func tempFile(contents string) string {
file, err := os.CreateTemp("", "go-tls-test")