]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Protocols field to Server and Transport
authorDamien Neil <dneil@google.com>
Wed, 29 May 2024 16:24:20 +0000 (09:24 -0700)
committerDamien Neil <dneil@google.com>
Tue, 5 Nov 2024 22:14:59 +0000 (22:14 +0000)
Support configuring which HTTP version(s) a server or client use
via an explicit set of protocols. The Protocols field takes
precedence over TLSNextProto and ForceAttemptHTTP2.

Fixes #67814

Change-Id: I09ece88f78ad4d98ca1f213157b5f62ae11e063f
Reviewed-on: https://go-review.googlesource.com/c/go/+/607496
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Jonathan Amsterdam <jba@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
api/next/67814.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/http/67814.md [new file with mode: 0644]
src/net/http/example_test.go
src/net/http/http.go
src/net/http/http_test.go
src/net/http/server.go
src/net/http/transport.go
src/net/http/transport_test.go

diff --git a/api/next/67814.txt b/api/next/67814.txt
new file mode 100644 (file)
index 0000000..05f5391
--- /dev/null
@@ -0,0 +1,8 @@
+pkg net/http, method (*Protocols) SetHTTP1(bool) #67814
+pkg net/http, method (*Protocols) SetHTTP2(bool) #67814
+pkg net/http, method (Protocols) String() string #67814
+pkg net/http, method (Protocols) HTTP1() bool #67814
+pkg net/http, method (Protocols) HTTP2() bool #67814
+pkg net/http, type Protocols struct #67814
+pkg net/http, type Server struct, Protocols *Protocols #67814
+pkg net/http, type Transport struct, Protocols *Protocols #67814
diff --git a/doc/next/6-stdlib/99-minor/net/http/67814.md b/doc/next/6-stdlib/99-minor/net/http/67814.md
new file mode 100644 (file)
index 0000000..902664d
--- /dev/null
@@ -0,0 +1,2 @@
+The new [Server.Protocols] and [Transport.Protocols] fields provide
+a simple way to configure what HTTP protocols a server or client use.
index 2f411d1d2ec23b3131567bc27bb9a614f08c92d0..f40273f14a2a24bc7b30ebbda4febdbae48142fb 100644 (file)
@@ -193,3 +193,31 @@ func ExampleNotFoundHandler() {
 
        log.Fatal(http.ListenAndServe(":8080", mux))
 }
+
+func ExampleProtocols_http1() {
+       srv := http.Server{
+               Addr: ":8443",
+       }
+
+       // Serve only HTTP/1.
+       srv.Protocols = new(http.Protocols)
+       srv.Protocols.SetHTTP1(true)
+
+       log.Fatal(srv.ListenAndServeTLS("cert.pem", "key.pem"))
+}
+
+func ExampleProtocols_http1or2() {
+       t := http.DefaultTransport.(*http.Transport).Clone()
+
+       // Use either HTTP/1 and HTTP/2.
+       t.Protocols = new(http.Protocols)
+       t.Protocols.SetHTTP1(true)
+       t.Protocols.SetHTTP2(true)
+
+       cli := &http.Client{Transport: t}
+       res, err := cli.Get("http://www.google.com/robots.txt")
+       if err != nil {
+               log.Fatal(err)
+       }
+       res.Body.Close()
+}
index 9dfc36c791323702f28181d239084dcbfcd2f3b1..55f518607d2808ece56cdc79e55d12a6781b3d94 100644 (file)
@@ -16,6 +16,54 @@ import (
        "golang.org/x/net/http/httpguts"
 )
 
+// Protocols is a set of HTTP protocols.
+//
+// The supported protocols are:
+//
+//   - HTTP1 is the HTTP/1.0 and HTTP/1.1 protocols.
+//     HTTP1 is supported on both unsecured TCP and secured TLS connections.
+//
+//   - HTTP2 is the HTTP/2 protcol over a TLS connection.
+type Protocols struct {
+       bits uint8
+}
+
+const (
+       protoHTTP1 = 1 << iota
+       protoHTTP2
+)
+
+// HTTP1 reports whether p includes HTTP/1.
+func (p Protocols) HTTP1() bool { return p.bits&protoHTTP1 != 0 }
+
+// SetHTTP1 adds or removes HTTP/1 from p.
+func (p *Protocols) SetHTTP1(ok bool) { p.setBit(protoHTTP1, ok) }
+
+// HTTP2 reports whether p includes HTTP/2.
+func (p Protocols) HTTP2() bool { return p.bits&protoHTTP2 != 0 }
+
+// SetHTTP2 adds or removes HTTP/2 from p.
+func (p *Protocols) SetHTTP2(ok bool) { p.setBit(protoHTTP2, ok) }
+
+func (p *Protocols) setBit(bit uint8, ok bool) {
+       if ok {
+               p.bits |= bit
+       } else {
+               p.bits &^= bit
+       }
+}
+
+func (p Protocols) String() string {
+       var s []string
+       if p.HTTP1() {
+               s = append(s, "HTTP1")
+       }
+       if p.HTTP2() {
+               s = append(s, "HTTP2")
+       }
+       return "{" + strings.Join(s, ",") + "}"
+}
+
 // incomparable is a zero-width, non-comparable type. Adding it to a struct
 // makes that struct also non-comparable, and generally doesn't add
 // any size (as long as it's first).
index 777634bbb25e0c9e489a0b8702a06a719165cb43..5aba3ed5a6e5363b4a9db91cb9b642b7a037c45b 100644 (file)
@@ -187,6 +187,28 @@ func TestNoUnicodeStrings(t *testing.T) {
        }
 }
 
+func TestProtocols(t *testing.T) {
+       var p Protocols
+       if p.HTTP1() {
+               t.Errorf("zero-value protocols: p.HTTP1() = true, want false")
+       }
+       p.SetHTTP1(true)
+       p.SetHTTP2(true)
+       if !p.HTTP1() {
+               t.Errorf("initialized protocols: p.HTTP1() = false, want true")
+       }
+       if !p.HTTP2() {
+               t.Errorf("initialized protocols: p.HTTP2() = false, want true")
+       }
+       p.SetHTTP1(false)
+       if p.HTTP1() {
+               t.Errorf("after unsetting HTTP1: p.HTTP1() = true, want false")
+       }
+       if !p.HTTP2() {
+               t.Errorf("after unsetting HTTP1: p.HTTP2() = false, want true")
+       }
+}
+
 const redirectURL = "/thisaredirect细雪withasciilettersのけぶabcdefghijk.html"
 
 func BenchmarkHexEscapeNonASCII(b *testing.B) {
index db44e7c5c241fe3eecb27d7181efa9bd0ca64a59..2c9774a7a581fad6fd310de02e17828e85883bec 100644 (file)
@@ -2979,6 +2979,13 @@ type Server struct {
        // See https://go.dev/issue/67813.
        HTTP2 *HTTP2Config
 
+       // Protocols is the set of protocols accepted by the server.
+       //
+       // If Protocols is nil, the default is usually HTTP/1 and HTTP/2.
+       // If TLSNextProto is non-nil and does not contain an "h2" entry,
+       // the default is HTTP/1 only.
+       Protocols *Protocols
+
        inShutdown atomic.Bool // true when server is in shutdown
 
        disableKeepAlives atomic.Bool
@@ -3389,9 +3396,7 @@ func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
        }
 
        config := cloneTLSConfig(s.TLSConfig)
-       if !slices.Contains(config.NextProtos, "http/1.1") {
-               config.NextProtos = append(config.NextProtos, "http/1.1")
-       }
+       config.NextProtos = adjustNextProtos(config.NextProtos, s.protocols())
 
        configHasCert := len(config.Certificates) > 0 || config.GetCertificate != nil || config.GetConfigForClient != nil
        if !configHasCert || certFile != "" || keyFile != "" {
@@ -3407,6 +3412,59 @@ func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
        return s.Serve(tlsListener)
 }
 
+func (s *Server) protocols() Protocols {
+       if s.Protocols != nil {
+               return *s.Protocols // user-configured set
+       }
+
+       // The historic way of disabling HTTP/2 is to set TLSNextProto to
+       // a non-nil map with no "h2" entry.
+       _, hasH2 := s.TLSNextProto["h2"]
+       http2Disabled := s.TLSNextProto != nil && !hasH2
+
+       // If GODEBUG=http2server=0, then HTTP/2 is disabled unless
+       // the user has manually added an "h2" entry to TLSNextProto
+       // (probably by using x/net/http2 directly).
+       if http2server.Value() == "0" && !hasH2 {
+               http2Disabled = true
+       }
+
+       var p Protocols
+       p.SetHTTP1(true) // default always includes HTTP/1
+       if !http2Disabled {
+               p.SetHTTP2(true)
+       }
+       return p
+}
+
+// adjustNextProtos adds or removes "http/1.1" and "h2" entries from
+// a tls.Config.NextProtos list, according to the set of protocols in protos.
+func adjustNextProtos(nextProtos []string, protos Protocols) []string {
+       var have Protocols
+       nextProtos = slices.DeleteFunc(nextProtos, func(s string) bool {
+               switch s {
+               case "http/1.1":
+                       if !protos.HTTP1() {
+                               return true
+                       }
+                       have.SetHTTP1(true)
+               case "h2":
+                       if !protos.HTTP2() {
+                               return true
+                       }
+                       have.SetHTTP2(true)
+               }
+               return false
+       })
+       if protos.HTTP2() && !have.HTTP2() {
+               nextProtos = append(nextProtos, "h2")
+       }
+       if protos.HTTP1() && !have.HTTP1() {
+               nextProtos = append(nextProtos, "http/1.1")
+       }
+       return nextProtos
+}
+
 // trackListener adds or removes a net.Listener to the set of tracked
 // listeners.
 //
@@ -3600,16 +3658,21 @@ func (s *Server) onceSetNextProtoDefaults() {
        if omitBundledHTTP2 {
                return
        }
+       if !s.protocols().HTTP2() {
+               return
+       }
        if http2server.Value() == "0" {
                http2server.IncNonDefault()
                return
        }
-       // Enable HTTP/2 by default if the user hasn't otherwise
-       // configured their TLSNextProto map.
-       if s.TLSNextProto == nil {
-               conf := &http2Server{}
-               s.nextProtoErr = http2ConfigureServer(s, conf)
+       if _, ok := s.TLSNextProto["h2"]; ok {
+               // TLSNextProto already contains an HTTP/2 implementation.
+               // The user probably called golang.org/x/net/http2.ConfigureServer
+               // to add it.
+               return
        }
+       conf := &http2Server{}
+       s.nextProtoErr = http2ConfigureServer(s, conf)
 }
 
 // TimeoutHandler returns a [Handler] that runs h with the given time limit.
index c980e727a66b546bd3d892a6e4e3b36e93024d61..a42533d2d51e317b749f6ad80723ce5f901a235b 100644 (file)
@@ -76,8 +76,7 @@ const DefaultMaxIdleConnsPerHost = 2
 // Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
 // for HTTPS URLs, depending on whether the server supports HTTP/2,
 // and how the Transport is configured. The [DefaultTransport] supports HTTP/2.
-// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2
-// and call ConfigureTransport. See the package docs for more about HTTP/2.
+// To explicitly enable HTTP/2 on a transport, set [Transport.Protocols].
 //
 // Responses with status codes in the 1xx range are either handled
 // automatically (100 expect-continue) or ignored. The one
@@ -300,6 +299,13 @@ type Transport struct {
        // This field does not yet have any effect.
        // See https://go.dev/issue/67813.
        HTTP2 *HTTP2Config
+
+       // Protocols is the set of protocols supported by the transport.
+       //
+       // If Protocols is nil, the default is usually HTTP/1 only.
+       // If ForceAttemptHTTP2 is true, or if TLSNextProto contains an "h2" entry,
+       // the default is HTTP/1 and HTTP/2.
+       Protocols *Protocols
 }
 
 func (t *Transport) writeBufferSize() int {
@@ -349,6 +355,10 @@ func (t *Transport) Clone() *Transport {
                t2.HTTP2 = &HTTP2Config{}
                *t2.HTTP2 = *t.HTTP2
        }
+       if t.Protocols != nil {
+               t2.Protocols = &Protocols{}
+               *t2.Protocols = *t.Protocols
+       }
        if !t.tlsNextProtoWasNil {
                npm := maps.Clone(t.TLSNextProto)
                if npm == nil {
@@ -399,18 +409,8 @@ func (t *Transport) onceSetNextProtoDefaults() {
                }
        }
 
-       if t.TLSNextProto != nil {
-               // This is the documented way to disable http2 on a
-               // Transport.
-               return
-       }
-       if !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()) {
-               // Be conservative and don't automatically enable
-               // http2 if they've specified a custom TLS config or
-               // custom dialers. Let them opt-in themselves via
-               // http2.ConfigureTransport so we don't surprise them
-               // by modifying their tls.Config. Issue 14275.
-               // However, if ForceAttemptHTTP2 is true, it overrides the above checks.
+       protocols := t.protocols()
+       if !protocols.HTTP2() {
                return
        }
        if omitBundledHTTP2 {
@@ -437,6 +437,40 @@ func (t *Transport) onceSetNextProtoDefaults() {
                        t2.MaxHeaderListSize = uint32(limit1)
                }
        }
+
+       // Server.ServeTLS clones the tls.Config before modifying it.
+       // Transport doesn't. We may want to make the two consistent some day.
+       //
+       // http2configureTransport will have already set NextProtos, but adjust it again
+       // here to remove HTTP/1.1 if the user has disabled it.
+       t.TLSClientConfig.NextProtos = adjustNextProtos(t.TLSClientConfig.NextProtos, protocols)
+}
+
+func (t *Transport) protocols() Protocols {
+       if t.Protocols != nil {
+               return *t.Protocols // user-configured set
+       }
+       var p Protocols
+       p.SetHTTP1(true) // default always includes HTTP/1
+       switch {
+       case t.TLSNextProto != nil:
+               // Setting TLSNextProto to an empty map is is a documented way
+               // to disable HTTP/2 on a Transport.
+               if t.TLSNextProto["h2"] != nil {
+                       p.SetHTTP2(true)
+               }
+       case !t.ForceAttemptHTTP2 && (t.TLSClientConfig != nil || t.Dial != nil || t.DialContext != nil || t.hasCustomTLSDialer()):
+               // Be conservative and don't automatically enable
+               // http2 if they've specified a custom TLS config or
+               // custom dialers. Let them opt-in themselves via
+               // Transport.Protocols.SetHTTP2(true) so we don't surprise them
+               // by modifying their tls.Config. Issue 14275.
+               // However, if ForceAttemptHTTP2 is true, it overrides the above checks.
+       case http2client.Value() == "0":
+       default:
+               p.SetHTTP2(true)
+       }
+       return p
 }
 
 // ProxyFromEnvironment returns the URL of the proxy to use for a
index 30a7e5eabd53d6a9ede4349200e00bb168992b31..9892fcaae501c7cf3c362f83cc877291921faf71 100644 (file)
@@ -6361,12 +6361,15 @@ func TestTransportClone(t *testing.T) {
                MaxResponseHeaderBytes: 1,
                ForceAttemptHTTP2:      true,
                HTTP2:                  &HTTP2Config{MaxConcurrentStreams: 1},
+               Protocols:              &Protocols{},
                TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{
                        "foo": func(authority string, c *tls.Conn) RoundTripper { panic("") },
                },
                ReadBufferSize:  1,
                WriteBufferSize: 1,
        }
+       tr.Protocols.SetHTTP1(true)
+       tr.Protocols.SetHTTP2(true)
        tr2 := tr.Clone()
        rv := reflect.ValueOf(tr2).Elem()
        rt := rv.Type()
@@ -7167,3 +7170,202 @@ func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
                })
        }
 }
+
+func TestTransportServerProtocols(t *testing.T) {
+       CondSkipHTTP2(t)
+       DefaultTransport.(*Transport).CloseIdleConnections()
+
+       cert, err := tls.X509KeyPair(testcert.LocalhostCert, testcert.LocalhostKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+       leafCert, err := x509.ParseCertificate(cert.Certificate[0])
+       if err != nil {
+               t.Fatal(err)
+       }
+       certpool := x509.NewCertPool()
+       certpool.AddCert(leafCert)
+
+       for _, test := range []struct {
+               name      string
+               scheme    string
+               setup     func(t *testing.T)
+               transport func(*Transport)
+               server    func(*Server)
+               want      string
+       }{{
+               name:   "http default",
+               scheme: "http",
+               want:   "HTTP/1.1",
+       }, {
+               name:   "https default",
+               scheme: "https",
+               transport: func(tr *Transport) {
+                       // Transport default is HTTP/1.
+               },
+               want: "HTTP/1.1",
+       }, {
+               name:   "https transport protocols include HTTP2",
+               scheme: "https",
+               transport: func(tr *Transport) {
+                       // Server default is to support HTTP/2, so if the Transport enables
+                       // HTTP/2 we get it.
+                       tr.Protocols = &Protocols{}
+                       tr.Protocols.SetHTTP1(true)
+                       tr.Protocols.SetHTTP2(true)
+               },
+               want: "HTTP/2.0",
+       }, {
+               name:   "https transport protocols only include HTTP1",
+               scheme: "https",
+               transport: func(tr *Transport) {
+                       // Explicitly enable only HTTP/1.
+                       tr.Protocols = &Protocols{}
+                       tr.Protocols.SetHTTP1(true)
+               },
+               want: "HTTP/1.1",
+       }, {
+               name:   "https transport ForceAttemptHTTP2",
+               scheme: "https",
+               transport: func(tr *Transport) {
+                       // Pre-Protocols-field way of enabling HTTP/2.
+                       tr.ForceAttemptHTTP2 = true
+               },
+               want: "HTTP/2.0",
+       }, {
+               name:   "https transport protocols override TLSNextProto",
+               scheme: "https",
+               transport: func(tr *Transport) {
+                       // Setting TLSNextProto to an empty map is the historical way
+                       // of disabling HTTP/2. Explicitly enabling HTTP2 in the Protocols
+                       // field takes precedence.
+                       tr.Protocols = &Protocols{}
+                       tr.Protocols.SetHTTP1(true)
+                       tr.Protocols.SetHTTP2(true)
+                       tr.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{}
+               },
+               want: "HTTP/2.0",
+       }, {
+               name:   "https server disables HTTP2 with TLSNextProto",
+               scheme: "https",
+               server: func(srv *Server) {
+                       // Disable HTTP/2 on the server with TLSNextProto,
+                       // use default Protocols value.
+                       srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
+               },
+               want: "HTTP/1.1",
+       }, {
+               name:   "https server Protocols overrides empty TLSNextProto",
+               scheme: "https",
+               server: func(srv *Server) {
+                       // Explicitly enabling HTTP2 in the Protocols field takes precedence
+                       // over setting an empty TLSNextProto.
+                       srv.Protocols = &Protocols{}
+                       srv.Protocols.SetHTTP1(true)
+                       srv.Protocols.SetHTTP2(true)
+                       srv.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
+               },
+               want: "HTTP/2.0",
+       }, {
+               name:   "https server protocols only include HTTP1",
+               scheme: "https",
+               server: func(srv *Server) {
+                       srv.Protocols = &Protocols{}
+                       srv.Protocols.SetHTTP1(true)
+               },
+               want: "HTTP/1.1",
+       }, {
+               name:   "https server protocols include HTTP2",
+               scheme: "https",
+               server: func(srv *Server) {
+                       srv.Protocols = &Protocols{}
+                       srv.Protocols.SetHTTP1(true)
+                       srv.Protocols.SetHTTP2(true)
+               },
+               want: "HTTP/2.0",
+       }, {
+               name:   "GODEBUG disables HTTP2 client",
+               scheme: "https",
+               setup: func(t *testing.T) {
+                       t.Setenv("GODEBUG", "http2client=0")
+               },
+               transport: func(tr *Transport) {
+                       // Server default is to support HTTP/2, so if the Transport enables
+                       // HTTP/2 we get it.
+                       tr.Protocols = &Protocols{}
+                       tr.Protocols.SetHTTP1(true)
+                       tr.Protocols.SetHTTP2(true)
+               },
+               want: "HTTP/1.1",
+       }, {
+               name:   "GODEBUG disables HTTP2 server",
+               scheme: "https",
+               setup: func(t *testing.T) {
+                       t.Setenv("GODEBUG", "http2server=0")
+               },
+               transport: func(tr *Transport) {
+                       // Server default is to support HTTP/2, so if the Transport enables
+                       // HTTP/2 we get it.
+                       tr.Protocols = &Protocols{}
+                       tr.Protocols.SetHTTP1(true)
+                       tr.Protocols.SetHTTP2(true)
+               },
+               want: "HTTP/1.1",
+       }} {
+               t.Run(test.name, func(t *testing.T) {
+                       // We don't use httptest here because it makes its own decisions
+                       // about how to enable/disable HTTP/2.
+                       srv := &Server{
+                               TLSConfig: &tls.Config{
+                                       Certificates: []tls.Certificate{cert},
+                               },
+                               Handler: HandlerFunc(func(w ResponseWriter, req *Request) {
+                                       w.Header().Set("X-Proto", req.Proto)
+                               }),
+                       }
+                       tr := &Transport{
+                               TLSClientConfig: &tls.Config{
+                                       RootCAs: certpool,
+                               },
+                       }
+
+                       if test.setup != nil {
+                               test.setup(t)
+                       }
+                       if test.server != nil {
+                               test.server(srv)
+                       }
+                       if test.transport != nil {
+                               test.transport(tr)
+                       } else {
+                               tr.Protocols = &Protocols{}
+                               tr.Protocols.SetHTTP1(true)
+                               tr.Protocols.SetHTTP2(true)
+                       }
+
+                       listener := newLocalListener(t)
+                       srvc := make(chan error, 1)
+                       go func() {
+                               switch test.scheme {
+                               case "http":
+                                       srvc <- srv.Serve(listener)
+                               case "https":
+                                       srvc <- srv.ServeTLS(listener, "", "")
+                               }
+                       }()
+                       t.Cleanup(func() {
+                               srv.Close()
+                               <-srvc
+                       })
+
+                       client := &Client{Transport: tr}
+                       resp, err := client.Get(test.scheme + "://" + listener.Addr().String())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if got := resp.Header.Get("X-Proto"); got != test.want {
+                               t.Fatalf("request proto %q, want %q", got, test.want)
+                       }
+               })
+       }
+}