http2noDialOnMiss = false
)
+// shouldTraceGetConn reports whether getClientConn should call any
+// ClientTrace.GetConn hook associated with the http.Request.
+//
+// This complexity is needed to avoid double calls of the GetConn hook
+// during the back-and-forth between net/http and x/net/http2 (when the
+// net/http.Transport is upgraded to also speak http2), as well as support
+// the case where x/net/http2 is being used directly.
+func (p *http2clientConnPool) shouldTraceGetConn(st http2clientConnIdleState) bool {
+ // If our Transport wasn't made via ConfigureTransport, always
+ // trace the GetConn hook if provided, because that means the
+ // http2 package is being used directly and it's the one
+ // dialing, as opposed to net/http.
+ if _, ok := p.t.ConnPool.(http2noDialClientConnPool); !ok {
+ return true
+ }
+ // Otherwise, only use the GetConn hook if this connection has
+ // been used previously for other requests. For fresh
+ // connections, the net/http package does the dialing.
+ return !st.freshConn
+}
+
func (p *http2clientConnPool) getClientConn(req *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
if http2isConnectionCloseRequest(req) && dialOnMiss {
// It gets its own connection.
+ http2traceGetConn(req, addr)
const singleUse = true
cc, err := p.t.dialClientConn(addr, singleUse)
if err != nil {
}
p.mu.Lock()
for _, cc := range p.conns[addr] {
- if cc.CanTakeNewRequest() {
+ if st := cc.idleState(); st.canTakeNewRequest {
+ if p.shouldTraceGetConn(st) {
+ http2traceGetConn(req, addr)
+ }
p.mu.Unlock()
return cc, nil
}
p.mu.Unlock()
return nil, http2ErrNoCachedConn
}
+ http2traceGetConn(req, addr)
call := p.getStartDialLocked(addr)
p.mu.Unlock()
<-call.done
return buf.String()
}
+func http2traceHasWroteHeaderField(trace *http2clientTrace) bool {
+ return trace != nil && trace.WroteHeaderField != nil
+}
+
+func http2traceWroteHeaderField(trace *http2clientTrace, k, v string) {
+ if trace != nil && trace.WroteHeaderField != nil {
+ trace.WroteHeaderField(k, []string{v})
+ }
+}
+
func http2transportExpectContinueTimeout(t1 *Transport) time.Duration {
return t1.ExpectContinueTimeout
}
context.Context
}
+var http2errCanceled = context.Canceled
+
func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
func http2setResponseUncompressed(res *Response) { res.Uncompressed = true }
+func http2traceGetConn(req *Request, hostPort string) {
+ trace := httptrace.ContextClientTrace(req.Context())
+ if trace == nil || trace.GetConn == nil {
+ return
+ }
+ trace.GetConn(hostPort)
+}
+
func http2traceGotConn(req *Request, cc *http2ClientConn) {
trace := httptrace.ContextClientTrace(req.Context())
if trace == nil || trace.GotConn == nil {
return cc.ping(ctx)
}
+// Shutdown gracefully closes the client connection, waiting for running streams to complete.
+func (cc *http2ClientConn) Shutdown(ctx context.Context) error {
+ return cc.shutdown(ctx)
+}
+
func http2cloneTLSConfig(c *tls.Config) *tls.Config {
c2 := c.Clone()
c2.GetClientCertificate = c.GetClientCertificate // golang.org/issue/19264
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow http2flow // our conn-level flow control quota (cs.flow is per stream)
inflow http2flow // peer's conn-level flow control
+ closing bool
closed bool
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
return cc.canTakeNewRequestLocked()
}
-func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+// clientConnIdleState describes the suitability of a client
+// connection to initiate a new RoundTrip request.
+type http2clientConnIdleState struct {
+ canTakeNewRequest bool
+ freshConn bool // whether it's unused by any previous request
+}
+
+func (cc *http2ClientConn) idleState() http2clientConnIdleState {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.idleStateLocked()
+}
+
+func (cc *http2ClientConn) idleStateLocked() (st http2clientConnIdleState) {
if cc.singleUse && cc.nextStreamID > 1 {
- return false
+ return
}
- return cc.goAway == nil && !cc.closed &&
+ st.canTakeNewRequest = cc.goAway == nil && !cc.closed && !cc.closing &&
int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
+ st.freshConn = cc.nextStreamID == 1 && st.canTakeNewRequest
+ return
+}
+
+func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+ st := cc.idleStateLocked()
+ return st.canTakeNewRequest
}
// onIdleTimeout is called from a time.AfterFunc goroutine. It will
cc.tconn.Close()
}
+var http2shutdownEnterWaitStateHook = func() {}
+
+// Shutdown gracefully close the client connection, waiting for running streams to complete.
+// Public implementation is in go17.go and not_go17.go
+func (cc *http2ClientConn) shutdown(ctx http2contextContext) error {
+ if err := cc.sendGoAway(); err != nil {
+ return err
+ }
+ // Wait for all in-flight streams to complete or connection to close
+ done := make(chan error, 1)
+ cancelled := false // guarded by cc.mu
+ go func() {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if len(cc.streams) == 0 || cc.closed {
+ cc.closed = true
+ done <- cc.tconn.Close()
+ break
+ }
+ if cancelled {
+ break
+ }
+ cc.cond.Wait()
+ }
+ }()
+ http2shutdownEnterWaitStateHook()
+ select {
+ case err := <-done:
+ return err
+ case <-ctx.Done():
+ cc.mu.Lock()
+ // Free the goroutine above
+ cancelled = true
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+ return ctx.Err()
+ }
+}
+
+func (cc *http2ClientConn) sendGoAway() error {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if cc.closing {
+ // GOAWAY sent already
+ return nil
+ }
+ // Send a graceful shutdown frame to server
+ maxStreamID := cc.nextStreamID
+ if err := cc.fr.WriteGoAway(maxStreamID, http2ErrCodeNo, nil); err != nil {
+ return err
+ }
+ if err := cc.bw.Flush(); err != nil {
+ return err
+ }
+ // Prevent new requests
+ cc.closing = true
+ return nil
+}
+
+// Close closes the client connection immediately.
+//
+// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
+func (cc *http2ClientConn) Close() error {
+ cc.mu.Lock()
+ defer cc.cond.Broadcast()
+ defer cc.mu.Unlock()
+ err := errors.New("http2: client connection force closed via ClientConn.Close")
+ for id, cs := range cc.streams {
+ select {
+ case cs.resc <- http2resAndError{err: err}:
+ default:
+ }
+ cs.bufPipe.CloseWithError(err)
+ delete(cc.streams, id)
+ }
+ cc.closed = true
+ return cc.tconn.Close()
+}
+
const http2maxAllocFrameSize = 512 << 10
// frameBuffer returns a scratch buffer suitable for writing DATA frames.
if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") {
return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv)
}
- if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") {
+ if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && !strings.EqualFold(vv[0], "close") && !strings.EqualFold(vv[0], "keep-alive")) {
return fmt.Errorf("http2: invalid Connection request header: %q", vv)
}
return nil
return nil, http2errRequestHeaderListSize
}
+ trace := http2requestTrace(req)
+ traceHeaders := http2traceHasWroteHeaderField(trace)
+
// Header list size is ok. Write the headers.
enumerateHeaders(func(name, value string) {
- cc.writeHeader(strings.ToLower(name), value)
+ name = strings.ToLower(name)
+ cc.writeHeader(name, value)
+ if traceHeaders {
+ http2traceWroteHeaderField(trace, name, value)
+ }
})
return cc.hbuf.Bytes(), nil
func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
defer afterTest(t)
const resBody = "some body"
- gotWroteReqEvent := make(chan struct{})
+ gotWroteReqEvent := make(chan struct{}, 500)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.Method == "GET" {
+ // Do nothing for the second request.
+ return
+ }
if _, err := ioutil.ReadAll(r.Body); err != nil {
t.Error(err)
}
Got100Continue: func() { logf("Got100Continue") },
WroteRequest: func(e httptrace.WroteRequestInfo) {
logf("WroteRequest: %+v", e)
- close(gotWroteReqEvent)
+ gotWroteReqEvent <- struct{}{}
},
}
if h2 {
if t.Failed() {
t.Errorf("Output:\n%s", got)
}
+
+ // And do a second request:
+ req, _ = NewRequest("GET", cst.scheme()+"://dns-is-faked.golang:"+port, nil)
+ req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
+ res, err = cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if res.StatusCode != 200 {
+ t.Fatal(res.Status)
+ }
+ res.Body.Close()
+
+ mu.Lock()
+ got = buf.String()
+ mu.Unlock()
+
+ sub := "Getting conn for dns-is-faked.golang:"
+ if gotn, want := strings.Count(got, sub), 2; gotn != want {
+ t.Errorf("substring %q appeared %d times; want %d. Log:\n%s", sub, gotn, want, got)
+ }
+
}
func TestTransportEventTraceTLSVerify(t *testing.T) {