]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix Transport races & deadlocks
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 26 Nov 2012 21:31:02 +0000 (13:31 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 26 Nov 2012 21:31:02 +0000 (13:31 -0800)
Thanks to Dustin Sallings for exposing the most frustrating
bug ever, and for providing repro cases (which formed the
basis of the new tests in this CL), and to Dave Cheney and
Dmitry Vyukov for help debugging and fixing.

This CL depends on submited pollster CLs ffd1e075c260 (Unix)
and 14b544194509 (Windows), as well as unsubmitted 6852085.
Some operating systems (OpenBSD, NetBSD, ?) may still require
more pollster work, fixing races (Issue 4434 and
http://goo.gl/JXB6W).

Tested on linux-amd64 and darwin-amd64, both with GOMAXPROCS 1
and 4 (all combinations of which previously failed differently)

Fixes #4191
Update #4434 (related fallout from this bug)

R=dave, bradfitz, dsallings, rsc, fullung
CC=golang-dev
https://golang.org/cl/6851061

src/pkg/net/http/export_test.go
src/pkg/net/http/server.go
src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 313c6af7a8258c84fced2e274a7ecbc8a84b4607..a7a07852d18b7dac221a694e1be0991292b9f484 100644 (file)
@@ -7,7 +7,14 @@
 
 package http
 
-import "time"
+import (
+       "net"
+       "time"
+)
+
+func NewLoggingConn(baseName string, c net.Conn) net.Conn {
+       return newLoggingConn(baseName, c)
+}
 
 func (t *Transport) IdleConnKeysForTesting() (keys []string) {
        keys = make([]string, 0)
index 3a4d61c21394be0350e718325bba4e39f6749d7f..b50c03ed3a1e73c3bad687a3073241c8f00275f8 100644 (file)
@@ -170,16 +170,23 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
 // noLimit is an effective infinite upper bound for io.LimitedReader
 const noLimit int64 = (1 << 63) - 1
 
+// debugServerConnections controls whether all server connections are wrapped
+// with a verbose logging wrapper.
+const debugServerConnections = false
+
 // Create new connection from rwc.
 func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
        c = new(conn)
        c.remoteAddr = rwc.RemoteAddr().String()
        c.server = srv
        c.rwc = rwc
+       if debugServerConnections {
+               c.rwc = newLoggingConn("server", c.rwc)
+       }
        c.body = make([]byte, sniffLen)
-       c.lr = io.LimitReader(rwc, noLimit).(*io.LimitedReader)
+       c.lr = io.LimitReader(c.rwc, noLimit).(*io.LimitedReader)
        br := bufio.NewReader(c.lr)
-       bw := bufio.NewWriter(rwc)
+       bw := bufio.NewWriter(c.rwc)
        c.buf = bufio.NewReadWriter(br, bw)
        return c, nil
 }
@@ -495,7 +502,7 @@ func (w *response) Write(data []byte) (n int, err error) {
        // then there would be fewer chunk headers.
        // On the other hand, it would make hijacking more difficult.
        if w.chunking {
-               fmt.Fprintf(w.conn.buf, "%x\r\n", len(data)) // TODO(rsc): use strconv not fmt
+               fmt.Fprintf(w.conn.buf, "%x\r\n", len(data))
        }
        n, err = w.conn.buf.Write(data)
        if err == nil && w.chunking {
@@ -1309,3 +1316,45 @@ func (tw *timeoutWriter) WriteHeader(code int) {
        tw.mu.Unlock()
        tw.w.WriteHeader(code)
 }
+
+// loggingConn is used for debugging.
+type loggingConn struct {
+       name string
+       net.Conn
+}
+
+var (
+       uniqNameMu   sync.Mutex
+       uniqNameNext = make(map[string]int)
+)
+
+func newLoggingConn(baseName string, c net.Conn) net.Conn {
+       uniqNameMu.Lock()
+       defer uniqNameMu.Unlock()
+       uniqNameNext[baseName]++
+       return &loggingConn{
+               name: fmt.Sprintf("%s-%d", baseName, uniqNameNext[baseName]),
+               Conn: c,
+       }
+}
+
+func (c *loggingConn) Write(p []byte) (n int, err error) {
+       log.Printf("%s.Write(%d) = ....", c.name, len(p))
+       n, err = c.Conn.Write(p)
+       log.Printf("%s.Write(%d) = %d, %v", c.name, len(p), n, err)
+       return
+}
+
+func (c *loggingConn) Read(p []byte) (n int, err error) {
+       log.Printf("%s.Read(%d) = ....", c.name, len(p))
+       n, err = c.Conn.Read(p)
+       log.Printf("%s.Read(%d) = %d, %v", c.name, len(p), n, err)
+       return
+}
+
+func (c *loggingConn) Close() (err error) {
+       log.Printf("%s.Close() = ...", c.name)
+       err = c.Conn.Close()
+       log.Printf("%s.Close() = %v", c.name, err)
+       return
+}
index 3e90d7a1a4d0136c8ea8d371505bb1a7ed9d54fa..0aec1ae51b53cf02c1f0a4e25db324963218dc5f 100644 (file)
@@ -24,7 +24,6 @@ import (
        "os"
        "strings"
        "sync"
-       "sync/atomic"
        "time"
 )
 
@@ -613,14 +612,18 @@ func (pc *persistConn) readLoop() {
                if hasBody {
                        lastbody = resp.Body
                        waitForBodyRead = make(chan bool, 1)
-                       resp.Body.(*bodyEOFSignal).fn = func() {
-                               if alive && !pc.t.putIdleConn(pc) {
-                                       alive = false
+                       resp.Body.(*bodyEOFSignal).fn = func(err error) {
+                               alive1 := alive
+                               if err != nil {
+                                       alive1 = false
                                }
-                               if !alive || pc.isBroken() {
+                               if alive1 && !pc.t.putIdleConn(pc) {
+                                       alive1 = false
+                               }
+                               if !alive1 || pc.isBroken() {
                                        pc.close()
                                }
-                               waitForBodyRead <- true
+                               waitForBodyRead <- alive1
                        }
                }
 
@@ -644,7 +647,7 @@ func (pc *persistConn) readLoop() {
                // Wait for the just-returned response body to be fully consumed
                // before we race and peek on the underlying bufio reader.
                if waitForBodyRead != nil {
-                       <-waitForBodyRead
+                       alive = <-waitForBodyRead
                }
 
                if !alive {
@@ -810,50 +813,61 @@ func canonicalAddr(url *url.URL) string {
 }
 
 // bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most
-// once, right before the final Read() or Close() call returns, but after
-// EOF has been seen.
+// once, right before its final (error-producing) Read or Close call
+// returns.
 type bodyEOFSignal struct {
-       body     io.ReadCloser
-       fn       func()
-       isClosed uint32 // atomic bool, non-zero if true
-       once     sync.Once
+       body   io.ReadCloser
+       mu     sync.Mutex  // guards closed, rerr and fn
+       closed bool        // whether Close has been called
+       rerr   error       // sticky Read error
+       fn     func(error) // error will be nil on Read io.EOF
 }
 
 func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
-       n, err = es.body.Read(p)
-       if es.closed() && n > 0 {
-               panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
+       es.mu.Lock()
+       closed, rerr := es.closed, es.rerr
+       es.mu.Unlock()
+       if closed {
+               return 0, errors.New("http: read on closed response body")
        }
-       if err == io.EOF {
-               es.condfn()
+       if rerr != nil {
+               return 0, rerr
        }
-       return
-}
 
-func (es *bodyEOFSignal) Close() (err error) {
-       if !es.setClosed() {
-               // already closed
-               return nil
-       }
-       err = es.body.Close()
-       if err == nil {
-               es.condfn()
+       n, err = es.body.Read(p)
+       if err != nil {
+               es.mu.Lock()
+               defer es.mu.Unlock()
+               if es.rerr == nil {
+                       es.rerr = err
+               }
+               es.condfn(err)
        }
        return
 }
 
-func (es *bodyEOFSignal) condfn() {
-       if es.fn != nil {
-               es.once.Do(es.fn)
+func (es *bodyEOFSignal) Close() error {
+       es.mu.Lock()
+       defer es.mu.Unlock()
+       if es.closed {
+               return nil
        }
+       es.closed = true
+       err := es.body.Close()
+       es.condfn(err)
+       return err
 }
 
-func (es *bodyEOFSignal) closed() bool {
-       return atomic.LoadUint32(&es.isClosed) != 0
-}
-
-func (es *bodyEOFSignal) setClosed() bool {
-       return atomic.CompareAndSwapUint32(&es.isClosed, 0, 1)
+// caller must hold es.mu.
+func (es *bodyEOFSignal) condfn(err error) {
+       if es.fn == nil {
+               return
+       }
+       if err == io.EOF {
+               err = nil
+       }
+       es.fn(err)
+       es.fn = nil
 }
 
 type readFirstCloseBoth struct {
index e114e71480d128a38c7177e917c1995959a97976..a594fa81d9e144dfbcd2b7db6eb90b6a1882cd50 100644 (file)
@@ -901,6 +901,111 @@ func TestTransportConcurrency(t *testing.T) {
        wg.Wait()
 }
 
+func TestIssue4191_InfiniteGetTimeout(t *testing.T) {
+       const debug = false
+       mux := NewServeMux()
+       mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+               io.Copy(w, neverEnding('a'))
+       })
+       ts := httptest.NewServer(mux)
+
+       client := &Client{
+               Transport: &Transport{
+                       Dial: func(n, addr string) (net.Conn, error) {
+                               conn, err := net.Dial(n, addr)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
+                               if debug {
+                                       conn = NewLoggingConn("client", conn)
+                               }
+                               return conn, nil
+                       },
+                       DisableKeepAlives: true,
+               },
+       }
+
+       nRuns := 5
+       if testing.Short() {
+               nRuns = 1
+       }
+       for i := 0; i < nRuns; i++ {
+               if debug {
+                       println("run", i+1, "of", nRuns)
+               }
+               sres, err := client.Get(ts.URL + "/get")
+               if err != nil {
+                       t.Errorf("Error issuing GET: %v", err)
+                       break
+               }
+               _, err = io.Copy(ioutil.Discard, sres.Body)
+               if err == nil {
+                       t.Errorf("Unexpected successful copy")
+                       break
+               }
+       }
+       if debug {
+               println("tests complete; waiting for handlers to finish")
+       }
+       ts.Close()
+}
+
+func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
+       const debug = false
+       mux := NewServeMux()
+       mux.HandleFunc("/get", func(w ResponseWriter, r *Request) {
+               io.Copy(w, neverEnding('a'))
+       })
+       mux.HandleFunc("/put", func(w ResponseWriter, r *Request) {
+               defer r.Body.Close()
+               io.Copy(ioutil.Discard, r.Body)
+       })
+       ts := httptest.NewServer(mux)
+
+       client := &Client{
+               Transport: &Transport{
+                       Dial: func(n, addr string) (net.Conn, error) {
+                               conn, err := net.Dial(n, addr)
+                               if err != nil {
+                                       return nil, err
+                               }
+                               conn.SetDeadline(time.Now().Add(100 * time.Millisecond))
+                               if debug {
+                                       conn = NewLoggingConn("client", conn)
+                               }
+                               return conn, nil
+                       },
+                       DisableKeepAlives: true,
+               },
+       }
+
+       nRuns := 5
+       if testing.Short() {
+               nRuns = 1
+       }
+       for i := 0; i < nRuns; i++ {
+               if debug {
+                       println("run", i+1, "of", nRuns)
+               }
+               sres, err := client.Get(ts.URL + "/get")
+               if err != nil {
+                       t.Errorf("Error issuing GET: %v", err)
+                       break
+               }
+               req, _ := NewRequest("PUT", ts.URL+"/put", sres.Body)
+               _, err = client.Do(req)
+               if err == nil {
+                       t.Errorf("Unexpected successful PUT")
+                       break
+               }
+       }
+       if debug {
+               println("tests complete; waiting for handlers to finish")
+       }
+       ts.Close()
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {