]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: limit 1xx based on size, do not limit when delivered
authorDamien Neil <dneil@google.com>
Mon, 23 Sep 2024 18:43:19 +0000 (11:43 -0700)
committerDamien Neil <dneil@google.com>
Thu, 24 Oct 2024 16:10:37 +0000 (16:10 +0000)
Replace Transport's limit of 5 1xx responses with a limit based
on MaxResponseHeaderBytes: The total number of responses
(including 1xx reponses and the final response) must not exceed
this value.

When the user is reading 1xx responses using a Got1xxResponse
client trace hook, disable the limit: Each 1xx response is
individually limited by MaxResponseHeaderBytes, but there
is no limit on the total number of responses. The user is
responsible for imposing a limit if they want one.

For #65035

Change-Id: If4bbbbb0b808cb5016701d50963c89f0ce1229f8
Reviewed-on: https://go-review.googlesource.com/c/go/+/615255
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/net/http/transport.go
src/net/http/transport_test.go

index ed7c2a52c2cfa354fbdfdc3c42a475c19d054977..c980e727a66b546bd3d892a6e4e3b36e93024d61 100644 (file)
@@ -2398,8 +2398,6 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
                        trace.GotFirstResponseByte()
                }
        }
-       num1xx := 0               // number of informational 1xx headers received
-       const max1xxResponses = 5 // arbitrary bound on number of informational responses
 
        continueCh := rc.continueCh
        for {
@@ -2419,15 +2417,18 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
                // treat 101 as a terminal status, see issue 26161
                is1xxNonTerminal := is1xx && resCode != StatusSwitchingProtocols
                if is1xxNonTerminal {
-                       num1xx++
-                       if num1xx > max1xxResponses {
-                               return nil, errors.New("net/http: too many 1xx informational responses")
-                       }
-                       pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
                        if trace != nil && trace.Got1xxResponse != nil {
                                if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil {
                                        return nil, err
                                }
+                               // If the 1xx response was delivered to the user,
+                               // then they're responsible for limiting the number of
+                               // responses. Reset the header limit.
+                               //
+                               // If the user didn't examine the 1xx response, then we
+                               // limit the size of all headers (including both 1xx
+                               // and the final response) to maxHeaderResponseSize.
+                               pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
                        }
                        continue
                }
index b76b8dfcffd869f706f3339546cea199361a6c0f..30a7e5eabd53d6a9ede4349200e00bb168992b31 100644 (file)
@@ -3258,29 +3258,67 @@ func testTransportIgnore1xxResponses(t *testing.T, mode testMode) {
        }
 }
 
-func TestTransportLimits1xxResponses(t *testing.T) {
-       run(t, testTransportLimits1xxResponses, []testMode{http1Mode})
-}
+func TestTransportLimits1xxResponses(t *testing.T) { run(t, testTransportLimits1xxResponses) }
 func testTransportLimits1xxResponses(t *testing.T, mode testMode) {
        cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
-               conn, buf, _ := w.(Hijacker).Hijack()
+               w.Header().Add("X-Header", strings.Repeat("a", 100))
                for i := 0; i < 10; i++ {
-                       buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
+                       w.WriteHeader(123)
                }
-               buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
-               buf.Flush()
-               conn.Close()
+               w.WriteHeader(204)
        }))
        cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+       cst.tr.MaxResponseHeaderBytes = 1000
 
        res, err := cst.c.Get(cst.ts.URL)
-       if res != nil {
-               defer res.Body.Close()
+       if err == nil {
+               res.Body.Close()
+               t.Fatalf("RoundTrip succeeded; want error")
+       }
+       for _, want := range []string{
+               "response headers exceeded",
+               "too many 1xx",
+       } {
+               if strings.Contains(err.Error(), want) {
+                       return
+               }
+       }
+       t.Errorf(`got error %q; want "response headers exceeded" or "too many 1xx"`, err)
+}
+
+func TestTransportDoesNotLimitDelivered1xxResponses(t *testing.T) {
+       run(t, testTransportDoesNotLimitDelivered1xxResponses)
+}
+func testTransportDoesNotLimitDelivered1xxResponses(t *testing.T, mode testMode) {
+       if mode == http2Mode {
+               t.Skip("skip until x/net/http2 updated")
        }
-       got := fmt.Sprint(err)
-       wantSub := "too many 1xx informational responses"
-       if !strings.Contains(got, wantSub) {
-               t.Errorf("Get error = %v; want substring %q", err, wantSub)
+       const num1xx = 10
+       cst := newClientServerTest(t, mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+               w.Header().Add("X-Header", strings.Repeat("a", 100))
+               for i := 0; i < 10; i++ {
+                       w.WriteHeader(123)
+               }
+               w.WriteHeader(204)
+       }))
+       cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+       cst.tr.MaxResponseHeaderBytes = 1000
+
+       got1xx := 0
+       ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+               Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+                       got1xx++
+                       return nil
+               },
+       })
+       req, _ := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+       res, err := cst.c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       res.Body.Close()
+       if got1xx != num1xx {
+               t.Errorf("Got %v 1xx responses, want %x", got1xx, num1xx)
        }
 }