]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: limit Transport's reading of response header bytes from servers
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 31 Mar 2016 07:06:27 +0000 (00:06 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 1 Apr 2016 00:47:25 +0000 (00:47 +0000)
The default is 10MB, like http2, but can be configured with a new
field http.Transport.MaxResponseHeaderBytes.

Fixes #9115

Change-Id: I01808ac631ce4794ef2b0dfc391ed51cf951ceb1
Reviewed-on: https://go-review.googlesource.com/21329
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
Reviewed-by: Emmanuel Odeke <emm.odeke@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Andrew Gerrand <adg@golang.org>
src/net/http/http.go [new file with mode: 0644]
src/net/http/server.go
src/net/http/transport.go
src/net/http/transport_test.go

diff --git a/src/net/http/http.go b/src/net/http/http.go
new file mode 100644 (file)
index 0000000..a40b23d
--- /dev/null
@@ -0,0 +1,12 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// maxInt64 is the effective "infinite" value for the Server and
+// Transport's byte-limiting readers.
+const maxInt64 = 1<<63 - 1
+
+// TODO(bradfitz): move common stuff here. The other files have accumulated
+// generic http stuff in random places.
index 5718cafbc3d72847f6a6ebf37468bc694cd29931..a2f9083a511c820444162dac2ffcfd04d23d0fee 100644 (file)
@@ -497,7 +497,7 @@ type connReader struct {
 }
 
 func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
-func (cr *connReader) setInfiniteReadLimit()     { cr.remain = 1<<63 - 1 }
+func (cr *connReader) setInfiniteReadLimit()     { cr.remain = maxInt64 }
 func (cr *connReader) hitReadLimit() bool        { return cr.remain <= 0 }
 
 func (cr *connReader) Read(p []byte) (n int, err error) {
index 06ac939bd5973142233817927507ac7009306255..d1b64c7da9a888ebbc6fa85f5da5fa62ed5898b9 100644 (file)
@@ -146,6 +146,13 @@ type Transport struct {
        // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
        TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
 
+       // MaxResponseHeaderBytes specifies a limit on how many
+       // response bytes are allowed in the server's response
+       // header.
+       //
+       // Zero means to use a default limit.
+       MaxResponseHeaderBytes int64
+
        // nextProtoOnce guards initialization of TLSNextProto and
        // h2transport (via onceSetNextProtoDefaults)
        nextProtoOnce sync.Once
@@ -188,8 +195,23 @@ func (t *Transport) onceSetNextProtoDefaults() {
        t2, err := http2configureTransport(t)
        if err != nil {
                log.Printf("Error enabling Transport HTTP/2 support: %v", err)
-       } else {
-               t.h2transport = t2
+               return
+       }
+       t.h2transport = t2
+
+       // Auto-configure the http2.Transport's MaxHeaderListSize from
+       // the http.Transport's MaxResponseHeaderBytes. They don't
+       // exactly mean the same thing, but they're close.
+       //
+       // TODO: also add this to x/net/http2.Configure Transport, behind
+       // a +build go1.7 build tag:
+       if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 {
+               const h2max = 1<<32 - 1
+               if limit1 >= h2max {
+                       t2.MaxHeaderListSize = h2max
+               } else {
+                       t2.MaxHeaderListSize = uint32(limit1)
+               }
        }
 }
 
@@ -351,7 +373,8 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
 // resent on a new connection. The non-nil input error is the error from
 // roundTrip, which might be wrapped in a beforeRespHeaderError error.
 //
-// The return value is err or the unwrapped error inside a
+// The return value is either nil to retry the request, the provided
+// err unmodified, or the unwrapped error inside a
 // beforeRespHeaderError.
 func checkTransportResend(err error, req *Request, pconn *persistConn) error {
        brhErr, ok := err.(beforeRespHeaderError)
@@ -864,7 +887,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
                }
        }
 
-       pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
+       pconn.br = bufio.NewReader(pconn)
        pconn.bw = bufio.NewWriter(pconn.conn)
        go pconn.readLoop()
        go pconn.writeLoop()
@@ -998,17 +1021,18 @@ type persistConn struct {
        // If it's non-nil, the rest of the fields are unused.
        alt RoundTripper
 
-       t        *Transport
-       cacheKey connectMethodKey
-       conn     net.Conn
-       tlsState *tls.ConnectionState
-       br       *bufio.Reader       // from conn
-       sawEOF   bool                // whether we've seen EOF from conn; owned by readLoop
-       bw       *bufio.Writer       // to conn
-       reqch    chan requestAndChan // written by roundTrip; read by readLoop
-       writech  chan writeRequest   // written by roundTrip; read by writeLoop
-       closech  chan struct{}       // closed when conn closed
-       isProxy  bool
+       t         *Transport
+       cacheKey  connectMethodKey
+       conn      net.Conn
+       tlsState  *tls.ConnectionState
+       br        *bufio.Reader       // from conn
+       bw        *bufio.Writer       // to conn
+       reqch     chan requestAndChan // written by roundTrip; read by readLoop
+       writech   chan writeRequest   // written by roundTrip; read by writeLoop
+       closech   chan struct{}       // closed when conn closed
+       isProxy   bool
+       sawEOF    bool  // whether we've seen EOF from conn; owned by readLoop
+       readLimit int64 // bytes allowed to be read; owned by readLoop
        // writeErrCh passes the request write error (usually nil)
        // from the writeLoop goroutine to the readLoop which passes
        // it off to the res.Body reader, which then uses it to decide
@@ -1027,6 +1051,28 @@ type persistConn struct {
        mutateHeaderFunc func(Header)
 }
 
+func (pc *persistConn) maxHeaderResponseSize() int64 {
+       if v := pc.t.MaxResponseHeaderBytes; v != 0 {
+               return v
+       }
+       return 10 << 20 // conservative default; same as http2
+}
+
+func (pc *persistConn) Read(p []byte) (n int, err error) {
+       if pc.readLimit <= 0 {
+               return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize())
+       }
+       if int64(len(p)) > pc.readLimit {
+               p = p[:pc.readLimit]
+       }
+       n, err = pc.conn.Read(p)
+       if err == io.EOF {
+               pc.sawEOF = true
+       }
+       pc.readLimit -= int64(n)
+       return
+}
+
 // isBroken reports whether this connection is in a known broken state.
 func (pc *persistConn) isBroken() bool {
        pc.mu.Lock()
@@ -1082,6 +1128,7 @@ func (pc *persistConn) readLoop() {
 
        alive := true
        for alive {
+               pc.readLimit = pc.maxHeaderResponseSize()
                _, err := pc.br.Peek(1)
                if err != nil {
                        err = beforeRespHeaderError{err}
@@ -1103,6 +1150,9 @@ func (pc *persistConn) readLoop() {
                }
 
                if err != nil {
+                       if pc.readLimit <= 0 {
+                               err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize())
+                       }
                        // If we won't be able to retry this request later (from the
                        // roundTrip goroutine), mark it as done now.
                        // BEFORE the send on rc.ch, as the client might re-use the
@@ -1120,6 +1170,7 @@ func (pc *persistConn) readLoop() {
                        }
                        return
                }
+               pc.readLimit = maxInt64 // effictively no limit for response bodies
 
                pc.mu.Lock()
                pc.numExpectedResponses--
@@ -1251,6 +1302,7 @@ func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err erro
                }
        }
        if resp.StatusCode == 100 {
+               pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
                resp, err = ReadResponse(pc.br, rc.req)
                if err != nil {
                        return
@@ -1706,19 +1758,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool   { return true }
 func (tlsHandshakeTimeoutError) Temporary() bool { return true }
 func (tlsHandshakeTimeoutError) Error() string   { return "net/http: TLS handshake timeout" }
 
-type noteEOFReader struct {
-       r      io.Reader
-       sawEOF *bool
-}
-
-func (nr noteEOFReader) Read(p []byte) (n int, err error) {
-       n, err = nr.r.Read(p)
-       if err == io.EOF {
-               *nr.sawEOF = true
-       }
-       return
-}
-
 // fakeLocker is a sync.Locker which does nothing. It's used to guard
 // test-only fields when not under test, to avoid runtime atomic
 // overhead.
index c4540d7e6a99fbe2ef847e66f72930b47e15caf7..9c2e40d7f51f6f61e36861cf2a4acfb17d341d4e 100644 (file)
@@ -3090,6 +3090,42 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) {
        }
 }
 
+func TestTransportResponseHeaderLength(t *testing.T) {
+       defer afterTest(t)
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.URL.Path == "/long" {
+                       w.Header().Set("Long", strings.Repeat("a", 1<<20))
+               }
+       }))
+       defer ts.Close()
+
+       tr := &Transport{
+               MaxResponseHeaderBytes: 512 << 10,
+       }
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+       if res, err := c.Get(ts.URL); err != nil {
+               t.Fatal(err)
+       } else {
+               res.Body.Close()
+       }
+
+       res, err := c.Get(ts.URL + "/long")
+       if err == nil {
+               defer res.Body.Close()
+               var n int64
+               for k, vv := range res.Header {
+                       for _, v := range vv {
+                               n += int64(len(k)) + int64(len(v))
+                       }
+               }
+               t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n)
+       }
+       if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) {
+               t.Errorf("got error: %v; want %q", err, want)
+       }
+}
+
 var errFakeRoundTrip = errors.New("fake roundtrip")
 
 type funcRoundTripper func()