}
}
+// A cancelTest is a test of request cancellation.
+type cancelTest struct {
+ mode testMode
+ newReq func(req *Request) *Request // prepare the request to cancel
+ cancel func(tr *Transport, req *Request) // cancel the request
+ checkErr func(when string, err error) // verify the expected error
+}
+
+// runCancelTestTransport uses Transport.CancelRequest.
+func runCancelTestTransport(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ t.Run("TransportCancel", func(t *testing.T) {
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ tr.CancelRequest(req)
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+ })
+}
+
+// runCancelTestChannel uses Request.Cancel.
+func runCancelTestChannel(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ var cancelOnce sync.Once
+ cancelc := make(chan struct{})
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ req.Cancel = cancelc
+ return req
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancelOnce.Do(func() {
+ close(cancelc)
+ })
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, ExportErrRequestCanceled) && !errors.Is(err, ExportErrRequestCanceledConn) {
+ t.Errorf("%v error = %v, want errRequestCanceled or errRequestCanceledConn", when, err)
+ }
+ },
+ })
+}
+
+// runCancelTestContext uses a request context.
+func runCancelTestContext(t *testing.T, mode testMode, f func(t *testing.T, test cancelTest)) {
+ ctx, cancel := context.WithCancel(context.Background())
+ f(t, cancelTest{
+ mode: mode,
+ newReq: func(req *Request) *Request {
+ return req.WithContext(ctx)
+ },
+ cancel: func(tr *Transport, req *Request) {
+ cancel()
+ },
+ checkErr: func(when string, err error) {
+ if !errors.Is(err, context.Canceled) {
+ t.Errorf("%v error = %v, want context.Canceled", when, err)
+ }
+ },
+ })
+}
+
+func runCancelTest(t *testing.T, f func(t *testing.T, test cancelTest), opts ...any) {
+ run(t, func(t *testing.T, mode testMode) {
+ if mode == http1Mode {
+ t.Run("TransportCancel", func(t *testing.T) {
+ runCancelTestTransport(t, mode, f)
+ })
+ }
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, f)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, f)
+ })
+ }, opts...)
+}
+
func TestTransportCancelRequest(t *testing.T) {
- run(t, testTransportCancelRequest, []testMode{http1Mode})
+ runCancelTest(t, testTransportCancelRequest)
}
-func testTransportCancelRequest(t *testing.T, mode testMode) {
+func testTransportCancelRequest(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
tr := c.Transport.(*Transport)
req, _ := NewRequest("GET", ts.URL, nil)
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
t.Fatal(err)
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- tr.CancelRequest(req)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
})
}
-func testTransportCancelRequestInDo(t *testing.T, mode testMode, body io.Reader) {
+func testTransportCancelRequestInDo(t *testing.T, test cancelTest, body io.Reader) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
})).ts
defer close(unblockc)
donec := make(chan bool)
req, _ := NewRequest("GET", ts.URL, body)
+ req = test.newReq(req)
go func() {
defer close(donec)
c.Do(req)
unblockc <- true
waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- tr.CancelRequest(req)
+ test.cancel(tr, req)
select {
case <-donec:
return true
}
func TestTransportCancelRequestInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, nil)
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, nil)
+ })
}
func TestTransportCancelRequestWithBodyInDo(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testTransportCancelRequestInDo(t, mode, bytes.NewBuffer([]byte{0}))
- }, []testMode{http1Mode})
+ runCancelTest(t, func(t *testing.T, test cancelTest) {
+ testTransportCancelRequestInDo(t, test, bytes.NewBuffer([]byte{0}))
+ })
}
func TestTransportCancelRequestInDial(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestInDial)
+}
+func testTransportCancelRequestInDial(t *testing.T, test cancelTest) {
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
cl := &Client{Transport: tr}
gotres := make(chan bool)
req, _ := NewRequest("GET", "http://something.no-network.tld/", nil)
+ req = test.newReq(req)
go func() {
_, err := cl.Do(req)
- eventLog.Printf("Get = %v", err)
+ eventLog.Printf("Get error = %v", err != nil)
+ test.checkErr("Get", err)
gotres <- true
}()
inDial <- true
eventLog.Printf("canceling")
- tr.CancelRequest(req)
- tr.CancelRequest(req) // used to panic on second call
+ test.cancel(tr, req)
+ test.cancel(tr, req) // used to panic on second call to Transport.Cancel
if d, ok := t.Deadline(); ok {
// When the test's deadline is about to expire, log the pending events for
got := logbuf.String()
want := `dial: blocking
canceling
-Get = Get "http://something.no-network.tld/": net/http: request canceled while waiting for connection
+Get error = true
`
if got != want {
t.Errorf("Got events:\n%s\nWant:\n%s", got, want)
}
}
-func TestCancelRequestWithChannel(t *testing.T) { run(t, testCancelRequestWithChannel) }
-func testCancelRequestWithChannel(t *testing.T, mode testMode) {
- if testing.Short() {
- t.Skip("skipping test in -short mode")
- }
-
- const msg = "Hello"
- unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
- io.WriteString(w, msg)
- w.(Flusher).Flush() // send headers and some body
- <-unblockc
- })).ts
- defer close(unblockc)
-
- c := ts.Client()
- tr := c.Transport.(*Transport)
-
- req, _ := NewRequest("GET", ts.URL, nil)
- cancel := make(chan struct{})
- req.Cancel = cancel
-
- res, err := c.Do(req)
- if err != nil {
- t.Fatal(err)
- }
- body := make([]byte, len(msg))
- n, _ := io.ReadFull(res.Body, body)
- if n != len(body) || !bytes.Equal(body, []byte(msg)) {
- t.Errorf("Body = %q; want %q", body[:n], msg)
- }
- close(cancel)
-
- tail, err := io.ReadAll(res.Body)
- res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
- t.Errorf("Spurious bytes from Body.Read: %q", tail)
- }
-
- // Verify no outstanding requests after readLoop/writeLoop
- // goroutines shut down.
- waitCondition(t, 10*time.Millisecond, func(d time.Duration) bool {
- n := tr.NumPendingRequestsForTesting()
- if n > 0 {
- if d > 0 {
- t.Logf("pending requests = %d after %v (want 0)", n, d)
- }
- return false
- }
- return true
- })
-}
-
// Issue 51354
-func TestCancelRequestWithBodyWithChannel(t *testing.T) {
- run(t, testCancelRequestWithBodyWithChannel, []testMode{http1Mode})
+func TestTransportCancelRequestWithBody(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestWithBody)
}
-func testCancelRequestWithBodyWithChannel(t *testing.T, mode testMode) {
+func testTransportCancelRequestWithBody(t *testing.T, test cancelTest) {
if testing.Short() {
t.Skip("skipping test in -short mode")
}
const msg = "Hello"
unblockc := make(chan struct{})
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
io.WriteString(w, msg)
w.(Flusher).Flush() // send headers and some body
<-unblockc
tr := c.Transport.(*Transport)
req, _ := NewRequest("POST", ts.URL, strings.NewReader("withbody"))
- cancel := make(chan struct{})
- req.Cancel = cancel
+ req = test.newReq(req)
res, err := c.Do(req)
if err != nil {
if n != len(body) || !bytes.Equal(body, []byte(msg)) {
t.Errorf("Body = %q; want %q", body[:n], msg)
}
- close(cancel)
+ test.cancel(tr, req)
tail, err := io.ReadAll(res.Body)
res.Body.Close()
- if err != ExportErrRequestCanceled {
- t.Errorf("Body.Read error = %v; want errRequestCanceled", err)
- } else if len(tail) > 0 {
+ test.checkErr("Body.Read", err)
+ if len(tail) > 0 {
t.Errorf("Spurious bytes from Body.Read: %q", tail)
}
})
}
-func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) {
+func TestTransportCancelRequestBeforeDo(t *testing.T) {
+ // We can't cancel a request that hasn't started using Transport.CancelRequest.
run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, false)
- })
-}
-func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) {
- run(t, func(t *testing.T, mode testMode) {
- testCancelRequestWithChannelBeforeDo(t, mode, true)
+ t.Run("RequestCancel", func(t *testing.T) {
+ runCancelTestChannel(t, mode, testTransportCancelRequestBeforeDo)
+ })
+ t.Run("ContextCancel", func(t *testing.T) {
+ runCancelTestContext(t, mode, testTransportCancelRequestBeforeDo)
+ })
})
}
-func testCancelRequestWithChannelBeforeDo(t *testing.T, mode testMode, withCtx bool) {
+func testTransportCancelRequestBeforeDo(t *testing.T, test cancelTest) {
unblockc := make(chan bool)
- ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, test.mode, HandlerFunc(func(w ResponseWriter, r *Request) {
<-unblockc
- })).ts
+ }))
defer close(unblockc)
- c := ts.Client()
+ c := cst.ts.Client()
- req, _ := NewRequest("GET", ts.URL, nil)
- if withCtx {
- ctx, cancel := context.WithCancel(context.Background())
- cancel()
- req = req.WithContext(ctx)
- } else {
- ch := make(chan struct{})
- req.Cancel = ch
- close(ch)
- }
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = test.newReq(req)
+ test.cancel(cst.tr, req)
_, err := c.Do(req)
- if ue, ok := err.(*url.Error); ok {
- err = ue.Err
- }
- if withCtx {
- if err != context.Canceled {
- t.Errorf("Do error = %v; want %v", err, context.Canceled)
- }
- } else {
- if err == nil || !strings.Contains(err.Error(), "canceled") {
- t.Errorf("Do error = %v; want cancellation", err)
- }
- }
+ test.checkErr("Do", err)
}
// Issue 11020. The returned error message should be errRequestCanceled
-func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
+func TestTransportCancelRequestBeforeResponseHeaders(t *testing.T) {
+ runCancelTest(t, testTransportCancelRequestBeforeResponseHeaders, []testMode{http1Mode})
+}
+func testTransportCancelRequestBeforeResponseHeaders(t *testing.T, test cancelTest) {
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
defer tr.CloseIdleConnections()
errc := make(chan error, 1)
req, _ := NewRequest("GET", "http://example.com/", nil)
+ req = test.newReq(req)
go func() {
_, err := tr.RoundTrip(req)
errc <- err
}
defer sc.Close()
- tr.CancelRequest(req)
+ test.cancel(tr, req)
err := <-errc
if err == nil {
t.Fatalf("unexpected success from RoundTrip")
}
- if err != ExportErrRequestCanceled {
- t.Errorf("RoundTrip error = %v; want ExportErrRequestCanceled", err)
- }
+ test.checkErr("RoundTrip", err)
}
// golang.org/issue/3672 -- Client can't close HTTP stream