]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Server BaseContext & ConnContext fields to control early context
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 15 Mar 2019 14:58:55 +0000 (14:58 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 15 Apr 2019 19:15:11 +0000 (19:15 +0000)
Fixes golang/go#30694

Change-Id: I12a0a870e4aee6576e879d88a4868666ef448298
Reviewed-on: https://go-review.googlesource.com/c/go/+/167681
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: JP Sugarbroad <jpsugar@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/serve_test.go
src/net/http/server.go

index ea6d7c2fda2a1a9f04790e37e0c0310d218c4506..f10a4272abb4423dc83c74eba07e8f56957c1148 100644 (file)
@@ -6034,6 +6034,43 @@ func TestStripPortFromHost(t *testing.T) {
        }
 }
 
+func TestServerContexts(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+       type baseKey struct{}
+       type connKey struct{}
+       ch := make(chan context.Context, 1)
+       ts := httptest.NewUnstartedServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+               ch <- r.Context()
+       }))
+       ts.Config.BaseContext = func(ln net.Listener) context.Context {
+               if strings.Contains(reflect.TypeOf(ln).String(), "onceClose") {
+                       t.Errorf("unexpected onceClose listener type %T", ln)
+               }
+               return context.WithValue(context.Background(), baseKey{}, "base")
+       }
+       ts.Config.ConnContext = func(ctx context.Context, c net.Conn) context.Context {
+               if got, want := ctx.Value(baseKey{}), "base"; got != want {
+                       t.Errorf("in ConnContext, base context key = %#v; want %q", got, want)
+               }
+               return context.WithValue(ctx, connKey{}, "conn")
+       }
+       ts.Start()
+       defer ts.Close()
+       res, err := ts.Client().Get(ts.URL)
+       if err != nil {
+               t.Fatal(err)
+       }
+       res.Body.Close()
+       ctx := <-ch
+       if got, want := ctx.Value(baseKey{}), "base"; got != want {
+               t.Errorf("base context key = %#v; want %q", got, want)
+       }
+       if got, want := ctx.Value(connKey{}), "conn"; got != want {
+               t.Errorf("conn context key = %#v; want %q", got, want)
+       }
+}
+
 func BenchmarkResponseStatusLine(b *testing.B) {
        b.ReportAllocs()
        b.RunParallel(func(pb *testing.PB) {
index 14f74285c12adc8fb10117d6c5835376a4fa6cfd..bc6d93bce096a3807d1be56f58f4b3fab11d3d54 100644 (file)
@@ -2542,6 +2542,20 @@ type Server struct {
        // If nil, logging is done via the log package's standard logger.
        ErrorLog *log.Logger
 
+       // BaseContext optionally specifies a function that returns
+       // the base context for incoming requests on this server.
+       // The provided Listener is the specific Listener that's
+       // about to start accepting requests.
+       // If BaseContext is nil, the default is context.Background().
+       // If non-nil, it must return a non-nil context.
+       BaseContext func(net.Listener) context.Context
+
+       // ConnContext optionally specifies a function that modifies
+       // the context used for a newly connection c. The provided ctx
+       // is derived from the base context and has a ServerContextKey
+       // value.
+       ConnContext func(ctx context.Context, c net.Conn) context.Context
+
        disableKeepAlives int32     // accessed atomically.
        inShutdown        int32     // accessed atomically (non-zero means we're in Shutdown)
        nextProtoOnce     sync.Once // guards setupHTTP2_* init
@@ -2838,6 +2852,7 @@ func (srv *Server) Serve(l net.Listener) error {
                fn(srv, l) // call hook with unwrapped listener
        }
 
+       origListener := l
        l = &onceCloseListener{Listener: l}
        defer l.Close()
 
@@ -2850,8 +2865,16 @@ func (srv *Server) Serve(l net.Listener) error {
        }
        defer srv.trackListener(&l, false)
 
-       var tempDelay time.Duration     // how long to sleep on accept failure
-       baseCtx := context.Background() // base is always background, per Issue 16220
+       var tempDelay time.Duration // how long to sleep on accept failure
+
+       baseCtx := context.Background()
+       if srv.BaseContext != nil {
+               baseCtx = srv.BaseContext(origListener)
+               if baseCtx == nil {
+                       panic("BaseContext returned a nil context")
+               }
+       }
+
        ctx := context.WithValue(baseCtx, ServerContextKey, srv)
        for {
                rw, e := l.Accept()
@@ -2876,6 +2899,12 @@ func (srv *Server) Serve(l net.Listener) error {
                        }
                        return e
                }
+               if cc := srv.ConnContext; cc != nil {
+                       ctx = cc(ctx, rw)
+                       if ctx == nil {
+                               panic("ConnContext returned nil")
+                       }
+               }
                tempDelay = 0
                c := srv.newConn(rw)
                c.setState(c.rwc, StateNew) // before Serve can return