]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix client goroutine leak with persistent connections
authorBrad Fitzpatrick <bradfitz@golang.org>
Tue, 14 Feb 2012 01:48:56 +0000 (12:48 +1100)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 14 Feb 2012 01:48:56 +0000 (12:48 +1100)
Thanks to Sascha Matzke & Florian Weimer for diagnosing.

R=golang-dev, adg, bradfitz, kevlar
CC=golang-dev
https://golang.org/cl/5656046

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index 510e55b0586f3347fd498cca0fb15895fee1171d..3e48abafb5e72e9be6cdbb1d5eae6204465a224c 100644 (file)
@@ -235,15 +235,19 @@ func (cm *connectMethod) proxyAuth() string {
        return ""
 }
 
-func (t *Transport) putIdleConn(pconn *persistConn) {
+// putIdleConn adds pconn to the list of idle persistent connections awaiting
+// a new request.
+// If pconn is no longer needed or not in a good state, putIdleConn
+// returns false.
+func (t *Transport) putIdleConn(pconn *persistConn) bool {
        t.lk.Lock()
        defer t.lk.Unlock()
        if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
                pconn.close()
-               return
+               return false
        }
        if pconn.isBroken() {
-               return
+               return false
        }
        key := pconn.cacheKey
        max := t.MaxIdleConnsPerHost
@@ -252,9 +256,10 @@ func (t *Transport) putIdleConn(pconn *persistConn) {
        }
        if len(t.idleConn[key]) >= max {
                pconn.close()
-               return
+               return false
        }
        t.idleConn[key] = append(t.idleConn[key], pconn)
+       return true
 }
 
 func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
@@ -565,7 +570,9 @@ func (pc *persistConn) readLoop() {
                                lastbody = resp.Body
                                waitForBodyRead = make(chan bool)
                                resp.Body.(*bodyEOFSignal).fn = func() {
-                                       pc.t.putIdleConn(pc)
+                                       if !pc.t.putIdleConn(pc) {
+                                               alive = false
+                                       }
                                        waitForBodyRead <- true
                                }
                        } else {
@@ -578,7 +585,9 @@ func (pc *persistConn) readLoop() {
                                // read it (even though it'll just be 0, EOF).
                                lastbody = nil
 
-                               pc.t.putIdleConn(pc)
+                               if !pc.t.putIdleConn(pc) {
+                                       alive = false
+                               }
                        }
                }
 
index ab67fa0ebcf7261b5de0b85e68220cd2690772dc..82e3882eb3c6edabc0770b25d0a7d54301ebf385 100644 (file)
@@ -16,6 +16,7 @@ import (
        . "net/http"
        "net/http/httptest"
        "net/url"
+       "runtime"
        "strconv"
        "strings"
        "testing"
@@ -632,6 +633,66 @@ func TestTransportGzipRecursive(t *testing.T) {
        }
 }
 
+// tests that persistent goroutine connections shut down when no longer desired.
+func TestTransportPersistConnLeak(t *testing.T) {
+       gotReqCh := make(chan bool)
+       unblockCh := make(chan bool)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               gotReqCh <- true
+               <-unblockCh
+               w.Header().Set("Content-Length", "0")
+               w.WriteHeader(204)
+       }))
+       defer ts.Close()
+
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+
+       n0 := runtime.Goroutines()
+
+       const numReq = 100
+       didReqCh := make(chan bool)
+       for i := 0; i < numReq; i++ {
+               go func() {
+                       c.Get(ts.URL)
+                       didReqCh <- true
+               }()
+       }
+
+       // Wait for all goroutines to be stuck in the Handler.
+       for i := 0; i < numReq; i++ {
+               <-gotReqCh
+       }
+
+       nhigh := runtime.Goroutines()
+
+       // Tell all handlers to unblock and reply.
+       for i := 0; i < numReq; i++ {
+               unblockCh <- true
+       }
+
+       // Wait for all HTTP clients to be done.
+       for i := 0; i < numReq; i++ {
+               <-didReqCh
+       }
+
+       time.Sleep(100 * time.Millisecond)
+       runtime.GC()
+       runtime.GC() // even more.
+       nfinal := runtime.Goroutines()
+
+       growth := nfinal - n0
+
+       // We expect 5 extra goroutines, empirically. That number is at least
+       // DefaultMaxIdleConnsPerHost * 2 (one reader goroutine, one writer),
+       // and something else.
+       expectedGoroutineGrowth := DefaultMaxIdleConnsPerHost*2 + 1
+
+       if int(growth) > expectedGoroutineGrowth*2 {
+               t.Errorf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth)
+       }
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {