]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: clean up Transport.RoundTrip error handling
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 27 Feb 2017 05:41:50 +0000 (05:41 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 2 Mar 2017 22:06:09 +0000 (22:06 +0000)
If I put a 10 millisecond sleep at testHookWaitResLoop, before the big
select in (*persistConn).roundTrip, two flakes immediately started
happening, TestTransportBodyReadError (#19231) and
TestTransportPersistConnReadLoopEOF.

The problem was that there are many ways for a RoundTrip call to fail
(errors reading from Request.Body while writing the response, errors
writing the response, errors reading the response due to server
closes, errors due to servers sending malformed responses,
cancelations, timeouts, etc.), and many of those failures then tear
down the TCP connection, causing more failures, since there are always
at least three goroutines involved (reading, writing, RoundTripping).

Because the errors were communicated over buffered channels to a giant
select, the error returned to the caller was a function of which
random select case was called, which was why a 10ms delay before the
select brought out so many bugs. (several fixed in my previous CLs the past
few days).

Instead, track the error explicitly in the transportRequest, guarded
by a mutex.

In addition, this CL now:

* differentiates between the two ways writing a request can fail: the
  io.Copy reading from the Request.Body or the io.Copy writing to the
  network. A new io.Reader type notes read errors from the
  Request.Body. The read-from-body vs write-to-network errors are now
  prioritized differently.

* unifies the two mapRoundTripErrorFromXXX methods into one
  mapRoundTripError method since their logic is now the same.

* adds a (*Request).WithT(*testing.T) method in export_test.go, usable
  by tests, to call t.Logf at points during RoundTrip. This is disabled
  behind a constant except when debugging.

* documents and deflakes TestClientRedirectContext

I've tested this CL with high -count values, with/without -race,
with/without delays before the select, etc. So far it seems robust.

Fixes #19231 (TestTransportBodyReadError flake)
Updates #14203 (source of errors unclear; they're now tracked more)
Updates #15935 (document Transport errors more; at least understood more now)

Change-Id: I3cccc3607f369724b5344763e35ad2b7ea415738
Reviewed-on: https://go-review.googlesource.com/37495
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Josh Bleecher Snyder <josharian@gmail.com>
src/net/http/client_test.go
src/net/http/export_test.go
src/net/http/request.go
src/net/http/transfer.go
src/net/http/transport.go
src/net/http/transport_internal_test.go
src/net/http/transport_test.go

index 534986e86753aed3c74e8e0db092c924f5ec8a32..c75456ae5378dc894ebe1b5c83421ca9d8029895 100644 (file)
@@ -304,6 +304,7 @@ func TestClientRedirects(t *testing.T) {
        }
 }
 
+// Tests that Client redirects' contexts are derived from the original request's context.
 func TestClientRedirectContext(t *testing.T) {
        setParallel(t)
        defer afterTest(t)
@@ -320,10 +321,12 @@ func TestClientRedirectContext(t *testing.T) {
                Transport: tr,
                CheckRedirect: func(req *Request, via []*Request) error {
                        cancel()
-                       if len(via) > 2 {
-                               return errors.New("too many redirects")
+                       select {
+                       case <-req.Context().Done():
+                               return nil
+                       case <-time.After(5 * time.Second):
+                               return errors.New("redirected request's context never expired after root request canceled")
                        }
-                       return nil
                },
        }
        req, _ := NewRequest("GET", ts.URL, nil)
@@ -1818,6 +1821,7 @@ func TestTransportBodyReadError(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       req = req.WithT(t)
        _, err = tr.RoundTrip(req)
        if err != someErr {
                t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr)
index b61f58b2db4317f317a923b682177ca5aa0efb2f..596171f5f034b7df2d72876340994bcd8ab79e3b 100644 (file)
@@ -8,9 +8,11 @@
 package http
 
 import (
+       "context"
        "net"
        "sort"
        "sync"
+       "testing"
        "time"
 )
 
@@ -199,3 +201,7 @@ func (s *Server) ExportAllConnsIdle() bool {
        }
        return true
 }
+
+func (r *Request) WithT(t *testing.T) *Request {
+       return r.WithContext(context.WithValue(r.Context(), tLogKey{}, t.Logf))
+}
index 168c03e86c3786559b642043bbdd7623919bda37..09d998dacfbfc3ad7e218bd0b783e39b3b0efe47 100644 (file)
@@ -621,6 +621,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
        // Write body and trailer
        err = tw.WriteBody(w)
        if err != nil {
+               if tw.bodyReadError == err {
+                       err = requestBodyReadError{err}
+               }
                return err
        }
 
@@ -630,6 +633,11 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
        return nil
 }
 
+// requestBodyReadError wraps an error from (*Request).write to indicate
+// that the error came from a Read call on the Request.Body.
+// This error type should not escape the net/http package to users.
+type requestBodyReadError struct{ error }
+
 func idnaASCII(v string) (string, error) {
        if isASCII(v) {
                return v, nil
index 4f47637aa76a6af7005a885f020cfec5484320ae..2a021154c95e1986434cd9b74da5c9738766d437 100644 (file)
@@ -51,6 +51,19 @@ func (br *byteReader) Read(p []byte) (n int, err error) {
        return 1, io.EOF
 }
 
+// transferBodyReader is an io.Reader that reads from tw.Body
+// and records any non-EOF error in tw.bodyReadError.
+// It is exactly 1 pointer wide to avoid allocations into interfaces.
+type transferBodyReader struct{ tw *transferWriter }
+
+func (br transferBodyReader) Read(p []byte) (n int, err error) {
+       n, err = br.tw.Body.Read(p)
+       if err != nil && err != io.EOF {
+               br.tw.bodyReadError = err
+       }
+       return
+}
+
 // transferWriter inspects the fields of a user-supplied Request or Response,
 // sanitizes them without changing the user object and provides methods for
 // writing the respective header, body and trailer in wire format.
@@ -64,6 +77,7 @@ type transferWriter struct {
        TransferEncoding []string
        Trailer          Header
        IsResponse       bool
+       bodyReadError    error // any non-EOF error from reading Body
 
        FlushHeaders bool            // flush headers to network before body
        ByteReadCh   chan readResult // non-nil if probeRequestBody called
@@ -304,24 +318,25 @@ func (t *transferWriter) WriteBody(w io.Writer) error {
 
        // Write body
        if t.Body != nil {
+               var body = transferBodyReader{t}
                if chunked(t.TransferEncoding) {
                        if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
                                w = &internal.FlushAfterChunkWriter{Writer: bw}
                        }
                        cw := internal.NewChunkedWriter(w)
-                       _, err = io.Copy(cw, t.Body)
+                       _, err = io.Copy(cw, body)
                        if err == nil {
                                err = cw.Close()
                        }
                } else if t.ContentLength == -1 {
-                       ncopy, err = io.Copy(w, t.Body)
+                       ncopy, err = io.Copy(w, body)
                } else {
-                       ncopy, err = io.Copy(w, io.LimitReader(t.Body, t.ContentLength))
+                       ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength))
                        if err != nil {
                                return err
                        }
                        var nextra int64
-                       nextra, err = io.Copy(ioutil.Discard, t.Body)
+                       nextra, err = io.Copy(ioutil.Discard, body)
                        ncopy += nextra
                }
                if err != nil {
index 2aa00de50a2a0621c18c01a0806e1111ffc9e5e2..0d4f427a57e860e28370bb7ff2a54a1050b373b3 100644 (file)
@@ -303,11 +303,15 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
 }
 
 // transportRequest is a wrapper around a *Request that adds
-// optional extra headers to write.
+// optional extra headers to write and stores any error to return
+// from roundTrip.
 type transportRequest struct {
        *Request                        // original request, not to be mutated
        extra    Header                 // extra headers to write, or nil
        trace    *httptrace.ClientTrace // optional
+
+       mu  sync.Mutex // guards err
+       err error      // first setError value for mapRoundTripError to consider
 }
 
 func (tr *transportRequest) extraHeaders() Header {
@@ -317,6 +321,14 @@ func (tr *transportRequest) extraHeaders() Header {
        return tr.extra
 }
 
+func (tr *transportRequest) setError(err error) {
+       tr.mu.Lock()
+       if tr.err == nil {
+               tr.err = err
+       }
+       tr.mu.Unlock()
+}
+
 // RoundTrip implements the RoundTripper interface.
 //
 // For higher-level HTTP client support (such as handling of cookies
@@ -1420,22 +1432,41 @@ func (pc *persistConn) closeConnIfStillIdle() {
        pc.close(errIdleConnTimeout)
 }
 
-// mapRoundTripErrorFromReadLoop maps the provided readLoop error into
-// the error value that should be returned from persistConn.roundTrip.
+// mapRoundTripError returns the appropriate error value for
+// persistConn.roundTrip.
+//
+// The provided err is the first error that (*persistConn).roundTrip
+// happened to receive from its select statement.
 //
 // The startBytesWritten value should be the value of pc.nwrite before the roundTrip
 // started writing the request.
-func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) {
+func (pc *persistConn) mapRoundTripError(req *transportRequest, startBytesWritten int64, err error) error {
        if err == nil {
                return nil
        }
-       if err := pc.canceled(); err != nil {
-               return err
+
+       // If the request was canceled, that's better than network
+       // failures that were likely the result of tearing down the
+       // connection.
+       if cerr := pc.canceled(); cerr != nil {
+               return cerr
+       }
+
+       // See if an error was set explicitly.
+       req.mu.Lock()
+       reqErr := req.err
+       req.mu.Unlock()
+       if reqErr != nil {
+               return reqErr
        }
+
        if err == errServerClosedIdle {
+               // Don't decorate
                return err
        }
+
        if _, ok := err.(transportReadFromServerError); ok {
+               // Don't decorate
                return err
        }
        if pc.isBroken() {
@@ -1443,40 +1474,11 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWri
                if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
                        return nothingWrittenError{err}
                }
+               return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err)
        }
        return err
 }
 
-// mapRoundTripErrorAfterClosed returns the error value to be propagated
-// up to Transport.RoundTrip method when persistConn.roundTrip sees
-// its pc.closech channel close, indicating the persistConn is dead.
-// (after closech is closed, pc.closed is valid).
-func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error {
-       if err := pc.canceled(); err != nil {
-               return err
-       }
-       err := pc.closed
-       if err == errServerClosedIdle {
-               // Don't decorate
-               return err
-       }
-       if _, ok := err.(transportReadFromServerError); ok {
-               // Don't decorate
-               return err
-       }
-
-       // Wait for the writeLoop goroutine to terminated, and then
-       // see if we actually managed to write anything. If not, we
-       // can retry the request.
-       <-pc.writeLoopDone
-       if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 {
-               return nothingWrittenError{err}
-       }
-
-       return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err)
-
-}
-
 func (pc *persistConn) readLoop() {
        closeErr := errReadLoopExiting // default value, if not changed below
        defer func() {
@@ -1746,6 +1748,17 @@ func (pc *persistConn) writeLoop() {
                case wr := <-pc.writech:
                        startBytesWritten := pc.nwrite
                        err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
+                       if bre, ok := err.(requestBodyReadError); ok {
+                               err = bre.error
+                               // Errors reading from the user's
+                               // Request.Body are high priority.
+                               // Set it here before sending on the
+                               // channels below or calling
+                               // pc.close() which tears town
+                               // connections and causes other
+                               // errors.
+                               wr.req.setError(err)
+                       }
                        if err == nil {
                                err = pc.bw.Flush()
                        }
@@ -1913,6 +1926,14 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        gone := make(chan struct{})
        defer close(gone)
 
+       defer func() {
+               if err != nil {
+                       pc.t.setReqCanceler(req.Request, nil)
+               }
+       }()
+
+       const debugRoundTrip = false
+
        // Write the request concurrently with waiting for a response,
        // in case the server decides to reply before reading our full
        // request body.
@@ -1929,38 +1950,50 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
                callerGone: gone,
        }
 
-       var re responseAndError
        var respHeaderTimer <-chan time.Time
        cancelChan := req.Request.Cancel
        ctxDoneChan := req.Context().Done()
-WaitResponse:
        for {
                testHookWaitResLoop()
                select {
                case err := <-writeErrCh:
+                       if debugRoundTrip {
+                               req.logf("writeErrCh resv: %T/%#v", err, err)
+                       }
                        if err != nil {
-                               if cerr := pc.canceled(); cerr != nil {
-                                       err = cerr
-                               }
-                               re = responseAndError{err: err}
                                pc.close(fmt.Errorf("write error: %v", err))
-                               break WaitResponse
+                               return nil, pc.mapRoundTripError(req, startBytesWritten, err)
                        }
                        if d := pc.t.ResponseHeaderTimeout; d > 0 {
+                               if debugRoundTrip {
+                                       req.logf("starting timer for %v", d)
+                               }
                                timer := time.NewTimer(d)
                                defer timer.Stop() // prevent leaks
                                respHeaderTimer = timer.C
                        }
                case <-pc.closech:
-                       re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)}
-                       break WaitResponse
+                       if debugRoundTrip {
+                               req.logf("closech recv: %T %#v", pc.closed, pc.closed)
+                       }
+                       return nil, pc.mapRoundTripError(req, startBytesWritten, pc.closed)
                case <-respHeaderTimer:
+                       if debugRoundTrip {
+                               req.logf("timeout waiting for response headers.")
+                       }
                        pc.close(errTimeout)
-                       re = responseAndError{err: errTimeout}
-                       break WaitResponse
-               case re = <-resc:
-                       re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err)
-                       break WaitResponse
+                       return nil, errTimeout
+               case re := <-resc:
+                       if (re.res == nil) == (re.err == nil) {
+                               panic(fmt.Sprintf("internal error: exactly one of res or err should be set; nil=%v", re.res == nil))
+                       }
+                       if debugRoundTrip {
+                               req.logf("resc recv: %p, %T/%#v", re.res, re.err, re.err)
+                       }
+                       if re.err != nil {
+                               return nil, pc.mapRoundTripError(req, startBytesWritten, re.err)
+                       }
+                       return re.res, nil
                case <-cancelChan:
                        pc.t.CancelRequest(req.Request)
                        cancelChan = nil
@@ -1970,14 +2003,16 @@ WaitResponse:
                        ctxDoneChan = nil
                }
        }
+}
 
-       if re.err != nil {
-               pc.t.setReqCanceler(req.Request, nil)
-       }
-       if (re.res == nil) == (re.err == nil) {
-               panic("internal error: exactly one of res or err should be set")
+// tLogKey is a context WithValue key for test debugging contexts containing
+// a t.Logf func. See export_test.go's Request.WithT method.
+type tLogKey struct{}
+
+func (r *transportRequest) logf(format string, args ...interface{}) {
+       if logf, ok := r.Request.Context().Value(tLogKey{}).(func(string, ...interface{})); ok {
+               logf(time.Now().Format(time.RFC3339Nano)+": "+format, args...)
        }
-       return re.res, re.err
 }
 
 // markReused marks this connection as having been successfully used for a
index 3d24fc127d45e315e031d2231123aa2e277eac39..262d8b4ac51030274c8b28a83d54accee1030601 100644 (file)
@@ -30,6 +30,7 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
 
        tr := new(Transport)
        req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil)
+       req = req.WithT(t)
        treq := &transportRequest{Request: req}
        cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()}
        pc, err := tr.getConn(treq, cm)
@@ -47,13 +48,13 @@ func TestTransportPersistConnReadLoopEOF(t *testing.T) {
 
        _, err = pc.roundTrip(treq)
        if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
-               t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
+               t.Errorf("roundTrip = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
        }
 
        <-pc.closech
        err = pc.closed
        if !isTransportReadFromServerError(err) && err != errServerClosedIdle {
-               t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err)
+               t.Errorf("pc.closed = %#v, %v; want errServerClosedIdle or transportReadFromServerError", err, err)
        }
 }
 
index ce98157ed5b949ba721f3c53b7615800fdc53f83..cb315f14f464e32c6a91d8c742172a1559e4975c 100644 (file)
@@ -1625,7 +1625,9 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
                {path: "/fast", want: 200},
        }
        for i, tt := range tests {
-               res, err := c.Get(ts.URL + tt.path)
+               req, _ := NewRequest("GET", ts.URL+tt.path, nil)
+               req = req.WithT(t)
+               res, err := c.Do(req)
                select {
                case <-inHandler:
                case <-time.After(5 * time.Second):