]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add MaxBytesHandler
authorCarl Johnson <me@carlmjohnson.net>
Tue, 31 Aug 2021 20:35:35 +0000 (20:35 +0000)
committerDamien Neil <dneil@google.com>
Tue, 9 Nov 2021 18:23:16 +0000 (18:23 +0000)
Fixes #39567

Change-Id: I226089b678a6a13d7ce69f360a23fc5bd297d550
GitHub-Last-Rev: 6435fd5881fc70a276d04df5a60440e365924b49
GitHub-Pull-Request: golang/go#48104
Reviewed-on: https://go-review.googlesource.com/c/go/+/346569
Trust: Damien Neil <dneil@google.com>
Trust: Ian Lance Taylor <iant@golang.org>
Run-TryBot: Damien Neil <dneil@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Damien Neil <dneil@google.com>
src/net/http/serve_test.go
src/net/http/server.go

index 30a6555d30cdf7ff72b49378aeef45fffc30dc86..1156b187ae3124da28c8a57f94641fd8b7c4210e 100644 (file)
@@ -6682,3 +6682,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
                }
        }
 }
+
+func TestMaxBytesHandler(t *testing.T) {
+       setParallel(t)
+       defer afterTest(t)
+
+       for _, maxSize := range []int64{100, 1_000, 1_000_000} {
+               for _, requestSize := range []int64{100, 1_000, 1_000_000} {
+                       t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
+                               func(t *testing.T) {
+                                       testMaxBytesHandler(t, maxSize, requestSize)
+                               })
+               }
+       }
+}
+
+func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
+       var (
+               handlerN   int64
+               handlerErr error
+       )
+       echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+               var buf bytes.Buffer
+               handlerN, handlerErr = io.Copy(&buf, r.Body)
+               io.Copy(w, &buf)
+       })
+
+       ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
+       defer ts.Close()
+
+       c := ts.Client()
+       var buf strings.Builder
+       body := strings.NewReader(strings.Repeat("a", int(requestSize)))
+       res, err := c.Post(ts.URL, "text/plain", body)
+       if err != nil {
+               t.Errorf("unexpected connection error: %v", err)
+       } else {
+               _, err = io.Copy(&buf, res.Body)
+               res.Body.Close()
+               if err != nil {
+                       t.Errorf("unexpected read error: %v", err)
+               }
+       }
+       if handlerN > maxSize {
+               t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
+       }
+       if requestSize > maxSize && handlerErr == nil {
+               t.Error("expected error on handler side; got nil")
+       }
+       if requestSize <= maxSize {
+               if handlerErr != nil {
+                       t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+               }
+               if handlerN != requestSize {
+                       t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+               }
+       }
+       if buf.Len() != int(handlerN) {
+               t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+       }
+}
index 08fd478ed9e48d69fa7170bce998fe4d5ad6fcec..c4a2d57dd4da765ffd60068e89add9bed1b42ccf 100644 (file)
@@ -3610,3 +3610,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool {
        }
        return false
 }
+
+// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader.
+func MaxBytesHandler(h Handler, n int64) Handler {
+       return HandlerFunc(func(w ResponseWriter, r *Request) {
+               r2 := *r
+               r2.Body = MaxBytesReader(w, r.Body, n)
+               h.ServeHTTP(w, &r2)
+       })
+}