]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: make Server validate Host headers
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 16 Dec 2015 18:51:12 +0000 (18:51 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 16 Dec 2015 19:52:07 +0000 (19:52 +0000)
Fixes #11206 (that we accept invalid bytes)
Fixes #13624 (that we don't require a Host header in HTTP/1.1 per spec)

Change-Id: I4138281d513998789163237e83bb893aeda43336
Reviewed-on: https://go-review.googlesource.com/17892
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/request.go
src/net/http/serve_test.go
src/net/http/server.go

index 9f740422ede4629648d7441069252fa8e599c30d..01575f33a5ae9f016031a373a65a693298062c10 100644 (file)
@@ -689,8 +689,9 @@ func putTextprotoReader(r *textproto.Reader) {
 }
 
 // ReadRequest reads and parses an incoming request from b.
-func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, true) }
 
+func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
        tp := newTextprotoReader(b)
        req = new(Request)
 
@@ -757,7 +758,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
        if req.Host == "" {
                req.Host = req.Header.get("Host")
        }
-       delete(req.Header, "Host")
+       if deleteHostHeader {
+               delete(req.Header, "Host")
+       }
 
        fixPragmaCacheControl(req.Header)
 
@@ -1060,3 +1063,59 @@ func (r *Request) isReplayable() bool {
                        r.Method == "OPTIONS" ||
                        r.Method == "TRACE")
 }
+
+func validHostHeader(h string) bool {
+       // The latests spec is actually this:
+       //
+       // http://tools.ietf.org/html/rfc7230#section-5.4
+       //     Host = uri-host [ ":" port ]
+       //
+       // Where uri-host is:
+       //     http://tools.ietf.org/html/rfc3986#section-3.2.2
+       //
+       // But we're going to be much more lenient for now and just
+       // search for any byte that's not a valid byte in any of those
+       // expressions.
+       for i := 0; i < len(h); i++ {
+               if !validHostByte[h[i]] {
+                       return false
+               }
+       }
+       return true
+}
+
+// See the validHostHeader comment.
+var validHostByte = [256]bool{
+       '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
+       '8': true, '9': true,
+
+       'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
+       'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
+       'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
+       'y': true, 'z': true,
+
+       'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
+       'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
+       'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
+       'Y': true, 'Z': true,
+
+       '!':  true, // sub-delims
+       '$':  true, // sub-delims
+       '%':  true, // pct-encoded (and used in IPv6 zones)
+       '&':  true, // sub-delims
+       '(':  true, // sub-delims
+       ')':  true, // sub-delims
+       '*':  true, // sub-delims
+       '+':  true, // sub-delims
+       ',':  true, // sub-delims
+       '-':  true, // unreserved
+       '.':  true, // unreserved
+       ':':  true, // IPv6address + Host expression's optional port
+       ';':  true, // sub-delims
+       '=':  true, // sub-delims
+       '[':  true,
+       '\'': true, // sub-delims
+       ']':  true,
+       '_':  true, // unreserved
+       '~':  true, // unreserved
+}
index 3e84f2e11d3d775d626fb40ad3e37b3fde83b0a9..31ba06a2674974ad227e0fdd3342aab7117db352 100644 (file)
@@ -2201,7 +2201,7 @@ func TestClientWriteShutdown(t *testing.T) {
 // buffered before chunk headers are added, not after chunk headers.
 func TestServerBufferedChunking(t *testing.T) {
        conn := new(testConn)
-       conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+       conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
        conn.closec = make(chan bool, 1)
        ls := &oneConnListener{conn}
        go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -2934,9 +2934,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) {
                        "GET / HTTP/1.0",
                        "GET /header HTTP/1.0",
                        "GET /more HTTP/1.0",
-                       "GET / HTTP/1.1",
-                       "GET /header HTTP/1.1",
-                       "GET /more HTTP/1.1",
+                       "GET / HTTP/1.1\nHost: foo",
+                       "GET /header HTTP/1.1\nHost: foo",
+                       "GET /more HTTP/1.1\nHost: foo",
                } {
                        got := ht.rawResponse(req)
                        wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
@@ -2957,7 +2957,7 @@ func TestContentTypeOkayOn204(t *testing.T) {
                w.Header().Set("Content-Type", "foo/bar")
                w.WriteHeader(204)
        }))
-       got := ht.rawResponse("GET / HTTP/1.1")
+       got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
        if !strings.Contains(got, "Content-Type: foo/bar") {
                t.Errorf("Response = %q; want Content-Type: foo/bar", got)
        }
@@ -3628,6 +3628,54 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
        }
 }
 
+// Test that we validate the Host header.
+func TestServerValidatesHostHeader(t *testing.T) {
+       tests := []struct {
+               proto string
+               host  string
+               want  int
+       }{
+               {"HTTP/1.1", "", 400},
+               {"HTTP/1.1", "Host: \r\n", 200},
+               {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+               {"HTTP/1.1", "Host: foo.com\r\n", 200},
+               {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
+               {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
+               {"HTTP/1.1", "Host: ::1\r\n", 200},
+               {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
+               {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
+               {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
+               {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+               {"HTTP/1.1", "Host: \x06\r\n", 400},
+               {"HTTP/1.1", "Host: \xff\r\n", 400},
+               {"HTTP/1.1", "Host: {\r\n", 400},
+               {"HTTP/1.1", "Host: }\r\n", 400},
+               {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
+
+               // HTTP/1.0 can lack a host header, but if present
+               // must play by the rules too:
+               {"HTTP/1.0", "", 200},
+               {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
+               {"HTTP/1.0", "Host: \xff\r\n", 400},
+       }
+       for _, tt := range tests {
+               conn := &testConn{closec: make(chan bool)}
+               io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
+
+               ln := &oneConnListener{conn}
+               go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+               <-conn.closec
+               res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+               if err != nil {
+                       t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
+                       continue
+               }
+               if res.StatusCode != tt.want {
+                       t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
+               }
+       }
+}
+
 func BenchmarkClientServer(b *testing.B) {
        b.ReportAllocs()
        b.StopTimer()
index cd5f9cf34f45e880353685046b39123637ee2f6f..a00085c2498dfa6c2e188319960ffb6f9cc715eb 100644 (file)
@@ -686,7 +686,7 @@ func (c *conn) readRequest() (w *response, err error) {
                peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
                c.bufr.Discard(numLeadingCRorLF(peek))
        }
-       req, err := ReadRequest(c.bufr)
+       req, err := readRequest(c.bufr, false)
        c.mu.Unlock()
        if err != nil {
                if c.r.hitReadLimit() {
@@ -697,6 +697,18 @@ func (c *conn) readRequest() (w *response, err error) {
        c.lastMethod = req.Method
        c.r.setInfiniteReadLimit()
 
+       hosts, haveHost := req.Header["Host"]
+       if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
+               return nil, badRequestError("missing required Host header")
+       }
+       if len(hosts) > 1 {
+               return nil, badRequestError("too many Host headers")
+       }
+       if len(hosts) == 1 && !validHostHeader(hosts[0]) {
+               return nil, badRequestError("malformed Host header")
+       }
+       delete(req.Header, "Host")
+
        req.RemoteAddr = c.remoteAddr
        req.TLS = c.tlsState
        if body, ok := req.Body.(*body); ok {
@@ -1334,6 +1346,13 @@ func (c *conn) setState(nc net.Conn, state ConnState) {
        }
 }
 
+// badRequestError is a literal string (used by in the server in HTML,
+// unescaped) to tell the user why their request was bad. It should
+// be plain text without user info or other embeddded errors.
+type badRequestError string
+
+func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+
 // Serve a new connection.
 func (c *conn) serve() {
        c.remoteAddr = c.rwc.RemoteAddr().String()
@@ -1399,7 +1418,11 @@ func (c *conn) serve() {
                        if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
                                return // don't reply
                        }
-                       io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request")
+                       var publicErr string
+                       if v, ok := err.(badRequestError); ok {
+                               publicErr = ": " + string(v)
+                       }
+                       io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
                        return
                }