//
// See https://http2.golang.org/ for a test server running this code.
//
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+//
package http
return err
}
+// testSyncHooks coordinates goroutines in tests.
+//
+// For example, a call to ClientConn.RoundTrip involves several goroutines, including:
+// - the goroutine running RoundTrip;
+// - the clientStream.doRequest goroutine, which writes the request; and
+// - the clientStream.readLoop goroutine, which reads the response.
+//
+// Using testSyncHooks, a test can start a RoundTrip and identify when all these goroutines
+// are blocked waiting for some condition such as reading the Request.Body or waiting for
+// flow control to become available.
+//
+// The testSyncHooks also manage timers and synthetic time in tests.
+// This permits us to, for example, start a request and cause it to time out waiting for
+// response headers without resorting to time.Sleep calls.
+type http2testSyncHooks struct {
+ // active/inactive act as a mutex and condition variable.
+ //
+ // - neither chan contains a value: testSyncHooks is locked.
+ // - active contains a value: unlocked, and at least one goroutine is not blocked
+ // - inactive contains a value: unlocked, and all goroutines are blocked
+ active chan struct{}
+ inactive chan struct{}
+
+ // goroutine counts
+ total int // total goroutines
+ condwait map[*sync.Cond]int // blocked in sync.Cond.Wait
+ blocked []*http2testBlockedGoroutine // otherwise blocked
+
+ // fake time
+ now time.Time
+ timers []*http2fakeTimer
+
+ // Transport testing: Report various events.
+ newclientconn func(*http2ClientConn)
+ newstream func(*http2clientStream)
+}
+
+// testBlockedGoroutine is a blocked goroutine.
+type http2testBlockedGoroutine struct {
+ f func() bool // blocked until f returns true
+ ch chan struct{} // closed when unblocked
+}
+
+func http2newTestSyncHooks() *http2testSyncHooks {
+ h := &http2testSyncHooks{
+ active: make(chan struct{}, 1),
+ inactive: make(chan struct{}, 1),
+ condwait: map[*sync.Cond]int{},
+ }
+ h.inactive <- struct{}{}
+ h.now = time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC)
+ return h
+}
+
+// lock acquires the testSyncHooks mutex.
+func (h *http2testSyncHooks) lock() {
+ select {
+ case <-h.active:
+ case <-h.inactive:
+ }
+}
+
+// waitInactive waits for all goroutines to become inactive.
+func (h *http2testSyncHooks) waitInactive() {
+ for {
+ <-h.inactive
+ if !h.unlock() {
+ break
+ }
+ }
+}
+
+// unlock releases the testSyncHooks mutex.
+// It reports whether any goroutines are active.
+func (h *http2testSyncHooks) unlock() (active bool) {
+ // Look for a blocked goroutine which can be unblocked.
+ blocked := h.blocked[:0]
+ unblocked := false
+ for _, b := range h.blocked {
+ if !unblocked && b.f() {
+ unblocked = true
+ close(b.ch)
+ } else {
+ blocked = append(blocked, b)
+ }
+ }
+ h.blocked = blocked
+
+ // Count goroutines blocked on condition variables.
+ condwait := 0
+ for _, count := range h.condwait {
+ condwait += count
+ }
+
+ if h.total > condwait+len(blocked) {
+ h.active <- struct{}{}
+ return true
+ } else {
+ h.inactive <- struct{}{}
+ return false
+ }
+}
+
+// goRun starts a new goroutine.
+func (h *http2testSyncHooks) goRun(f func()) {
+ h.lock()
+ h.total++
+ h.unlock()
+ go func() {
+ defer func() {
+ h.lock()
+ h.total--
+ h.unlock()
+ }()
+ f()
+ }()
+}
+
+// blockUntil indicates that a goroutine is blocked waiting for some condition to become true.
+// It waits until f returns true before proceeding.
+//
+// Example usage:
+//
+// h.blockUntil(func() bool {
+// // Is the context done yet?
+// select {
+// case <-ctx.Done():
+// default:
+// return false
+// }
+// return true
+// })
+// // Wait for the context to become done.
+// <-ctx.Done()
+//
+// The function f passed to blockUntil must be non-blocking and idempotent.
+func (h *http2testSyncHooks) blockUntil(f func() bool) {
+ if f() {
+ return
+ }
+ ch := make(chan struct{})
+ h.lock()
+ h.blocked = append(h.blocked, &http2testBlockedGoroutine{
+ f: f,
+ ch: ch,
+ })
+ h.unlock()
+ <-ch
+}
+
+// broadcast is sync.Cond.Broadcast.
+func (h *http2testSyncHooks) condBroadcast(cond *sync.Cond) {
+ h.lock()
+ delete(h.condwait, cond)
+ h.unlock()
+ cond.Broadcast()
+}
+
+// broadcast is sync.Cond.Wait.
+func (h *http2testSyncHooks) condWait(cond *sync.Cond) {
+ h.lock()
+ h.condwait[cond]++
+ h.unlock()
+}
+
+// newTimer creates a new fake timer.
+func (h *http2testSyncHooks) newTimer(d time.Duration) http2timer {
+ h.lock()
+ defer h.unlock()
+ t := &http2fakeTimer{
+ hooks: h,
+ when: h.now.Add(d),
+ c: make(chan time.Time),
+ }
+ h.timers = append(h.timers, t)
+ return t
+}
+
+// afterFunc creates a new fake AfterFunc timer.
+func (h *http2testSyncHooks) afterFunc(d time.Duration, f func()) http2timer {
+ h.lock()
+ defer h.unlock()
+ t := &http2fakeTimer{
+ hooks: h,
+ when: h.now.Add(d),
+ f: f,
+ }
+ h.timers = append(h.timers, t)
+ return t
+}
+
+func (h *http2testSyncHooks) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ ctx, cancel := context.WithCancel(ctx)
+ t := h.afterFunc(d, cancel)
+ return ctx, func() {
+ t.Stop()
+ cancel()
+ }
+}
+
+func (h *http2testSyncHooks) timeUntilEvent() time.Duration {
+ h.lock()
+ defer h.unlock()
+ var next time.Time
+ for _, t := range h.timers {
+ if next.IsZero() || t.when.Before(next) {
+ next = t.when
+ }
+ }
+ if d := next.Sub(h.now); d > 0 {
+ return d
+ }
+ return 0
+}
+
+// advance advances time and causes synthetic timers to fire.
+func (h *http2testSyncHooks) advance(d time.Duration) {
+ h.lock()
+ defer h.unlock()
+ h.now = h.now.Add(d)
+ timers := h.timers[:0]
+ for _, t := range h.timers {
+ t := t // remove after go.mod depends on go1.22
+ t.mu.Lock()
+ switch {
+ case t.when.After(h.now):
+ timers = append(timers, t)
+ case t.when.IsZero():
+ // stopped timer
+ default:
+ t.when = time.Time{}
+ if t.c != nil {
+ close(t.c)
+ }
+ if t.f != nil {
+ h.total++
+ go func() {
+ defer func() {
+ h.lock()
+ h.total--
+ h.unlock()
+ }()
+ t.f()
+ }()
+ }
+ }
+ t.mu.Unlock()
+ }
+ h.timers = timers
+}
+
+// A timer wraps a time.Timer, or a synthetic equivalent in tests.
+// Unlike time.Timer, timer is single-use: The timer channel is closed when the timer expires.
+type http2timer interface {
+ C() <-chan time.Time
+ Stop() bool
+ Reset(d time.Duration) bool
+}
+
+// timeTimer implements timer using real time.
+type http2timeTimer struct {
+ t *time.Timer
+ c chan time.Time
+}
+
+// newTimeTimer creates a new timer using real time.
+func http2newTimeTimer(d time.Duration) http2timer {
+ ch := make(chan time.Time)
+ t := time.AfterFunc(d, func() {
+ close(ch)
+ })
+ return &http2timeTimer{t, ch}
+}
+
+// newTimeAfterFunc creates an AfterFunc timer using real time.
+func http2newTimeAfterFunc(d time.Duration, f func()) http2timer {
+ return &http2timeTimer{
+ t: time.AfterFunc(d, f),
+ }
+}
+
+func (t http2timeTimer) C() <-chan time.Time { return t.c }
+
+func (t http2timeTimer) Stop() bool { return t.t.Stop() }
+
+func (t http2timeTimer) Reset(d time.Duration) bool { return t.t.Reset(d) }
+
+// fakeTimer implements timer using fake time.
+type http2fakeTimer struct {
+ hooks *http2testSyncHooks
+
+ mu sync.Mutex
+ when time.Time // when the timer will fire
+ c chan time.Time // closed when the timer fires; mutually exclusive with f
+ f func() // called when the timer fires; mutually exclusive with c
+}
+
+func (t *http2fakeTimer) C() <-chan time.Time { return t.c }
+
+func (t *http2fakeTimer) Stop() bool {
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ stopped := t.when.IsZero()
+ t.when = time.Time{}
+ return stopped
+}
+
+func (t *http2fakeTimer) Reset(d time.Duration) bool {
+ if t.c != nil || t.f == nil {
+ panic("fakeTimer only supports Reset on AfterFunc timers")
+ }
+ t.mu.Lock()
+ defer t.mu.Unlock()
+ t.hooks.lock()
+ defer t.hooks.unlock()
+ active := !t.when.IsZero()
+ t.when = t.hooks.now.Add(d)
+ if !active {
+ t.hooks.timers = append(t.hooks.timers, t)
+ }
+ return active
+}
+
const (
// transportDefaultConnFlow is how many connection-level flow control
// tokens we give the server at start-up, past the default 64k.
connPoolOnce sync.Once
connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
+
+ syncHooks *http2testSyncHooks
}
func (t *http2Transport) maxHeaderListSize() uint32 {
werr error // first write error that has occurred
hbuf bytes.Buffer // HPACK encoder writes into this
henc *hpack.Encoder
+
+ syncHooks *http2testSyncHooks // can be nil
+}
+
+// Hook points used for testing.
+// Outside of tests, cc.syncHooks is nil and these all have minimal implementations.
+// Inside tests, see the testSyncHooks function docs.
+
+// goRun starts a new goroutine.
+func (cc *http2ClientConn) goRun(f func()) {
+ if cc.syncHooks != nil {
+ cc.syncHooks.goRun(f)
+ return
+ }
+ go f()
+}
+
+// condBroadcast is cc.cond.Broadcast.
+func (cc *http2ClientConn) condBroadcast() {
+ if cc.syncHooks != nil {
+ cc.syncHooks.condBroadcast(cc.cond)
+ }
+ cc.cond.Broadcast()
+}
+
+// condWait is cc.cond.Wait.
+func (cc *http2ClientConn) condWait() {
+ if cc.syncHooks != nil {
+ cc.syncHooks.condWait(cc.cond)
+ }
+ cc.cond.Wait()
+}
+
+// newTimer creates a new time.Timer, or a synthetic timer in tests.
+func (cc *http2ClientConn) newTimer(d time.Duration) http2timer {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.newTimer(d)
+ }
+ return http2newTimeTimer(d)
+}
+
+// afterFunc creates a new time.AfterFunc timer, or a synthetic timer in tests.
+func (cc *http2ClientConn) afterFunc(d time.Duration, f func()) http2timer {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.afterFunc(d, f)
+ }
+ return http2newTimeAfterFunc(d, f)
+}
+
+func (cc *http2ClientConn) contextWithTimeout(ctx context.Context, d time.Duration) (context.Context, context.CancelFunc) {
+ if cc.syncHooks != nil {
+ return cc.syncHooks.contextWithTimeout(ctx, d)
+ }
+ return context.WithTimeout(ctx, d)
}
// clientStream is the state for a single HTTP/2 stream. One of these
// TODO(dneil): Clean up tests where cs.cc.cond is nil.
if cs.cc.cond != nil {
// Wake up writeRequestBody if it is waiting on flow control.
- cs.cc.cond.Broadcast()
+ cs.cc.condBroadcast()
}
}
defer cc.mu.Unlock()
if cs.reqBody != nil && cs.reqBodyClosed == nil {
cs.closeReqBodyLocked()
- cc.cond.Broadcast()
+ cc.condBroadcast()
}
}
}
cs.reqBodyClosed = make(chan struct{})
reqBodyClosed := cs.reqBodyClosed
- go func() {
+ cs.cc.goRun(func() {
cs.reqBody.Close()
close(reqBodyClosed)
- }()
+ })
}
type http2stickyErrWriter struct {
return net.JoinHostPort(host, port)
}
-var http2retryBackoffHook func(time.Duration) *time.Timer
-
-func http2backoffNewTimer(d time.Duration) *time.Timer {
- if http2retryBackoffHook != nil {
- return http2retryBackoffHook(d)
- }
- return time.NewTimer(d)
-}
-
// RoundTripOpt is like RoundTrip, but takes options.
func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) {
backoff := float64(uint(1) << (uint(retry) - 1))
backoff += backoff * (0.1 * mathrand.Float64())
d := time.Second * time.Duration(backoff)
- timer := http2backoffNewTimer(d)
+ var tm http2timer
+ if t.syncHooks != nil {
+ tm = t.syncHooks.newTimer(d)
+ t.syncHooks.blockUntil(func() bool {
+ select {
+ case <-tm.C():
+ case <-req.Context().Done():
+ default:
+ return false
+ }
+ return true
+ })
+ } else {
+ tm = http2newTimeTimer(d)
+ }
select {
- case <-timer.C:
+ case <-tm.C():
t.vlogf("RoundTrip retrying after failure: %v", roundTripErr)
continue
case <-req.Context().Done():
- timer.Stop()
+ tm.Stop()
err = req.Context().Err()
}
}
}
func (t *http2Transport) dialClientConn(ctx context.Context, addr string, singleUse bool) (*http2ClientConn, error) {
+ if t.syncHooks != nil {
+ return t.newClientConn(nil, singleUse, t.syncHooks)
+ }
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
if err != nil {
return nil, err
}
- return t.newClientConn(tconn, singleUse)
+ return t.newClientConn(tconn, singleUse, nil)
}
func (t *http2Transport) newTLSConfig(host string) *tls.Config {
}
func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
- return t.newClientConn(c, t.disableKeepAlives())
+ return t.newClientConn(c, t.disableKeepAlives(), nil)
}
-func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2ClientConn, error) {
+func (t *http2Transport) newClientConn(c net.Conn, singleUse bool, hooks *http2testSyncHooks) (*http2ClientConn, error) {
cc := &http2ClientConn{
t: t,
tconn: c,
wantSettingsAck: true,
pings: make(map[[8]byte]chan struct{}),
reqHeaderMu: make(chan struct{}, 1),
+ syncHooks: hooks,
+ }
+ if hooks != nil {
+ hooks.newclientconn(cc)
+ c = cc.tconn
}
if d := t.idleConnTimeout(); d != 0 {
cc.idleTimeout = d
return nil, cc.werr
}
- go cc.readLoop()
+ cc.goRun(cc.readLoop)
return cc, nil
}
pingTimeout := cc.t.pingTimeout()
// We don't need to periodically ping in the health check, because the readLoop of ClientConn will
// trigger the healthCheck again if there is no frame received.
- ctx, cancel := context.WithTimeout(context.Background(), pingTimeout)
+ ctx, cancel := cc.contextWithTimeout(context.Background(), pingTimeout)
defer cancel()
cc.vlogf("http2: Transport sending health check")
err := cc.Ping(ctx)
// Wait for all in-flight streams to complete or connection to close
done := make(chan struct{})
cancelled := false // guarded by cc.mu
- go func() {
+ cc.goRun(func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if cancelled {
break
}
- cc.cond.Wait()
+ cc.condWait()
}
- }()
+ })
http2shutdownEnterWaitStateHook()
select {
case <-done:
cc.mu.Lock()
// Free the goroutine above
cancelled = true
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
return ctx.Err()
}
for _, cs := range cc.streams {
cs.abortStreamLocked(err)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
cc.closeConn()
}
}
func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
+ return cc.roundTrip(req, nil)
+}
+
+func (cc *http2ClientConn) roundTrip(req *Request, streamf func(*http2clientStream)) (*Response, error) {
ctx := req.Context()
cs := &http2clientStream{
cc: cc,
respHeaderRecv: make(chan struct{}),
donec: make(chan struct{}),
}
- go cs.doRequest(req)
+ cc.goRun(func() {
+ cs.doRequest(req)
+ })
waitDone := func() error {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.donec:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.donec:
return nil
return err
}
+ if streamf != nil {
+ streamf(cs)
+ }
+
for {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.respHeaderRecv:
+ case <-cs.abort:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.respHeaderRecv:
return handleResponseHeaders()
if cc.reqHeaderMu == nil {
panic("RoundTrip on uninitialized ClientConn") // for tests
}
+ var newStreamHook func(*http2clientStream)
+ if cc.syncHooks != nil {
+ newStreamHook = cc.syncHooks.newstream
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case cc.reqHeaderMu <- struct{}{}:
+ <-cc.reqHeaderMu
+ case <-cs.reqCancel:
+ case <-ctx.Done():
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case cc.reqHeaderMu <- struct{}{}:
case <-cs.reqCancel:
}
cc.mu.Unlock()
+ if newStreamHook != nil {
+ newStreamHook(cs)
+ }
+
// TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere?
if !cc.t.disableCompression() &&
req.Header.Get("Accept-Encoding") == "" &&
var respHeaderTimer <-chan time.Time
var respHeaderRecv chan struct{}
if d := cc.responseHeaderTimeout(); d != 0 {
- timer := time.NewTimer(d)
+ timer := cc.newTimer(d)
defer timer.Stop()
- respHeaderTimer = timer.C
+ respHeaderTimer = timer.C()
respHeaderRecv = cs.respHeaderRecv
}
// Wait until the peer half-closes its end of the stream,
// or until the request is aborted (via context, error, or otherwise),
// whichever comes first.
for {
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-cs.peerClosed:
+ case <-respHeaderTimer:
+ case <-respHeaderRecv:
+ case <-cs.abort:
+ case <-ctx.Done():
+ case <-cs.reqCancel:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-cs.peerClosed:
return nil
return nil
}
cc.pendingRequests++
- cc.cond.Wait()
+ cc.condWait()
cc.pendingRequests--
select {
case <-cs.abort:
cs.flow.take(take)
return take, nil
}
- cc.cond.Wait()
+ cc.condWait()
}
}
+func http2validateHeaders(hdrs Header) string {
+ for k, vv := range hdrs {
+ if !httpguts.ValidHeaderFieldName(k) {
+ return fmt.Sprintf("name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error,
+ // because it may be sensitive.
+ return fmt.Sprintf("value for header %q", k)
+ }
+ }
+ }
+ return ""
+}
+
var http2errNilRequestURL = errors.New("http2: Request.URI is nil")
// requires cc.wmu be held.
}
}
- // Check for any invalid headers and return an error before we
+ // Check for any invalid headers+trailers and return an error before we
// potentially pollute our hpack state. (We want to be able to
// continue to reuse the hpack encoder for future requests)
- for k, vv := range req.Header {
- if !httpguts.ValidHeaderFieldName(k) {
- return nil, fmt.Errorf("invalid HTTP header name %q", k)
- }
- for _, v := range vv {
- if !httpguts.ValidHeaderFieldValue(v) {
- // Don't include the value in the error, because it may be sensitive.
- return nil, fmt.Errorf("invalid HTTP header value for header %q", k)
- }
- }
+ if err := http2validateHeaders(req.Header); err != "" {
+ return nil, fmt.Errorf("invalid HTTP header %s", err)
+ }
+ if err := http2validateHeaders(req.Trailer); err != "" {
+ return nil, fmt.Errorf("invalid HTTP trailer %s", err)
}
enumerateHeaders := func(f func(name, value string)) {
}
// Wake up writeRequestBody via clientStream.awaitFlowControl and
// wake up RoundTrip if there is a pending request.
- cc.cond.Broadcast()
+ cc.condBroadcast()
closeOnIdle := cc.singleUse || cc.doNotReuse || cc.t.disableKeepAlives() || cc.goAway != nil
if closeOnIdle && cc.streamsReserved == 0 && len(cc.streams) == 0 {
cs.abortStreamLocked(err)
}
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.mu.Unlock()
}
cc := rl.cc
gotSettings := false
readIdleTimeout := cc.t.ReadIdleTimeout
- var t *time.Timer
+ var t http2timer
if readIdleTimeout != 0 {
- t = time.AfterFunc(readIdleTimeout, cc.healthCheck)
- defer t.Stop()
+ t = cc.afterFunc(readIdleTimeout, cc.healthCheck)
}
for {
f, err := cc.fr.ReadFrame()
})
return nil
}
- if !cs.firstByte {
+ if !cs.pastHeaders {
cc.logf("protocol error: received DATA before a HEADERS frame")
rl.endStreamError(cs, http2StreamError{
StreamID: f.StreamID,
for _, cs := range cc.streams {
cs.flow.add(delta)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
cc.initialWindowSize = s.Val
case http2SettingHeaderTableSize:
return http2ConnectionError(http2ErrCodeFlowControl)
}
- cc.cond.Broadcast()
+ cc.condBroadcast()
return nil
}
}
cc.mu.Unlock()
}
- errc := make(chan error, 1)
- go func() {
+ var pingError error
+ errc := make(chan struct{})
+ cc.goRun(func() {
cc.wmu.Lock()
defer cc.wmu.Unlock()
- if err := cc.fr.WritePing(false, p); err != nil {
- errc <- err
+ if pingError = cc.fr.WritePing(false, p); pingError != nil {
+ close(errc)
return
}
- if err := cc.bw.Flush(); err != nil {
- errc <- err
+ if pingError = cc.bw.Flush(); pingError != nil {
+ close(errc)
return
}
- }()
+ })
+ if cc.syncHooks != nil {
+ cc.syncHooks.blockUntil(func() bool {
+ select {
+ case <-c:
+ case <-errc:
+ case <-ctx.Done():
+ case <-cc.readerDone:
+ default:
+ return false
+ }
+ return true
+ })
+ }
select {
case <-c:
return nil
- case err := <-errc:
- return err
+ case <-errc:
+ return pingError
case <-ctx.Done():
return ctx.Err()
case <-cc.readerDone: