]> Cypherpunks repositories - gostls13.git/commitdiff
http: avoid server crash on malformed client request
authorFumitoshi Ukai <ukai@google.com>
Fri, 19 Feb 2010 02:32:40 +0000 (18:32 -0800)
committerRuss Cox <rsc@golang.org>
Fri, 19 Feb 2010 02:32:40 +0000 (18:32 -0800)
R=r, rsc
CC=golang-dev
https://golang.org/cl/206079

src/pkg/websocket/server.go
src/pkg/websocket/websocket_test.go

index 43c2a7c8d0cfa5bb46a734ada3c498e89672f164..0ccb31e8a22e7cb9f391734dc3755d6b2364671c 100644 (file)
@@ -38,20 +38,34 @@ type Handler func(*Conn)
 
 // ServeHTTP implements the http.Handler interface for a Web Socket.
 func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
-       if req.Method != "GET" || req.Proto != "HTTP/1.1" ||
-               req.Header["Upgrade"] != "WebSocket" ||
-               req.Header["Connection"] != "Upgrade" {
-               c.WriteHeader(http.StatusNotFound)
-               io.WriteString(c, "must use websocket to connect here")
+       if req.Method != "GET" || req.Proto != "HTTP/1.1" {
+               c.WriteHeader(http.StatusBadRequest)
+               io.WriteString(c, "Unexpected request")
                return
        }
+       if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
+               c.WriteHeader(http.StatusBadRequest)
+               io.WriteString(c, "missing Upgrade: WebSocket header")
+               return
+       }
+       if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
+               c.WriteHeader(http.StatusBadRequest)
+               io.WriteString(c, "missing Connection: Upgrade header")
+               return
+       }
+       origin, present := req.Header["Origin"]
+       if !present {
+               c.WriteHeader(http.StatusBadRequest)
+               io.WriteString(c, "missing Origin header")
+               return
+       }
+
        rwc, buf, err := c.Hijack()
        if err != nil {
                panic("Hijack failed: ", err.String())
                return
        }
        defer rwc.Close()
-       origin := req.Header["Origin"]
        location := "ws://" + req.Host + req.URL.Path
 
        // TODO(ukai): verify origin,location,protocol.
@@ -61,9 +75,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
        buf.WriteString("Connection: Upgrade\r\n")
        buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
        buf.WriteString("WebSocket-Location: " + location + "\r\n")
-       protocol := ""
+       protocol, present := req.Header["Websocket-Protocol"]
        // canonical header key of WebSocket-Protocol.
-       if protocol, found := req.Header["Websocket-Protocol"]; found {
+       if present {
                buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
        }
        buf.WriteString("\r\n")
index c62604621ec8c13b42f12043d0a3a9ec187896e2..c15c4353857ac771eecbc7de2222e27fd971be59 100644 (file)
@@ -6,6 +6,7 @@ package websocket
 
 import (
        "bytes"
+       "fmt"
        "http"
        "io"
        "log"
@@ -59,3 +60,17 @@ func TestEcho(t *testing.T) {
        }
        ws.Close()
 }
+
+func TestHTTP(t *testing.T) {
+       once.Do(startServer)
+
+       r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
+       if err != nil {
+               t.Errorf("Get: error %v", err)
+               return
+       }
+       if r.StatusCode != http.StatusBadRequest {
+               t.Errorf("Get: got status %d", r.StatusCode)
+               return
+       }
+}