]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't reuse Transport connection unless Request.Write finished
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Apr 2014 04:50:24 +0000 (21:50 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 10 Apr 2014 04:50:24 +0000 (21:50 -0700)
In a typical HTTP request, the client writes the request, and
then the server replies. Go's HTTP client code (Transport) has
two goroutines per connection: one writing, and one reading. A
third goroutine (the one initiating the HTTP request)
coordinates with those two.

Because most HTTP requests are done when the server replies,
the Go code has always handled connection reuse purely in the
readLoop goroutine.

But if a client is writing a large request and the server
replies before it's consumed the entire request (e.g. it
replied with a 403 Forbidden and had no use for the body), it
was possible for Go to re-select that connection for a
subsequent request before we were done writing the first. That
wasn't actually a data race; the second HTTP request would
just get enqueued to write its request on the writeLoop. But
because the previous writeLoop didn't finish writing (and
might not ever), that connection is in a weird state. We
really just don't want to get into a state where we're
re-using a connection when the server spoke out of turn.

This CL changes the readLoop goroutine to verify that the
writeLoop finished before returning the connection.

In the process, it also fixes a potential goroutine leak where
a connection could close but the recycling logic could be
blocked forever waiting for the client to read to EOF or
error. Now it also selects on the persistConn's close channel,
and the closer of that is no longer the readLoop (which was
dead locking in some cases before). It's now closed at the
same place the underlying net.Conn is closed. This likely fixes
or helps Issue 7620.

Also addressed some small cosmetic things in the process.

Update #7620
Fixes #7569

LGTM=adg
R=golang-codereviews, adg
CC=dsymonds, golang-codereviews, rsc
https://golang.org/cl/86290043

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 3759b88fe0d8c0177c2be0c7b6aba243c6c83e87..de0ff9ce57ca4347d15c630f571ecab2225b2dd5 100644 (file)
@@ -230,9 +230,6 @@ func (t *Transport) CloseIdleConnections() {
        t.idleConn = nil
        t.idleConnCh = nil
        t.idleMu.Unlock()
-       if m == nil {
-               return
-       }
        for _, conns := range m {
                for _, pconn := range conns {
                        pconn.close()
@@ -498,12 +495,13 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
        pa := cm.proxyAuth()
 
        pconn := &persistConn{
-               t:        t,
-               cacheKey: cm.key(),
-               conn:     conn,
-               reqch:    make(chan requestAndChan, 50),
-               writech:  make(chan writeRequest, 50),
-               closech:  make(chan struct{}),
+               t:          t,
+               cacheKey:   cm.key(),
+               conn:       conn,
+               reqch:      make(chan requestAndChan, 1),
+               writech:    make(chan writeRequest, 1),
+               closech:    make(chan struct{}),
+               writeErrCh: make(chan error, 1),
        }
 
        switch {
@@ -727,8 +725,13 @@ type persistConn struct {
        bw       *bufio.Writer       // to conn
        reqch    chan requestAndChan // written by roundTrip; read by readLoop
        writech  chan writeRequest   // written by roundTrip; read by writeLoop
-       closech  chan struct{}       // broadcast close when readLoop (TCP connection) closes
+       closech  chan struct{}       // closed when conn closed
        isProxy  bool
+       // writeErrCh passes the request write error (usually nil)
+       // from the writeLoop goroutine to the readLoop which passes
+       // it off to the res.Body reader, which then uses it to decide
+       // whether or not a connection can be reused. Issue 7569.
+       writeErrCh chan error
 
        lk                   sync.Mutex // guards following 3 fields
        numExpectedResponses int
@@ -739,6 +742,7 @@ type persistConn struct {
        mutateHeaderFunc func(Header)
 }
 
+// isBroken reports whether this connection is in a known broken state.
 func (pc *persistConn) isBroken() bool {
        pc.lk.Lock()
        b := pc.broken
@@ -763,7 +767,6 @@ func remoteSideClosed(err error) bool {
 }
 
 func (pc *persistConn) readLoop() {
-       defer close(pc.closech)
        alive := true
 
        for alive {
@@ -838,27 +841,18 @@ func (pc *persistConn) readLoop() {
                                return nil
                        }
                        resp.Body.(*bodyEOFSignal).fn = func(err error) {
-                               alive1 := alive
-                               if err != nil {
-                                       alive1 = false
-                               }
-                               if alive1 && pc.sawEOF {
-                                       alive1 = false
-                               }
-                               if alive1 && !pc.t.putIdleConn(pc) {
-                                       alive1 = false
-                               }
-                               if !alive1 || pc.isBroken() {
-                                       pc.close()
-                               }
-                               waitForBodyRead <- alive1
+                               waitForBodyRead <- alive &&
+                                       err == nil &&
+                                       !pc.sawEOF &&
+                                       pc.wroteRequest() &&
+                                       pc.t.putIdleConn(pc)
                        }
                }
 
                if alive && !hasBody {
-                       if !pc.t.putIdleConn(pc) {
-                               alive = false
-                       }
+                       alive = !pc.sawEOF &&
+                               pc.wroteRequest() &&
+                               pc.t.putIdleConn(pc)
                }
 
                rc.ch <- responseAndError{resp, err}
@@ -866,7 +860,11 @@ func (pc *persistConn) readLoop() {
                // Wait for the just-returned response body to be fully consumed
                // before we race and peek on the underlying bufio reader.
                if waitForBodyRead != nil {
-                       alive = <-waitForBodyRead
+                       select {
+                       case alive = <-waitForBodyRead:
+                       case <-pc.closech:
+                               alive = false
+                       }
                }
 
                pc.t.setReqCanceler(rc.req, nil)
@@ -892,13 +890,42 @@ func (pc *persistConn) writeLoop() {
                        if err != nil {
                                pc.markBroken()
                        }
-                       wr.ch <- err
+                       pc.writeErrCh <- err // to the body reader, which might recycle us
+                       wr.ch <- err         // to the roundTrip function
                case <-pc.closech:
                        return
                }
        }
 }
 
+// wroteRequest is a check before recycling a connection that the previous write
+// (from writeLoop above) happened and was successful.
+func (pc *persistConn) wroteRequest() bool {
+       select {
+       case err := <-pc.writeErrCh:
+               // Common case: the write happened well before the response, so
+               // avoid creating a timer.
+               return err == nil
+       default:
+               // Rare case: the request was written in writeLoop above but
+               // before it could send to pc.writeErrCh, the reader read it
+               // all, processed it, and called us here. In this case, give the
+               // write goroutine a bit of time to finish its send.
+               //
+               // Less rare case: We also get here in the legitimate case of
+               // Issue 7569, where the writer is still writing (or stalled),
+               // but the server has already replied. In this case, we don't
+               // want to wait too long, and we want to return false so this
+               // connection isn't re-used.
+               select {
+               case err := <-pc.writeErrCh:
+                       return err == nil
+               case <-time.After(50 * time.Millisecond):
+                       return false
+               }
+       }
+}
+
 type responseAndError struct {
        res *Response
        err error
@@ -1046,6 +1073,7 @@ func (pc *persistConn) closeLocked() {
        if !pc.closed {
                pc.conn.Close()
                pc.closed = true
+               close(pc.closech)
        }
        pc.mutateHeaderFunc = nil
 }
index 0eb6e63b369f7318b40e4df20df0dc9d0a29b9c6..24466e53690ba9cef75a8b1fbae6282a9f929736 100644 (file)
@@ -57,21 +57,21 @@ func (c *testCloseConn) Close() error {
 // been closed.
 type testConnSet struct {
        t      *testing.T
+       mu     sync.Mutex // guards closed and list
        closed map[net.Conn]bool
        list   []net.Conn // in order created
-       mutex  sync.Mutex
 }
 
 func (tcs *testConnSet) insert(c net.Conn) {
-       tcs.mutex.Lock()
-       defer tcs.mutex.Unlock()
+       tcs.mu.Lock()
+       defer tcs.mu.Unlock()
        tcs.closed[c] = false
        tcs.list = append(tcs.list, c)
 }
 
 func (tcs *testConnSet) remove(c net.Conn) {
-       tcs.mutex.Lock()
-       defer tcs.mutex.Unlock()
+       tcs.mu.Lock()
+       defer tcs.mu.Unlock()
        tcs.closed[c] = true
 }
 
@@ -94,11 +94,19 @@ func makeTestDial(t *testing.T) (*testConnSet, func(n, addr string) (net.Conn, e
 }
 
 func (tcs *testConnSet) check(t *testing.T) {
-       tcs.mutex.Lock()
-       defer tcs.mutex.Unlock()
-
-       for i, c := range tcs.list {
-               if !tcs.closed[c] {
+       tcs.mu.Lock()
+       defer tcs.mu.Unlock()
+       for i := 4; i >= 0; i-- {
+               for i, c := range tcs.list {
+                       if tcs.closed[c] {
+                               continue
+                       }
+                       if i != 0 {
+                               tcs.mu.Unlock()
+                               time.Sleep(50 * time.Millisecond)
+                               tcs.mu.Lock()
+                               continue
+                       }
                        t.Errorf("TCP connection #%d, %p (of %d total) was not closed", i+1, c, len(tcs.list))
                }
        }
@@ -1905,6 +1913,111 @@ func TestTLSServerClosesConnection(t *testing.T) {
        }
 }
 
+// byteFromChanReader is an io.Reader that reads a single byte at a
+// time from the channel.  When the channel is closed, the reader
+// returns io.EOF.
+type byteFromChanReader chan byte
+
+func (c byteFromChanReader) Read(p []byte) (n int, err error) {
+       if len(p) == 0 {
+               return
+       }
+       b, ok := <-c
+       if !ok {
+               return 0, io.EOF
+       }
+       p[0] = b
+       return 1, nil
+}
+
+// Verifies that the Transport doesn't reuse a connection in the case
+// where the server replies before the request has been fully
+// written. We still honor that reply (see TestIssue3595), but don't
+// send future requests on the connection because it's then in a
+// questionable state.
+// golang.org/issue/7569
+func TestTransportNoReuseAfterEarlyResponse(t *testing.T) {
+       defer afterTest(t)
+       var sconn struct {
+               sync.Mutex
+               c net.Conn
+       }
+       var getOkay bool
+       closeConn := func() {
+               sconn.Lock()
+               defer sconn.Unlock()
+               if sconn.c != nil {
+                       sconn.c.Close()
+                       sconn.c = nil
+                       if !getOkay {
+                               t.Logf("Closed server connection")
+                       }
+               }
+       }
+       defer closeConn()
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.Method == "GET" {
+                       io.WriteString(w, "bar")
+                       return
+               }
+               conn, _, _ := w.(Hijacker).Hijack()
+               sconn.Lock()
+               sconn.c = conn
+               sconn.Unlock()
+               conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive
+               go io.Copy(ioutil.Discard, conn)
+       }))
+       defer ts.Close()
+       tr := &Transport{}
+       defer tr.CloseIdleConnections()
+       client := &Client{Transport: tr}
+
+       const bodySize = 256 << 10
+       finalBit := make(byteFromChanReader, 1)
+       req, _ := NewRequest("POST", ts.URL, io.MultiReader(io.LimitReader(neverEnding('x'), bodySize-1), finalBit))
+       req.ContentLength = bodySize
+       res, err := client.Do(req)
+       if err := wantBody(res, err, "foo"); err != nil {
+               t.Errorf("POST response: %v", err)
+       }
+       donec := make(chan bool)
+       go func() {
+               defer close(donec)
+               res, err = client.Get(ts.URL)
+               if err := wantBody(res, err, "bar"); err != nil {
+                       t.Errorf("GET response: %v", err)
+                       return
+               }
+               getOkay = true // suppress test noise
+       }()
+       time.AfterFunc(5*time.Second, closeConn)
+       select {
+       case <-donec:
+               finalBit <- 'x' // unblock the writeloop of the first Post
+               close(finalBit)
+       case <-time.After(7 * time.Second):
+               t.Fatal("timeout waiting for GET request to finish")
+       }
+}
+
+func wantBody(res *http.Response, err error, want string) error {
+       if err != nil {
+               return err
+       }
+       slurp, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               return fmt.Errorf("error reading body: %v", err)
+       }
+       if string(slurp) != want {
+               return fmt.Errorf("body = %q; want %q", slurp, want)
+       }
+       if err := res.Body.Close(); err != nil {
+               return fmt.Errorf("body Close = %v", err)
+       }
+       return nil
+}
+
 func newLocalListener(t *testing.T) net.Listener {
        ln, err := net.Listen("tcp", "127.0.0.1:0")
        if err != nil {