From: Damien Neil Date: Thu, 4 Apr 2024 18:01:28 +0000 (-0700) Subject: net/http: don't cancel Dials when requests are canceled X-Git-Tag: go1.23rc1~596 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=334ce510046ad30b1be466634cf313aad3040892;p=gostls13.git net/http: don't cancel Dials when requests are canceled Currently, when a Transport creates a new connection for a request, it uses the request's Context to make the Dial. If a request times out or is canceled before a Dial completes, the Dial is canceled. Change this so that the lifetime of a Dial call is not bound by the request that originated it. This change avoids a scenario where a Transport can start and then cancel many Dial calls in rapid succession: - Request starts a Dial. - A previous request completes, making its connection available. - The new request uses the now-idle connection, and completes. - The request Context is canceled, and the Dial is aborted. Fixes #59017 Change-Id: I996ffabc56d3b1b43129cbfd9b3e9ea7d53d263c Reviewed-on: https://go-review.googlesource.com/c/go/+/576555 Reviewed-by: Brad Fitzpatrick LUCI-TryBot-Result: Go LUCI Reviewed-by: Cherry Mui --- diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 569b58ca62..33e69467c6 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -1938,21 +1938,25 @@ func TestClientCloseIdleConnections(t *testing.T) { } } +type testRoundTripper func(*Request) (*Response, error) + +func (t testRoundTripper) RoundTrip(req *Request) (*Response, error) { + return t(req) +} + func TestClientPropagatesTimeoutToContext(t *testing.T) { - errDial := errors.New("not actually dialing") c := &Client{ Timeout: 5 * time.Second, - Transport: &Transport{ - DialContext: func(ctx context.Context, netw, addr string) (net.Conn, error) { - deadline, ok := ctx.Deadline() - if !ok { - t.Error("no deadline") - } else { - t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10)) - } - return nil, errDial - }, - }, + Transport: testRoundTripper(func(req *Request) (*Response, error) { + ctx := req.Context() + deadline, ok := ctx.Deadline() + if !ok { + t.Error("no deadline") + } else { + t.Logf("deadline in %v", deadline.Sub(time.Now()).Round(time.Second/10)) + } + return nil, errors.New("not actually making a request") + }), } c.Get("https://example.tld/") } diff --git a/src/net/http/export_test.go b/src/net/http/export_test.go index 8a6f4f192f..56ebda180b 100644 --- a/src/net/http/export_test.go +++ b/src/net/http/export_test.go @@ -86,6 +86,14 @@ func SetPendingDialHooks(before, after func()) { func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn } +func SetTestHookProxyConnectTimeout(t *testing.T, f func(context.Context, time.Duration) (context.Context, context.CancelFunc)) { + orig := testHookProxyConnectTimeout + t.Cleanup(func() { + testHookProxyConnectTimeout = orig + }) + testHookProxyConnectTimeout = f +} + func NewTestTimeoutHandler(handler Handler, ctx context.Context) Handler { return &timeoutHandler{ handler: handler, diff --git a/src/net/http/transport.go b/src/net/http/transport.go index d97298ecd9..e6a97a00c6 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -108,6 +108,7 @@ type Transport struct { connsPerHostMu sync.Mutex connsPerHost map[connectMethodKey]int connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns + dialsInProgress wantConnQueue // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -807,6 +808,13 @@ func (t *Transport) CloseIdleConnections() { pconn.close(errCloseIdleConns) } } + t.connsPerHostMu.Lock() + t.dialsInProgress.all(func(w *wantConn) { + if w.cancelCtx != nil && !w.waiting() { + w.cancelCtx() + } + }) + t.connsPerHostMu.Unlock() if t2 := t.h2transport; t2 != nil { t2.CloseIdleConnections() } @@ -1116,7 +1124,7 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) { t.idleConnWait = make(map[connectMethodKey]wantConnQueue) } q := t.idleConnWait[w.key] - q.cleanFront() + q.cleanFrontNotWaiting() q.pushBack(w) t.idleConnWait[w.key] = q return false @@ -1230,10 +1238,11 @@ type wantConn struct { beforeDial func() afterDial func() - mu sync.Mutex // protects ctx, done and sending of the result - ctx context.Context // context for dial, cleared after delivered or canceled - done bool // true after delivered or canceled - result chan connOrError // channel to deliver connection or error + mu sync.Mutex // protects ctx, done and sending of the result + ctx context.Context // context for dial, cleared after delivered or canceled + cancelCtx context.CancelFunc + done bool // true after delivered or canceled + result chan connOrError // channel to deliver connection or error } type connOrError struct { @@ -1352,9 +1361,9 @@ func (q *wantConnQueue) peekFront() *wantConn { return nil } -// cleanFront pops any wantConns that are no longer waiting from the head of the +// cleanFrontNotWaiting pops any wantConns that are no longer waiting from the head of the // queue, reporting whether any were popped. -func (q *wantConnQueue) cleanFront() (cleaned bool) { +func (q *wantConnQueue) cleanFrontNotWaiting() (cleaned bool) { for { w := q.peekFront() if w == nil || w.waiting() { @@ -1365,6 +1374,28 @@ func (q *wantConnQueue) cleanFront() (cleaned bool) { } } +// cleanFrontCanceled pops any wantConns with canceled dials from the head of the queue. +func (q *wantConnQueue) cleanFrontCanceled() { + for { + w := q.peekFront() + if w == nil || w.cancelCtx != nil { + return + } + q.popFront() + } +} + +// all iterates over all wantConns in the queue. +// The caller must not modify the queue while iterating. +func (q *wantConnQueue) all(f func(*wantConn)) { + for _, w := range q.head[q.headPos:] { + f(w) + } + for _, w := range q.tail { + f(w) + } +} + func (t *Transport) customDialTLS(ctx context.Context, network, addr string) (conn net.Conn, err error) { if t.DialTLSContext != nil { conn, err = t.DialTLSContext(ctx, network, addr) @@ -1389,10 +1420,18 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis trace.GetConn(cm.addr()) } + // Detach from the request context's cancellation signal. + // The dial should proceed even if the request is canceled, + // because a future request may be able to make use of the connection. + // + // We retain the request context's values. + dialCtx, dialCancel := context.WithCancel(context.WithoutCancel(ctx)) + w := &wantConn{ cm: cm, key: cm.key(), - ctx: ctx, + ctx: dialCtx, + cancelCtx: dialCancel, result: make(chan connOrError, 1), beforeDial: testHookPrePendingDial, afterDial: testHookPostPendingDial, @@ -1470,20 +1509,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (_ *persis // Once w receives permission to dial, it will do so in a separate goroutine. func (t *Transport) queueForDial(w *wantConn) { w.beforeDial() - if t.MaxConnsPerHost <= 0 { - go t.dialConnFor(w) - return - } t.connsPerHostMu.Lock() defer t.connsPerHostMu.Unlock() + if t.MaxConnsPerHost <= 0 { + t.startDialConnForLocked(w) + return + } + if n := t.connsPerHost[w.key]; n < t.MaxConnsPerHost { if t.connsPerHost == nil { t.connsPerHost = make(map[connectMethodKey]int) } t.connsPerHost[w.key] = n + 1 - go t.dialConnFor(w) + t.startDialConnForLocked(w) return } @@ -1491,11 +1531,24 @@ func (t *Transport) queueForDial(w *wantConn) { t.connsPerHostWait = make(map[connectMethodKey]wantConnQueue) } q := t.connsPerHostWait[w.key] - q.cleanFront() + q.cleanFrontNotWaiting() q.pushBack(w) t.connsPerHostWait[w.key] = q } +// startDialConnFor calls dialConn in a new goroutine. +// t.connsPerHostMu must be held. +func (t *Transport) startDialConnForLocked(w *wantConn) { + t.dialsInProgress.cleanFrontCanceled() + t.dialsInProgress.pushBack(w) + go func() { + t.dialConnFor(w) + t.connsPerHostMu.Lock() + defer t.connsPerHostMu.Unlock() + w.cancelCtx = nil + }() +} + // dialConnFor dials on behalf of w and delivers the result to w. // dialConnFor has received permission to dial w.cm and is counted in t.connCount[w.cm.key()]. // If the dial is canceled or unsuccessful, dialConnFor decrements t.connCount[w.cm.key()]. @@ -1545,7 +1598,7 @@ func (t *Transport) decConnsPerHost(key connectMethodKey) { for q.len() > 0 { w := q.popFront() if w.waiting() { - go t.dialConnFor(w) + t.startDialConnForLocked(w) done = true break } @@ -1626,6 +1679,8 @@ type erringRoundTripper interface { RoundTripErr() error } +var testHookProxyConnectTimeout = context.WithTimeout + func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) { pconn = &persistConn{ t: t, @@ -1742,17 +1797,11 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers Header: hdr, } - // If there's no done channel (no deadline or cancellation - // from the caller possible), at least set some (long) - // timeout here. This will make sure we don't block forever - // and leak a goroutine if the connection stops replying - // after the TCP connect. - connectCtx := ctx - if ctx.Done() == nil { - newCtx, cancel := context.WithTimeout(ctx, 1*time.Minute) - defer cancel() - connectCtx = newCtx - } + // Set a (long) timeout here to make sure we don't block forever + // and leak a goroutine if the connection stops replying after + // the TCP connect. + connectCtx, cancel := testHookProxyConnectTimeout(ctx, 1*time.Minute) + defer cancel() didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails var ( diff --git a/src/net/http/transport_dial_test.go b/src/net/http/transport_dial_test.go new file mode 100644 index 0000000000..39e35cec55 --- /dev/null +++ b/src/net/http/transport_dial_test.go @@ -0,0 +1,235 @@ +// Copyright 2024 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http_test + +import ( + "context" + "io" + "net" + "net/http" + "net/http/httptrace" + "testing" +) + +func TestTransportPoolConnReusePriorConnection(t *testing.T) { + dt := newTransportDialTester(t, http1Mode) + + // First request creates a new connection. + rt1 := dt.roundTrip() + c1 := dt.wantDial() + c1.finish(nil) + rt1.wantDone(c1) + rt1.finish() + + // Second request reuses the first connection. + rt2 := dt.roundTrip() + rt2.wantDone(c1) + rt2.finish() +} + +func TestTransportPoolConnCannotReuseConnectionInUse(t *testing.T) { + dt := newTransportDialTester(t, http1Mode) + + // First request creates a new connection. + rt1 := dt.roundTrip() + c1 := dt.wantDial() + c1.finish(nil) + rt1.wantDone(c1) + + // Second request is made while the first request is still using its connection, + // so it goes on a new connection. + rt2 := dt.roundTrip() + c2 := dt.wantDial() + c2.finish(nil) + rt2.wantDone(c2) +} + +func TestTransportPoolConnConnectionBecomesAvailableDuringDial(t *testing.T) { + dt := newTransportDialTester(t, http1Mode) + + // First request creates a new connection. + rt1 := dt.roundTrip() + c1 := dt.wantDial() + c1.finish(nil) + rt1.wantDone(c1) + + // Second request is made while the first request is still using its connection. + // The first connection completes while the second Dial is in progress, so the + // second request uses the first connection. + rt2 := dt.roundTrip() + c2 := dt.wantDial() + rt1.finish() + rt2.wantDone(c1) + + // This section is a bit overfitted to the current Transport implementation: + // A third request starts. We have an in-progress dial that was started by rt2, + // but this new request (rt3) is going to ignore it and make a dial of its own. + // rt3 will use the first of these dials that completes. + rt3 := dt.roundTrip() + c3 := dt.wantDial() + c2.finish(nil) + rt3.wantDone(c2) + + c3.finish(nil) +} + +// A transportDialTester manages a test of a connection's Dials. +type transportDialTester struct { + t *testing.T + cst *clientServerTest + + dials chan *transportDialTesterConn // each new conn is sent to this channel + + roundTripCount int + dialCount int +} + +// A transportDialTesterRoundTrip is a RoundTrip made as part of a dial test. +type transportDialTesterRoundTrip struct { + t *testing.T + + roundTripID int // distinguishes RoundTrips in logs + cancel context.CancelFunc // cancels the Request context + reqBody io.WriteCloser // write half of the Request.Body + finished bool + + done chan struct{} // closed when RoundTrip returns:w + res *http.Response + err error + conn *transportDialTesterConn +} + +// A transportDialTesterConn is a client connection created by the Transport as +// part of a dial test. +type transportDialTesterConn struct { + t *testing.T + + connID int // distinguished Dials in logs + ready chan error // sent on to complete the Dial + + net.Conn +} + +func newTransportDialTester(t *testing.T, mode testMode) *transportDialTester { + t.Helper() + dt := &transportDialTester{ + t: t, + dials: make(chan *transportDialTesterConn), + } + dt.cst = newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Write response headers when we receive a request. + http.NewResponseController(w).EnableFullDuplex() + w.WriteHeader(200) + http.NewResponseController(w).Flush() + // Wait for the client to send the request body, + // to synchronize with the rest of the test. + io.ReadAll(r.Body) + }), func(tr *http.Transport) { + tr.DialContext = func(ctx context.Context, network, address string) (net.Conn, error) { + c := &transportDialTesterConn{ + t: t, + ready: make(chan error), + } + // Notify the test that a Dial has started, + // and wait for the test to notify us that it should complete. + dt.dials <- c + if err := <-c.ready; err != nil { + return nil, err + } + nc, err := net.Dial(network, address) + if err != nil { + return nil, err + } + // Use the *transportDialTesterConn as the net.Conn, + // to let tests associate requests with connections. + c.Conn = nc + return c, err + } + }) + return dt +} + +// roundTrip starts a RoundTrip. +// It returns immediately, without waiting for the RoundTrip call to complete. +func (dt *transportDialTester) roundTrip() *transportDialTesterRoundTrip { + dt.t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + pr, pw := io.Pipe() + rt := &transportDialTesterRoundTrip{ + t: dt.t, + roundTripID: dt.roundTripCount, + done: make(chan struct{}), + reqBody: pw, + cancel: cancel, + } + dt.roundTripCount++ + dt.t.Logf("RoundTrip %v: started", rt.roundTripID) + dt.t.Cleanup(func() { + rt.cancel() + rt.finish() + }) + go func() { + ctx = httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + GotConn: func(info httptrace.GotConnInfo) { + rt.conn = info.Conn.(*transportDialTesterConn) + }, + }) + req, _ := http.NewRequestWithContext(ctx, "POST", dt.cst.ts.URL, pr) + req.Header.Set("Content-Type", "text/plain") + rt.res, rt.err = dt.cst.tr.RoundTrip(req) + dt.t.Logf("RoundTrip %v: done (err:%v)", rt.roundTripID, rt.err) + close(rt.done) + }() + return rt +} + +// wantDone indicates that a RoundTrip should have returned. +func (rt *transportDialTesterRoundTrip) wantDone(c *transportDialTesterConn) { + rt.t.Helper() + <-rt.done + if rt.err != nil { + rt.t.Fatalf("RoundTrip %v: want success, got err %v", rt.roundTripID, rt.err) + } + if rt.conn != c { + rt.t.Fatalf("RoundTrip %v: want on conn %v, got conn %v", rt.roundTripID, c.connID, rt.conn.connID) + } +} + +// finish completes a RoundTrip by sending the request body, consuming the response body, +// and closing the response body. +func (rt *transportDialTesterRoundTrip) finish() { + rt.t.Helper() + + if rt.finished { + return + } + rt.finished = true + + <-rt.done + + if rt.err != nil { + return + } + rt.reqBody.Close() + io.ReadAll(rt.res.Body) + rt.res.Body.Close() + rt.t.Logf("RoundTrip %v: closed request body", rt.roundTripID) +} + +// wantDial waits for the Transport to start a Dial. +func (dt *transportDialTester) wantDial() *transportDialTesterConn { + c := <-dt.dials + c.connID = dt.dialCount + dt.dialCount++ + dt.t.Logf("Dial %v: started", c.connID) + return c +} + +// finish completes a Dial. +func (c *transportDialTesterConn) finish(err error) { + c.t.Logf("Dial %v: finished (err:%v)", c.connID, err) + c.ready <- err + close(c.ready) +} diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index e8baa486a4..fa147e164e 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -1626,11 +1626,20 @@ func TestOnProxyConnectResponse(t *testing.T) { // Issue 28012: verify that the Transport closes its TCP connection to http proxies // when they're slow to reply to HTTPS CONNECT responses. func TestTransportProxyHTTPSConnectLeak(t *testing.T) { - setParallel(t) - defer afterTest(t) + cancelc := make(chan struct{}) + SetTestHookProxyConnectTimeout(t, func(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(ctx) + go func() { + select { + case <-cancelc: + case <-ctx.Done(): + } + cancel() + }() + return ctx, cancel + }) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + defer afterTest(t) ln := newLocalListener(t) defer ln.Close() @@ -1658,7 +1667,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) { // Now hang and never write a response; instead, cancel the request and wait // for the client to close. // (Prior to Issue 28012 being fixed, we never closed.) - cancel() + close(cancelc) var buf [1]byte _, err = br.Read(buf[:]) if err != io.EOF { @@ -1674,7 +1683,7 @@ func TestTransportProxyHTTPSConnectLeak(t *testing.T) { }, }, } - req, err := NewRequestWithContext(ctx, "GET", "https://golang.fake.tld/", nil) + req, err := NewRequest("GET", "https://golang.fake.tld/", nil) if err != nil { t.Fatal(err) } @@ -3927,9 +3936,13 @@ func testTransportDialTLS(t *testing.T, mode testMode) { func TestTransportDialContext(t *testing.T) { run(t, testTransportDialContext) } func testTransportDialContext(t *testing.T, mode testMode) { - var mu sync.Mutex // guards following - var gotReq bool - var receivedContext context.Context + ctxKey := "some-key" + ctxValue := "some-value" + var ( + mu sync.Mutex // guards following + gotReq bool + gotCtxValue any + ) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() @@ -3939,7 +3952,7 @@ func testTransportDialContext(t *testing.T, mode testMode) { c := ts.Client() c.Transport.(*Transport).DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() - receivedContext = ctx + gotCtxValue = ctx.Value(ctxKey) mu.Unlock() return net.Dial(netw, addr) } @@ -3948,7 +3961,7 @@ func testTransportDialContext(t *testing.T, mode testMode) { if err != nil { t.Fatal(err) } - ctx := context.WithValue(context.Background(), "some-key", "some-value") + ctx := context.WithValue(context.Background(), ctxKey, ctxValue) res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) @@ -3958,8 +3971,8 @@ func testTransportDialContext(t *testing.T, mode testMode) { if !gotReq { t.Error("didn't get request") } - if receivedContext != ctx { - t.Error("didn't receive correct context") + if got, want := gotCtxValue, ctxValue; got != want { + t.Errorf("got context with value %v, want %v", got, want) } } @@ -3967,9 +3980,13 @@ func TestTransportDialTLSContext(t *testing.T) { run(t, testTransportDialTLSContext, []testMode{https1Mode, http2Mode}) } func testTransportDialTLSContext(t *testing.T, mode testMode) { - var mu sync.Mutex // guards following - var gotReq bool - var receivedContext context.Context + ctxKey := "some-key" + ctxValue := "some-value" + var ( + mu sync.Mutex // guards following + gotReq bool + gotCtxValue any + ) ts := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) { mu.Lock() @@ -3979,7 +3996,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) { c := ts.Client() c.Transport.(*Transport).DialTLSContext = func(ctx context.Context, netw, addr string) (net.Conn, error) { mu.Lock() - receivedContext = ctx + gotCtxValue = ctx.Value(ctxKey) mu.Unlock() c, err := tls.Dial(netw, addr, c.Transport.(*Transport).TLSClientConfig) if err != nil { @@ -3992,7 +4009,7 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) { if err != nil { t.Fatal(err) } - ctx := context.WithValue(context.Background(), "some-key", "some-value") + ctx := context.WithValue(context.Background(), ctxKey, ctxValue) res, err := c.Do(req.WithContext(ctx)) if err != nil { t.Fatal(err) @@ -4002,8 +4019,8 @@ func testTransportDialTLSContext(t *testing.T, mode testMode) { if !gotReq { t.Error("didn't get request") } - if receivedContext != ctx { - t.Error("didn't receive correct context") + if got, want := gotCtxValue, ctxValue; got != want { + t.Errorf("got context with value %v, want %v", got, want) } }