]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.NewClientConn
authorDamien Neil <dneil@google.com>
Tue, 18 Nov 2025 22:15:05 +0000 (14:15 -0800)
committerGopher Robot <gobot@golang.org>
Tue, 25 Nov 2025 01:26:36 +0000 (17:26 -0800)
For #75772

Change-Id: Iad7607b40636bab1faf8653455e92e9700309003
Reviewed-on: https://go-review.googlesource.com/c/go/+/722223
Reviewed-by: Nicholas Husin <nsh@golang.org>
Reviewed-by: Nicholas Husin <husin@google.com>
Auto-Submit: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

api/next/75772.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/http/75772.md [new file with mode: 0644]
src/net/http/clientconn.go [new file with mode: 0644]
src/net/http/clientconn_test.go [new file with mode: 0644]
src/net/http/clientserver_test.go
src/net/http/transport.go

diff --git a/api/next/75772.txt b/api/next/75772.txt
new file mode 100644 (file)
index 0000000..18c0f86
--- /dev/null
@@ -0,0 +1,10 @@
+pkg net/http, method (*ClientConn) Available() int #75772
+pkg net/http, method (*ClientConn) Close() error #75772
+pkg net/http, method (*ClientConn) Err() error #75772
+pkg net/http, method (*ClientConn) InFlight() int #75772
+pkg net/http, method (*ClientConn) Release() #75772
+pkg net/http, method (*ClientConn) Reserve() error #75772
+pkg net/http, method (*ClientConn) RoundTrip(*Request) (*Response, error) #75772
+pkg net/http, method (*ClientConn) SetStateHook(func(*ClientConn)) #75772
+pkg net/http, method (*Transport) NewClientConn(context.Context, string, string) (*ClientConn, error) #75772
+pkg net/http, type ClientConn struct #75772
diff --git a/doc/next/6-stdlib/99-minor/net/http/75772.md b/doc/next/6-stdlib/99-minor/net/http/75772.md
new file mode 100644 (file)
index 0000000..59d3e87
--- /dev/null
@@ -0,0 +1,5 @@
+The new [Transport.NewClientConn] method returns a client connection
+to an HTTP server.
+Most users should continue to use [Transport.RoundTrip] to make requests,
+which manages a pool of connection.
+`NewClientConn` is useful for users who need to implement their own conection management.
diff --git a/src/net/http/clientconn.go b/src/net/http/clientconn.go
new file mode 100644 (file)
index 0000000..2e1f33e
--- /dev/null
@@ -0,0 +1,456 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+import (
+       "context"
+       "errors"
+       "fmt"
+       "net"
+       "net/http/httptrace"
+       "net/url"
+       "sync"
+)
+
+// A ClientConn is a client connection to an HTTP server.
+//
+// Unlike a [Transport], a ClientConn represents a single connection.
+// Most users should use a Transport rather than creating client connections directly.
+type ClientConn struct {
+       cc genericClientConn
+
+       stateHookMu      sync.Mutex
+       userStateHook    func(*ClientConn)
+       stateHookRunning bool
+       lastAvailable    int
+       lastInFlight     int
+       lastClosed       bool
+}
+
+// newClientConner is the interface implemented by HTTP/2 transports to create new client conns.
+//
+// The http package (this package) needs a way to ask the http2 package to
+// create a client connection.
+//
+// Transport.TLSNextProto["h2"] contains a function which appears to do this,
+// but for historical reasons it does not: The TLSNextProto function adds a
+// *tls.Conn to the http2.Transport's connection pool and returns a RoundTripper
+// which is backed by that connection pool. NewClientConn needs a way to get a
+// single client connection out of the http2 package.
+//
+// The http2 package registers a RoundTripper with Transport.RegisterProtocol.
+// If this RoundTripper implements newClientConner, then Transport.NewClientConn will use
+// it to create new HTTP/2 client connections.
+type newClientConner interface {
+       // NewClientConn creates a new client connection from a net.Conn.
+       //
+       // The RoundTripper returned by NewClientConn must implement genericClientConn.
+       // (We don't define NewClientConn as returning genericClientConn,
+       // because either we'd need to make genericClientConn an exported type
+       // or define it as a type alias. Neither is particularly appealing.)
+       //
+       // The state hook passed here is the internal state hook
+       // (ClientConn.maybeRunStateHook). The internal state hook calls
+       // the user state hook (if any), which is set by the user with
+       // ClientConn.SetStateHook.
+       //
+       // The client connection should arrange to call the internal state hook
+       // when the connection closes, when requests complete, and when the
+       // connection concurrency limit changes.
+       //
+       // The client connection must call the internal state hook when the connection state
+       // changes asynchronously, such as when a request completes.
+       //
+       // The internal state hook need not be called after synchronous changes to the state:
+       // Close, Reserve, Release, and RoundTrip calls which don't start a request
+       // do not need to call the hook.
+       //
+       // The general idea is that if we call (for example) Close,
+       // we know that the connection state has probably changed and we
+       // don't need the state hook to tell us that.
+       // However, if the connection closes asynchronously
+       // (because, for example, the other end of the conn closed it),
+       // the state hook needs to inform us.
+       NewClientConn(nc net.Conn, internalStateHook func()) (RoundTripper, error)
+}
+
+// genericClientConn is an interface implemented by HTTP/2 client conns
+// returned from newClientConner.NewClientConn.
+//
+// See the newClientConner doc comment for more information.
+type genericClientConn interface {
+       Close() error
+       Err() error
+       RoundTrip(req *Request) (*Response, error)
+       Reserve() error
+       Release()
+       Available() int
+       InFlight() int
+}
+
+// NewClientConn creates a new client connection to the given address.
+//
+// If scheme is "http", the connection is unencrypted.
+// If scheme is "https", the connection uses TLS.
+//
+// The protocol used for the new connection is determined by the scheme,
+// Transport.Protocols configuration field, and protocols supported by the
+// server. See Transport.Protocols for more details.
+//
+// If Transport.Proxy is set and indicates that a request sent to the given
+// address should use a proxy, the new connection uses that proxy.
+//
+// NewClientConn always creates a new connection,
+// even if the Transport has an existing cached connection to the given host.
+//
+// The new connection is not added to the Transport's connection cache,
+// and will not be used by [Transport.RoundTrip].
+// It does not count against the MaxIdleConns and MaxConnsPerHost limits.
+//
+// The caller is responsible for closing the new connection.
+func (t *Transport) NewClientConn(ctx context.Context, scheme, address string) (*ClientConn, error) {
+       t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+
+       switch scheme {
+       case "http", "https":
+       default:
+               return nil, fmt.Errorf("net/http: invalid scheme %q", scheme)
+       }
+
+       host, port, err := net.SplitHostPort(address)
+       if err != nil {
+               return nil, err
+       }
+       if port == "" {
+               port = schemePort(scheme)
+       }
+
+       var proxyURL *url.URL
+       if t.Proxy != nil {
+               // Transport.Proxy takes a *Request, so create a fake one to pass it.
+               req := &Request{
+                       ctx:    ctx,
+                       Method: "GET",
+                       URL: &url.URL{
+                               Scheme: scheme,
+                               Host:   host,
+                               Path:   "/",
+                       },
+                       Proto:      "HTTP/1.1",
+                       ProtoMajor: 1,
+                       ProtoMinor: 1,
+                       Header:     make(Header),
+                       Body:       NoBody,
+                       Host:       host,
+               }
+               var err error
+               proxyURL, err = t.Proxy(req)
+               if err != nil {
+                       return nil, err
+               }
+       }
+
+       cm := connectMethod{
+               targetScheme: scheme,
+               targetAddr:   net.JoinHostPort(host, port),
+               proxyURL:     proxyURL,
+       }
+
+       // The state hook is a bit tricky:
+       // The persistConn has a state hook which calls ClientConn.maybeRunStateHook,
+       // which in turn calls the user-provided state hook (if any).
+       //
+       // ClientConn.maybeRunStateHook handles debouncing hook calls for both
+       // HTTP/1 and HTTP/2.
+       //
+       // Since there's no need to change the persistConn's hook, we set it at creation time.
+       cc := &ClientConn{}
+       const isClientConn = true
+       pconn, err := t.dialConn(ctx, cm, isClientConn, cc.maybeRunStateHook)
+       if err != nil {
+               return nil, err
+       }
+
+       // Note that cc.maybeRunStateHook may have been called
+       // in the short window between dialConn and now.
+       // This is fine.
+       cc.stateHookMu.Lock()
+       defer cc.stateHookMu.Unlock()
+       if pconn.alt != nil {
+               // If pconn.alt is set, this is a connection implemented in another package
+               // (probably x/net/http2) or the bundled copy in h2_bundle.go.
+               gc, ok := pconn.alt.(genericClientConn)
+               if !ok {
+                       return nil, errors.New("http: NewClientConn returned something that is not a ClientConn")
+               }
+               cc.cc = gc
+               cc.lastAvailable = gc.Available()
+       } else {
+               // This is an HTTP/1 connection.
+               pconn.availch = make(chan struct{}, 1)
+               pconn.availch <- struct{}{}
+               cc.cc = http1ClientConn{pconn}
+               cc.lastAvailable = 1
+       }
+       return cc, nil
+}
+
+// Close closes the connection.
+// Outstanding RoundTrip calls are interrupted.
+func (cc *ClientConn) Close() error {
+       defer cc.maybeRunStateHook()
+       return cc.cc.Close()
+}
+
+// Err reports any fatal connection errors.
+// It returns nil if the connection is usable.
+// If it returns non-nil, the connection can no longer be used.
+func (cc *ClientConn) Err() error {
+       return cc.cc.Err()
+}
+
+func validateClientConnRequest(req *Request) error {
+       if req.URL == nil {
+               return errors.New("http: nil Request.URL")
+       }
+       if req.Header == nil {
+               return errors.New("http: nil Request.Header")
+       }
+       // Validate the outgoing headers.
+       if err := validateHeaders(req.Header); err != "" {
+               return fmt.Errorf("http: invalid header %s", err)
+       }
+       // Validate the outgoing trailers too.
+       if err := validateHeaders(req.Trailer); err != "" {
+               return fmt.Errorf("http: invalid trailer %s", err)
+       }
+       if req.Method != "" && !validMethod(req.Method) {
+               return fmt.Errorf("http: invalid method %q", req.Method)
+       }
+       if req.URL.Host == "" {
+               return errors.New("http: no Host in request URL")
+       }
+       return nil
+}
+
+// RoundTrip implements the [RoundTripper] interface.
+//
+// The request is sent on the client connection,
+// regardless of the URL being requested or any proxy settings.
+//
+// If the connection is at its concurrency limit,
+// RoundTrip waits for the connection to become available
+// before sending the request.
+func (cc *ClientConn) RoundTrip(req *Request) (*Response, error) {
+       defer cc.maybeRunStateHook()
+       if err := validateClientConnRequest(req); err != nil {
+               cc.Release()
+               return nil, err
+       }
+       return cc.cc.RoundTrip(req)
+}
+
+// Available reports the number of requests that may be sent
+// to the connection without blocking.
+// It returns 0 if the connection is closed.
+func (cc *ClientConn) Available() int {
+       return cc.cc.Available()
+}
+
+// InFlight reports the number of requests in flight,
+// including reserved requests.
+// It returns 0 if the connection is closed.
+func (cc *ClientConn) InFlight() int {
+       return cc.cc.InFlight()
+}
+
+// Reserve reserves a concurrency slot on the connection.
+// If Reserve returns nil, one additional RoundTrip call may be made
+// without waiting for an existing request to complete.
+//
+// The reserved concurrency slot is accounted as an in-flight request.
+// A successful call to RoundTrip will decrement the Available count
+// and increment the InFlight count.
+//
+// Each successful call to Reserve should be followed by exactly one call
+// to RoundTrip or Release, which will consume or release the reservation.
+//
+// If the connection is closed or at its concurrency limit,
+// Reserve returns an error.
+func (cc *ClientConn) Reserve() error {
+       defer cc.maybeRunStateHook()
+       return cc.cc.Reserve()
+}
+
+// Release releases an unused concurrency slot reserved by Reserve.
+// If there are no reserved concurrency slots, it has no effect.
+func (cc *ClientConn) Release() {
+       defer cc.maybeRunStateHook()
+       cc.cc.Release()
+}
+
+// shouldRunStateHook returns the user's state hook if we should call it,
+// or nil if we don't need to call it at this time.
+func (cc *ClientConn) shouldRunStateHook(stopRunning bool) func(*ClientConn) {
+       cc.stateHookMu.Lock()
+       defer cc.stateHookMu.Unlock()
+       if cc.cc == nil {
+               return nil
+       }
+       if stopRunning {
+               cc.stateHookRunning = false
+       }
+       if cc.userStateHook == nil {
+               return nil
+       }
+       if cc.stateHookRunning {
+               return nil
+       }
+       var (
+               available = cc.Available()
+               inFlight  = cc.InFlight()
+               closed    = cc.Err() != nil
+       )
+       var hook func(*ClientConn)
+       if available > cc.lastAvailable || inFlight < cc.lastInFlight || closed != cc.lastClosed {
+               hook = cc.userStateHook
+               cc.stateHookRunning = true
+       }
+       cc.lastAvailable = available
+       cc.lastInFlight = inFlight
+       cc.lastClosed = closed
+       return hook
+}
+
+func (cc *ClientConn) maybeRunStateHook() {
+       hook := cc.shouldRunStateHook(false)
+       if hook == nil {
+               return
+       }
+       // Run the hook synchronously.
+       //
+       // This means that if, for example, the user calls resp.Body.Close to finish a request,
+       // the Close call will synchronously run the hook, giving the hook the chance to
+       // return the ClientConn to a connection pool before the next request is made.
+       hook(cc)
+       // The connection state may have changed while the hook was running,
+       // in which case we need to run it again.
+       //
+       // If we do need to run the hook again, do so in a new goroutine to avoid blocking
+       // the current goroutine indefinitely.
+       hook = cc.shouldRunStateHook(true)
+       if hook != nil {
+               go func() {
+                       for hook != nil {
+                               hook(cc)
+                               hook = cc.shouldRunStateHook(true)
+                       }
+               }()
+       }
+}
+
+// SetStateHook arranges for f to be called when the state of the connection changes.
+// At most one call to f is made at a time.
+// If the connection's state has changed since it was created,
+// f is called immediately in a separate goroutine.
+// f may be called synchronously from RoundTrip or Response.Body.Close.
+//
+// If SetStateHook is called multiple times, the new hook replaces the old one.
+// If f is nil, no further calls will be made to f after SetStateHook returns.
+//
+// f is called when Available increases (more requests may be sent on the connection),
+// InFlight decreases (existing requests complete), or Err begins returning non-nil
+// (the connection is no longer usable).
+func (cc *ClientConn) SetStateHook(f func(*ClientConn)) {
+       cc.stateHookMu.Lock()
+       cc.userStateHook = f
+       cc.stateHookMu.Unlock()
+       cc.maybeRunStateHook()
+}
+
+// http1ClientConn is a genericClientConn implementation backed by
+// an HTTP/1 *persistConn (pconn.alt is nil).
+type http1ClientConn struct {
+       pconn *persistConn
+}
+
+func (cc http1ClientConn) RoundTrip(req *Request) (*Response, error) {
+       ctx := req.Context()
+       trace := httptrace.ContextClientTrace(ctx)
+
+       // Convert Request.Cancel into context cancelation.
+       ctx, cancel := context.WithCancelCause(req.Context())
+       if req.Cancel != nil {
+               go awaitLegacyCancel(ctx, cancel, req)
+       }
+
+       treq := &transportRequest{Request: req, trace: trace, ctx: ctx, cancel: cancel}
+       resp, err := cc.pconn.roundTrip(treq)
+       if err != nil {
+               return nil, err
+       }
+       resp.Request = req
+       return resp, nil
+}
+
+func (cc http1ClientConn) Close() error {
+       cc.pconn.close(errors.New("ClientConn closed"))
+       return nil
+}
+
+func (cc http1ClientConn) Err() error {
+       select {
+       case <-cc.pconn.closech:
+               return cc.pconn.closed
+       default:
+               return nil
+       }
+}
+
+func (cc http1ClientConn) Available() int {
+       cc.pconn.mu.Lock()
+       defer cc.pconn.mu.Unlock()
+       if cc.pconn.closed != nil || cc.pconn.reserved || cc.pconn.inFlight {
+               return 0
+       }
+       return 1
+}
+
+func (cc http1ClientConn) InFlight() int {
+       cc.pconn.mu.Lock()
+       defer cc.pconn.mu.Unlock()
+       if cc.pconn.closed == nil && (cc.pconn.reserved || cc.pconn.inFlight) {
+               return 1
+       }
+       return 0
+}
+
+func (cc http1ClientConn) Reserve() error {
+       cc.pconn.mu.Lock()
+       defer cc.pconn.mu.Unlock()
+       if cc.pconn.closed != nil {
+               return cc.pconn.closed
+       }
+       select {
+       case <-cc.pconn.availch:
+       default:
+               return errors.New("connection is unavailable")
+       }
+       cc.pconn.reserved = true
+       return nil
+}
+
+func (cc http1ClientConn) Release() {
+       cc.pconn.mu.Lock()
+       defer cc.pconn.mu.Unlock()
+       if cc.pconn.reserved {
+               select {
+               case cc.pconn.availch <- struct{}{}:
+               default:
+                       panic("cannot release reservation")
+               }
+               cc.pconn.reserved = false
+       }
+}
diff --git a/src/net/http/clientconn_test.go b/src/net/http/clientconn_test.go
new file mode 100644 (file)
index 0000000..e46f6e6
--- /dev/null
@@ -0,0 +1,374 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http_test
+
+import (
+       "context"
+       "fmt"
+       "io"
+       "net/http"
+       "sync"
+       "sync/atomic"
+       "testing"
+       "testing/synctest"
+)
+
+func TestTransportNewClientConnRoundTrip(t *testing.T) { run(t, testTransportNewClientConnRoundTrip) }
+func testTransportNewClientConnRoundTrip(t *testing.T, mode testMode) {
+       cst := newClientServerTest(t, mode, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
+               io.WriteString(w, req.Host)
+       }), optFakeNet)
+
+       scheme := mode.Scheme() // http or https
+       cc, err := cst.tr.NewClientConn(t.Context(), scheme, cst.ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer cc.Close()
+
+       // Send requests for a couple different domains.
+       // All use the same connection.
+       for _, host := range []string{"example.tld", "go.dev"} {
+               req, _ := http.NewRequest("GET", fmt.Sprintf("%v://%v/", scheme, host), nil)
+               resp, err := cc.RoundTrip(req)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               got, _ := io.ReadAll(resp.Body)
+               if string(got) != host {
+                       t.Errorf("got response body %q, want %v", got, host)
+               }
+               resp.Body.Close()
+
+               // CloseIdleConnections does not close connections created by NewClientConn.
+               cst.tr.CloseIdleConnections()
+       }
+
+       if err := cc.Err(); err != nil {
+               t.Errorf("before close: ClientConn.Err() = %v, want nil", err)
+       }
+
+       cc.Close()
+       if err := cc.Err(); err == nil {
+               t.Errorf("after close: ClientConn.Err() = nil, want error")
+       }
+
+       req, _ := http.NewRequest("GET", scheme+"://example.tld/", nil)
+       resp, err := cc.RoundTrip(req)
+       if err == nil {
+               resp.Body.Close()
+               t.Errorf("after close: cc.RoundTrip succeeded, want error")
+       }
+       t.Log(err)
+}
+
+func newClientConnTest(t testing.TB, mode testMode, h http.HandlerFunc, opts ...any) (*clientServerTest, *http.ClientConn) {
+       if h == nil {
+               h = func(w http.ResponseWriter, req *http.Request) {}
+       }
+       cst := newClientServerTest(t, mode, h, opts...)
+       cc, err := cst.tr.NewClientConn(t.Context(), mode.Scheme(), cst.ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       t.Cleanup(func() {
+               cc.Close()
+       })
+       synctest.Wait()
+       return cst, cc
+}
+
+// TestClientConnReserveAll reserves every concurrency slot on a connection.
+func TestClientConnReserveAll(t *testing.T) { runSynctest(t, testClientConnReserveAll) }
+func testClientConnReserveAll(t *testing.T, mode testMode) {
+       cst, cc := newClientConnTest(t, mode, nil, optFakeNet, func(s *http.Server) {
+               s.HTTP2 = &http.HTTP2Config{
+                       MaxConcurrentStreams: 3,
+               }
+       })
+
+       want := 1
+       switch mode {
+       case http2Mode, http2UnencryptedMode:
+               want = cst.ts.Config.HTTP2.MaxConcurrentStreams
+       }
+       available := cc.Available()
+       if available != want {
+               t.Fatalf("cc.Available() = %v, want %v", available, want)
+       }
+
+       // Reserve every available concurrency slot on the connection.
+       for i := range available {
+               if err := cc.Reserve(); err != nil {
+                       t.Fatalf("cc.Reserve() #%v = %v, want nil", i, err)
+               }
+               if got, want := cc.Available(), available-i-1; got != want {
+                       t.Fatalf("cc.Available() = %v, want %v", got, want)
+               }
+               if got, want := cc.InFlight(), i+1; got != want {
+                       t.Fatalf("cc.InFlight() = %v, want %v", got, want)
+               }
+       }
+
+       // The next reservation attempt should fail, since every slot is consumed.
+       if err := cc.Reserve(); err == nil {
+               t.Fatalf("cc.Reserve() = nil, want error")
+       }
+}
+
+// TestClientConnReserveParallel starts concurrent goroutines which reserve every
+// concurrency slot on a connection.
+func TestClientConnReserveParallel(t *testing.T) { runSynctest(t, testClientConnReserveParallel) }
+func testClientConnReserveParallel(t *testing.T, mode testMode) {
+       _, cc := newClientConnTest(t, mode, nil, optFakeNet, func(s *http.Server) {
+               s.HTTP2 = &http.HTTP2Config{
+                       MaxConcurrentStreams: 3,
+               }
+       })
+       var (
+               wg      sync.WaitGroup
+               mu      sync.Mutex
+               success int
+               failure int
+       )
+       available := cc.Available()
+       const extra = 2
+       for range available + extra {
+               wg.Go(func() {
+                       err := cc.Reserve()
+                       mu.Lock()
+                       defer mu.Unlock()
+                       if err == nil {
+                               success++
+                       } else {
+                               failure++
+                       }
+               })
+       }
+       wg.Wait()
+
+       if got, want := success, available; got != want {
+               t.Errorf("%v successful reservations, want %v", got, want)
+       }
+       if got, want := failure, extra; got != want {
+               t.Errorf("%v failed reservations, want %v", got, want)
+       }
+}
+
+// TestClientConnReserveRelease repeatedly reserves and releases concurrency slots.
+func TestClientConnReserveRelease(t *testing.T) { runSynctest(t, testClientConnReserveRelease) }
+func testClientConnReserveRelease(t *testing.T, mode testMode) {
+       _, cc := newClientConnTest(t, mode, nil, optFakeNet, func(s *http.Server) {
+               s.HTTP2 = &http.HTTP2Config{
+                       MaxConcurrentStreams: 3,
+               }
+       })
+
+       available := cc.Available()
+       for i := range 2 * available {
+               if err := cc.Reserve(); err != nil {
+                       t.Fatalf("cc.Reserve() #%v = %v, want nil", i, err)
+               }
+               cc.Release()
+       }
+
+       if got, want := cc.Available(), available; got != want {
+               t.Fatalf("cc.Available() = %v, want %v", available, want)
+       }
+}
+
+// TestClientConnReserveAndConsume reserves a concurrency slot on a connection,
+// and then verifies that various events consume the reservation.
+func TestClientConnReserveAndConsume(t *testing.T) {
+       for _, test := range []struct {
+               name     string
+               consume  func(t *testing.T, cc *http.ClientConn, mode testMode)
+               handler  func(w http.ResponseWriter, req *http.Request, donec chan struct{})
+               h1Closed bool
+       }{{
+               // Explicit release.
+               name: "release",
+               consume: func(t *testing.T, cc *http.ClientConn, mode testMode) {
+                       cc.Release()
+               },
+       }, {
+               // Invalid request sent to RoundTrip.
+               name: "invalid field name",
+               consume: func(t *testing.T, cc *http.ClientConn, mode testMode) {
+                       req, _ := http.NewRequest("GET", mode.Scheme()+"://example.tld/", nil)
+                       req.Header["invalid field name"] = []string{"x"}
+                       _, err := cc.RoundTrip(req)
+                       if err == nil {
+                               t.Fatalf("RoundTrip succeeded, want failure")
+                       }
+               },
+       }, {
+               // Successful request/response cycle.
+               name: "body close",
+               consume: func(t *testing.T, cc *http.ClientConn, mode testMode) {
+                       req, _ := http.NewRequest("GET", mode.Scheme()+"://example.tld/", nil)
+                       resp, err := cc.RoundTrip(req)
+                       if err != nil {
+                               t.Fatalf("RoundTrip: %v", err)
+                       }
+                       resp.Body.Close()
+               },
+       }, {
+               // Request context canceled before headers received.
+               name: "cancel",
+               consume: func(t *testing.T, cc *http.ClientConn, mode testMode) {
+                       ctx, cancel := context.WithCancel(t.Context())
+                       go func() {
+                               req, _ := http.NewRequestWithContext(ctx, "GET", mode.Scheme()+"://example.tld/", nil)
+                               _, err := cc.RoundTrip(req)
+                               if err == nil {
+                                       t.Errorf("RoundTrip succeeded, want failure")
+                               }
+                       }()
+                       synctest.Wait()
+                       cancel()
+               },
+               handler: func(w http.ResponseWriter, req *http.Request, donec chan struct{}) {
+                       <-donec
+               },
+               // An HTTP/1 connection is closed after a request is canceled on it.
+               h1Closed: true,
+       }, {
+               // Response body closed before full response received.
+               name: "early body close",
+               consume: func(t *testing.T, cc *http.ClientConn, mode testMode) {
+                       req, _ := http.NewRequest("GET", mode.Scheme()+"://example.tld/", nil)
+                       resp, err := cc.RoundTrip(req)
+                       if err != nil {
+                               t.Fatalf("RoundTrip: %v", err)
+                       }
+                       t.Logf("%T", resp.Body)
+                       resp.Body.Close()
+               },
+               handler: func(w http.ResponseWriter, req *http.Request, donec chan struct{}) {
+                       w.WriteHeader(200)
+                       http.NewResponseController(w).Flush()
+                       <-donec
+               },
+               // An HTTP/1 connection is closed after a request is canceled on it.
+               h1Closed: true,
+       }} {
+               t.Run(test.name, func(t *testing.T) {
+                       runSynctest(t, func(t *testing.T, mode testMode) {
+                               donec := make(chan struct{})
+                               defer close(donec)
+                               handler := func(w http.ResponseWriter, req *http.Request) {
+                                       if test.handler != nil {
+                                               test.handler(w, req, donec)
+                                       }
+                               }
+
+                               _, cc := newClientConnTest(t, mode, handler, optFakeNet)
+                               stateHookCalls := 0
+                               cc.SetStateHook(func(cc *http.ClientConn) {
+                                       stateHookCalls++
+                               })
+                               synctest.Wait()
+                               stateHookCalls = 0 // ignore any initial update call
+
+                               avail := cc.Available()
+                               if err := cc.Reserve(); err != nil {
+                                       t.Fatalf("cc.Reserve() = %v, want nil", err)
+                               }
+                               synctest.Wait()
+                               if got, want := stateHookCalls, 0; got != want {
+                                       t.Errorf("connection state hook calls: %v, want %v", got, want)
+                               }
+
+                               test.consume(t, cc, mode)
+                               synctest.Wait()
+
+                               // State hook should be called, either to report the
+                               // connection availability increasing or the connection closing.
+                               if got, want := stateHookCalls, 1; got != want {
+                                       t.Errorf("connection state hook calls: %v, want %v", got, want)
+                               }
+
+                               if test.h1Closed && (mode == http1Mode || mode == https1Mode) {
+                                       if got, want := cc.Available(), 0; got != want {
+                                               t.Errorf("cc.Available() = %v, want %v", got, want)
+                                       }
+                                       if got, want := cc.InFlight(), 0; got != want {
+                                               t.Errorf("cc.InFlight() = %v, want %v", got, want)
+                                       }
+                                       if err := cc.Err(); err == nil {
+                                               t.Errorf("cc.Err() = nil, want closed connection")
+                                       }
+                               } else {
+                                       if got, want := cc.Available(), avail; got != want {
+                                               t.Errorf("cc.Available() = %v, want %v", got, want)
+                                       }
+                                       if got, want := cc.InFlight(), 0; got != want {
+                                               t.Errorf("cc.InFlight() = %v, want %v", got, want)
+                                       }
+                                       if err := cc.Err(); err != nil {
+                                               t.Errorf("cc.Err() = %v, want nil", err)
+                                       }
+                               }
+
+                               if cc.Available() > 0 {
+                                       if err := cc.Reserve(); err != nil {
+                                               t.Errorf("cc.Reserve() = %v, want success", err)
+                                       }
+                               }
+                       })
+               })
+       }
+
+}
+
+// TestClientConnRoundTripBlocks verifies that RoundTrip blocks until a concurrency
+// slot is available on a connection.
+func TestClientConnRoundTripBlocks(t *testing.T) { runSynctest(t, testClientConnRoundTripBlocks) }
+func testClientConnRoundTripBlocks(t *testing.T, mode testMode) {
+       var handlerCalls atomic.Int64
+       requestc := make(chan struct{})
+       handler := func(w http.ResponseWriter, req *http.Request) {
+               handlerCalls.Add(1)
+               <-requestc
+       }
+       _, cc := newClientConnTest(t, mode, handler, optFakeNet, func(s *http.Server) {
+               s.HTTP2 = &http.HTTP2Config{
+                       MaxConcurrentStreams: 3,
+               }
+       })
+
+       available := cc.Available()
+       var responses atomic.Int64
+       const extra = 2
+       for range available + extra {
+               go func() {
+                       req, _ := http.NewRequest("GET", mode.Scheme()+"://example.tld/", nil)
+                       resp, err := cc.RoundTrip(req)
+                       responses.Add(1)
+                       if err != nil {
+                               t.Errorf("RoundTrip: %v", err)
+                               return
+                       }
+                       resp.Body.Close()
+               }()
+       }
+
+       synctest.Wait()
+       if got, want := int(handlerCalls.Load()), available; got != want {
+               t.Errorf("got %v handler calls, want %v", got, want)
+       }
+       if got, want := int(responses.Load()), 0; got != want {
+               t.Errorf("got %v responses, want %v", got, want)
+       }
+
+       for i := range available + extra {
+               requestc <- struct{}{}
+               synctest.Wait()
+               if got, want := int(responses.Load()), i+1; got != want {
+                       t.Errorf("got %v responses, want %v", got, want)
+               }
+       }
+}
index 8665bae38ad0f26ccfd622fe9dbbce45b23847cf..2bca1d3253648b42cefb9c850b17aa42626a82d1 100644 (file)
@@ -46,6 +46,16 @@ const (
        http2UnencryptedMode = testMode("h2unencrypted") // HTTP/2
 )
 
+func (m testMode) Scheme() string {
+       switch m {
+       case http1Mode, http2UnencryptedMode:
+               return "http"
+       case https1Mode, http2Mode:
+               return "https"
+       }
+       panic("unknown testMode")
+}
+
 type testNotParallelOpt struct{}
 
 var (
index 033eddf1f5ba7e31fd746c723202bd40df9f1a32..26a25d2a022fd4fb9a32af7012afe04c421f8df6 100644 (file)
@@ -1067,6 +1067,22 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
                return errConnBroken
        }
        pconn.markReused()
+       if pconn.isClientConn {
+               // internalStateHook is always set for conns created by NewClientConn.
+               defer pconn.internalStateHook()
+               pconn.mu.Lock()
+               defer pconn.mu.Unlock()
+               if !pconn.inFlight {
+                       panic("pconn is not in flight")
+               }
+               pconn.inFlight = false
+               select {
+               case pconn.availch <- struct{}{}:
+               default:
+                       panic("unable to make pconn available")
+               }
+               return nil
+       }
 
        t.idleMu.Lock()
        defer t.idleMu.Unlock()
@@ -1243,6 +1259,9 @@ func (t *Transport) queueForIdleConn(w *wantConn) (delivered bool) {
 
 // removeIdleConn marks pconn as dead.
 func (t *Transport) removeIdleConn(pconn *persistConn) bool {
+       if pconn.isClientConn {
+               return true
+       }
        t.idleMu.Lock()
        defer t.idleMu.Unlock()
        return t.removeIdleConnLocked(pconn)
@@ -1625,7 +1644,8 @@ func (t *Transport) dialConnFor(w *wantConn) {
                return
        }
 
-       pc, err := t.dialConn(ctx, w.cm)
+       const isClientConn = false
+       pc, err := t.dialConn(ctx, w.cm, isClientConn, nil)
        delivered := w.tryDeliver(pc, err, time.Time{})
        if err == nil && (!delivered || pc.alt != nil) {
                // pconn was not passed to w,
@@ -1746,15 +1766,17 @@ type erringRoundTripper interface {
 
 var testHookProxyConnectTimeout = context.WithTimeout
 
-func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *persistConn, err error) {
+func (t *Transport) dialConn(ctx context.Context, cm connectMethod, isClientConn bool, internalStateHook func()) (pconn *persistConn, err error) {
        pconn = &persistConn{
-               t:             t,
-               cacheKey:      cm.key(),
-               reqch:         make(chan requestAndChan, 1),
-               writech:       make(chan writeRequest, 1),
-               closech:       make(chan struct{}),
-               writeErrCh:    make(chan error, 1),
-               writeLoopDone: make(chan struct{}),
+               t:                 t,
+               cacheKey:          cm.key(),
+               reqch:             make(chan requestAndChan, 1),
+               writech:           make(chan writeRequest, 1),
+               closech:           make(chan struct{}),
+               writeErrCh:        make(chan error, 1),
+               writeLoopDone:     make(chan struct{}),
+               isClientConn:      isClientConn,
+               internalStateHook: internalStateHook,
        }
        trace := httptrace.ContextClientTrace(ctx)
        wrapErr := func(err error) error {
@@ -1927,6 +1949,21 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers
                t.Protocols != nil &&
                t.Protocols.UnencryptedHTTP2() &&
                !t.Protocols.HTTP1()
+
+       if isClientConn && (unencryptedHTTP2 || (pconn.tlsState != nil && pconn.tlsState.NegotiatedProtocol == "h2")) {
+               altProto, _ := t.altProto.Load().(map[string]RoundTripper)
+               h2, ok := altProto["https"].(newClientConner)
+               if !ok {
+                       return nil, errors.New("http: HTTP/2 implementation does not support NewClientConn (update golang.org/x/net?)")
+               }
+               alt, err := h2.NewClientConn(pconn.conn, internalStateHook)
+               if err != nil {
+                       pconn.conn.Close()
+                       return nil, err
+               }
+               return &persistConn{t: t, cacheKey: pconn.cacheKey, alt: alt, isClientConn: true}, nil
+       }
+
        if unencryptedHTTP2 {
                next, ok := t.TLSNextProto[nextProtoUnencryptedHTTP2]
                if !ok {
@@ -2081,19 +2118,21 @@ type persistConn struct {
        // If it's non-nil, the rest of the fields are unused.
        alt RoundTripper
 
-       t         *Transport
-       cacheKey  connectMethodKey
-       conn      net.Conn
-       tlsState  *tls.ConnectionState
-       br        *bufio.Reader       // from conn
-       bw        *bufio.Writer       // to conn
-       nwrite    int64               // bytes written
-       reqch     chan requestAndChan // written by roundTrip; read by readLoop
-       writech   chan writeRequest   // written by roundTrip; read by writeLoop
-       closech   chan struct{}       // closed when conn closed
-       isProxy   bool
-       sawEOF    bool  // whether we've seen EOF from conn; owned by readLoop
-       readLimit int64 // bytes allowed to be read; owned by readLoop
+       t            *Transport
+       cacheKey     connectMethodKey
+       conn         net.Conn
+       tlsState     *tls.ConnectionState
+       br           *bufio.Reader       // from conn
+       bw           *bufio.Writer       // to conn
+       nwrite       int64               // bytes written
+       reqch        chan requestAndChan // written by roundTrip; read by readLoop
+       writech      chan writeRequest   // written by roundTrip; read by writeLoop
+       closech      chan struct{}       // closed when conn closed
+       availch      chan struct{}       // ClientConn only: contains a value when conn is usable
+       isProxy      bool
+       sawEOF       bool  // whether we've seen EOF from conn; owned by readLoop
+       isClientConn bool  // whether this is a ClientConn (outside any pool)
+       readLimit    int64 // bytes allowed to be read; owned by readLoop
        // writeErrCh passes the request write error (usually nil)
        // from the writeLoop goroutine to the readLoop which passes
        // it off to the res.Body reader, which then uses it to decide
@@ -2108,9 +2147,13 @@ type persistConn struct {
 
        mu                   sync.Mutex // guards following fields
        numExpectedResponses int
-       closed               error // set non-nil when conn is closed, before closech is closed
-       canceledErr          error // set non-nil if conn is canceled
-       reused               bool  // whether conn has had successful request/response and is being reused.
+       closed               error  // set non-nil when conn is closed, before closech is closed
+       canceledErr          error  // set non-nil if conn is canceled
+       reused               bool   // whether conn has had successful request/response and is being reused.
+       reserved             bool   // ClientConn only: concurrency slot reserved
+       inFlight             bool   // ClientConn only: request is in flight
+       internalStateHook    func() // ClientConn state hook
+
        // mutateHeaderFunc is an optional func to modify extra
        // headers on each outbound request before it's written. (the
        // original Request given to RoundTrip is not modified)
@@ -2250,6 +2293,9 @@ func (pc *persistConn) readLoop() {
        defer func() {
                pc.close(closeErr)
                pc.t.removeIdleConn(pc)
+               if pc.internalStateHook != nil {
+                       pc.internalStateHook()
+               }
        }()
 
        tryPutIdleConn := func(treq *transportRequest) bool {
@@ -2753,9 +2799,32 @@ var (
        testHookReadLoopBeforeNextRead             = nop
 )
 
+func (pc *persistConn) waitForAvailability(ctx context.Context) error {
+       select {
+       case <-pc.availch:
+               return nil
+       case <-pc.closech:
+               return pc.closed
+       case <-ctx.Done():
+               return ctx.Err()
+       }
+}
+
 func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
        testHookEnterRoundTrip()
+
        pc.mu.Lock()
+       if pc.isClientConn {
+               if !pc.reserved {
+                       pc.mu.Unlock()
+                       if err := pc.waitForAvailability(req.ctx); err != nil {
+                               return nil, err
+                       }
+                       pc.mu.Lock()
+               }
+               pc.reserved = false
+               pc.inFlight = true
+       }
        pc.numExpectedResponses++
        headerFn := pc.mutateHeaderFunc
        pc.mu.Unlock()