]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptest: add EnableHTTP2 to Server
authorEmmanuel T Odeke <emmanuel@orijtech.com>
Wed, 16 Oct 2019 23:05:24 +0000 (16:05 -0700)
committerEmmanuel Odeke <emm.odeke@gmail.com>
Fri, 18 Oct 2019 19:29:10 +0000 (19:29 +0000)
Adds a knob EnableHTTP2, that enables an unstarted
Server and its respective client to speak HTTP/2,
but only after StartTLS has been invoked.

Fixes #34939

Change-Id: I287c568b8708a4d3c03e7d9eca7c323b8f4c65b6
Reviewed-on: https://go-review.googlesource.com/c/go/+/201557
Run-TryBot: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/http/httptest/example_test.go
src/net/http/httptest/server.go
src/net/http/httptest/server_test.go

index e3d392130e0fa66554d440a6748d1addfb08f643..54e77dbb84c6882bb7540e7711720b70b936e857 100644 (file)
@@ -55,6 +55,28 @@ func ExampleServer() {
        // Output: Hello, client
 }
 
+func ExampleServer_hTTP2() {
+       ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+               fmt.Fprintf(w, "Hello, %s", r.Proto)
+       }))
+       ts.EnableHTTP2 = true
+       ts.StartTLS()
+       defer ts.Close()
+
+       res, err := ts.Client().Get(ts.URL)
+       if err != nil {
+               log.Fatal(err)
+       }
+       greeting, err := ioutil.ReadAll(res.Body)
+       res.Body.Close()
+       if err != nil {
+               log.Fatal(err)
+       }
+       fmt.Printf("%s", greeting)
+
+       // Output: Hello, HTTP/2.0
+}
+
 func ExampleNewTLSServer() {
        ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                fmt.Fprintln(w, "Hello, client")
index b4e2e9266e685db3a3e48b1cabce041e60812b37..65165d9eb3272a7fcaf71daca0e6575681abcf80 100644 (file)
@@ -27,6 +27,11 @@ type Server struct {
        URL      string // base URL of form http://ipaddr:port with no trailing slash
        Listener net.Listener
 
+       // EnableHTTP2 controls whether HTTP/2 is enabled
+       // on the server. It must be set between calling
+       // NewUnstartedServer and calling Server.StartTLS.
+       EnableHTTP2 bool
+
        // TLS is the optional TLS configuration, populated with a new config
        // after TLS is started. If set on an unstarted server before StartTLS
        // is called, existing fields are copied into the new config.
@@ -151,7 +156,11 @@ func (s *Server) StartTLS() {
                s.TLS = new(tls.Config)
        }
        if s.TLS.NextProtos == nil {
-               s.TLS.NextProtos = []string{"http/1.1"}
+               nextProtos := []string{"http/1.1"}
+               if s.EnableHTTP2 {
+                       nextProtos = []string{"h2"}
+               }
+               s.TLS.NextProtos = nextProtos
        }
        if len(s.TLS.Certificates) == 0 {
                s.TLS.Certificates = []tls.Certificate{cert}
@@ -166,6 +175,7 @@ func (s *Server) StartTLS() {
                TLSClientConfig: &tls.Config{
                        RootCAs: certpool,
                },
+               ForceAttemptHTTP2: s.EnableHTTP2,
        }
        s.Listener = tls.NewListener(s.Listener, s.TLS)
        s.URL = "https://" + s.Listener.Addr().String()
index 8ab50cdb0abd39a72074823aa1ba20086f186d92..0aad15c5ed2129b4fca23fd122bd0a2ecb8f71a2 100644 (file)
@@ -202,3 +202,39 @@ func TestServerZeroValueClose(t *testing.T) {
 
        ts.Close() // tests that it doesn't panic
 }
+
+func TestTLSServerWithHTTP2(t *testing.T) {
+       modes := []struct {
+               name      string
+               wantProto string
+       }{
+               {"http1", "HTTP/1.1"},
+               {"http2", "HTTP/2.0"},
+       }
+
+       for _, tt := range modes {
+               t.Run(tt.name, func(t *testing.T) {
+                       cst := NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+                               w.Header().Set("X-Proto", r.Proto)
+                       }))
+
+                       switch tt.name {
+                       case "http2":
+                               cst.EnableHTTP2 = true
+                               cst.StartTLS()
+                       default:
+                               cst.Start()
+                       }
+
+                       defer cst.Close()
+
+                       res, err := cst.Client().Get(cst.URL)
+                       if err != nil {
+                               t.Fatalf("Failed to make request: %v", err)
+                       }
+                       if g, w := res.Header.Get("X-Proto"), tt.wantProto; g != w {
+                               t.Fatalf("X-Proto header mismatch:\n\tgot:  %q\n\twant: %q", g, w)
+                       }
+               })
+       }
+}