]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: test for racing idle conn closure and new requests
authorDamien Neil <dneil@google.com>
Mon, 25 Nov 2024 19:27:33 +0000 (11:27 -0800)
committerGopher Robot <gobot@golang.org>
Tue, 26 Nov 2024 18:05:09 +0000 (18:05 +0000)
TestTransportRemovesH2ConnsAfterIdle is experiencing flaky
failures due to a bug in idle connection handling.
Upon inspection, TestTransportRemovesH2ConnsAfterIdle
is slow and (I think) not currently testing the condition
that it was added to test.

Using the new synctest package, this CL:

- Adds a test for the failure causing flakes in this test.
- Rewrites the existing test to use synctest to avoid sleeps.
- Adds a new test that covers the condition the test was
  intended to examine.

The new TestTransportIdleConnRacesRequest exercises the
scenario where a never-used connection is closed by the
idle-conn timer at the same time as a new request attempts
to use it. In this race, the new request should either
successfully use the old connection (superseding the
idle timer) or should use a new connection; it should not
use the closing connection and fail.

TestTransportRemovesConnsAfterIdle verifies that
a connection is reused before the idle timer expires,
and not reused after.

TestTransportRemovesConnsAfterBroken verifies
that a connection is not reused after it encounters
an error. This exercises the bug fixed in CL 196665,
which introduced TestTransportRemovesH2ConnsAfterIdle.

For #70515

Change-Id: Id23026d2903fb15ef9a831b2df71177ea177b096
Reviewed-on: https://go-review.googlesource.com/c/go/+/631795
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Damien Neil <dneil@google.com>

src/net/http/clientserver_test.go
src/net/http/netconn_test.go
src/net/http/transport_test.go

index 08730387578297e7fc3e7cbde90233fcdb93e5ef..32d97ea9f083ce3094d2edf52fc4d8d613001b18 100644 (file)
@@ -40,9 +40,10 @@ import (
 type testMode string
 
 const (
-       http1Mode  = testMode("h1")     // HTTP/1.1
-       https1Mode = testMode("https1") // HTTPS/1.1
-       http2Mode  = testMode("h2")     // HTTP/2
+       http1Mode            = testMode("h1")            // HTTP/1.1
+       https1Mode           = testMode("https1")        // HTTPS/1.1
+       http2Mode            = testMode("h2")            // HTTP/2
+       http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2
 )
 
 type testNotParallelOpt struct{}
@@ -132,6 +133,7 @@ type clientServerTest struct {
        ts *httptest.Server
        tr *Transport
        c  *Client
+       li *fakeNetListener
 }
 
 func (t *clientServerTest) close() {
@@ -169,6 +171,8 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
        }
 }
 
+var optFakeNet = new(struct{})
+
 // newClientServerTest creates and starts an httptest.Server.
 //
 // The mode parameter selects the implementation to test:
@@ -180,6 +184,9 @@ func optWithServerLog(lg *log.Logger) func(*httptest.Server) {
 //
 //     func(*httptest.Server) // run before starting the server
 //     func(*http.Transport)
+//
+// The optFakeNet option configures the server and client to use a fake network implementation,
+// suitable for use in testing/synctest tests.
 func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *clientServerTest {
        if mode == http2Mode {
                CondSkipHTTP2(t)
@@ -189,9 +196,31 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c
                h2: mode == http2Mode,
                h:  h,
        }
-       cst.ts = httptest.NewUnstartedServer(h)
 
        var transportFuncs []func(*Transport)
+
+       if idx := slices.Index(opts, any(optFakeNet)); idx >= 0 {
+               opts = slices.Delete(opts, idx, idx+1)
+               cst.li = fakeNetListen()
+               cst.ts = &httptest.Server{
+                       Config:   &Server{Handler: h},
+                       Listener: cst.li,
+               }
+               transportFuncs = append(transportFuncs, func(tr *Transport) {
+                       tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
+                               return cst.li.connect(), nil
+                       }
+               })
+       } else {
+               cst.ts = httptest.NewUnstartedServer(h)
+       }
+
+       if mode == http2UnencryptedMode {
+               p := &Protocols{}
+               p.SetUnencryptedHTTP2(true)
+               cst.ts.Config.Protocols = p
+       }
+
        for _, opt := range opts {
                switch opt := opt.(type) {
                case func(*Transport):
@@ -212,6 +241,9 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c
                cst.ts.Start()
        case https1Mode:
                cst.ts.StartTLS()
+       case http2UnencryptedMode:
+               ExportHttp2ConfigureServer(cst.ts.Config, nil)
+               cst.ts.Start()
        case http2Mode:
                ExportHttp2ConfigureServer(cst.ts.Config, nil)
                cst.ts.TLS = cst.ts.Config.TLSConfig
@@ -221,7 +253,7 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c
        }
        cst.c = cst.ts.Client()
        cst.tr = cst.c.Transport.(*Transport)
-       if mode == http2Mode {
+       if mode == http2Mode || mode == http2UnencryptedMode {
                if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
                        t.Fatal(err)
                }
@@ -229,6 +261,13 @@ func newClientServerTest(t testing.TB, mode testMode, h Handler, opts ...any) *c
        for _, f := range transportFuncs {
                f(cst.tr)
        }
+
+       if mode == http2UnencryptedMode {
+               p := &Protocols{}
+               p.SetUnencryptedHTTP2(true)
+               cst.tr.Protocols = p
+       }
+
        t.Cleanup(func() {
                cst.close()
        })
@@ -246,9 +285,19 @@ func (w testLogWriter) Write(b []byte) (int, error) {
 
 // Testing the newClientServerTest helper itself.
 func TestNewClientServerTest(t *testing.T) {
-       run(t, testNewClientServerTest, []testMode{http1Mode, https1Mode, http2Mode})
+       modes := []testMode{http1Mode, https1Mode, http2Mode}
+       t.Run("realnet", func(t *testing.T) {
+               run(t, func(t *testing.T, mode testMode) {
+                       testNewClientServerTest(t, mode)
+               }, modes)
+       })
+       t.Run("synctest", func(t *testing.T) {
+               runSynctest(t, func(t testing.TB, mode testMode) {
+                       testNewClientServerTest(t, mode, optFakeNet)
+               }, modes)
+       })
 }
-func testNewClientServerTest(t *testing.T, mode testMode) {
+func testNewClientServerTest(t testing.TB, mode testMode, opts ...any) {
        var got struct {
                sync.Mutex
                proto  string
@@ -260,7 +309,7 @@ func testNewClientServerTest(t *testing.T, mode testMode) {
                got.proto = r.Proto
                got.hasTLS = r.TLS != nil
        })
-       cst := newClientServerTest(t, mode, h)
+       cst := newClientServerTest(t, mode, h, opts...)
        if _, err := cst.c.Head(cst.ts.URL); err != nil {
                t.Fatal(err)
        }
index 251b919f6724fbe1cbd77b98a200048088b8e13e..ed02b98d43ee9ddfc8188eae5f2074ebc9f6f242 100644 (file)
@@ -19,9 +19,10 @@ import (
 
 func fakeNetListen() *fakeNetListener {
        li := &fakeNetListener{
-               setc:   make(chan struct{}, 1),
-               unsetc: make(chan struct{}, 1),
-               addr:   net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000")),
+               setc:    make(chan struct{}, 1),
+               unsetc:  make(chan struct{}, 1),
+               addr:    netip.MustParseAddrPort("127.0.0.1:8000"),
+               locPort: 10000,
        }
        li.unsetc <- struct{}{}
        return li
@@ -31,7 +32,13 @@ type fakeNetListener struct {
        setc, unsetc chan struct{}
        queue        []net.Conn
        closed       bool
-       addr         net.Addr
+       addr         netip.AddrPort
+       locPort      uint16
+
+       onDial func() // called when making a new connection
+
+       trackConns bool // set this to record all created conns
+       conns      []*fakeNetConn
 }
 
 func (li *fakeNetListener) lock() {
@@ -50,10 +57,18 @@ func (li *fakeNetListener) unlock() {
 }
 
 func (li *fakeNetListener) connect() *fakeNetConn {
+       if li.onDial != nil {
+               li.onDial()
+       }
        li.lock()
        defer li.unlock()
-       c0, c1 := fakeNetPipe()
+       locAddr := netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), li.locPort)
+       li.locPort++
+       c0, c1 := fakeNetPipe(li.addr, locAddr)
        li.queue = append(li.queue, c0)
+       if li.trackConns {
+               li.conns = append(li.conns, c0)
+       }
        return c1
 }
 
@@ -76,7 +91,7 @@ func (li *fakeNetListener) Close() error {
 }
 
 func (li *fakeNetListener) Addr() net.Addr {
-       return li.addr
+       return net.TCPAddrFromAddrPort(li.addr)
 }
 
 // fakeNetPipe creates an in-memory, full duplex network connection.
@@ -84,13 +99,16 @@ func (li *fakeNetListener) Addr() net.Addr {
 // Unlike net.Pipe, the connection is not synchronous.
 // Writes are made to a buffer, and return immediately.
 // By default, the buffer size is unlimited.
-func fakeNetPipe() (r, w *fakeNetConn) {
-       s1addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8000"))
-       s2addr := net.TCPAddrFromAddrPort(netip.MustParseAddrPort("127.0.0.1:8001"))
+func fakeNetPipe(s1ap, s2ap netip.AddrPort) (r, w *fakeNetConn) {
+       s1addr := net.TCPAddrFromAddrPort(s1ap)
+       s2addr := net.TCPAddrFromAddrPort(s2ap)
        s1 := newSynctestNetConnHalf(s1addr)
        s2 := newSynctestNetConnHalf(s2addr)
-       return &fakeNetConn{loc: s1, rem: s2},
-               &fakeNetConn{loc: s2, rem: s1}
+       c1 := &fakeNetConn{loc: s1, rem: s2}
+       c2 := &fakeNetConn{loc: s2, rem: s1}
+       c1.peer = c2
+       c2.peer = c1
+       return c1, c2
 }
 
 // A fakeNetConn is one endpoint of the connection created by fakeNetPipe.
@@ -102,6 +120,11 @@ type fakeNetConn struct {
 
        // When set, synctest.Wait is automatically called before reads and after writes.
        autoWait bool
+
+       // peer is the other endpoint.
+       peer *fakeNetConn
+
+       onClose func() // called when closing
 }
 
 // Read reads data from the connection.
@@ -143,6 +166,9 @@ func (c *fakeNetConn) IsClosedByPeer() bool {
 
 // Close closes the connection.
 func (c *fakeNetConn) Close() error {
+       if c.onClose != nil {
+               c.onClose()
+       }
        // Local half of the conn is now closed.
        c.loc.lock()
        c.loc.writeErr = net.ErrClosed
index d742b78cf81b541a03b61e0bb534a8b4098a526e..2963255b874378ca031b27ba9ac37deecdb5a1f9 100644 (file)
@@ -22,6 +22,7 @@ import (
        "fmt"
        "go/token"
        "internal/nettrace"
+       "internal/synctest"
        "io"
        "log"
        mrand "math/rand"
@@ -4219,53 +4220,175 @@ func TestTransportTraceGotConnH2IdleConns(t *testing.T) {
        wantIdle("after round trip", 1)
 }
 
-func TestTransportRemovesH2ConnsAfterIdle(t *testing.T) {
-       run(t, testTransportRemovesH2ConnsAfterIdle, []testMode{http2Mode})
+// https://go.dev/issue/70515
+//
+// When the first request on a new connection fails, we do not retry the request.
+// If the first request on a connection races with IdleConnTimeout,
+// we should not fail the request.
+func TestTransportIdleConnRacesRequest(t *testing.T) {
+       // Use unencrypted HTTP/2, since the *tls.Conn interfers with our ability to
+       // block the connection closing.
+       runSynctest(t, testTransportIdleConnRacesRequest, []testMode{http1Mode, http2UnencryptedMode})
+}
+func testTransportIdleConnRacesRequest(t testing.TB, mode testMode) {
+       if mode == http2UnencryptedMode {
+               t.Skip("remove skip when #70515 is fixed")
+       }
+       timeout := 1 * time.Millisecond
+       trFunc := func(tr *Transport) {
+               tr.IdleConnTimeout = timeout
+       }
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+       }), trFunc, optFakeNet)
+       cst.li.trackConns = true
+
+       // We want to put a connection into the pool which has never had a request made on it.
+       //
+       // Make a request and cancel it before the dial completes.
+       // Then complete the dial.
+       dialc := make(chan struct{})
+       cst.li.onDial = func() {
+               <-dialc
+       }
+       ctx, cancel := context.WithCancel(context.Background())
+       req1c := make(chan error)
+       go func() {
+               req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+               resp, err := cst.c.Do(req)
+               if err == nil {
+                       resp.Body.Close()
+               }
+               req1c <- err
+       }()
+       // Wait for the connection attempt to start.
+       synctest.Wait()
+       // Cancel the request.
+       cancel()
+       synctest.Wait()
+       if err := <-req1c; err == nil {
+               t.Fatal("expected request to fail, but it succeeded")
+       }
+       // Unblock the dial, placing a new, unused connection into the Transport's pool.
+       close(dialc)
+
+       // We want IdleConnTimeout to race with a new request.
+       //
+       // There's no perfect way to do this, but the following exercises the bug in #70515:
+       // Block net.Conn.Close, wait until IdleConnTimeout occurs, and make a request while
+       // the connection close is still blocked.
+       //
+       // First: Wait for IdleConnTimeout. The net.Conn.Close blocks.
+       synctest.Wait()
+       closec := make(chan struct{})
+       cst.li.conns[0].peer.onClose = func() {
+               <-closec
+       }
+       time.Sleep(timeout)
+       synctest.Wait()
+       // Make a request, which will use a new connection (since the existing one is closing).
+       req2c := make(chan error)
+       go func() {
+               resp, err := cst.c.Get(cst.ts.URL)
+               if err == nil {
+                       resp.Body.Close()
+               }
+               req2c <- err
+       }()
+       // Don't synctest.Wait here: The HTTP/1 transport closes the idle conn
+       // with a mutex held, and we'll end up in a deadlock.
+       close(closec)
+       if err := <-req2c; err != nil {
+               t.Fatalf("Get: %v", err)
+       }
+}
+
+func TestTransportRemovesConnsAfterIdle(t *testing.T) {
+       runSynctest(t, testTransportRemovesConnsAfterIdle)
 }
-func testTransportRemovesH2ConnsAfterIdle(t *testing.T, mode testMode) {
+func testTransportRemovesConnsAfterIdle(t testing.TB, mode testMode) {
        if testing.Short() {
                t.Skip("skipping in short mode")
        }
 
-       timeout := 1 * time.Millisecond
-       retry := true
-       for retry {
-               trFunc := func(tr *Transport) {
-                       tr.MaxConnsPerHost = 1
-                       tr.MaxIdleConnsPerHost = 1
-                       tr.IdleConnTimeout = timeout
-               }
-               cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {}), trFunc)
-
-               retry = false
-               tooShort := func(err error) bool {
-                       if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
-                               return false
-                       }
-                       if !retry {
-                               t.Helper()
-                               t.Logf("idle conn timeout %v may be too short; retrying with longer", timeout)
-                               timeout *= 2
-                               retry = true
-                               cst.close()
-                       }
-                       return true
-               }
+       timeout := 1 * time.Second
+       trFunc := func(tr *Transport) {
+               tr.MaxConnsPerHost = 1
+               tr.MaxIdleConnsPerHost = 1
+               tr.IdleConnTimeout = timeout
+       }
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("X-Addr", r.RemoteAddr)
+       }), trFunc, optFakeNet)
 
-               if _, err := cst.c.Get(cst.ts.URL); err != nil {
-                       if tooShort(err) {
-                               continue
-                       }
+       // makeRequest returns the local address a request was made from
+       // (unique for each connection).
+       makeRequest := func() string {
+               resp, err := cst.c.Get(cst.ts.URL)
+               if err != nil {
                        t.Fatalf("got error: %s", err)
                }
+               resp.Body.Close()
+               return resp.Header.Get("X-Addr")
+       }
 
-               time.Sleep(10 * timeout)
-               if _, err := cst.c.Get(cst.ts.URL); err != nil {
-                       if tooShort(err) {
-                               continue
-                       }
+       addr1 := makeRequest()
+
+       time.Sleep(timeout / 2)
+       synctest.Wait()
+       addr2 := makeRequest()
+       if addr1 != addr2 {
+               t.Fatalf("two requests made within IdleConnTimeout should have used the same conn, but used %v, %v", addr1, addr2)
+       }
+
+       time.Sleep(timeout)
+       synctest.Wait()
+       addr3 := makeRequest()
+       if addr1 == addr3 {
+               t.Fatalf("two requests made more than IdleConnTimeout apart should have used different conns, but used %v, %v", addr1, addr3)
+       }
+}
+
+func TestTransportRemovesConnsAfterBroken(t *testing.T) {
+       runSynctest(t, testTransportRemovesConnsAfterBroken)
+}
+func testTransportRemovesConnsAfterBroken(t testing.TB, mode testMode) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+
+       trFunc := func(tr *Transport) {
+               tr.MaxConnsPerHost = 1
+               tr.MaxIdleConnsPerHost = 1
+       }
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Set("X-Addr", r.RemoteAddr)
+       }), trFunc, optFakeNet)
+       cst.li.trackConns = true
+
+       // makeRequest returns the local address a request was made from
+       // (unique for each connection).
+       makeRequest := func() string {
+               resp, err := cst.c.Get(cst.ts.URL)
+               if err != nil {
                        t.Fatalf("got error: %s", err)
                }
+               resp.Body.Close()
+               return resp.Header.Get("X-Addr")
+       }
+
+       addr1 := makeRequest()
+       addr2 := makeRequest()
+       if addr1 != addr2 {
+               t.Fatalf("successive requests should have used the same conn, but used %v, %v", addr1, addr2)
+       }
+
+       // The connection breaks.
+       synctest.Wait()
+       cst.li.conns[0].peer.Close()
+       synctest.Wait()
+       addr3 := makeRequest()
+       if addr1 == addr3 {
+               t.Fatalf("successive requests made with conn broken between should have used different conns, but used %v, %v", addr1, addr3)
        }
 }