]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix Transport race(s) with high GOMAXPROCS
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 11 Jul 2012 23:40:44 +0000 (16:40 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 11 Jul 2012 23:40:44 +0000 (16:40 -0700)
Also adds a new test for GOMAXPROCS=16 explicitly, which now passes
reliably in a stress loop like:

$ go test -c
$ (while ./http.test -test.v -test.run=Concurrency; do echo pass; done ) 2>&1 | tee foo; less foo

(It used to fail very quickly and reliably on at least Linux/amd64)

Fixes #3793

R=golang-dev, adg, r
CC=golang-dev
https://golang.org/cl/6347061

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index e0e2856477ee7eeafb966172bdfa7fea94a7d006..746de4061d78fb6e94f7faf750c0481ff8089a8d 100644 (file)
@@ -24,6 +24,7 @@ import (
        "os"
        "strings"
        "sync"
+       "time"
 )
 
 // DefaultTransport is the default implementation of Transport and is
@@ -260,6 +261,11 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
                pconn.close()
                return false
        }
+       for _, exist := range t.idleConn[key] {
+               if exist == pconn {
+                       log.Fatalf("dup idle pconn %p in freelist", pconn)
+               }
+       }
        t.idleConn[key] = append(t.idleConn[key], pconn)
        return true
 }
@@ -289,7 +295,7 @@ func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
                        return
                }
        }
-       return
+       panic("unreachable")
 }
 
 func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
@@ -324,6 +330,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, error) {
                conn:     conn,
                reqch:    make(chan requestAndChan, 50),
                writech:  make(chan writeRequest, 50),
+               closech:  make(chan struct{}),
        }
 
        switch {
@@ -491,6 +498,7 @@ type persistConn struct {
        bw       *bufio.Writer       // to conn
        reqch    chan requestAndChan // written by roundTrip; read by readLoop
        writech  chan writeRequest   // written by roundTrip; read by writeLoop
+       closech  chan struct{}       // broadcast close when readLoop (TCP connection) closes
        isProxy  bool
 
        // mutateHeaderFunc is an optional func to modify extra
@@ -522,6 +530,7 @@ func remoteSideClosed(err error) bool {
 }
 
 func (pc *persistConn) readLoop() {
+       defer close(pc.closech)
        defer close(pc.writech)
        alive := true
        var lastbody io.ReadCloser // last response body, if any, read on this connection
@@ -549,7 +558,11 @@ func (pc *persistConn) readLoop() {
                        lastbody.Close() // assumed idempotent
                        lastbody = nil
                }
-               resp, err := ReadResponse(pc.br, rc.req)
+
+               var resp *Response
+               if err == nil {
+                       resp, err = ReadResponse(pc.br, rc.req)
+               }
 
                if err != nil {
                        pc.close()
@@ -578,7 +591,7 @@ func (pc *persistConn) readLoop() {
                var waitForBodyRead chan bool
                if hasBody {
                        lastbody = resp.Body
-                       waitForBodyRead = make(chan bool)
+                       waitForBodyRead = make(chan bool, 1)
                        resp.Body.(*bodyEOFSignal).fn = func() {
                                if alive && !pc.t.putIdleConn(pc) {
                                        alive = false
@@ -692,6 +705,8 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
 
        var re responseAndError
+       var pconnDeadCh = pc.closech
+       var failTicker <-chan time.Time
 WaitResponse:
        for {
                select {
@@ -700,6 +715,24 @@ WaitResponse:
                                re = responseAndError{nil, err}
                                break WaitResponse
                        }
+               case <-pconnDeadCh:
+                       // The persist connection is dead. This shouldn't
+                       // usually happen (only with Connection: close responses
+                       // with no response bodies), but if it does happen it
+                       // means either a) the remote server hung up on us
+                       // prematurely, or b) the readLoop sent us a response &
+                       // closed its closech at roughly the same time, and we
+                       // selected this case first, in which case a response
+                       // might still be coming soon.
+                       //
+                       // We can't avoid the select race in b) by using a unbuffered
+                       // resc channel instead, because then goroutines can
+                       // leak if we exit due to other errors.
+                       pconnDeadCh = nil                               // avoid spinning
+                       failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
+               case <-failTicker:
+                       re = responseAndError{nil, errors.New("net/http: transport closed before response was received")}
+                       break WaitResponse
                case re = <-resc:
                        break WaitResponse
                }
@@ -762,6 +795,7 @@ type bodyEOFSignal struct {
        body     io.ReadCloser
        fn       func()
        isClosed bool
+       once     sync.Once
 }
 
 func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
@@ -769,9 +803,8 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err error) {
        if es.isClosed && n > 0 {
                panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725")
        }
-       if err == io.EOF && es.fn != nil {
-               es.fn()
-               es.fn = nil
+       if err == io.EOF {
+               es.condfn()
        }
        return
 }
@@ -782,13 +815,18 @@ func (es *bodyEOFSignal) Close() (err error) {
        }
        es.isClosed = true
        err = es.body.Close()
-       if err == nil && es.fn != nil {
-               es.fn()
-               es.fn = nil
+       if err == nil {
+               es.condfn()
        }
        return
 }
 
+func (es *bodyEOFSignal) condfn() {
+       if es.fn != nil {
+               es.once.Do(es.fn)
+       }
+}
+
 type readFirstCloseBoth struct {
        io.ReadCloser
        io.Closer
index c377eff5d1bd90c150806951c90a34bf45d2fa81..9cf292cee49880a8401ff8560f98de2f95ef2f37 100644 (file)
@@ -857,6 +857,50 @@ func TestIssue3595(t *testing.T) {
        }
 }
 
+func TestTransportConcurrency(t *testing.T) {
+       const maxProcs = 16
+       const numReqs = 500
+       defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               fmt.Fprintf(w, "%v", r.FormValue("echo"))
+       }))
+       defer ts.Close()
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+       reqs := make(chan string)
+       defer close(reqs)
+
+       var wg sync.WaitGroup
+       wg.Add(numReqs)
+       for i := 0; i < maxProcs*2; i++ {
+               go func() {
+                       for req := range reqs {
+                               res, err := c.Get(ts.URL + "/?echo=" + req)
+                               if err != nil {
+                                       t.Errorf("error on req %s: %v", req, err)
+                                       wg.Done()
+                                       continue
+                               }
+                               all, err := ioutil.ReadAll(res.Body)
+                               if err != nil {
+                                       t.Errorf("read error on req %s: %v", req, err)
+                                       wg.Done()
+                                       continue
+                               }
+                               if string(all) != req {
+                                       t.Errorf("body of req %s = %q; want %q", req, all, req)
+                               }
+                               wg.Done()
+                               res.Body.Close()
+                       }
+               }()
+       }
+       for i := 0; i < numReqs; i++ {
+               reqs <- fmt.Sprintf("request-%d", i)
+       }
+       wg.Wait()
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {