]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: avoid leaking writer goroutines in tests
authorBryan C. Mills <bcmills@google.com>
Wed, 12 Apr 2023 13:58:54 +0000 (13:58 +0000)
committerGopher Robot <gobot@golang.org>
Wed, 12 Apr 2023 16:04:39 +0000 (16:04 +0000)
In TestTransportPrefersResponseOverWriteError and TestMaxBytesHandler,
the server may respond to an incoming request without ever reading the
request body. The client's Do method will return as soon as the
server's response headers are read, but the Transport will remain
active until it notices that the server has closed the connection,
which may be arbitrarily later.

When the server has closed the connection, it will call the Close
method on the request body (if it has such a method). So we can use
that method to find out when the Transport is close enough to done for
the test to complete without interfering too much with other tests.

For #57612.
For #59526.

Change-Id: Iddc7a3b7b09429113ad76ccc1c090ebc9e1835a1
Reviewed-on: https://go-review.googlesource.com/c/go/+/483895
Run-TryBot: Bryan Mills <bcmills@google.com>
Commit-Queue: Bryan Mills <bcmills@google.com>
Auto-Submit: Bryan Mills <bcmills@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Reviewed-by: Michael Pratt <mpratt@google.com>
src/net/http/serve_test.go
src/net/http/transport_test.go

index 164b18287f179ec6bee542bba0dfa9152b268905..9b8496e7ad091de58f354358e06d9ceb0abee9ed 100644 (file)
@@ -6546,9 +6546,30 @@ func testMaxBytesHandler(t *testing.T, mode testMode, maxSize, requestSize int64
        defer ts.Close()
 
        c := ts.Client()
+
+       body := strings.Repeat("a", int(requestSize))
+       var wg sync.WaitGroup
+       defer wg.Wait()
+       getBody := func() (io.ReadCloser, error) {
+               wg.Add(1)
+               body := &wgReadCloser{
+                       Reader: strings.NewReader(body),
+                       wg:     &wg,
+               }
+               return body, nil
+       }
+       reqBody, _ := getBody()
+       req, err := NewRequest("POST", ts.URL, reqBody)
+       if err != nil {
+               reqBody.Close()
+               t.Fatal(err)
+       }
+       req.ContentLength = int64(len(body))
+       req.GetBody = getBody
+       req.Header.Set("Content-Type", "text/plain")
+
        var buf strings.Builder
-       body := strings.NewReader(strings.Repeat("a", int(requestSize)))
-       res, err := c.Post(ts.URL, "text/plain", body)
+       res, err := c.Do(req)
        if err != nil {
                t.Errorf("unexpected connection error: %v", err)
        } else {
index 6f57629eff892ac6b20a9e5defb913b4d3d4e255..f9e8a285c5e5e5c96f5ab9d0a32126b24b4cd897 100644 (file)
@@ -4250,6 +4250,21 @@ func testTransportFlushesRequestHeader(t *testing.T, mode testMode) {
        <-gotRes
 }
 
+type wgReadCloser struct {
+       io.Reader
+       wg     *sync.WaitGroup
+       closed bool
+}
+
+func (c *wgReadCloser) Close() error {
+       if c.closed {
+               return net.ErrClosed
+       }
+       c.closed = true
+       c.wg.Done()
+       return nil
+}
+
 // Issue 11745.
 func TestTransportPrefersResponseOverWriteError(t *testing.T) {
        run(t, testTransportPrefersResponseOverWriteError)
@@ -4271,12 +4286,29 @@ func testTransportPrefersResponseOverWriteError(t *testing.T, mode testMode) {
 
        fail := 0
        count := 100
+
        bigBody := strings.Repeat("a", contentLengthLimit*2)
+       var wg sync.WaitGroup
+       defer wg.Wait()
+       getBody := func() (io.ReadCloser, error) {
+               wg.Add(1)
+               body := &wgReadCloser{
+                       Reader: strings.NewReader(bigBody),
+                       wg:     &wg,
+               }
+               return body, nil
+       }
+
        for i := 0; i < count; i++ {
-               req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
+               reqBody, _ := getBody()
+               req, err := NewRequest("PUT", ts.URL, reqBody)
                if err != nil {
+                       reqBody.Close()
                        t.Fatal(err)
                }
+               req.ContentLength = int64(len(bigBody))
+               req.GetBody = getBody
+
                resp, err := c.Do(req)
                if err != nil {
                        fail++