--- /dev/null
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+ "bytes"
+ "context"
+ "internal/synctest"
+ "io"
+ "math"
+ "net"
+ "net/netip"
+ "os"
+ "sync"
+ "time"
+)
+
+func fakeNetListen() *fakeNetListener {
+ li := &fakeNetListener{
+ setc: make(chan struct{}, 1),
+ unsetc: make(chan struct{}, 1),
+ addr: net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")),
+ }
+ li.unsetc <- struct{}{}
+ return li
+}
+
+type fakeNetListener struct {
+ setc, unsetc chan struct{}
+ queue []net.Conn
+ closed bool
+ addr net.Addr
+}
+
+func (li *fakeNetListener) lock() {
+ select {
+ case <-li.setc:
+ case <-li.unsetc:
+ }
+}
+
+func (li *fakeNetListener) unlock() {
+ if li.closed || len(li.queue) > 0 {
+ li.setc <- struct{}{}
+ } else {
+ li.unsetc <- struct{}{}
+ }
+}
+
+func (li *fakeNetListener) connect() *fakeNetConn {
+ li.lock()
+ defer li.unlock()
+ c0, c1 := fakeNetPipe()
+ li.queue = append(li.queue, c0)
+ return c1
+}
+
+func (li *fakeNetListener) Accept() (net.Conn, error) {
+ <-li.setc
+ defer li.unlock()
+ if li.closed {
+ return nil, net.ErrClosed
+ }
+ c := li.queue[0]
+ li.queue = li.queue[1:]
+ return c, nil
+}
+
+func (li *fakeNetListener) Close() error {
+ li.lock()
+ defer li.unlock()
+ li.closed = true
+ return nil
+}
+
+func (li *fakeNetListener) Addr() net.Addr {
+ return li.addr
+}
+
+// fakeNetPipe creates an in-memory, full duplex network connection.
+//
+// Unlike net.Pipe, the connection is not synchronous.
+// Writes are made to a buffer, and return immediately.
+// By default, the buffer size is unlimited.
+func fakeNetPipe() (r, w *fakeNetConn) {
+ s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
+ s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+ s1 := newSynctestNetConnHalf(s1addr)
+ s2 := newSynctestNetConnHalf(s2addr)
+ return &fakeNetConn{loc: s1, rem: s2},
+ &fakeNetConn{loc: s2, rem: s1}
+}
+
+// A fakeNetConn is one endpoint of the connection created by fakeNetPipe.
+type fakeNetConn struct {
+ // local and remote connection halves.
+ // Each half contains a buffer.
+ // Reads pull from the local buffer, and writes push to the remote buffer.
+ loc, rem *fakeNetConnHalf
+
+ // When set, synctest.Wait is automatically called before reads and after writes.
+ autoWait bool
+}
+
+// Read reads data from the connection.
+func (c *fakeNetConn) Read(b []byte) (n int, err error) {
+ if c.autoWait {
+ synctest.Wait()
+ }
+ return c.loc.read(b)
+}
+
+// Peek returns the available unread read buffer,
+// without consuming its contents.
+func (c *fakeNetConn) Peek() []byte {
+ if c.autoWait {
+ synctest.Wait()
+ }
+ return c.loc.peek()
+}
+
+// Write writes data to the connection.
+func (c *fakeNetConn) Write(b []byte) (n int, err error) {
+ if c.autoWait {
+ defer synctest.Wait()
+ }
+ return c.rem.write(b)
+}
+
+// IsClosed reports whether the peer has closed its end of the connection.
+func (c *fakeNetConn) IsClosedByPeer() bool {
+ if c.autoWait {
+ synctest.Wait()
+ }
+ c.rem.lock()
+ defer c.rem.unlock()
+ // If the remote half of the conn is returning ErrClosed,
+ // the peer has closed the connection.
+ return c.rem.readErr == net.ErrClosed
+}
+
+// Close closes the connection.
+func (c *fakeNetConn) Close() error {
+ // Local half of the conn is now closed.
+ c.loc.lock()
+ c.loc.writeErr = net.ErrClosed
+ c.loc.readErr = net.ErrClosed
+ c.loc.buf.Reset()
+ c.loc.unlock()
+ // Remote half of the connection reads EOF after reading any remaining data.
+ c.rem.lock()
+ if c.rem.readErr != nil {
+ c.rem.readErr = io.EOF
+ }
+ c.rem.unlock()
+ if c.autoWait {
+ synctest.Wait()
+ }
+ return nil
+}
+
+// LocalAddr returns the (fake) local network address.
+func (c *fakeNetConn) LocalAddr() net.Addr {
+ return c.loc.addr
+}
+
+// LocalAddr returns the (fake) remote network address.
+func (c *fakeNetConn) RemoteAddr() net.Addr {
+ return c.rem.addr
+}
+
+// SetDeadline sets the read and write deadlines for the connection.
+func (c *fakeNetConn) SetDeadline(t time.Time) error {
+ c.SetReadDeadline(t)
+ c.SetWriteDeadline(t)
+ return nil
+}
+
+// SetReadDeadline sets the read deadline for the connection.
+func (c *fakeNetConn) SetReadDeadline(t time.Time) error {
+ c.loc.rctx.setDeadline(t)
+ return nil
+}
+
+// SetWriteDeadline sets the write deadline for the connection.
+func (c *fakeNetConn) SetWriteDeadline(t time.Time) error {
+ c.rem.wctx.setDeadline(t)
+ return nil
+}
+
+// SetReadBufferSize sets the read buffer limit for the connection.
+// Writes by the peer will block so long as the buffer is full.
+func (c *fakeNetConn) SetReadBufferSize(size int) {
+ c.loc.setReadBufferSize(size)
+}
+
+// fakeNetConnHalf is one data flow in the connection created by fakeNetPipe.
+// Each half contains a buffer. Writes to the half push to the buffer, and reads pull from it.
+type fakeNetConnHalf struct {
+ addr net.Addr
+
+ // Read and write timeouts.
+ rctx, wctx deadlineContext
+
+ // A half can be readable and/or writable.
+ //
+ // These four channels act as a lock,
+ // and allow waiting for readability/writability.
+ // When the half is unlocked, exactly one channel contains a value.
+ // When the half is locked, all channels are empty.
+ lockr chan struct{} // readable
+ lockw chan struct{} // writable
+ lockrw chan struct{} // readable and writable
+ lockc chan struct{} // neither readable nor writable
+
+ bufMax int // maximum buffer size
+ buf bytes.Buffer
+ readErr error // error returned by reads
+ writeErr error // error returned by writes
+}
+
+func newSynctestNetConnHalf(addr net.Addr) *fakeNetConnHalf {
+ h := &fakeNetConnHalf{
+ addr: addr,
+ lockw: make(chan struct{}, 1),
+ lockr: make(chan struct{}, 1),
+ lockrw: make(chan struct{}, 1),
+ lockc: make(chan struct{}, 1),
+ bufMax: math.MaxInt, // unlimited
+ }
+ h.unlock()
+ return h
+}
+
+// lock locks h.
+func (h *fakeNetConnHalf) lock() {
+ select {
+ case <-h.lockw: // writable
+ case <-h.lockr: // readable
+ case <-h.lockrw: // readable and writable
+ case <-h.lockc: // neither readable nor writable
+ }
+}
+
+// h unlocks h.
+func (h *fakeNetConnHalf) unlock() {
+ canRead := h.readErr != nil || h.buf.Len() > 0
+ canWrite := h.writeErr != nil || h.bufMax > h.buf.Len()
+ switch {
+ case canRead && canWrite:
+ h.lockrw <- struct{}{} // readable and writable
+ case canRead:
+ h.lockr <- struct{}{} // readable
+ case canWrite:
+ h.lockw <- struct{}{} // writable
+ default:
+ h.lockc <- struct{}{} // neither readable nor writable
+ }
+}
+
+// waitAndLockForRead waits until h is readable and locks it.
+func (h *fakeNetConnHalf) waitAndLockForRead() error {
+ // First a non-blocking select to see if we can make immediate progress.
+ // This permits using a canceled context for a non-blocking operation.
+ select {
+ case <-h.lockr:
+ return nil // readable
+ case <-h.lockrw:
+ return nil // readable and writable
+ default:
+ }
+ ctx := h.rctx.context()
+ select {
+ case <-h.lockr:
+ return nil // readable
+ case <-h.lockrw:
+ return nil // readable and writable
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+// waitAndLockForWrite waits until h is writable and locks it.
+func (h *fakeNetConnHalf) waitAndLockForWrite() error {
+ // First a non-blocking select to see if we can make immediate progress.
+ // This permits using a canceled context for a non-blocking operation.
+ select {
+ case <-h.lockw:
+ return nil // writable
+ case <-h.lockrw:
+ return nil // readable and writable
+ default:
+ }
+ ctx := h.wctx.context()
+ select {
+ case <-h.lockw:
+ return nil // writable
+ case <-h.lockrw:
+ return nil // readable and writable
+ case <-ctx.Done():
+ return context.Cause(ctx)
+ }
+}
+
+func (h *fakeNetConnHalf) peek() []byte {
+ h.lock()
+ defer h.unlock()
+ return h.buf.Bytes()
+}
+
+func (h *fakeNetConnHalf) read(b []byte) (n int, err error) {
+ if err := h.waitAndLockForRead(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.buf.Len() == 0 && h.readErr != nil {
+ return 0, h.readErr
+ }
+ return h.buf.Read(b)
+}
+
+func (h *fakeNetConnHalf) setReadBufferSize(size int) {
+ h.lock()
+ defer h.unlock()
+ h.bufMax = size
+}
+
+func (h *fakeNetConnHalf) write(b []byte) (n int, err error) {
+ for n < len(b) {
+ nn, err := h.writePartial(b[n:])
+ n += nn
+ if err != nil {
+ return n, err
+ }
+ }
+ return n, nil
+}
+
+func (h *fakeNetConnHalf) writePartial(b []byte) (n int, err error) {
+ if err := h.waitAndLockForWrite(); err != nil {
+ return 0, err
+ }
+ defer h.unlock()
+ if h.writeErr != nil {
+ return 0, h.writeErr
+ }
+ writeMax := h.bufMax - h.buf.Len()
+ if writeMax < len(b) {
+ b = b[:writeMax]
+ }
+ return h.buf.Write(b)
+}
+
+// deadlineContext converts a changable deadline (as in net.Conn.SetDeadline) into a Context.
+type deadlineContext struct {
+ mu sync.Mutex
+ ctx context.Context
+ cancel context.CancelCauseFunc
+ timer *time.Timer
+}
+
+// context returns a Context which expires when the deadline does.
+func (t *deadlineContext) context() context.Context {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ if t.ctx == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ return t.ctx
+}
+
+// setDeadline sets the current deadline.
+func (t *deadlineContext) setDeadline(deadline time.Time) {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ // If t.ctx is non-nil and t.cancel is nil, then t.ctx was canceled
+ // and we should create a new one.
+ if t.ctx == nil || t.cancel == nil {
+ t.ctx, t.cancel = context.WithCancelCause(context.Background())
+ }
+ // Stop any existing deadline from expiring.
+ if t.timer != nil {
+ t.timer.Stop()
+ }
+ if deadline.IsZero() {
+ // No deadline.
+ return
+ }
+ now := time.Now()
+ if !deadline.After(now) {
+ // Deadline has already expired.
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ return
+ }
+ if t.timer != nil {
+ // Reuse existing deadline timer.
+ t.timer.Reset(deadline.Sub(now))
+ return
+ }
+ // Create a new timer to cancel the context at the deadline.
+ t.timer = time.AfterFunc(deadline.Sub(now), func() {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.cancel(os.ErrDeadlineExceeded)
+ t.cancel = nil
+ })
+}
"encoding/json"
"errors"
"fmt"
+ "internal/synctest"
"internal/testenv"
"io"
"log"
}
}
-func TestServerShutdownStateNew(t *testing.T) { run(t, testServerShutdownStateNew) }
-func testServerShutdownStateNew(t *testing.T, mode testMode) {
+func TestServerShutdownStateNew(t *testing.T) { runSynctest(t, testServerShutdownStateNew) }
+func testServerShutdownStateNew(t testing.TB, mode testMode) {
if testing.Short() {
t.Skip("test takes 5-6 seconds; skipping in short mode")
}
- var connAccepted sync.WaitGroup
+ listener := fakeNetListen()
+ defer listener.Close()
+
ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
// nothing.
}), func(ts *httptest.Server) {
- ts.Config.ConnState = func(conn net.Conn, state ConnState) {
- if state == StateNew {
- connAccepted.Done()
- }
- }
+ ts.Listener.Close()
+ ts.Listener = listener
+ // Ignore irrelevant error about TLS handshake failure.
+ ts.Config.ErrorLog = log.New(io.Discard, "", 0)
}).ts
// Start a connection but never write to it.
- connAccepted.Add(1)
- c, err := net.Dial("tcp", ts.Listener.Addr().String())
- if err != nil {
- t.Fatal(err)
- }
+ c := listener.connect()
defer c.Close()
+ synctest.Wait()
- // Wait for the connection to be accepted by the server. Otherwise, if
- // Shutdown happens to run first, the server will be closed when
- // encountering the connection, in which case it will be rejected
- // immediately.
- connAccepted.Wait()
-
- shutdownRes := make(chan error, 1)
- go func() {
- shutdownRes <- ts.Config.Shutdown(context.Background())
- }()
- readRes := make(chan error, 1)
- go func() {
- _, err := c.Read([]byte{0})
- readRes <- err
- }()
+ shutdownRes := runAsync(func() (struct{}, error) {
+ return struct{}{}, ts.Config.Shutdown(context.Background())
+ })
// TODO(#59037): This timeout is hard-coded in closeIdleConnections.
// It is undocumented, and some users may find it surprising.
// Either document it, or switch to a less surprising behavior.
const expectTimeout = 5 * time.Second
- t0 := time.Now()
- select {
- case got := <-shutdownRes:
- d := time.Since(t0)
- if got != nil {
- t.Fatalf("shutdown error after %v: %v", d, err)
- }
- if d < expectTimeout/2 {
- t.Errorf("shutdown too soon after %v", d)
- }
- case <-time.After(expectTimeout * 3 / 2):
- t.Fatalf("timeout waiting for shutdown")
+ // Wait until just before the expected timeout.
+ time.Sleep(expectTimeout - 1)
+ synctest.Wait()
+ if shutdownRes.done() {
+ t.Fatal("shutdown too soon")
+ }
+ if c.IsClosedByPeer() {
+ t.Fatal("connection was closed by server too soon")
}
- // Wait for c.Read to unblock; should be already done at this point,
- // or within a few milliseconds.
- if err := <-readRes; err == nil {
- t.Error("expected error from Read")
+ // closeIdleConnections isn't precise about its actual shutdown time.
+ // Wait long enough for it to definitely have shut down.
+ //
+ // (It would be good to make closeIdleConnections less sloppy.)
+ time.Sleep(2 * time.Second)
+ synctest.Wait()
+ if _, err := shutdownRes.result(); err != nil {
+ t.Fatalf("Shutdown() = %v, want complete", err)
+ }
+ if !c.IsClosedByPeer() {
+ t.Fatalf("connection was not closed by server after shutdown")
}
}