]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.13] net/http: fix wantConnQueue memory leaks in Transport
authorBryan C. Mills <bcmills@google.com>
Mon, 26 Aug 2019 20:44:36 +0000 (16:44 -0400)
committerBryan C. Mills <bcmills@google.com>
Tue, 27 Aug 2019 18:37:25 +0000 (18:37 +0000)
I'm trying to keep the code changes minimal for backporting to Go 1.13,
so it is still possible for a handful of entries to leak,
but the leaks are now O(1) instead of O(N) in the steady state.

Longer-term, I think it would be a good idea to coalesce idleMu with
connsPerHostMu and clear entries out of both queues as soon as their
goroutines are done waiting.

Cherry-picked from CL 191964.

Updates #33849
Updates #33850
Fixes #33878

Change-Id: Ia66bc64671eb1014369f2d3a01debfc023b44281
Reviewed-on: https://go-review.googlesource.com/c/go/+/191964
Run-TryBot: Bryan C. Mills <bcmills@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
(cherry picked from commit 94bf9a8d4ad479e5a9dd57b3cb8e682e841d58d4)
Reviewed-on: https://go-review.googlesource.com/c/go/+/191967

src/net/http/transport.go
src/net/http/transport_test.go

index f9d9f4451cdb635d83afa7894ac74baeb50e4e9d..ee279877e02e2cede7ed3e2016ce518642f89966 100644 (file)
@@ -953,6 +953,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
                t.idleConnWait = make(map[connectMethodKey]wantConnQueue)
        }
        q := t.idleConnWait[w.key]
+       q.cleanFront()
        q.pushBack(w)
        t.idleConnWait[w.key] = q
        return false
@@ -1137,7 +1138,7 @@ func (q *wantConnQueue) pushBack(w *wantConn) {
        q.tail = append(q.tail, w)
 }
 
-// popFront removes and returns the w at the front of the queue.
+// popFront removes and returns the wantConn at the front of the queue.
 func (q *wantConnQueue) popFront() *wantConn {
        if q.headPos >= len(q.head) {
                if len(q.tail) == 0 {
@@ -1152,6 +1153,30 @@ func (q *wantConnQueue) popFront() *wantConn {
        return w
 }
 
+// peekFront returns the wantConn at the front of the queue without removing it.
+func (q *wantConnQueue) peekFront() *wantConn {
+       if q.headPos < len(q.head) {
+               return q.head[q.headPos]
+       }
+       if len(q.tail) > 0 {
+               return q.tail[0]
+       }
+       return nil
+}
+
+// cleanFront pops any wantConns that are no longer waiting from the head of the
+// queue, reporting whether any were popped.
+func (q *wantConnQueue) cleanFront() (cleaned bool) {
+       for {
+               w := q.peekFront()
+               if w == nil || w.waiting() {
+                       return cleaned
+               }
+               q.popFront()
+               cleaned = true
+       }
+}
+
 // getConn dials and creates a new persistConn to the target as
 // specified in the connectMethod. This includes doing a proxy CONNECT
 // and/or setting up TLS.  If this doesn't return an error, the persistConn
@@ -1261,6 +1286,7 @@ func (t *Transport) queueForDial(w *wantConn) {
                t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue)
        }
        q := t.connsPerHostWait[w.key]
+       q.cleanFront()
        q.pushBack(w)
        t.connsPerHostWait[w.key] = q
 }
index 1a6f631ea209332ee98142c6bddb8b0108470111..23afff5d843160008b89670a4e5d5a632cfbe40e 100644 (file)
@@ -1658,6 +1658,176 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) {
        }
 }
 
+// A countedConn is a net.Conn that decrements an atomic counter when finalized.
+type countedConn struct {
+       net.Conn
+}
+
+// A countingDialer dials connections and counts the number that remain reachable.
+type countingDialer struct {
+       dialer      net.Dialer
+       mu          sync.Mutex
+       total, live int64
+}
+
+func (d *countingDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
+       conn, err := d.dialer.DialContext(ctx, network, address)
+       if err != nil {
+               return nil, err
+       }
+
+       counted := new(countedConn)
+       counted.Conn = conn
+
+       d.mu.Lock()
+       defer d.mu.Unlock()
+       d.total++
+       d.live++
+
+       runtime.SetFinalizer(counted, d.decrement)
+       return counted, nil
+}
+
+func (d *countingDialer) decrement(*countedConn) {
+       d.mu.Lock()
+       defer d.mu.Unlock()
+       d.live--
+}
+
+func (d *countingDialer) Read() (total, live int64) {
+       d.mu.Lock()
+       defer d.mu.Unlock()
+       return d.total, d.live
+}
+
+func TestTransportPersistConnLeakNeverIdle(t *testing.T) {
+       defer afterTest(t)
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               // Close every connection so that it cannot be kept alive.
+               conn, _, err := w.(Hijacker).Hijack()
+               if err != nil {
+                       t.Errorf("Hijack failed unexpectedly: %v", err)
+                       return
+               }
+               conn.Close()
+       }))
+       defer ts.Close()
+
+       var d countingDialer
+       c := ts.Client()
+       c.Transport.(*Transport).DialContext = d.DialContext
+
+       body := []byte("Hello")
+       for i := 0; ; i++ {
+               total, live := d.Read()
+               if live < total {
+                       break
+               }
+               if i >= 1<<12 {
+                       t.Fatalf("Count of live client net.Conns (%d) not lower than total (%d) after %d Do / GC iterations.", live, total, i)
+               }
+
+               req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+               if err != nil {
+                       t.Fatal(err)
+               }
+               _, err = c.Do(req)
+               if err == nil {
+                       t.Fatal("expected broken connection")
+               }
+
+               runtime.GC()
+       }
+}
+
+type countedContext struct {
+       context.Context
+}
+
+type contextCounter struct {
+       mu   sync.Mutex
+       live int64
+}
+
+func (cc *contextCounter) Track(ctx context.Context) context.Context {
+       counted := new(countedContext)
+       counted.Context = ctx
+       cc.mu.Lock()
+       defer cc.mu.Unlock()
+       cc.live++
+       runtime.SetFinalizer(counted, cc.decrement)
+       return counted
+}
+
+func (cc *contextCounter) decrement(*countedContext) {
+       cc.mu.Lock()
+       defer cc.mu.Unlock()
+       cc.live--
+}
+
+func (cc *contextCounter) Read() (live int64) {
+       cc.mu.Lock()
+       defer cc.mu.Unlock()
+       return cc.live
+}
+
+func TestTransportPersistConnContextLeakMaxConnsPerHost(t *testing.T) {
+       defer afterTest(t)
+
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               runtime.Gosched()
+               w.WriteHeader(StatusOK)
+       }))
+       defer ts.Close()
+
+       c := ts.Client()
+       c.Transport.(*Transport).MaxConnsPerHost = 1
+
+       ctx := context.Background()
+       body := []byte("Hello")
+       doPosts := func(cc *contextCounter) {
+               var wg sync.WaitGroup
+               for n := 64; n > 0; n-- {
+                       wg.Add(1)
+                       go func() {
+                               defer wg.Done()
+
+                               ctx := cc.Track(ctx)
+                               req, err := NewRequest("POST", ts.URL, bytes.NewReader(body))
+                               if err != nil {
+                                       t.Error(err)
+                               }
+
+                               _, err = c.Do(req.WithContext(ctx))
+                               if err != nil {
+                                       t.Errorf("Do failed with error: %v", err)
+                               }
+                       }()
+               }
+               wg.Wait()
+       }
+
+       var initialCC contextCounter
+       doPosts(&initialCC)
+
+       // flushCC exists only to put pressure on the GC to finalize the initialCC
+       // contexts: the flushCC allocations should eventually displace the initialCC
+       // allocations.
+       var flushCC contextCounter
+       for i := 0; ; i++ {
+               live := initialCC.Read()
+               if live == 0 {
+                       break
+               }
+               if i >= 100 {
+                       t.Fatalf("%d Contexts still not finalized after %d GC cycles.", live, i)
+               }
+               doPosts(&flushCC)
+               runtime.GC()
+       }
+}
+
 // This used to crash; https://golang.org/issue/3266
 func TestTransportIdleConnCrash(t *testing.T) {
        defer afterTest(t)