func testClientTimeoutKillsConn_BeforeHeaders(t *testing.T, mode testMode) {
timeout := 1 * time.Millisecond
for {
- inHandler := make(chan net.Conn, 1)
- handlerReadReturned := make(chan bool, 1)
+ inHandler := make(chan bool)
+ cancelHandler := make(chan struct{})
+ handlerDone := make(chan bool)
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ <-r.Context().Done()
+
+ select {
+ case <-cancelHandler:
+ return
+ case inHandler <- true:
+ }
+ defer func() { handlerDone <- true }()
+
+ // Read from the conn until EOF to verify that it was correctly closed.
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
- inHandler <- conn
n, err := conn.Read([]byte{0})
if n != 0 || err != io.EOF {
t.Errorf("unexpected Read result: %v, %v", n, err)
}
- handlerReadReturned <- true
+ conn.Close()
}))
cst.c.Timeout = timeout
_, err := cst.c.Get(cst.ts.URL)
if err == nil {
+ close(cancelHandler)
t.Fatal("unexpected Get succeess")
}
- var c net.Conn
tooSlow := time.NewTimer(timeout * 10)
select {
case <-tooSlow.C:
// just slow and the Get failed in that time but never made it to the
// server. That's fine; we'll try again with a longer timout.
t.Logf("no handler seen in %v; retrying with longer timout", timeout)
- timeout *= 2
+ close(cancelHandler)
cst.close()
+ timeout *= 2
continue
- case c = <-inHandler:
+ case <-inHandler:
tooSlow.Stop()
+ <-handlerDone
}
- <-handlerReadReturned
- c.Close()
break
}
}
run(t, testClientTimeoutKillsConn_AfterHeaders, []testMode{http1Mode})
}
func testClientTimeoutKillsConn_AfterHeaders(t *testing.T, mode testMode) {
- inHandler := make(chan net.Conn, 1)
- handlerResult := make(chan error, 1)
+ inHandler := make(chan bool)
+ cancelHandler := make(chan struct{})
+ handlerDone := make(chan bool)
cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Length", "100")
w.(Flusher).Flush()
+
+ select {
+ case <-cancelHandler:
+ return
+ case inHandler <- true:
+ }
+ defer func() { handlerDone <- true }()
+
conn, _, err := w.(Hijacker).Hijack()
if err != nil {
t.Error(err)
return
}
conn.Write([]byte("foo"))
- inHandler <- conn
+
n, err := conn.Read([]byte{0})
// The error should be io.EOF or "read tcp
// 127.0.0.1:35827->127.0.0.1:40290: read: connection
// care that it returns at all. But if it returns with
// data, that's weird.
if n != 0 || err == nil {
- handlerResult <- fmt.Errorf("unexpected Read result: %v, %v", n, err)
- return
+ t.Errorf("unexpected Read result: %v, %v", n, err)
}
- handlerResult <- nil
+ conn.Close()
}))
// Set Timeout to something very long but non-zero to exercise
// the codepaths that check for it. But rather than wait for it to fire
// (which would make the test slow), we send on the req.Cancel channel instead,
// which happens to exercise the same code paths.
- cst.c.Timeout = time.Minute // just to be non-zero, not to hit it.
+ cst.c.Timeout = 24 * time.Hour // just to be non-zero, not to hit it.
req, _ := NewRequest("GET", cst.ts.URL, nil)
- cancel := make(chan struct{})
- req.Cancel = cancel
+ cancelReq := make(chan struct{})
+ req.Cancel = cancelReq
res, err := cst.c.Do(req)
if err != nil {
- select {
- case <-inHandler:
- t.Fatalf("Get error: %v", err)
- default:
- // Failed before entering handler. Ignore result.
- t.Skip("skipping test on slow builder")
- }
+ close(cancelHandler)
+ t.Fatalf("Get error: %v", err)
}
- close(cancel)
+ // Cancel the request while the handler is still blocked on sending to the
+ // inHandler channel. Then read it until it fails, to verify that the
+ // connection is broken before the handler itself closes it.
+ close(cancelReq)
got, err := io.ReadAll(res.Body)
if err == nil {
- t.Fatalf("unexpected success; read %q, nil", got)
+ t.Errorf("unexpected success; read %q, nil", got)
}
- c := <-inHandler
- if err := <-handlerResult; err != nil {
- t.Errorf("handler: %v", err)
- }
- c.Close()
+ // Now unblock the handler and wait for it to complete.
+ <-inHandler
+ <-handlerDone
}
func TestTransportResponseBodyWritableOnProtocolSwitch(t *testing.T) {