]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: return correct error when reading from a canceled request body
authorDamien Neil <dneil@google.com>
Thu, 16 May 2024 23:26:04 +0000 (16:26 -0700)
committerGopher Robot <gobot@golang.org>
Fri, 17 May 2024 15:58:58 +0000 (15:58 +0000)
CL 546676 inadvertently changed the error returned when reading
from the body of a canceled request. Fix it.

Rework various request cancelation tests to exercise all three ways
of canceling a request.

Fixes #67439

Change-Id: I14ecaf8bff9452eca4a05df923d57d768127a90c
Cq-Include-Trybots: luci.golang.try:gotip-linux-amd64-longtest,gotip-linux-amd64-longtest-race
Reviewed-on: https://go-review.googlesource.com/c/go/+/586315
Auto-Submit: Damien Neil <dneil@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/net/http/transport.go
src/net/http/transport_test.go

index f7a7092ef749fe7d6afdeaa3bc1ee809a320c851..0d4332c3445ba8d2d93350d33f15406f3ff4846c 100644 (file)
@@ -2333,7 +2333,7 @@ func (pc *persistConn) readLoop() {
                        }
                case <-rc.treq.ctx.Done():
                        alive = false
-                       pc.cancelRequest(errRequestCanceled)
+                       pc.cancelRequest(context.Cause(rc.treq.ctx))
                case <-pc.closech:
                        alive = false
                }
index fa147e164ed843181980cbcde07e2e4b6144dff7..25876e8d16879d13ad8035f70a92dd76c4abd150 100644 (file)
@@ -2507,17 +2507,103 @@ func testTransportResponseHeaderTimeout(t *testing.T, mode testMode) {
        }
 }
 
+// A cancelTest is a test of request cancellation.
+type cancelTest struct {
+       mode     testMode
+       newReq   func(req *Request) *Request       // prepare the request to cancel
+       cancel   func(tr *Transport, req *Request) // cancel the request
+       checkErr func(when string, err error)      // verify the expected error
+}
+
+// runCancelTestTransport uses Transport.CancelRequest.
+func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+       t.Run("TransportCancel", func(t *testing.T) {
+               f(t, cancelTest{
+                       mode: mode,
+                       newReq: func(req *Request) *Request {
+                               return req
+                       },
+                       cancel: func(tr *Transport, req *Request) {
+                               tr.CancelRequest(req)
+                       },
+                       checkErr: func(when string, err error) {
+                               if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+                                       t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+                               }
+                       },
+               })
+       })
+}
+
+// runCancelTestChannel uses Request.Cancel.
+func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+       var cancelOnce sync.Once
+       cancelc := make(chan struct{})
+       f(t, cancelTest{
+               mode: mode,
+               newReq: func(req *Request) *Request {
+                       req.Cancel = cancelc
+                       return req
+               },
+               cancel: func(tr *Transport, req *Request) {
+                       cancelOnce.Do(func() {
+                               close(cancelc)
+                       })
+               },
+               checkErr: func(when string, err error) {
+                       if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+                               t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+                       }
+               },
+       })
+}
+
+// runCancelTestContext uses a request context.
+func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+       ctx, cancel := context.WithCancel(context.Background())
+       f(t, cancelTest{
+               mode: mode,
+               newReq: func(req *Request) *Request {
+                       return req.WithContext(ctx)
+               },
+               cancel: func(tr *Transport, req *Request) {
+                       cancel()
+               },
+               checkErr: func(when string, err error) {
+                       if !errors.Is(err, context.Canceled) {
+                               t.Errorf("%v error = %v, want context.Canceled", when, err)
+                       }
+               },
+       })
+}
+
+func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
+       run(t, func(t *testing.T, mode testMode) {
+               if mode == http1Mode {
+                       t.Run("TransportCancel", func(t *testing.T) {
+                               runCancelTestTransport(t, mode, f)
+                       })
+               }
+               t.Run("RequestCancel", func(t *testing.T) {
+                       runCancelTestChannel(t, mode, f)
+               })
+               t.Run("ContextCancel", func(t *testing.T) {
+                       runCancelTestContext(t, mode, f)
+               })
+       }, opts...)
+}
+
 func TestTransportCancelRequest(t *testing.T) {
-       run(t, testTransportCancelRequest, []testMode{http1Mode})
+       runCancelTest(t, testTransportCancelRequest)
 }
-func testTransportCancelRequest(t *testing.T, mode testMode) {
+func testTransportCancelRequest(t *testing.T, test cancelTest) {
        if testing.Short() {
                t.Skip("skipping test in -short mode")
        }
 
        const msg = "Hello"
        unblockc := make(chan bool)
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+       ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                io.WriteString(w, msg)
                w.(Flusher).Flush() // send headers and some body
                <-unblockc
@@ -2528,6 +2614,7 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
        tr := c.Transport.(*Transport)
 
        req, _ := NewRequest("GET", ts.URL, nil)
+       req = test.newReq(req)
        res, err := c.Do(req)
        if err != nil {
                t.Fatal(err)
@@ -2537,13 +2624,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
        if n != len(body) || !bytes.Equal(body, []byte(msg)) {
                t.Errorf("Body = %q; want %q", body[:n], msg)
        }
-       tr.CancelRequest(req)
+       test.cancel(tr, req)
 
        tail, err := io.ReadAll(res.Body)
        res.Body.Close()
-       if err != ExportErrRequestCanceled {
-               t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
-       } else if len(tail) > 0 {
+       test.checkErr("Body.Read", err)
+       if len(tail) > 0 {
                t.Errorf("Spurious bytes from Body.Read: %q", tail)
        }
 
@@ -2561,12 +2647,12 @@ func testTransportCancelRequest(t *testing.T, mode testMode) {
        })
 }
 
-func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
        if testing.Short() {
                t.Skip("skipping test in -short mode")
        }
        unblockc := make(chan bool)
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+       ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                <-unblockc
        })).ts
        defer close(unblockc)
@@ -2576,6 +2662,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
 
        donec := make(chan bool)
        req, _ := NewRequest("GET", ts.URL, body)
+       req = test.newReq(req)
        go func() {
                defer close(donec)
                c.Do(req)
@@ -2583,7 +2670,7 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
 
        unblockc <- true
        waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
-               tr.CancelRequest(req)
+               test.cancel(tr, req)
                select {
                case <-donec:
                        return true
@@ -2597,18 +2684,21 @@ func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader)
 }
 
 func TestTransportCancelRequestInDo(t *testing.T) {
-       run(t, func(t *testing.T, mode testMode) {
-               testTransportCancelRequestInDo(t, mode, nil)
-       }, []testMode{http1Mode})
+       runCancelTest(t, func(t *testing.T, test cancelTest) {
+               testTransportCancelRequestInDo(t, test, nil)
+       })
 }
 
 func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
-       run(t, func(t *testing.T, mode testMode) {
-               testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
-       }, []testMode{http1Mode})
+       runCancelTest(t, func(t *testing.T, test cancelTest) {
+               testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
+       })
 }
 
 func TestTransportCancelRequestInDial(t *testing.T) {
+       runCancelTest(t, testTransportCancelRequestInDial)
+}
+func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
        defer afterTest(t)
        if testing.Short() {
                t.Skip("skipping test in -short mode")
@@ -2633,17 +2723,19 @@ func TestTransportCancelRequestInDial(t *testing.T) {
        cl := &Client{Transport: tr}
        gotres := make(chan bool)
        req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+       req = test.newReq(req)
        go func() {
                _, err := cl.Do(req)
-               eventLog.Printf("Get = %v", err)
+               eventLog.Printf("Get error = %v", err != nil)
+               test.checkErr("Get", err)
                gotres <- true
        }()
 
        inDial <- true
 
        eventLog.Printf("canceling")
-       tr.CancelRequest(req)
-       tr.CancelRequest(req) // used to panic on second call
+       test.cancel(tr, req)
+       test.cancel(tr, req) // used to panic on second call to Transport.Cancel
 
        if d, ok := t.Deadline(); ok {
                // When the test's deadline is about to expire, log the pending events for
@@ -2659,80 +2751,25 @@ func TestTransportCancelRequestInDial(t *testing.T) {
        got := logbuf.String()
        want := `dial: blocking
 canceling
-Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+Get error = true
 `
        if got != want {
                t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
        }
 }
 
-func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
-func testCancelRequestWithChannel(t *testing.T, mode testMode) {
-       if testing.Short() {
-               t.Skip("skipping test in -short mode")
-       }
-
-       const msg = "Hello"
-       unblockc := make(chan struct{})
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               io.WriteString(w, msg)
-               w.(Flusher).Flush() // send headers and some body
-               <-unblockc
-       })).ts
-       defer close(unblockc)
-
-       c := ts.Client()
-       tr := c.Transport.(*Transport)
-
-       req, _ := NewRequest("GET", ts.URL, nil)
-       cancel := make(chan struct{})
-       req.Cancel = cancel
-
-       res, err := c.Do(req)
-       if err != nil {
-               t.Fatal(err)
-       }
-       body := make([]byte, len(msg))
-       n, _ := io.ReadFull(res.Body, body)
-       if n != len(body) || !bytes.Equal(body, []byte(msg)) {
-               t.Errorf("Body = %q; want %q", body[:n], msg)
-       }
-       close(cancel)
-
-       tail, err := io.ReadAll(res.Body)
-       res.Body.Close()
-       if err != ExportErrRequestCanceled {
-               t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
-       } else if len(tail) > 0 {
-               t.Errorf("Spurious bytes from Body.Read: %q", tail)
-       }
-
-       // Verify no outstanding requests after readLoop/writeLoop
-       // goroutines shut down.
-       waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
-               n := tr.NumPendingRequestsForTesting()
-               if n > 0 {
-                       if d > 0 {
-                               t.Logf("pending requests = %d after %v (want 0)", n, d)
-                       }
-                       return false
-               }
-               return true
-       })
-}
-
 // Issue 51354
-func TestCancelRequestWithBodyWithChannel(t *testing.T) {
-       run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
+func TestTransportCancelRequestWithBody(t *testing.T) {
+       runCancelTest(t, testTransportCancelRequestWithBody)
 }
-func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
+func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
        if testing.Short() {
                t.Skip("skipping test in -short mode")
        }
 
        const msg = "Hello"
        unblockc := make(chan struct{})
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+       ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                io.WriteString(w, msg)
                w.(Flusher).Flush() // send headers and some body
                <-unblockc
@@ -2743,8 +2780,7 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
        tr := c.Transport.(*Transport)
 
        req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
-       cancel := make(chan struct{})
-       req.Cancel = cancel
+       req = test.newReq(req)
 
        res, err := c.Do(req)
        if err != nil {
@@ -2755,13 +2791,12 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
        if n != len(body) || !bytes.Equal(body, []byte(msg)) {
                t.Errorf("Body = %q; want %q", body[:n], msg)
        }
-       close(cancel)
+       test.cancel(tr, req)
 
        tail, err := io.ReadAll(res.Body)
        res.Body.Close()
-       if err != ExportErrRequestCanceled {
-               t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
-       } else if len(tail) > 0 {
+       test.checkErr("Body.Read", err)
+       if len(tail) > 0 {
                t.Errorf("Spurious bytes from Body.Read: %q", tail)
        }
 
@@ -2779,53 +2814,39 @@ func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
        })
 }
 
-func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
+func TestTransportCancelRequestBeforeDo(t *testing.T) {
+       // We can't cancel a request that hasn't started using Transport.CancelRequest.
        run(t, func(t *testing.T, mode testMode) {
-               testCancelRequestWithChannelBeforeDo(t, mode, false)
-       })
-}
-func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
-       run(t, func(t *testing.T, mode testMode) {
-               testCancelRequestWithChannelBeforeDo(t, mode, true)
+               t.Run("RequestCancel", func(t *testing.T) {
+                       runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
+               })
+               t.Run("ContextCancel", func(t *testing.T) {
+                       runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
+               })
        })
 }
-func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
        unblockc := make(chan bool)
-       ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+       cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
                <-unblockc
-       })).ts
+       }))
        defer close(unblockc)
 
-       c := ts.Client()
+       c := cst.ts.Client()
 
-       req, _ := NewRequest("GET", ts.URL, nil)
-       if withCtx {
-               ctx, cancel := context.WithCancel(context.Background())
-               cancel()
-               req = req.WithContext(ctx)
-       } else {
-               ch := make(chan struct{})
-               req.Cancel = ch
-               close(ch)
-       }
+       req, _ := NewRequest("GET", cst.ts.URL, nil)
+       req = test.newReq(req)
+       test.cancel(cst.tr, req)
 
        _, err := c.Do(req)
-       if ue, ok := err.(*url.Error); ok {
-               err = ue.Err
-       }
-       if withCtx {
-               if err != context.Canceled {
-                       t.Errorf("Do error = %v; want %v", err, context.Canceled)
-               }
-       } else {
-               if err == nil || !strings.Contains(err.Error(), "canceled") {
-                       t.Errorf("Do error = %v; want cancellation", err)
-               }
-       }
+       test.checkErr("Do", err)
 }
 
 // Issue 11020. The returned error message should be errRequestCanceled
-func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
+       runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
+}
+func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
        defer afterTest(t)
 
        serverConnCh := make(chan net.Conn, 1)
@@ -2839,6 +2860,7 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
        defer tr.CloseIdleConnections()
        errc := make(chan error, 1)
        req, _ := NewRequest("GET", "http://example.com/", nil)
+       req = test.newReq(req)
        go func() {
                _, err := tr.RoundTrip(req)
                errc <- err
@@ -2854,15 +2876,13 @@ func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
        }
        defer sc.Close()
 
-       tr.CancelRequest(req)
+       test.cancel(tr, req)
 
        err := <-errc
        if err == nil {
                t.Fatalf("unexpected success from RoundTrip")
        }
-       if err != ExportErrRequestCanceled {
-               t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
-       }
+       test.checkErr("RoundTrip", err)
 }
 
 // golang.org/issue/3672 -- Client can't close HTTP stream