]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: scale rstAvoidanceDelay to reduce test flakiness
authorBryan C. Mills <bcmills@google.com>
Mon, 11 Sep 2023 20:17:03 +0000 (16:17 -0400)
committerGopher Robot <gobot@golang.org>
Wed, 13 Sep 2023 20:57:25 +0000 (20:57 +0000)
As far as I can tell, some flakiness is unavoidable in tests
that race a large client request write against a server's response
when the server doesn't read the full request.
It does not appear to be possible to simultaneously ensure that
well-behaved clients see EOF instead of ECONNRESET and also prevent
misbehaving clients from consuming arbitrary server resources.
(See RFC 7230 §6.6 for more detail.)

Since there doesn't appear to be a way to cleanly eliminate
this source of flakiness, we can instead work around it:
we can allow the test to adjust the hard-coded delay if it
sees a plausibly-related failure, so that the test can retry
with a longer delay.

As a nice side benefit, this also allows the tests to run more quickly
in the typical case: since the test will retry in case of spurious
failures, we can start with an aggressively short delay, and only back
off to a longer one if it is really needed on the specific machine
running the test.

Fixes #57084.
Fixes #51104.
For #58398.

Change-Id: Ia4050679f0777e5eeba7670307a77d93cfce856f
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest-race,gotip-linux-amd64-race,gotip-windows-amd64-race
Reviewed-on: https://go-review.googlesource.com/c/go/+/527196
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>

src/net/http/export_test.go
src/net/http/serve_test.go
src/net/http/server.go
src/net/http/transport_test.go

index 5d198f3f8947d1e33de2d2865e499d1bfdbcb3fd..7e6d3d8e304495aa075adf5b698ad4d0cce6ff7f 100644 (file)
@@ -315,3 +315,21 @@ func ResponseWriterConnForTesting(w ResponseWriter) (c net.Conn, ok bool) {
        }
        return nil, false
 }
+
+func init() {
+       // Set the default rstAvoidanceDelay to the minimum possible value to shake
+       // out tests that unexpectedly depend on it. Such tests should use
+       // runTimeSensitiveTest and SetRSTAvoidanceDelay to explicitly raise the delay
+       // if needed.
+       rstAvoidanceDelay = 1 * time.Nanosecond
+}
+
+// SetRSTAvoidanceDelay sets how long we are willing to wait between calling
+// CloseWrite on a connection and fully closing the connection.
+func SetRSTAvoidanceDelay(t *testing.T, d time.Duration) {
+       prevDelay := rstAvoidanceDelay
+       t.Cleanup(func() {
+               rstAvoidanceDelay = prevDelay
+       })
+       rstAvoidanceDelay = d
+}
index f26a6b3190a442d9396a4f1f74a5c1fa6966dd68..9fe99a37a06b83869c38f64d6622f276351998dd 100644 (file)
@@ -646,19 +646,15 @@ func benchmarkServeMux(b *testing.B, runHandler bool) {
 
 func TestServerTimeouts(t *testing.T) { run(t, testServerTimeouts, []testMode{http1Mode}) }
 func testServerTimeouts(t *testing.T, mode testMode) {
-       // Try three times, with increasing timeouts.
-       tries := []time.Duration{250 * time.Millisecond, 500 * time.Millisecond, 1 * time.Second}
-       for i, timeout := range tries {
-               err := testServerTimeoutsWithTimeout(t, timeout, mode)
-               if err == nil {
-                       return
-               }
-               t.Logf("failed at %v: %v", timeout, err)
-               if i != len(tries)-1 {
-                       t.Logf("retrying at %v ...", tries[i+1])
-               }
-       }
-       t.Fatal("all attempts failed")
+       runTimeSensitiveTest(t, []time.Duration{
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               1 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               return testServerTimeoutsWithTimeout(t, timeout, mode)
+       })
 }
 
 func testServerTimeoutsWithTimeout(t *testing.T, timeout time.Duration, mode testMode) error {
@@ -3101,47 +3097,68 @@ func TestServerBufferedChunking(t *testing.T) {
 // closing the TCP connection, causing the client to get a RST.
 // See https://golang.org/issue/3595
 func TestServerGracefulClose(t *testing.T) {
-       run(t, testServerGracefulClose, []testMode{http1Mode})
+       // Not parallel: modifies the global rstAvoidanceDelay.
+       run(t, testServerGracefulClose, []testMode{http1Mode}, testNotParallel)
 }
 func testServerGracefulClose(t *testing.T, mode testMode) {
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               Error(w, "bye", StatusUnauthorized)
-       })).ts
+       runTimeSensitiveTest(t, []time.Duration{
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               time.Second,
+               5 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               SetRSTAvoidanceDelay(t, timeout)
+               t.Logf("set RST avoidance delay to %v", timeout)
 
-       conn, err := net.Dial("tcp", ts.Listener.Addr().String())
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer conn.Close()
-       const bodySize = 5 << 20
-       req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
-       for i := 0; i < bodySize; i++ {
-               req = append(req, 'x')
-       }
-       writeErr := make(chan error)
-       go func() {
-               _, err := conn.Write(req)
-               writeErr <- err
-       }()
-       br := bufio.NewReader(conn)
-       lineNum := 0
-       for {
-               line, err := br.ReadString('\n')
-               if err == io.EOF {
-                       break
+               const bodySize = 5 << 20
+               req := []byte(fmt.Sprintf("POST / HTTP/1.1\r\nHost: foo.com\r\nContent-Length: %d\r\n\r\n", bodySize))
+               for i := 0; i < bodySize; i++ {
+                       req = append(req, 'x')
                }
+
+               cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+                       Error(w, "bye", StatusUnauthorized)
+               }))
+               // We need to close cst explicitly here so that in-flight server
+               // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+               defer cst.close()
+               ts := cst.ts
+
+               conn, err := net.Dial("tcp", ts.Listener.Addr().String())
                if err != nil {
-                       t.Fatalf("ReadLine: %v", err)
+                       return err
                }
-               lineNum++
-               if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
-                       t.Errorf("Response line = %q; want a 401", line)
+               defer conn.Close()
+               writeErr := make(chan error)
+               go func() {
+                       _, err := conn.Write(req)
+                       writeErr <- err
+               }()
+               br := bufio.NewReader(conn)
+               lineNum := 0
+               for {
+                       line, err := br.ReadString('\n')
+                       if err == io.EOF {
+                               break
+                       }
+                       if err != nil {
+                               return fmt.Errorf("ReadLine: %v", err)
+                       }
+                       lineNum++
+                       if lineNum == 1 && !strings.Contains(line, "401 Unauthorized") {
+                               t.Errorf("Response line = %q; want a 401", line)
+                       }
                }
-       }
-       // Wait for write to finish. This is a broken pipe on both
-       // Darwin and Linux, but checking this isn't the point of
-       // the test.
-       <-writeErr
+               // Wait for write to finish. This is a broken pipe on both
+               // Darwin and Linux, but checking this isn't the point of
+               // the test.
+               <-writeErr
+               return nil
+       })
 }
 
 func TestCaseSensitiveMethod(t *testing.T) { run(t, testCaseSensitiveMethod) }
@@ -3923,91 +3940,78 @@ func TestContentTypeOkayOn204(t *testing.T) {
 // and the http client), and both think they can close it on failure.
 // Therefore, all incoming server requests Bodies need to be thread-safe.
 func TestTransportAndServerSharedBodyRace(t *testing.T) {
-       run(t, testTransportAndServerSharedBodyRace)
+       run(t, testTransportAndServerSharedBodyRace, testNotParallel)
 }
 func testTransportAndServerSharedBodyRace(t *testing.T, mode testMode) {
-       const bodySize = 1 << 20
-
-       // errorf is like t.Errorf, but also writes to println. When
-       // this test fails, it hangs. This helps debugging and I've
-       // added this enough times "temporarily".  It now gets added
-       // full time.
-       errorf := func(format string, args ...any) {
-               v := fmt.Sprintf(format, args...)
-               println(v)
-               t.Error(v)
-       }
-
-       unblockBackend := make(chan bool)
-       backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
-               gone := rw.(CloseNotifier).CloseNotify()
-               didCopy := make(chan any)
-               go func() {
+       // The proxy server in the middle of the stack for this test potentially
+       // from its handler after only reading half of the body.
+       // That can trigger https://go.dev/issue/3595, which is otherwise
+       // irrelevant to this test.
+       runTimeSensitiveTest(t, []time.Duration{
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               time.Second,
+               5 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               SetRSTAvoidanceDelay(t, timeout)
+               t.Logf("set RST avoidance delay to %v", timeout)
+
+               const bodySize = 1 << 20
+
+               backend := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
                        n, err := io.CopyN(rw, req.Body, bodySize)
-                       didCopy <- []any{n, err}
-               }()
-               isGone := false
-       Loop:
-               for {
-                       select {
-                       case <-didCopy:
-                               break Loop
-                       case <-gone:
-                               isGone = true
-                       case <-time.After(time.Second):
-                               println("1 second passes in backend, proxygone=", isGone)
+                       t.Logf("backend CopyN: %v, %v", n, err)
+                       <-req.Context().Done()
+               }))
+               // We need to close explicitly here so that in-flight server
+               // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+               defer backend.close()
+
+               var proxy *clientServerTest
+               proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+                       req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
+                       req2.ContentLength = bodySize
+                       cancel := make(chan struct{})
+                       req2.Cancel = cancel
+
+                       bresp, err := proxy.c.Do(req2)
+                       if err != nil {
+                               t.Errorf("Proxy outbound request: %v", err)
+                               return
                        }
-               }
-               <-unblockBackend
-       }))
-       defer backend.close()
-
-       backendRespc := make(chan *Response, 1)
-       var proxy *clientServerTest
-       proxy = newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
-               req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
-               req2.ContentLength = bodySize
-               cancel := make(chan struct{})
-               req2.Cancel = cancel
+                       _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
+                       if err != nil {
+                               t.Errorf("Proxy copy error: %v", err)
+                               return
+                       }
+                       t.Cleanup(func() { bresp.Body.Close() })
+
+                       // Try to cause a race. Canceling the client request will cause the client
+                       // transport to close req2.Body. Returning from the server handler will
+                       // cause the server to close req.Body. Since they are the same underlying
+                       // ReadCloser, that will result in concurrent calls to Close (and possibly a
+                       // Read concurrent with a Close).
+                       if mode == http2Mode {
+                               close(cancel)
+                       } else {
+                               proxy.c.Transport.(*Transport).CancelRequest(req2)
+                       }
+                       rw.Write([]byte("OK"))
+               }))
+               defer proxy.close()
 
-               bresp, err := proxy.c.Do(req2)
-               if err != nil {
-                       errorf("Proxy outbound request: %v", err)
-                       return
-               }
-               _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2)
+               req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
+               res, err := proxy.c.Do(req)
                if err != nil {
-                       errorf("Proxy copy error: %v", err)
-                       return
-               }
-               backendRespc <- bresp // to close later
-
-               // Try to cause a race: Both the Transport and the proxy handler's Server
-               // will try to read/close req.Body (aka req2.Body)
-               if mode == http2Mode {
-                       close(cancel)
-               } else {
-                       proxy.c.Transport.(*Transport).CancelRequest(req2)
+                       return fmt.Errorf("original request: %v", err)
                }
-               rw.Write([]byte("OK"))
-       }))
-       defer proxy.close()
-
-       defer close(unblockBackend)
-       req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
-       res, err := proxy.c.Do(req)
-       if err != nil {
-               t.Fatalf("Original request: %v", err)
-       }
-
-       // Cleanup, so we don't leak goroutines.
-       res.Body.Close()
-       select {
-       case res := <-backendRespc:
                res.Body.Close()
-       default:
-               // We failed earlier. (e.g. on proxy.c.Do(req2))
-       }
+               return nil
+       })
 }
 
 // Test that a hanging Request.Body.Read from another goroutine can't
@@ -4342,7 +4346,8 @@ func (c *closeWriteTestConn) CloseWrite() error {
 }
 
 func TestCloseWrite(t *testing.T) {
-       setParallel(t)
+       SetRSTAvoidanceDelay(t, 1*time.Millisecond)
+
        var srv Server
        var testConn closeWriteTestConn
        c := ExportServerNewConn(&srv, &testConn)
@@ -5382,49 +5387,73 @@ func testServerIdleTimeout(t *testing.T, mode testMode) {
        if testing.Short() {
                t.Skip("skipping in short mode")
        }
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               io.Copy(io.Discard, r.Body)
-               io.WriteString(w, r.RemoteAddr)
-       }), func(ts *httptest.Server) {
-               ts.Config.ReadHeaderTimeout = 1 * time.Second
-               ts.Config.IdleTimeout = 2 * time.Second
-       }).ts
-       c := ts.Client()
+       runTimeSensitiveTest(t, []time.Duration{
+               10 * time.Millisecond,
+               100 * time.Millisecond,
+               1 * time.Second,
+               10 * time.Second,
+       }, func(t *testing.T, readHeaderTimeout time.Duration) error {
+               ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+                       io.Copy(io.Discard, r.Body)
+                       io.WriteString(w, r.RemoteAddr)
+               }), func(ts *httptest.Server) {
+                       ts.Config.ReadHeaderTimeout = readHeaderTimeout
+                       ts.Config.IdleTimeout = 2 * readHeaderTimeout
+               }).ts
+               t.Logf("ReadHeaderTimeout = %v", ts.Config.ReadHeaderTimeout)
+               t.Logf("IdleTimeout = %v", ts.Config.IdleTimeout)
+               c := ts.Client()
 
-       get := func() string {
-               res, err := c.Get(ts.URL)
+               get := func() (string, error) {
+                       res, err := c.Get(ts.URL)
+                       if err != nil {
+                               return "", err
+                       }
+                       defer res.Body.Close()
+                       slurp, err := io.ReadAll(res.Body)
+                       if err != nil {
+                               // If we're at this point the headers have definitely already been
+                               // read and the server is not idle, so neither timeout applies:
+                               // this should never fail.
+                               t.Fatal(err)
+                       }
+                       return string(slurp), nil
+               }
+
+               a1, err := get()
                if err != nil {
-                       t.Fatal(err)
+                       return err
                }
-               defer res.Body.Close()
-               slurp, err := io.ReadAll(res.Body)
+               a2, err := get()
                if err != nil {
-                       t.Fatal(err)
+                       return err
+               }
+               if a1 != a2 {
+                       return fmt.Errorf("did requests on different connections")
+               }
+               time.Sleep(ts.Config.IdleTimeout * 3 / 2)
+               a3, err := get()
+               if err != nil {
+                       return err
+               }
+               if a2 == a3 {
+                       return fmt.Errorf("request three unexpectedly on same connection")
                }
-               return string(slurp)
-       }
 
-       a1, a2 := get(), get()
-       if a1 != a2 {
-               t.Fatalf("did requests on different connections")
-       }
-       time.Sleep(3 * time.Second)
-       a3 := get()
-       if a2 == a3 {
-               t.Fatal("request three unexpectedly on same connection")
-       }
+               // And test that ReadHeaderTimeout still works:
+               conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+               if err != nil {
+                       return err
+               }
+               defer conn.Close()
+               conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
+               time.Sleep(ts.Config.ReadHeaderTimeout * 2)
+               if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
+                       return fmt.Errorf("copy byte succeeded; want err")
+               }
 
-       // And test that ReadHeaderTimeout still works:
-       conn, err := net.Dial("tcp", ts.Listener.Addr().String())
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer conn.Close()
-       conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n"))
-       time.Sleep(2 * time.Second)
-       if _, err := io.CopyN(io.Discard, conn, 1); err == nil {
-               t.Fatal("copy byte succeeded; want err")
-       }
+               return nil
+       })
 }
 
 func get(t *testing.T, c *Client, url string) string {
@@ -5773,9 +5802,10 @@ func runTimeSensitiveTest(t *testing.T, durations []time.Duration, test func(t *
                if err == nil {
                        return
                }
-               if i == len(durations)-1 {
+               if i == len(durations)-1 || t.Failed() {
                        t.Fatalf("failed with duration %v: %v", d, err)
                }
+               t.Logf("retrying after error with duration %v: %v", d, err)
        }
 }
 
@@ -6620,7 +6650,7 @@ func testQuerySemicolon(t *testing.T, mode testMode, query string, wantX string,
 }
 
 func TestMaxBytesHandler(t *testing.T) {
-       setParallel(t)
+       // Not parallel: modifies the global rstAvoidanceDelay.
        defer afterTest(t)
 
        for _, maxSize := range []int64{100, 1_000, 1_000_000} {
@@ -6629,77 +6659,99 @@ func TestMaxBytesHandler(t *testing.T) {
                                func(t *testing.T) {
                                        run(t, func(t *testing.T, mode testMode) {
                                                testMaxBytesHandler(t, mode, maxSize, requestSize)
-                                       })
+                                       }, testNotParallel)
                                })
                }
        }
 }
 
 func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64) {
-       var (
-               handlerN   int64
-               handlerErr error
-       )
-       echo := HandlerFunc(func(w ResponseWriter, r *Request) {
-               var buf bytes.Buffer
-               handlerN, handlerErr = io.Copy(&buf, r.Body)
-               io.Copy(w, &buf)
-       })
-
-       ts := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize)).ts
-       defer ts.Close()
+       runTimeSensitiveTest(t, []time.Duration{
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               time.Second,
+               5 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               SetRSTAvoidanceDelay(t, timeout)
+               t.Logf("set RST avoidance delay to %v", timeout)
+
+               var (
+                       handlerN   int64
+                       handlerErr error
+               )
+               echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+                       var buf bytes.Buffer
+                       handlerN, handlerErr = io.Copy(&buf, r.Body)
+                       io.Copy(w, &buf)
+               })
 
-       c := ts.Client()
+               cst := newClientServerTest(t, mode, MaxBytesHandler(echo, maxSize))
+               // We need to close cst explicitly here so that in-flight server
+               // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+               defer cst.close()
+               ts := cst.ts
+               c := ts.Client()
 
-       body := strings.Repeat("a", int(requestSize))
-       var wg sync.WaitGroup
-       defer wg.Wait()
-       getBody := func() (io.ReadCloser, error) {
-               wg.Add(1)
-               body := &wgReadCloser{
-                       Reader: strings.NewReader(body),
-                       wg:     &wg,
+               body := strings.Repeat("a", int(requestSize))
+               var wg sync.WaitGroup
+               defer wg.Wait()
+               getBody := func() (io.ReadCloser, error) {
+                       wg.Add(1)
+                       body := &wgReadCloser{
+                               Reader: strings.NewReader(body),
+                               wg:     &wg,
+                       }
+                       return body, nil
                }
-               return body, nil
-       }
-       reqBody, _ := getBody()
-       req, err := NewRequest("POST", ts.URL, reqBody)
-       if err != nil {
-               reqBody.Close()
-               t.Fatal(err)
-       }
-       req.ContentLength = int64(len(body))
-       req.GetBody = getBody
-       req.Header.Set("Content-Type", "text/plain")
+               reqBody, _ := getBody()
+               req, err := NewRequest("POST", ts.URL, reqBody)
+               if err != nil {
+                       reqBody.Close()
+                       t.Fatal(err)
+               }
+               req.ContentLength = int64(len(body))
+               req.GetBody = getBody
+               req.Header.Set("Content-Type", "text/plain")
 
-       var buf strings.Builder
-       res, err := c.Do(req)
-       if err != nil {
-               t.Errorf("unexpected connection error: %v", err)
-       } else {
-               _, err = io.Copy(&buf, res.Body)
-               res.Body.Close()
+               var buf strings.Builder
+               res, err := c.Do(req)
                if err != nil {
-                       t.Errorf("unexpected read error: %v", err)
+                       return fmt.Errorf("unexpected connection error: %v", err)
+               } else {
+                       _, err = io.Copy(&buf, res.Body)
+                       res.Body.Close()
+                       if err != nil {
+                               return fmt.Errorf("unexpected read error: %v", err)
+                       }
                }
-       }
-       if handlerN > maxSize {
-               t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
-       }
-       if requestSize > maxSize && handlerErr == nil {
-               t.Error("expected error on handler side; got nil")
-       }
-       if requestSize <= maxSize {
-               if handlerErr != nil {
-                       t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+               // We don't expect any of the errors after this point to occur due
+               // to rstAvoidanceDelay being too short, so we use t.Errorf for those
+               // instead of returning a (retriable) error.
+
+               if handlerN > maxSize {
+                       t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
                }
-               if handlerN != requestSize {
-                       t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+               if requestSize > maxSize && handlerErr == nil {
+                       t.Error("expected error on handler side; got nil")
                }
-       }
-       if buf.Len() != int(handlerN) {
-               t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
-       }
+               if requestSize <= maxSize {
+                       if handlerErr != nil {
+                               t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+                       }
+                       if handlerN != requestSize {
+                               t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+                       }
+               }
+               if buf.Len() != int(handlerN) {
+                       t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+               }
+
+               return nil
+       })
 }
 
 func TestEarlyHints(t *testing.T) {
index 26df238495290fd57920b6499b93ce38d0033439..6fe917e0867d4d7e30af07dcca7b933f597e7c8b 100644 (file)
@@ -1750,8 +1750,12 @@ func (c *conn) close() {
 // and processes its final data before they process the subsequent RST
 // from closing a connection with known unread data.
 // This RST seems to occur mostly on BSD systems. (And Windows?)
-// This timeout is somewhat arbitrary (~latency around the planet).
-const rstAvoidanceDelay = 500 * time.Millisecond
+// This timeout is somewhat arbitrary (~latency around the planet),
+// and may be modified by tests.
+//
+// TODO(bcmills): This should arguably be a server configuration parameter,
+// not a hard-coded value.
+var rstAvoidanceDelay = 500 * time.Millisecond
 
 type closeWriter interface {
        CloseWrite() error
@@ -1770,6 +1774,27 @@ func (c *conn) closeWriteAndWait() {
        if tcp, ok := c.rwc.(closeWriter); ok {
                tcp.CloseWrite()
        }
+
+       // When we return from closeWriteAndWait, the caller will fully close the
+       // connection. If client is still writing to the connection, this will cause
+       // the write to fail with ECONNRESET or similar. Unfortunately, many TCP
+       // implementations will also drop unread packets from the client's read buffer
+       // when a write fails, causing our final response to be truncated away too.
+       //
+       // As a result, https://www.rfc-editor.org/rfc/rfc7230#section-6.6 recommends
+       // that “[t]he server … continues to read from the connection until it
+       // receives a corresponding close by the client, or until the server is
+       // reasonably certain that its own TCP stack has received the client's
+       // acknowledgement of the packet(s) containing the server's last response.”
+       //
+       // Unfortunately, we have no straightforward way to be “reasonably certain”
+       // that we have received the client's ACK, and at any rate we don't want to
+       // allow a misbehaving client to soak up server connections indefinitely by
+       // withholding an ACK, nor do we want to go through the complexity or overhead
+       // of using low-level APIs to figure out when a TCP round-trip has completed.
+       //
+       // Instead, we declare that we are “reasonably certain” that we received the
+       // ACK if maxRSTAvoidanceDelay has elapsed.
        time.Sleep(rstAvoidanceDelay)
 }
 
index 9f086172d3c4c303bbf56d325744429c48cadba4..8c09de70ff5e2b6b4618a7d4613e76e36e6e1334 100644 (file)
@@ -2099,25 +2099,50 @@ func testIssue3644(t *testing.T, mode testMode) {
 
 // Test that a client receives a server's reply, even if the server doesn't read
 // the entire request body.
-func TestIssue3595(t *testing.T) { run(t, testIssue3595) }
+func TestIssue3595(t *testing.T) {
+       // Not parallel: modifies the global rstAvoidanceDelay.
+       run(t, testIssue3595, testNotParallel)
+}
 func testIssue3595(t *testing.T, mode testMode) {
-       const deniedMsg = "sorry, denied."
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               Error(w, deniedMsg, StatusUnauthorized)
-       })).ts
-       c := ts.Client()
-       res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
-       if err != nil {
-               t.Errorf("Post: %v", err)
-               return
-       }
-       got, err := io.ReadAll(res.Body)
-       if err != nil {
-               t.Fatalf("Body ReadAll: %v", err)
-       }
-       if !strings.Contains(string(got), deniedMsg) {
-               t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
-       }
+       runTimeSensitiveTest(t, []time.Duration{
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               time.Second,
+               5 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               SetRSTAvoidanceDelay(t, timeout)
+               t.Logf("set RST avoidance delay to %v", timeout)
+
+               const deniedMsg = "sorry, denied."
+               cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+                       Error(w, deniedMsg, StatusUnauthorized)
+               }))
+               // We need to close cst explicitly here so that in-flight server
+               // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+               defer cst.close()
+               ts := cst.ts
+               c := ts.Client()
+
+               res, err := c.Post(ts.URL, "application/octet-stream", neverEnding('a'))
+               if err != nil {
+                       return fmt.Errorf("Post: %v", err)
+               }
+               got, err := io.ReadAll(res.Body)
+               if err != nil {
+                       return fmt.Errorf("Body ReadAll: %v", err)
+               }
+               t.Logf("server response:\n%s", got)
+               if !strings.Contains(string(got), deniedMsg) {
+                       // If we got an RST packet too early, we should have seen an error
+                       // from io.ReadAll, not a silently-truncated body.
+                       t.Errorf("Known bug: response %q does not contain %q", got, deniedMsg)
+               }
+               return nil
+       })
 }
 
 // From https://golang.org/issue/4454 ,
@@ -4327,68 +4352,78 @@ func (c *wgReadCloser) Close() error {
 
 // Issue 11745.
 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
-       run(t, testTransportPrefersResponseOverWriteError)
+       // Not parallel: modifies the global rstAvoidanceDelay.
+       run(t, testTransportPrefersResponseOverWriteError, testNotParallel)
 }
 func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
        if testing.Short() {
                t.Skip("skipping in short mode")
        }
-       const contentLengthLimit = 1024 * 1024 // 1MB
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               if r.ContentLength >= contentLengthLimit {
-                       w.WriteHeader(StatusBadRequest)
-                       r.Body.Close()
-                       return
-               }
-               w.WriteHeader(StatusOK)
-       })).ts
-       c := ts.Client()
 
-       fail := 0
-       count := 100
+       runTimeSensitiveTest(t, []time.Duration{
+               1 * time.Millisecond,
+               5 * time.Millisecond,
+               10 * time.Millisecond,
+               50 * time.Millisecond,
+               100 * time.Millisecond,
+               500 * time.Millisecond,
+               time.Second,
+               5 * time.Second,
+       }, func(t *testing.T, timeout time.Duration) error {
+               SetRSTAvoidanceDelay(t, timeout)
+               t.Logf("set RST avoidance delay to %v", timeout)
+
+               const contentLengthLimit = 1024 * 1024 // 1MB
+               cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+                       if r.ContentLength >= contentLengthLimit {
+                               w.WriteHeader(StatusBadRequest)
+                               r.Body.Close()
+                               return
+                       }
+                       w.WriteHeader(StatusOK)
+               }))
+               // We need to close cst explicitly here so that in-flight server
+               // requests don't race with the call to SetRSTAvoidanceDelay for a retry.
+               defer cst.close()
+               ts := cst.ts
+               c := ts.Client()
 
-       bigBody := strings.Repeat("a", contentLengthLimit*2)
-       var wg sync.WaitGroup
-       defer wg.Wait()
-       getBody := func() (io.ReadCloser, error) {
-               wg.Add(1)
-               body := &wgReadCloser{
-                       Reader: strings.NewReader(bigBody),
-                       wg:     &wg,
-               }
-               return body, nil
-       }
+               count := 100
 
-       for i := 0; i < count; i++ {
-               reqBody, _ := getBody()
-               req, err := NewRequest("PUT", ts.URL, reqBody)
-               if err != nil {
-                       reqBody.Close()
-                       t.Fatal(err)
+               bigBody := strings.Repeat("a", contentLengthLimit*2)
+               var wg sync.WaitGroup
+               defer wg.Wait()
+               getBody := func() (io.ReadCloser, error) {
+                       wg.Add(1)
+                       body := &wgReadCloser{
+                               Reader: strings.NewReader(bigBody),
+                               wg:     &wg,
+                       }
+                       return body, nil
                }
-               req.ContentLength = int64(len(bigBody))
-               req.GetBody = getBody
 
-               resp, err := c.Do(req)
-               if err != nil {
-                       fail++
-                       t.Logf("%d = %#v", i, err)
-                       if ue, ok := err.(*url.Error); ok {
-                               t.Logf("urlErr = %#v", ue.Err)
-                               if ne, ok := ue.Err.(*net.OpError); ok {
-                                       t.Logf("netOpError = %#v", ne.Err)
-                               }
+               for i := 0; i < count; i++ {
+                       reqBody, _ := getBody()
+                       req, err := NewRequest("PUT", ts.URL, reqBody)
+                       if err != nil {
+                               reqBody.Close()
+                               t.Fatal(err)
                        }
-               } else {
-                       resp.Body.Close()
-                       if resp.StatusCode != 400 {
-                               t.Errorf("Expected status code 400, got %v", resp.Status)
+                       req.ContentLength = int64(len(bigBody))
+                       req.GetBody = getBody
+
+                       resp, err := c.Do(req)
+                       if err != nil {
+                               return fmt.Errorf("Do %d: %v", i, err)
+                       } else {
+                               resp.Body.Close()
+                               if resp.StatusCode != 400 {
+                                       t.Errorf("Expected status code 400, got %v", resp.Status)
+                               }
                        }
                }
-       }
-       if fail > 0 {
-               t.Errorf("Failed %v out of %v\n", fail, count)
-       }
+               return nil
+       })
 }
 
 func TestTransportAutomaticHTTP2(t *testing.T) {