]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: run TestServerShutdownStateNew in a synctest bubble
authorDamien Neil <dneil@google.com>
Tue, 11 Jun 2024 20:44:43 +0000 (13:44 -0700)
committerDamien Neil <dneil@google.com>
Mon, 25 Nov 2024 22:02:07 +0000 (22:02 +0000)
Took ~12s previously, ~0s now.

Change-Id: I72580fbde73482a40142cf84cd3d78a50afb9f44
Reviewed-on: https://go-review.googlesource.com/c/go/+/630382
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
src/net/http/async_test.go [new file with mode: 0644]
src/net/http/clientserver_test.go
src/net/http/netconn_test.go [new file with mode: 0644]
src/net/http/serve_test.go

diff --git a/src/net/http/async_test.go b/src/net/http/async_test.go
new file mode 100644 (file)
index 0000000..545cbcf
--- /dev/null
@@ -0,0 +1,52 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "errors"
+       "internal/synctest"
+)
+
+var errStillRunning = errors.New("async op still running")
+
+type asyncResult[T any] struct {
+       donec chan struct{}
+       res   T
+       err   error
+}
+
+// runAsync runs f in a new goroutine.
+// It returns an asyncResult which acts as a future.
+//
+// Must be called from within a synctest bubble.
+func runAsync[T any](f func() (T, error)) *asyncResult[T] {
+       r := &asyncResult[T]{
+               donec: make(chan struct{}),
+       }
+       go func() {
+               defer close(r.donec)
+               r.res, r.err = f()
+       }()
+       synctest.Wait()
+       return r
+}
+
+// done reports whether the function has returned.
+func (r *asyncResult[T]) done() bool {
+       _, err := r.result()
+       return err != errStillRunning
+}
+
+// result returns the result of the function.
+// If the function hasn't completed yet, it returns errStillRunning.
+func (r *asyncResult[T]) result() (T, error) {
+       select {
+       case <-r.donec:
+               return r.res, r.err
+       default:
+               var zero T
+               return zero, errStillRunning
+       }
+}
index 606715a25c2816934a05c338259e0f06b656304c..08730387578297e7fc3e7cbde90233fcdb93e5ef 100644 (file)
@@ -15,6 +15,7 @@ import (
        "crypto/tls"
        "fmt"
        "hash"
+       "internal/synctest"
        "io"
        "log"
        "maps"
@@ -93,6 +94,37 @@ func run[T TBRun[T]](t T, f func(t T, mode testMode), opts ...any) {
        }
 }
 
+// cleanupT wraps a testing.T and adds its own Cleanup method.
+// Used to execute cleanup functions within a synctest bubble.
+type cleanupT struct {
+       *testing.T
+       cleanups []func()
+}
+
+// Cleanup replaces T.Cleanup.
+func (t *cleanupT) Cleanup(f func()) {
+       t.cleanups = append(t.cleanups, f)
+}
+
+func (t *cleanupT) done() {
+       for _, f := range slices.Backward(t.cleanups) {
+               f()
+       }
+}
+
+// runSynctest is run combined with synctest.Run.
+//
+// The TB passed to f arranges for cleanup functions to be run in the synctest bubble.
+func runSynctest(t *testing.T, f func(t testing.TB, mode testMode), opts ...any) {
+       run(t, func(t *testing.T, mode testMode) {
+               synctest.Run(func() {
+                       ct := &cleanupT{T: t}
+                       defer ct.done()
+                       f(ct, mode)
+               })
+       }, opts...)
+}
+
 type clientServerTest struct {
        t  testing.TB
        h2 bool
diff --git a/src/net/http/netconn_test.go b/src/net/http/netconn_test.go
new file mode 100644 (file)
index 0000000..251b919
--- /dev/null
@@ -0,0 +1,410 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "bytes"
+       "context"
+       "internal/synctest"
+       "io"
+       "math"
+       "net"
+       "net/netip"
+       "os"
+       "sync"
+       "time"
+)
+
+func fakeNetListen() *fakeNetListener {
+       li := &fakeNetListener{
+               setc:   make(chan struct{}, 1),
+               unsetc: make(chan struct{}, 1),
+               addr:   net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")),
+       }
+       li.unsetc <- struct{}{}
+       return li
+}
+
+type fakeNetListener struct {
+       setc, unsetc chan struct{}
+       queue        []net.Conn
+       closed       bool
+       addr         net.Addr
+}
+
+func (li *fakeNetListener) lock() {
+       select {
+       case <-li.setc:
+       case <-li.unsetc:
+       }
+}
+
+func (li *fakeNetListener) unlock() {
+       if li.closed || len(li.queue) > 0 {
+               li.setc <- struct{}{}
+       } else {
+               li.unsetc <- struct{}{}
+       }
+}
+
+func (li *fakeNetListener) connect() *fakeNetConn {
+       li.lock()
+       defer li.unlock()
+       c0, c1 := fakeNetPipe()
+       li.queue = append(li.queue, c0)
+       return c1
+}
+
+func (li *fakeNetListener) Accept() (net.Conn, error) {
+       <-li.setc
+       defer li.unlock()
+       if li.closed {
+               return nil, net.ErrClosed
+       }
+       c := li.queue[0]
+       li.queue = li.queue[1:]
+       return c, nil
+}
+
+func (li *fakeNetListener) Close() error {
+       li.lock()
+       defer li.unlock()
+       li.closed = true
+       return nil
+}
+
+func (li *fakeNetListener) Addr() net.Addr {
+       return li.addr
+}
+
+// fakeNetPipe creates an in-memory, full duplex network connection.
+//
+// Unlike net.Pipe, the connection is not synchronous.
+// Writes are made to a buffer, and return immediately.
+// By default, the buffer size is unlimited.
+func fakeNetPipe() (r, w *fakeNetConn) {
+       s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
+       s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+       s1 := newSynctestNetConnHalf(s1addr)
+       s2 := newSynctestNetConnHalf(s2addr)
+       return &fakeNetConn{loc: s1, rem: s2},
+               &fakeNetConn{loc: s2, rem: s1}
+}
+
+// A fakeNetConn is one endpoint of the connection created by fakeNetPipe.
+type fakeNetConn struct {
+       // local and remote connection halves.
+       // Each half contains a buffer.
+       // Reads pull from the local buffer, and writes push to the remote buffer.
+       loc, rem *fakeNetConnHalf
+
+       // When set, synctest.Wait is automatically called before reads and after writes.
+       autoWait bool
+}
+
+// Read reads data from the connection.
+func (c *fakeNetConn) Read(b []byte) (n int, err error) {
+       if c.autoWait {
+               synctest.Wait()
+       }
+       return c.loc.read(b)
+}
+
+// Peek returns the available unread read buffer,
+// without consuming its contents.
+func (c *fakeNetConn) Peek() []byte {
+       if c.autoWait {
+               synctest.Wait()
+       }
+       return c.loc.peek()
+}
+
+// Write writes data to the connection.
+func (c *fakeNetConn) Write(b []byte) (n int, err error) {
+       if c.autoWait {
+               defer synctest.Wait()
+       }
+       return c.rem.write(b)
+}
+
+// IsClosed reports whether the peer has closed its end of the connection.
+func (c *fakeNetConn) IsClosedByPeer() bool {
+       if c.autoWait {
+               synctest.Wait()
+       }
+       c.rem.lock()
+       defer c.rem.unlock()
+       // If the remote half of the conn is returning ErrClosed,
+       // the peer has closed the connection.
+       return c.rem.readErr == net.ErrClosed
+}
+
+// Close closes the connection.
+func (c *fakeNetConn) Close() error {
+       // Local half of the conn is now closed.
+       c.loc.lock()
+       c.loc.writeErr = net.ErrClosed
+       c.loc.readErr = net.ErrClosed
+       c.loc.buf.Reset()
+       c.loc.unlock()
+       // Remote half of the connection reads EOF after reading any remaining data.
+       c.rem.lock()
+       if c.rem.readErr != nil {
+               c.rem.readErr = io.EOF
+       }
+       c.rem.unlock()
+       if c.autoWait {
+               synctest.Wait()
+       }
+       return nil
+}
+
+// LocalAddr returns the (fake) local network address.
+func (c *fakeNetConn) LocalAddr() net.Addr {
+       return c.loc.addr
+}
+
+// LocalAddr returns the (fake) remote network address.
+func (c *fakeNetConn) RemoteAddr() net.Addr {
+       return c.rem.addr
+}
+
+// SetDeadline sets the read and write deadlines for the connection.
+func (c *fakeNetConn) SetDeadline(t time.Time) error {
+       c.SetReadDeadline(t)
+       c.SetWriteDeadline(t)
+       return nil
+}
+
+// SetReadDeadline sets the read deadline for the connection.
+func (c *fakeNetConn) SetReadDeadline(t time.Time) error {
+       c.loc.rctx.setDeadline(t)
+       return nil
+}
+
+// SetWriteDeadline sets the write deadline for the connection.
+func (c *fakeNetConn) SetWriteDeadline(t time.Time) error {
+       c.rem.wctx.setDeadline(t)
+       return nil
+}
+
+// SetReadBufferSize sets the read buffer limit for the connection.
+// Writes by the peer will block so long as the buffer is full.
+func (c *fakeNetConn) SetReadBufferSize(size int) {
+       c.loc.setReadBufferSize(size)
+}
+
+// fakeNetConnHalf is one data flow in the connection created by fakeNetPipe.
+// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it.
+type fakeNetConnHalf struct {
+       addr net.Addr
+
+       // Read and write timeouts.
+       rctx, wctx deadlineContext
+
+       // A half can be readable and/or writable.
+       //
+       // These four channels act as a lock,
+       // and allow waiting for readability/writability.
+       // When the half is unlocked, exactly one channel contains a value.
+       // When the half is locked, all channels are empty.
+       lockr  chan struct{} // readable
+       lockw  chan struct{} // writable
+       lockrw chan struct{} // readable and writable
+       lockc  chan struct{} // neither readable nor writable
+
+       bufMax   int // maximum buffer size
+       buf      bytes.Buffer
+       readErr  error // error returned by reads
+       writeErr error // error returned by writes
+}
+
+func newSynctestNetConnHalf(addr net.Addr) *fakeNetConnHalf {
+       h := &fakeNetConnHalf{
+               addr:   addr,
+               lockw:  make(chan struct{}, 1),
+               lockr:  make(chan struct{}, 1),
+               lockrw: make(chan struct{}, 1),
+               lockc:  make(chan struct{}, 1),
+               bufMax: math.MaxInt, // unlimited
+       }
+       h.unlock()
+       return h
+}
+
+// lock locks h.
+func (h *fakeNetConnHalf) lock() {
+       select {
+       case <-h.lockw: // writable
+       case <-h.lockr: // readable
+       case <-h.lockrw: // readable and writable
+       case <-h.lockc: // neither readable nor writable
+       }
+}
+
+// h unlocks h.
+func (h *fakeNetConnHalf) unlock() {
+       canRead := h.readErr != nil || h.buf.Len() > 0
+       canWrite := h.writeErr != nil || h.bufMax > h.buf.Len()
+       switch {
+       case canRead && canWrite:
+               h.lockrw <- struct{}{} // readable and writable
+       case canRead:
+               h.lockr <- struct{}{} // readable
+       case canWrite:
+               h.lockw <- struct{}{} // writable
+       default:
+               h.lockc <- struct{}{} // neither readable nor writable
+       }
+}
+
+// waitAndLockForRead waits until h is readable and locks it.
+func (h *fakeNetConnHalf) waitAndLockForRead() error {
+       // First a non-blocking select to see if we can make immediate progress.
+       // This permits using a canceled context for a non-blocking operation.
+       select {
+       case <-h.lockr:
+               return nil // readable
+       case <-h.lockrw:
+               return nil // readable and writable
+       default:
+       }
+       ctx := h.rctx.context()
+       select {
+       case <-h.lockr:
+               return nil // readable
+       case <-h.lockrw:
+               return nil // readable and writable
+       case <-ctx.Done():
+               return context.Cause(ctx)
+       }
+}
+
+// waitAndLockForWrite waits until h is writable and locks it.
+func (h *fakeNetConnHalf) waitAndLockForWrite() error {
+       // First a non-blocking select to see if we can make immediate progress.
+       // This permits using a canceled context for a non-blocking operation.
+       select {
+       case <-h.lockw:
+               return nil // writable
+       case <-h.lockrw:
+               return nil // readable and writable
+       default:
+       }
+       ctx := h.wctx.context()
+       select {
+       case <-h.lockw:
+               return nil // writable
+       case <-h.lockrw:
+               return nil // readable and writable
+       case <-ctx.Done():
+               return context.Cause(ctx)
+       }
+}
+
+func (h *fakeNetConnHalf) peek() []byte {
+       h.lock()
+       defer h.unlock()
+       return h.buf.Bytes()
+}
+
+func (h *fakeNetConnHalf) read(b []byte) (n int, err error) {
+       if err := h.waitAndLockForRead(); err != nil {
+               return 0, err
+       }
+       defer h.unlock()
+       if h.buf.Len() == 0 && h.readErr != nil {
+               return 0, h.readErr
+       }
+       return h.buf.Read(b)
+}
+
+func (h *fakeNetConnHalf) setReadBufferSize(size int) {
+       h.lock()
+       defer h.unlock()
+       h.bufMax = size
+}
+
+func (h *fakeNetConnHalf) write(b []byte) (n int, err error) {
+       for n < len(b) {
+               nn, err := h.writePartial(b[n:])
+               n += nn
+               if err != nil {
+                       return n, err
+               }
+       }
+       return n, nil
+}
+
+func (h *fakeNetConnHalf) writePartial(b []byte) (n int, err error) {
+       if err := h.waitAndLockForWrite(); err != nil {
+               return 0, err
+       }
+       defer h.unlock()
+       if h.writeErr != nil {
+               return 0, h.writeErr
+       }
+       writeMax := h.bufMax - h.buf.Len()
+       if writeMax < len(b) {
+               b = b[:writeMax]
+       }
+       return h.buf.Write(b)
+}
+
+// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context.
+type deadlineContext struct {
+       mu     sync.Mutex
+       ctx    context.Context
+       cancel context.CancelCauseFunc
+       timer  *time.Timer
+}
+
+// context returns a Context which expires when the deadline does.
+func (t *deadlineContext) context() context.Context {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       if t.ctx == nil {
+               t.ctx, t.cancel = context.WithCancelCause(context.Background())
+       }
+       return t.ctx
+}
+
+// setDeadline sets the current deadline.
+func (t *deadlineContext) setDeadline(deadline time.Time) {
+       t.mu.Lock()
+       defer t.mu.Unlock()
+       // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled
+       // and we should create a new one.
+       if t.ctx == nil || t.cancel == nil {
+               t.ctx, t.cancel = context.WithCancelCause(context.Background())
+       }
+       // Stop any existing deadline from expiring.
+       if t.timer != nil {
+               t.timer.Stop()
+       }
+       if deadline.IsZero() {
+               // No deadline.
+               return
+       }
+       now := time.Now()
+       if !deadline.After(now) {
+               // Deadline has already expired.
+               t.cancel(os.ErrDeadlineExceeded)
+               t.cancel = nil
+               return
+       }
+       if t.timer != nil {
+               // Reuse existing deadline timer.
+               t.timer.Reset(deadline.Sub(now))
+               return
+       }
+       // Create a new timer to cancel the context at the deadline.
+       t.timer = time.AfterFunc(deadline.Sub(now), func() {
+               t.mu.Lock()
+               defer t.mu.Unlock()
+               t.cancel(os.ErrDeadlineExceeded)
+               t.cancel = nil
+       })
+}
index 4d71eb0498f5fdbb191a86e23a77cb63d06c5ff0..0c46b1ecc3e0b186e0a31685e37d2ddc6b382ce4 100644 (file)
@@ -16,6 +16,7 @@ import (
        "encoding/json"
        "errors"
        "fmt"
+       "internal/synctest"
        "internal/testenv"
        "io"
        "log"
@@ -5805,70 +5806,59 @@ func testServerShutdown(t *testing.T, mode testMode) {
        }
 }
 
-func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
-func testServerShutdownStateNew(t *testing.T, mode testMode) {
+func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
+func testServerShutdownStateNew(t testing.TB, mode testMode) {
        if testing.Short() {
                t.Skip("test takes 5-6 seconds; skipping in short mode")
        }
 
-       var connAccepted sync.WaitGroup
+       listener := fakeNetListen()
+       defer listener.Close()
+
        ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                // nothing.
        }), func(ts *httptest.Server) {
-               ts.Config.ConnState = func(conn net.Conn, state ConnState) {
-                       if state == StateNew {
-                               connAccepted.Done()
-                       }
-               }
+               ts.Listener.Close()
+               ts.Listener = listener
+               // Ignore irrelevant error about TLS handshake failure.
+               ts.Config.ErrorLog = log.New(io.Discard, "", 0)
        }).ts
 
        // Start a connection but never write to it.
-       connAccepted.Add(1)
-       c, err := net.Dial("tcp", ts.Listener.Addr().String())
-       if err != nil {
-               t.Fatal(err)
-       }
+       c := listener.connect()
        defer c.Close()
+       synctest.Wait()
 
-       // Wait for the connection to be accepted by the server. Otherwise, if
-       // Shutdown happens to run first, the server will be closed when
-       // encountering the connection, in which case it will be rejected
-       // immediately.
-       connAccepted.Wait()
-
-       shutdownRes := make(chan error, 1)
-       go func() {
-               shutdownRes <- ts.Config.Shutdown(context.Background())
-       }()
-       readRes := make(chan error, 1)
-       go func() {
-               _, err := c.Read([]byte{0})
-               readRes <- err
-       }()
+       shutdownRes := runAsync(func() (struct{}, error) {
+               return struct{}{}, ts.Config.Shutdown(context.Background())
+       })
 
        // TODO(#59037): This timeout is hard-coded in closeIdleConnections.
        // It is undocumented, and some users may find it surprising.
        // Either document it, or switch to a less surprising behavior.
        const expectTimeout = 5 * time.Second
 
-       t0 := time.Now()
-       select {
-       case got := <-shutdownRes:
-               d := time.Since(t0)
-               if got != nil {
-                       t.Fatalf("shutdown error after %v: %v", d, err)
-               }
-               if d < expectTimeout/2 {
-                       t.Errorf("shutdown too soon after %v", d)
-               }
-       case <-time.After(expectTimeout * 3 / 2):
-               t.Fatalf("timeout waiting for shutdown")
+       // Wait until just before the expected timeout.
+       time.Sleep(expectTimeout - 1)
+       synctest.Wait()
+       if shutdownRes.done() {
+               t.Fatal("shutdown too soon")
+       }
+       if c.IsClosedByPeer() {
+               t.Fatal("connection was closed by server too soon")
        }
 
-       // Wait for c.Read to unblock; should be already done at this point,
-       // or within a few milliseconds.
-       if err := <-readRes; err == nil {
-               t.Error("expected error from Read")
+       // closeIdleConnections isn't precise about its actual shutdown time.
+       // Wait long enough for it to definitely have shut down.
+       //
+       // (It would be good to make closeIdleConnections less sloppy.)
+       time.Sleep(2 * time.Second)
+       synctest.Wait()
+       if _, err := shutdownRes.result(); err != nil {
+               t.Fatalf("Shutdown() = %v, want complete", err)
+       }
+       if !c.IsClosedByPeer() {
+               t.Fatalf("connection was not closed by server after shutdown")
        }
 }