From: Brad Fitzpatrick Date: Thu, 27 Oct 2016 00:29:43 +0000 (+0000) Subject: net/http: update bundled http2 X-Git-Tag: go1.8beta1~554 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=07e72666ec6da774860c6f796c3adbdc57c92480;p=gostls13.git net/http: update bundled http2 Updates http2 to x/net git rev b626cca for: http2: implement support for server push https://golang.org/cl/29439 http2: reject stream self-dependencies https://golang.org/cl/31858 http2: optimize server frame writes https://golang.org/cl/31495 http2: interface to support pluggable schedulers https://golang.org/cl/25366 (no user-visible behavior change or API surface) http2: add Server.IdleTimeout https://golang.org/cl/31727 http2: make Server return conn protocol errors on bad idle stream frames https://golang.org/cl/31736 http2: fix optimized write scheduling https://golang.org/cl/32217 (fix for CL 31495 above) Change-Id: Ie894c72943d355115c8391573bf6b96dc1bd5894 Reviewed-on: https://go-review.googlesource.com/32215 Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot Reviewed-by: Tom Bergan Reviewed-by: Brad Fitzpatrick --- diff --git a/src/net/http/h2_bundle.go b/src/net/http/h2_bundle.go index 9d6d3caef6..f69623c1f5 100644 --- a/src/net/http/h2_bundle.go +++ b/src/net/http/h2_bundle.go @@ -2161,6 +2161,18 @@ func (cc *http2ClientConn) Ping(ctx context.Context) error { func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } +var _ Pusher = (*http2responseWriter)(nil) + +// Push implements http.Pusher. +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { + internalOpts := http2pushOptions{} + if opts != nil { + internalOpts.Method = opts.Method + internalOpts.Header = opts.Header + } + return w.push(target, internalOpts) +} + var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -2426,13 +2438,23 @@ var ( type http2streamState int +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. const ( http2stateIdle http2streamState = iota http2stateOpen http2stateHalfClosedLocal http2stateHalfClosedRemote - http2stateResvLocal - http2stateResvRemote http2stateClosed ) @@ -2441,8 +2463,6 @@ var http2stateName = [...]string{ http2stateOpen: "Open", http2stateHalfClosedLocal: "HalfClosedLocal", http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateResvLocal: "ResvLocal", - http2stateResvRemote: "ResvRemote", http2stateClosed: "Closed", } @@ -2603,13 +2623,27 @@ func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { return &http2bufferedWriter{w: w} } +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + var http2bufWriterPool = sync.Pool{ New: func() interface{} { - - return bufio.NewWriterSize(nil, 4<<10) + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) }, } +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := http2bufWriterPool.Get().(*bufio.Writer) @@ -2912,6 +2946,15 @@ type http2Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler } func (s *http2Server) maxReadFrameSize() uint32 { @@ -3044,29 +3087,35 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { defer cancel() sc := &http2serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2frameWriteMsg, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), - bodyReadCh: make(chan http2bodyReadMsg), - doneServing: make(chan struct{}), - advMaxStreams: s.maxConcurrentStreams(), - writeSched: http2writeScheduler{ - maxFrameSize: http2initialMaxFrameSize, - }, + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + wantStartPushCh: make(chan http2startPushRequest, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), + bodyReadCh: make(chan http2bodyReadMsg), + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, + advMaxStreams: s.maxConcurrentStreams(), initialWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, headerTableSize: http2initialHeaderTableSize, serveG: http2newGoroutineLock(), pushEnabled: true, } + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3120,16 +3169,18 @@ type http2serverConn struct { handler Handler baseCtx http2contextContext framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wantStartPushCh chan http2startPushRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string + writeSched http2WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): serveG http2goroutineLock // used to verify funcs are on serve() @@ -3139,21 +3190,27 @@ type http2serverConn struct { unackedSettings int // how many SETTINGS have we sent without ACKs? clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curOpenStreams uint32 // client's number of open streams - maxStreamID uint32 // max ever seen + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxStreamID uint32 // max ever seen from client + maxPushPromiseID uint32 // ID of the last push promise, or 0 if there have been no pushes streams map[uint32]*http2stream initialWindowSize int32 + maxFrameSize int32 headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush - writeSched http2writeScheduler - inGoAway bool // we've started to or sent GOAWAY - needToSendGoAway bool // we need to schedule a GOAWAY frame write + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -3357,17 +3414,17 @@ func (sc *http2serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type http2frameWriteResult struct { - wm http2frameWriteMsg // what was written (or attempted) - err error // result of the writeFrame call + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) { - err := wm.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wm, err} +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr, err} } func (sc *http2serverConn) closeAllStreamsOnConnClose() { @@ -3411,7 +3468,7 @@ func (sc *http2serverConn) serve() { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeSettings{ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, @@ -3428,6 +3485,12 @@ func (sc *http2serverConn) serve() { sc.setConnState(StateActive) sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + go sc.readFrames() settingsTimer := time.NewTimer(http2firstSettingsTimeout) @@ -3435,8 +3498,10 @@ func (sc *http2serverConn) serve() { for { loopNum++ select { - case wm := <-sc.wantWriteFrameCh: - sc.writeFrame(wm) + case wr := <-sc.wantWriteFrameCh: + sc.writeFrame(wr) + case spr := <-sc.wantStartPushCh: + sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -3456,6 +3521,9 @@ func (sc *http2serverConn) serve() { case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } @@ -3506,7 +3574,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte ch := http2errChanPool.Get().(chan error) writeArg := http2writeDataPool.Get().(*http2writeData) *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2frameWriteMsg{ + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -3536,17 +3604,17 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte return err } -// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts // if the connection has gone away. // // This must not be run from the serve goroutine itself, else it might // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { sc.serveG.checkNotOn() select { - case sc.wantWriteFrameCh <- wm: + case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: @@ -3562,36 +3630,36 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { sc.serveG.check() var ignoreWrite bool - switch wm.write.(type) { + switch wr.write.(type) { case *http2writeResHeaders: - wm.stream.wroteHeaders = true + wr.stream.wroteHeaders = true case http2write100ContinueHeadersFrame: - if wm.stream.wroteHeaders { + if wr.stream.wroteHeaders { ignoreWrite = true } } if !ignoreWrite { - sc.writeSched.add(wm) + sc.writeSched.Push(wr) } sc.scheduleFrameWrite() } -// startFrameWrite starts a goroutine to write wm (in a separate +// startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wm. -func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") } - st := wm.stream + st := wr.stream if st != nil { switch st.state { case http2stateHalfClosedLocal: @@ -3602,13 +3670,31 @@ func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { sc.scheduleFrameWrite() return } - panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + if wr.done != nil { + wr.done <- err + } + return } } sc.writingFrame = true sc.needsFrameFlush = true - go sc.writeFrameAsync(wm) + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr, err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } } // errHandlerPanicked is the error given to any callers blocked in a read from @@ -3624,24 +3710,25 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { panic("internal error: expected to be already writing a frame") } sc.writingFrame = false + sc.writingFrameAsync = false - wm := res.wm - st := wm.stream + wr := res.wr + st := wr.stream - closeStream := http2endsStream(wm.write) + closeStream := http2endsStream(wr.write) - if _, ok := wm.write.(http2handlerPanicRST); ok { + if _, ok := wr.write.(http2handlerPanicRST); ok { sc.closeStream(st, http2errHandlerPanicked) } - if ch := wm.done; ch != nil { + if ch := wr.done; ch != nil { select { case ch <- res.err: default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) } } - wm.write = nil + wr.write = nil if closeStream { if st == nil { @@ -3675,35 +3762,40 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // flush the write buffer. func (sc *http2serverConn) scheduleFrameWrite() { sc.serveG.check() - if sc.writingFrame { - return - } - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(http2frameWriteMsg{ - write: &http2writeGoAway{ - maxStreamID: sc.maxStreamID, - code: sc.goAwayCode, - }, - }) - return - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}}) + if sc.writingFrame || sc.inFrameScheduleLoop { return } - if !sc.inGoAway { - if wm, ok := sc.writeSched.take(); ok { - sc.startFrameWrite(wm) - return + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxStreamID, + code: sc.goAwayCode, + }, + }) + continue } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway { + if wr, ok := sc.writeSched.Pop(); ok { + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false + continue + } + break } - if sc.needsFrameFlush { - sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false - return - } + sc.inFrameScheduleLoop = false } func (sc *http2serverConn) goAway(code http2ErrCode) { @@ -3731,7 +3823,7 @@ func (sc *http2serverConn) shutDownIn(d time.Duration) { func (sc *http2serverConn) resetStream(se http2StreamError) { sc.serveG.check() - sc.writeFrame(http2frameWriteMsg{write: se}) + sc.writeFrame(http2FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { st.sentReset = true sc.closeStream(st, se) @@ -3830,15 +3922,25 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}}) + if sc.inGoAway { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) return nil } func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } switch { case f.StreamID != 0: - st := sc.streams[f.StreamID] + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } if st == nil { return nil @@ -3857,6 +3959,9 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } state, st := sc.state(f.StreamID) if state == http2stateIdle { @@ -3877,11 +3982,18 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed - sc.curOpenStreams-- - if sc.curOpenStreams == 0 { + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- + } + if sc.curClientStreams+sc.curPushedStreams == 0 { sc.setConnState(StateIdle) } delete(sc.streams, st.id) + if len(sc.streams) == 0 && sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } if p := st.body; p != nil { sc.sendWindowUpdate(nil, p.Len()) @@ -3889,7 +4001,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { p.CloseWithError(err) } st.cw.Close() - sc.writeSched.forgetStream(st.id) + sc.writeSched.CloseStream(st.id) } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3902,6 +4014,9 @@ func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { } return nil } + if sc.inGoAway { + return nil + } if err := f.ForeachSetting(sc.processSetting); err != nil { return err } @@ -3929,7 +4044,7 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { case http2SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case http2SettingMaxFrameSize: - sc.writeSched.maxFrameSize = s.Val + sc.maxFrameSize = int32(s.Val) case http2SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: @@ -3958,11 +4073,18 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() + if sc.inGoAway { + return nil + } data := f.Data() id := f.Header().StreamID - st, ok := sc.streams[id] - if !ok || st.state != http2stateOpen || st.gotTrailerHeader { + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil || state != http2stateOpen || st.gotTrailerHeader { if sc.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) @@ -4010,6 +4132,11 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { return nil } +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). func (st *http2stream) endStream() { @@ -4039,7 +4166,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() - id := f.Header().StreamID + id := f.StreamID if sc.inGoAway { return nil @@ -4049,8 +4176,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - st := sc.streams[f.Header().StreamID] - if st != nil { + if st := sc.streams[f.StreamID]; st != nil { return st.processTrailerHeaders(f) } @@ -4059,40 +4185,30 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { } sc.maxStreamID = id - ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) - st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, - ctx: ctx, - cancelCtx: cancelCtx, - } - if f.StreamEnded() { - st.state = http2stateHalfClosedRemote + if sc.idleTimer != nil { + sc.idleTimer.Stop() } - st.cw.Init() - st.flow.conn = &sc.flow - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow - st.inflow.add(http2initialWindowSize) + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { - sc.streams[id] = st - if f.HasPriority() { - http2adjustStreamPriority(sc.streams, st.id, f.Priority) - } - sc.curOpenStreams++ - if sc.curOpenStreams == 1 { - sc.setConnState(StateActive) + return http2streamError(id, http2ErrCodeProtocol) + } + + return http2streamError(id, http2ErrCodeRefusedStream) } - if sc.curOpenStreams > sc.advMaxStreams { - if sc.unackedSettings == 0 { + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) - return http2streamError(st.id, http2ErrCodeProtocol) + if f.HasPriority() { + if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + return err } - - return http2streamError(st.id, http2ErrCodeRefusedStream) + sc.writeSched.AdjustStream(st.id, f.Priority) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -4110,7 +4226,7 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if f.Truncated { handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2Request(req); err != nil { + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = http2new400Handler(err) } @@ -4150,90 +4266,138 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam) +func http2checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + + return http2streamError(streamID, http2ErrCodeProtocol) + } return nil } -func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) { - st, ok := streams[streamID] - if !ok { +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} - return +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") } - st.weight = priority.Weight - parent := streams[priority.StreamDep] - if parent == st { - return + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, } + st.cw.Init() + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow + st.inflow.add(http2initialWindowSize) - for piter := parent; piter != nil; piter = piter.parent { - if piter == st { - parent.parent = st.parent - break - } + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ } - st.parent = parent - if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { - for _, openStream := range streams { - if openStream != st && openStream.parent == st.parent { - openStream.parent = st - } - } + if sc.curClientStreams+sc.curPushedStreams == 1 { + sc.setConnState(StateActive) } + + return st } func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - method := f.PseudoValue("method") - path := f.PseudoValue("path") - scheme := f.PseudoValue("scheme") - authority := f.PseudoValue("authority") + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } - isConnect := method == "CONNECT" + isConnect := rp.method == "CONNECT" if isConnect { - if path != "" || scheme != "" || authority == "" { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - } else if method == "" || path == "" || - (scheme != "https" && scheme != "http") { + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() - if method == "HEAD" && bodyOpen { + if rp.method == "HEAD" && bodyOpen { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - var tlsState *tls.ConnectionState // nil if not scheme https - if scheme == "https" { - tlsState = sc.tlsState + rp.header = make(Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") } - header := make(Header) - for _, hf := range f.RegularFields() { - header.Add(sc.canonicalHeader(hf.Name), hf.Value) + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err } + if bodyOpen { + st.reqBuf = http2getRequestBodyBuf() + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2fixedBuffer{buf: st.reqBuf}, + } + + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } + } + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) { + sc.serveG.check() - if authority == "" { - authority = header.Get("Host") + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState } - needsContinue := header.Get("Expect") == "100-continue" + + needsContinue := rp.header.Get("Expect") == "100-continue" if needsContinue { - header.Del("Expect") + rp.header.Del("Expect") } - if cookies := header["Cookie"]; len(cookies) > 1 { - header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range header["Trailer"] { + for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -4247,53 +4411,42 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } } } - delete(header, "Trailer") + delete(rp.header, "Trailer") - body := &http2requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } var url_ *url.URL var requestURI string - if isConnect { - url_ = &url.URL{Host: authority} - requestURI = authority + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority } else { var err error - url_, err = url.ParseRequestURI(path) + url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) } - requestURI = path + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, } req := &Request{ - Method: method, + Method: rp.method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: header, + Header: rp.header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: authority, + Host: rp.authority, Body: body, Trailer: trailer, } req = http2requestWithContext(req, st.ctx) - if bodyOpen { - st.reqBuf = http2getRequestBodyBuf() - body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: st.reqBuf}, - } - - if vv, ok := header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) - } else { - req.ContentLength = -1 - } - } rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw @@ -4338,7 +4491,7 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) @@ -4370,7 +4523,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR errc = http2errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(http2frameWriteMsg{ + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -4393,7 +4546,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR // called from handler goroutines. func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2write100ContinueHeadersFrame{st.id}, stream: st, }) @@ -4463,7 +4616,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) @@ -4844,74 +4997,236 @@ func (w *http2responseWriter) handlerDone() { http2responseWriterStatePool.Put(rws) } -// foreachHeaderElement splits v according to the "#rule" construction -// in RFC 2616 section 2.1 and calls fn for each non-empty element. -func http2foreachHeaderElement(v string, fn func(string)) { - v = textproto.TrimString(v) - if v == "" { - return +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +// pushOptions is the internal version of http.PushOptions, which we +// cannot include here because it's only defined in Go 1.8 and later. +type http2pushOptions struct { + Method string + Header Header +} + +func (w *http2responseWriter) push(target string, opts http2pushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + if st.isPushed() { + return http2ErrRecursivePush } - if !strings.Contains(v, ",") { - fn(v) - return + + if opts.Method == "" { + opts.Method = "GET" } - for _, f := range strings.Split(v, ",") { - if f = textproto.TrimString(f); f != "" { - fn(f) - } + if opts.Header == nil { + opts.Header = Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" } -} -// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 -var http2connHeaders = []string{ - "Connection", - "Keep-Alive", - "Proxy-Connection", - "Transfer-Encoding", - "Upgrade", -} + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include psuedo header %q", k) + } -// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, -// per RFC 7540 Section 8.1.2.2. -// The returned error is reported to users. -func http2checkValidHTTP2Request(req *Request) error { - for _, h := range http2connHeaders { - if _, ok := req.Header[h]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", h) + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) } } - te := req.Header["Te"] - if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { - return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err } - return nil -} -func http2new400Handler(err error) HandlerFunc { - return func(w ResponseWriter, r *Request) { - Error(w, err.Error(), StatusBadRequest) + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) } -} -// ValidTrailerHeader reports whether name is a valid header field name to appear -// in trailers. -// See: http://tools.ietf.org/html/rfc7230#section-4.1.2 -func http2ValidTrailerHeader(name string) bool { - name = CanonicalHeaderKey(name) - if strings.HasPrefix(name, "If-") || http2badTrailer[name] { - return false + msg := http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), } - return true -} -var http2badTrailer = map[string]bool{ - "Authorization": true, - "Cache-Control": true, - "Connection": true, - "Content-Encoding": true, - "Content-Length": true, - "Content-Range": true, - "Content-Type": true, + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.wantStartPushCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header Header + done chan error +} + +func (sc *http2serverConn) startPush(msg http2startPushRequest) { + sc.serveG.check() + + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + + msg.done <- http2errStreamClosed + return + } + + if !sc.pushEnabled { + msg.done <- ErrNotSupported + return + } + + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + if !sc.pushEnabled { + return 0, ErrNotSupported + } + + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: msg.header, + }) + if err != nil { + + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + +// foreachHeaderElement splits v according to the "#rule" construction +// in RFC 2616 section 2.1 and calls fn for each non-empty element. +func http2foreachHeaderElement(v string, fn func(string)) { + v = textproto.TrimString(v) + if v == "" { + return + } + if !strings.Contains(v, ",") { + fn(v) + return + } + for _, f := range strings.Split(v, ",") { + if f = textproto.TrimString(f); f != "" { + fn(f) + } + } +} + +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2RequestHeaders(h Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) + } + } + te := h["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) HandlerFunc { + return func(w ResponseWriter, r *Request) { + Error(w, err.Error(), StatusBadRequest) + } +} + +// ValidTrailerHeader reports whether name is a valid header field name to appear +// in trailers. +// See: http://tools.ietf.org/html/rfc7230#section-4.1.2 +func http2ValidTrailerHeader(name string) bool { + name = CanonicalHeaderKey(name) + if strings.HasPrefix(name, "If-") || http2badTrailer[name] { + return false + } + return true +} + +var http2badTrailer = map[string]bool{ + "Authorization": true, + "Cache-Control": true, + "Connection": true, + "Content-Encoding": true, + "Content-Length": true, + "Content-Range": true, + "Content-Type": true, "Expect": true, "Host": true, "Keep-Alive": true, @@ -6852,6 +7167,11 @@ func http2isConnectionCloseRequest(req *Request) bool { // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool } // writeContext is the interface needed by the various frame writer @@ -6894,8 +7214,16 @@ func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { return ctx.Flush() } +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + type http2writeSettings []http2Setting +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + func (s http2writeSettings) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettings([]http2Setting(s)...) } @@ -6915,6 +7243,8 @@ func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { return err } +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } + type http2writeData struct { streamID uint32 p []byte @@ -6929,6 +7259,10 @@ func (w *http2writeData) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + // handlerPanicRST is the message sent from handler goroutines when // the handler panics. type http2handlerPanicRST struct { @@ -6939,22 +7273,59 @@ func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) } +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (se http2StreamError) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + type http2writePingAck struct{ pf *http2PingFrame } func (w http2writePingAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + type http2writeSettingsAck struct{} func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettingsAck() } +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. type http2writeResHeaders struct { @@ -6976,6 +7347,11 @@ func http2encKV(enc *hpack.Encoder, k, v string) { enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + + return false +} + func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() @@ -7001,39 +7377,69 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { panic("unexpected empty hpack") } - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - endHeaders := len(headerBlock) == 0 - var err error - if first { - first = false - err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: endHeaders, - }) - } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) - } - if err != nil { - return err - } +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } - return nil } type http2write100ContinueHeadersFrame struct { @@ -7052,15 +7458,24 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err }) } +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + + return 9+2*(len(":status")+len("100")) <= max +} + type http2writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only only if k is in keys. func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter) @@ -7090,14 +7505,51 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { } } -// frameWriteMsg is a request to write a frame. -type http2frameWriteMsg struct { +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { // write is the interface value that does the writing, once the - // writeScheduler (below) has decided to select this frame - // to write. The write functions are all defined in write.go. + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. write http2writeFramer - stream *http2stream // used for prioritization. nil for non-stream frames. + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *http2stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -7105,247 +7557,626 @@ type http2frameWriteMsg struct { done chan error } -// for debugging only: -func (wm http2frameWriteMsg) String() string { +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + return 0 + } + return wr.stream.id +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + + endStream: false, + }, + + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 + } + + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { var streamID uint32 - if wm.stream != nil { - streamID = wm.stream.id + if wr.stream != nil { + streamID = wr.stream.id } var des string - if s, ok := wm.write.(fmt.Stringer); ok { + if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() } else { - des = fmt.Sprintf("%T", wm.write) + des = fmt.Sprintf("%T", wr.write) } - return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", streamID, wr.done != nil, des) } -// writeScheduler tracks pending frames to write, priorities, and decides -// the next one to use. It is not thread-safe. -type http2writeScheduler struct { - // zero are frames not associated with a specific stream. - // They're sent before any stream-specific freams. - zero http2writeQueue +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} - // maxFrameSize is the maximum size of a DATA frame - // we'll write. Must be non-zero and between 16K-16M. - maxFrameSize uint32 +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } - // sq contains the stream-specific queues, keyed by stream ID. - // when a stream is idle, it's deleted from the map. - sq map[uint32]*http2writeQueue +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} - // canSend is a slice of memory that's reused between frame - // scheduling decisions to hold the list of writeQueues (from sq) - // which have enough flow control data to send. After canSend is - // built, the best is selected. - canSend []*http2writeQueue +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] - // pool of empty queues for reuse. - queuePool []*http2writeQueue + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr } -func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) { - if len(q.s) != 0 { - panic("queue must be empty") +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false } - ws.queuePool = append(ws.queuePool, q) + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true +} + +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) } -func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue { - ln := len(ws.queuePool) +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) if ln == 0 { return new(http2writeQueue) } - q := ws.queuePool[ln-1] - ws.queuePool = ws.queuePool[:ln-1] + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] return q } -func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 -func (ws *http2writeScheduler) add(wm http2frameWriteMsg) { - st := wm.stream - if st == nil { - ws.zero.push(wm) +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 } else { - ws.streamQueue(st.id).push(wm) + ws.writeThrottleLimit = math.MaxInt32 } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings } -func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue { - if q, ok := ws.sq[streamID]; ok { - return q +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") + } + if n.parent == parent { + return } - if ws.sq == nil { - ws.sq = make(map[uint32]*http2writeQueue) + + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n } - q := ws.getEmptyQueue() - ws.sq[streamID] = q - return q } -// take returns the most important frame to write and removes it from the scheduler. -// It is illegal to call this if the scheduler is empty or if there are no connection-level -// flow control bytes available. -func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) { - if ws.maxFrameSize == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b } +} - if !ws.zero.empty() { - return ws.zero.shift(), true +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this funcion returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true } - if len(ws.sq) == 0 { - return + if n.kids == nil { + return false } - for id, q := range ws.sq { - if q.firstIsNoCost() { - return ws.takeFrom(id, q) + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) + } + + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break + } + } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } } + return false } - if len(ws.canSend) != 0 { - panic("should be empty") + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) } - for _, q := range ws.sq { - if n := ws.streamWritableBytes(q); n > 0 { - ws.canSend = append(ws.canSend, q) + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true } } - if len(ws.canSend) == 0 { - return + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk } - defer ws.zeroCanSend() + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} + +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode + + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 - q := ws.canSend[0] + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode - return ws.takeFrom(q.streamID(), q) + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// zeroCanSend is defered from take. -func (ws *http2writeScheduler) zeroCanSend() { - for i := range ws.canSend { - ws.canSend[i] = nil +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID } - ws.canSend = ws.canSend[:0] } -// streamWritableBytes returns the number of DATA bytes we could write -// from the given queue's stream, if this stream/queue were -// selected. It is an error to call this if q's head isn't a -// *writeData. -func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 { - wm := q.head() - ret := wm.stream.flow.available() - if ret == 0 { - return 0 +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") } - if int32(ws.maxFrameSize) < ret { - ret = int32(ws.maxFrameSize) + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ret == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } - wd := wm.write.(*http2writeData) - if len(wd.p) < int(ret) { - ret = int32(len(wd.p)) + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) } - return ret } -func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) { - wm = q.head() - - if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 { - allowed := wm.stream.flow.available() - if allowed == 0 { +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } - return http2frameWriteMsg{}, false + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return } - if int32(ws.maxFrameSize) < allowed { - allowed = int32(ws.maxFrameSize) + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } - if len(wd.p) > int(allowed) { - wm.stream.flow.take(allowed) - chunk := wd.p[:allowed] - wd.p = wd.p[allowed:] + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } - return http2frameWriteMsg{ - stream: wm.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: chunk, + if n == parent { + return + } - endStream: false, - }, + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } - done: nil, - }, true + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next } - wm.stream.flow.take(int32(len(wd.p))) } - q.shift() - if q.empty() { - ws.putEmptyQueue(q) - delete(ws.sq, id) + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + panic("add on non-open stream") + } } - return wm, true + n.q.push(wr) } -func (ws *http2writeScheduler) forgetStream(id uint32) { - q, ok := ws.sq[id] - if !ok { +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { return } - delete(ws.sq, id) + if len(*list) == maxSize { - for i := range q.s { - q.s[i] = http2frameWriteMsg{} + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] } - q.s = q.s[:0] - ws.putEmptyQueue(q) + *list = append(*list, n) } -type http2writeQueue struct { - s []http2frameWriteMsg +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) } -// streamID returns the stream ID for a non-empty stream-specific queue. -func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id } +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle or closed, it's deleted from the map. + sq map[uint32]*http2writeQueue -func (q *http2writeQueue) push(wm http2frameWriteMsg) { - q.s = append(q.s, wm) + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// head returns the next item that would be removed by shift. -func (q *http2writeQueue) head() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") - } - return q.s[0] +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + } -func (q *http2writeQueue) shift() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return } - wm := q.s[0] + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2frameWriteMsg{} - q.s = q.s[:len(q.s)-1] - return wm } -func (q *http2writeQueue) firstIsNoCost() bool { - if df, ok := q.s[0].write.(*http2writeData); ok { - return len(df.p) == 0 +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return } - return true + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + + if !ws.zero.empty() { + return ws.zero.shift(), true + } + + for _, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + return wr, true + } + } + return http2FrameWriteRequest{}, false }