]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: validate outgoing/client request trailers
authorEmmanuel T Odeke <emmanuel@orijtech.com>
Tue, 19 Mar 2024 06:05:12 +0000 (23:05 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 20 Mar 2024 11:31:46 +0000 (11:31 +0000)
This change validates outbound client request trailers
just like we do for headers. This helps prevent header
injection or other sorts of smuggling from easily being
performed using the standard HTTP client.

Fixes #64766

Change-Id: Idb34df876a0c308b1f57e9ae2695b118ac6bcc2d
Reviewed-on: https://go-review.googlesource.com/c/go/+/572615
Reviewed-by: David Chase <drchase@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Damien Neil <dneil@google.com>
Run-TryBot: Emmanuel Odeke <emmanuel@orijtech.com>
Auto-Submit: Emmanuel Odeke <emmanuel@orijtech.com>

src/net/http/transport.go
src/net/http/transport_test.go

index 44d5515705f2f87b91e9feb60d677e9273bc2b9c..bbac2bf448b32cfbea80957592dd8a6f44c0a89d 100644 (file)
@@ -513,6 +513,22 @@ func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
        return altProto[req.URL.Scheme]
 }
 
+func validateHeaders(hdrs Header) string {
+       for k, vv := range hdrs {
+               if !httpguts.ValidHeaderFieldName(k) {
+                       return fmt.Sprintf("field name %q", k)
+               }
+               for _, v := range vv {
+                       if !httpguts.ValidHeaderFieldValue(v) {
+                               // Don't include the value in the error,
+                               // because it may be sensitive.
+                               return fmt.Sprintf("field value for %q", k)
+                       }
+               }
+       }
+       return ""
+}
+
 // roundTrip implements a RoundTripper over HTTP.
 func (t *Transport) roundTrip(req *Request) (*Response, error) {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
@@ -530,18 +546,16 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
        scheme := req.URL.Scheme
        isHTTP := scheme == "http" || scheme == "https"
        if isHTTP {
-               for k, vv := range req.Header {
-                       if !httpguts.ValidHeaderFieldName(k) {
-                               req.closeBody()
-                               return nil, fmt.Errorf("net/http: invalid header field name %q", k)
-                       }
-                       for _, v := range vv {
-                               if !httpguts.ValidHeaderFieldValue(v) {
-                                       req.closeBody()
-                                       // Don't include the value in the error, because it may be sensitive.
-                                       return nil, fmt.Errorf("net/http: invalid header field value for %q", k)
-                               }
-                       }
+               // Validate the outgoing headers.
+               if err := validateHeaders(req.Header); err != "" {
+                       req.closeBody()
+                       return nil, fmt.Errorf("net/http: invalid header %s", err)
+               }
+
+               // Validate the outgoing trailers too.
+               if err := validateHeaders(req.Trailer); err != "" {
+                       req.closeBody()
+                       return nil, fmt.Errorf("net/http: invalid trailer %s", err)
                }
        }
 
index d3f43cfd9ab98187c1bcdfbd3989f1e89b050ae3..204133f1302d03b789abfe46fbfe256e2bb97c3e 100644 (file)
@@ -7031,3 +7031,42 @@ func TestTransportReqCancelerCleanupOnRequestBodyWriteError(t *testing.T) {
                return true
        })
 }
+
+func TestValidateClientRequestTrailers(t *testing.T) {
+       run(t, testValidateClientRequestTrailers)
+}
+
+func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
+       cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+               rw.Write([]byte("Hello"))
+       })).ts
+
+       cases := []struct {
+               trailer Header
+               wantErr string
+       }{
+               {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
+               {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
+       }
+
+       for i, tt := range cases {
+               testName := fmt.Sprintf("%s%d", mode, i)
+               t.Run(testName, func(t *testing.T) {
+                       req, err := NewRequest("GET", cst.URL, nil)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       req.Trailer = tt.trailer
+                       res, err := cst.Client().Do(req)
+                       if err == nil {
+                               t.Fatal("Expected an error")
+                       }
+                       if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
+                               t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
+                       }
+                       if res != nil {
+                               t.Fatal("Unexpected non-nil response")
+                       }
+               })
+       }
+}