]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: update bundled x/net/http2 for Server context changes
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 18 May 2016 17:59:12 +0000 (17:59 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 18 May 2016 21:20:10 +0000 (21:20 +0000)
Updates x/net/http2 to golang.org/cl/23220
(http2: with Go 1.7 set Request.Context in ServeHTTP handlers)

Fixes #15134

Change-Id: I73bac2601118614528f051e85dab51dc48e74f41
Reviewed-on: https://go-review.googlesource.com/23221
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
src/net/http/h2_bundle.go
src/net/http/serve_test.go

index 21b10355a9d5b0a0e0e05f5d4fb3c8adb12faad0..22047c582634c80993aae968d78864aac382b8a0 100644 (file)
@@ -1974,6 +1974,27 @@ func http2summarizeFrame(f http2Frame) string {
        return buf.String()
 }
 
+type http2contextContext interface {
+       context.Context
+}
+
+func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) {
+       ctx, cancel = context.WithCancel(context.Background())
+       ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr())
+       if hs := opts.baseConfig(); hs != nil {
+               ctx = context.WithValue(ctx, ServerContextKey, hs)
+       }
+       return
+}
+
+func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) {
+       return context.WithCancel(ctx)
+}
+
+func http2requestWithContext(req *Request, ctx http2contextContext) *Request {
+       return req.WithContext(ctx)
+}
+
 type http2clientTrace httptrace.ClientTrace
 
 func http2reqContext(r *Request) context.Context { return r.Context() }
@@ -2994,10 +3015,14 @@ func (o *http2ServeConnOpts) handler() Handler {
 //
 // The opts parameter is optional. If nil, default values are used.
 func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
+       baseCtx, cancel := http2serverConnBaseContext(c, opts)
+       defer cancel()
+
        sc := &http2serverConn{
                srv:              s,
                hs:               opts.baseConfig(),
                conn:             c,
+               baseCtx:          baseCtx,
                remoteAddrStr:    c.RemoteAddr().String(),
                bw:               http2newBufferedWriter(c),
                handler:          opts.handler(),
@@ -3016,6 +3041,7 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) {
                serveG:            http2newGoroutineLock(),
                pushEnabled:       true,
        }
+
        sc.flow.add(http2initialWindowSize)
        sc.inflow.add(http2initialWindowSize)
        sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
@@ -3088,6 +3114,7 @@ type http2serverConn struct {
        conn             net.Conn
        bw               *http2bufferedWriter // writing to conn
        handler          Handler
+       baseCtx          http2contextContext
        framer           *http2Framer
        doneServing      chan struct{}              // closed when serverConn.serve ends
        readFrameCh      chan http2readFrameResult  // written by serverConn.readFrames
@@ -3151,10 +3178,12 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 {
 // responseWriter's state field.
 type http2stream struct {
        // immutable:
-       sc   *http2serverConn
-       id   uint32
-       body *http2pipe       // non-nil if expecting DATA frames
-       cw   http2closeWaiter // closed wait stream transitions to closed state
+       sc        *http2serverConn
+       id        uint32
+       body      *http2pipe       // non-nil if expecting DATA frames
+       cw        http2closeWaiter // closed wait stream transitions to closed state
+       ctx       http2contextContext
+       cancelCtx func()
 
        // owned by serverConn's serve loop:
        bodyBytes        int64        // body bytes seen so far
@@ -3818,6 +3847,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
        }
        if st != nil {
                st.gotReset = true
+               st.cancelCtx()
                sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode})
        }
        return nil
@@ -3997,10 +4027,13 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error {
        }
        sc.maxStreamID = id
 
+       ctx, cancelCtx := http2contextWithCancel(sc.baseCtx)
        st = &http2stream{
-               sc:    sc,
-               id:    id,
-               state: http2stateOpen,
+               sc:        sc,
+               id:        id,
+               state:     http2stateOpen,
+               ctx:       ctx,
+               cancelCtx: cancelCtx,
        }
        if f.StreamEnded() {
                st.state = http2stateHalfClosedRemote
@@ -4208,6 +4241,7 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead
                Body:       body,
                Trailer:    trailer,
        }
+       req = http2requestWithContext(req, st.ctx)
        if bodyOpen {
 
                buf := make([]byte, http2initialWindowSize)
@@ -4250,6 +4284,7 @@ func (sc *http2serverConn) getRequestBodyBuf() []byte {
 func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
        didPanic := true
        defer func() {
+               rw.rws.stream.cancelCtx()
                if didPanic {
                        e := recover()
                        // Same as net/http:
index e398c92638eb24acd326f79156ece7864ae35cc7..c32ff299029a49ed74be0f48dbb6a7bec5beac6c 100644 (file)
@@ -4064,10 +4064,16 @@ func TestServerValidatesHeaders(t *testing.T) {
        }
 }
 
-func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
+func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) {
+       testServerRequestContextCancel_ServeHTTPDone(t, h1Mode)
+}
+func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) {
+       testServerRequestContextCancel_ServeHTTPDone(t, h2Mode)
+}
+func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) {
        defer afterTest(t)
        ctxc := make(chan context.Context, 1)
-       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+       cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
                ctx := r.Context()
                select {
                case <-ctx.Done():
@@ -4076,8 +4082,8 @@ func TestServerRequestContextCancel_ServeHTTPDone(t *testing.T) {
                }
                ctxc <- ctx
        }))
-       defer ts.Close()
-       res, err := Get(ts.URL)
+       defer cst.close()
+       res, err := cst.c.Get(cst.ts.URL)
        if err != nil {
                t.Fatal(err)
        }
@@ -4130,9 +4136,15 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) {
        }
 }
 
-func TestServerContext_ServerContextKey(t *testing.T) {
+func TestServerContext_ServerContextKey_h1(t *testing.T) {
+       testServerContext_ServerContextKey(t, h1Mode)
+}
+func TestServerContext_ServerContextKey_h2(t *testing.T) {
+       testServerContext_ServerContextKey(t, h2Mode)
+}
+func testServerContext_ServerContextKey(t *testing.T, h2 bool) {
        defer afterTest(t)
-       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+       cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
                ctx := r.Context()
                got := ctx.Value(ServerContextKey)
                if _, ok := got.(*Server); !ok {
@@ -4140,12 +4152,14 @@ func TestServerContext_ServerContextKey(t *testing.T) {
                }
 
                got = ctx.Value(LocalAddrContextKey)
-               if _, ok := got.(net.Addr); !ok {
+               if addr, ok := got.(net.Addr); !ok {
                        t.Errorf("local addr value = %T; want net.Addr", got)
+               } else if fmt.Sprint(addr) != r.Host {
+                       t.Errorf("local addr = %v; want %v", addr, r.Host)
                }
        }))
-       defer ts.Close()
-       res, err := Get(ts.URL)
+       defer cst.close()
+       res, err := cst.c.Get(cst.ts.URL)
        if err != nil {
                t.Fatal(err)
        }