return altProto[req.URL.Scheme]
}
+func validateHeaders(hdrs Header) string {
+ for k, vv := range hdrs {
+ if !httpguts.ValidHeaderFieldName(k) {
+ return fmt.Sprintf("field name %q", k)
+ }
+ for _, v := range vv {
+ if !httpguts.ValidHeaderFieldValue(v) {
+ // Don't include the value in the error,
+ // because it may be sensitive.
+ return fmt.Sprintf("field value for %q", k)
+ }
+ }
+ }
+ return ""
+}
+
// roundTrip implements a RoundTripper over HTTP.
func (t *Transport) roundTrip(req *Request) (*Response, error) {
t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
scheme := req.URL.Scheme
isHTTP := scheme == "http" || scheme == "https"
if isHTTP {
- for k, vv := range req.Header {
- if !httpguts.ValidHeaderFieldName(k) {
- req.closeBody()
- return nil, fmt.Errorf("net/http: invalid header field name %q", k)
- }
- for _, v := range vv {
- if !httpguts.ValidHeaderFieldValue(v) {
- req.closeBody()
- // Don't include the value in the error, because it may be sensitive.
- return nil, fmt.Errorf("net/http: invalid header field value for %q", k)
- }
- }
+ // Validate the outgoing headers.
+ if err := validateHeaders(req.Header); err != "" {
+ req.closeBody()
+ return nil, fmt.Errorf("net/http: invalid header %s", err)
+ }
+
+ // Validate the outgoing trailers too.
+ if err := validateHeaders(req.Trailer); err != "" {
+ req.closeBody()
+ return nil, fmt.Errorf("net/http: invalid trailer %s", err)
}
}
return true
})
}
+
+func TestValidateClientRequestTrailers(t *testing.T) {
+ run(t, testValidateClientRequestTrailers)
+}
+
+func testValidateClientRequestTrailers(t *testing.T, mode testMode) {
+ cst := newClientServerTest(t, mode, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ rw.Write([]byte("Hello"))
+ })).ts
+
+ cases := []struct {
+ trailer Header
+ wantErr string
+ }{
+ {Header{"Trx": {"x\r\nX-Another-One"}}, `invalid trailer field value for "Trx"`},
+ {Header{"\r\nTrx": {"X-Another-One"}}, `invalid trailer field name "\r\nTrx"`},
+ }
+
+ for i, tt := range cases {
+ testName := fmt.Sprintf("%s%d", mode, i)
+ t.Run(testName, func(t *testing.T) {
+ req, err := NewRequest("GET", cst.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Trailer = tt.trailer
+ res, err := cst.Client().Do(req)
+ if err == nil {
+ t.Fatal("Expected an error")
+ }
+ if g, w := err.Error(), tt.wantErr; !strings.Contains(g, w) {
+ t.Fatalf("Mismatched error\n\t%q\ndoes not contain\n\t%q", g, w)
+ }
+ if res != nil {
+ t.Fatal("Unexpected non-nil response")
+ }
+ })
+ }
+}