From 371496e0b8eb63b6c26b9207860ed6fc1875e6f5 Mon Sep 17 00:00:00 2001 From: Fumitoshi Ukai Date: Tue, 23 Mar 2010 18:09:24 -0700 Subject: [PATCH] websocket: implement new protocol http://www.whatwg.org/specs/web-socket-protocol/ (draft of draft-hixie-thewebsocketprotocol-76) draft-hixie-thewebsocketprotocol-76 will introduce new handshake incompatible draft 75 or prior. http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol R=rsc CC=golang-dev https://golang.org/cl/583041 --- src/pkg/websocket/client.go | 229 +++++++++++++++++++++++++++- src/pkg/websocket/server.go | 152 +++++++++++++++++- src/pkg/websocket/websocket_test.go | 65 +++++++- 3 files changed, 431 insertions(+), 15 deletions(-) diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go index 7bf53d840c..52870800cc 100644 --- a/src/pkg/websocket/client.go +++ b/src/pkg/websocket/client.go @@ -5,11 +5,18 @@ package websocket import ( + "encoding/binary" "bufio" + "bytes" + "container/vector" + "crypto/md5" + "fmt" "http" "io" "net" "os" + "rand" + "strings" ) type ProtocolError struct { @@ -26,10 +33,26 @@ var ( ErrBadWebSocketLocation = &ProtocolError{"bad WebSocket-Location"} ErrNoWebSocketProtocol = &ProtocolError{"no WebSocket-Protocol"} ErrBadWebSocketProtocol = &ProtocolError{"bad WebSocket-Protocol"} + ErrChallengeResponse = &ProtocolError{"mismatch challange/response"} + secKeyRandomChars [0x30 - 0x21 + 0x7F - 0x3A]byte ) +func init() { + i := 0 + for ch := byte(0x21); ch < 0x30; ch++ { + secKeyRandomChars[i] = ch + i++ + } + for ch := byte(0x3a); ch < 0x7F; ch++ { + secKeyRandomChars[i] = ch + i++ + } +} + +type handshaker func(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) os.Error + // newClient creates a new Web Socket client connection. -func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) { +func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser, handshake handshaker) (ws *Conn, err os.Error) { br := bufio.NewReader(rwc) bw := bufio.NewWriter(rwc) err = handshake(resourceName, host, origin, location, protocol, br, bw) @@ -76,10 +99,212 @@ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) { if err != nil { return } - return newClient(parsedUrl.Path, parsedUrl.Host, origin, url, protocol, client) + return newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake) +} + +/* + Generates handshake key as described in 4.1 Opening handshake + step 16 to 22. + cf. http://www.whatwg.org/specs/web-socket-protocol/ +*/ +func generateKeyNumber() (key string, number uint32) { + // 16. Let /spaces_n/ be a random integer from 1 to 12 inclusive. + spaces := rand.Intn(12) + 1 + + // 17. Let /max_n/ be the largest integer not greater than + // 4,294,967,295 divided by /spaces_n/ + max := int(4294967295 / uint32(spaces)) + + // 18. Let /number_n/ be a random integer from 0 to /max_n/ inclusive. + number = uint32(rand.Intn(max + 1)) + + // 19. Let /product_n/ be the result of multiplying /number_n/ and + // /spaces_n/ together. + product := number * uint32(spaces) + + // 20. Let /key_n/ be a string consisting of /product_n/, expressed + // in base ten using the numerals in the range U+0030 DIGIT ZERO (0) + // to U+0039 DIGIT NINE (9). + key = fmt.Sprintf("%d", product) + + // 21. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random + // posisions. + for i := 0; i < spaces; i++ { + pos := rand.Intn(len(key)-1) + 1 + key = key[0:pos] + " " + key[pos:] + } + + // 22. Insert between one and twelve random characters from the ranges + // U+0021 to U+002F and U+003A to U+007E into /key_n/ at random + // positions. + n := rand.Intn(12) + 1 + for i := 0; i < n; i++ { + pos := rand.Intn(len(key)) + 1 + ch := secKeyRandomChars[rand.Intn(len(secKeyRandomChars))] + key = key[0:pos] + string(ch) + key[pos:] + } + return +} + +/* + Generates handshake key_3 as described in 4.1 Opening handshake + step 26. + cf. http://www.whatwg.org/specs/web-socket-protocol/ +*/ +func generateKey3() (key []byte) { + // 26. Let /key3/ be a string consisting of eight random bytes (or + // equivalently, a random 64 bit integer encoded in big-endian order). + key = make([]byte, 8) + for i := 0; i < 8; i++ { + key[i] = byte(rand.Intn(256)) + } + return } +/* + Gets expected from challenge as described in 4.1 Opening handshake + Step 42 to 43. + cf. http://www.whatwg.org/specs/web-socket-protocol/ +*/ +func getExpectedForChallenge(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) { + // 41. Let /challenge/ be the concatenation of /number_1/, expressed + // a big-endian 32 bit integer, /number_2/, expressed in a big- + // endian 32 bit integer, and the eight bytes of /key_3/ in the + // order they were sent to the wire. + challenge := make([]byte, 16) + challengeBuf := bytes.NewBuffer(challenge) + binary.Write(challengeBuf, binary.BigEndian, number1) + binary.Write(challengeBuf, binary.BigEndian, number2) + copy(challenge[8:], key3) + + // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big- + // endian 128 bit string. + h := md5.New() + if _, err = h.Write(challenge); err != nil { + return + } + expected = h.Sum() + return +} + +/* + Web Socket protocol handshake based on + http://www.whatwg.org/specs/web-socket-protocol/ + (draft of http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol) +*/ func handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { + // 4.1. Opening handshake. + // Step 5. send a request line. + bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n") + + // Step 6-14. push request headers in fields. + var fields vector.StringVector + fields.Push("Upgrade: WebSocket\r\n") + fields.Push("Connection: Upgrade\r\n") + fields.Push("Host: " + host + "\r\n") + fields.Push("Origin: " + origin + "\r\n") + if protocol != "" { + fields.Push("Sec-WebSocket-Protocol: " + protocol + "\r\n") + } + // TODO(ukai): Step 15. send cookie if any. + + // Step 16-23. generate keys and push Sec-WebSocket-Key in fields. + key1, number1 := generateKeyNumber() + key2, number2 := generateKeyNumber() + fields.Push("Sec-WebSocket-Key1: " + key1 + "\r\n") + fields.Push("Sec-WebSocket-Key2: " + key2 + "\r\n") + + // Step 24. shuffle fields and send them out. + for i := 1; i < len(fields); i++ { + j := rand.Intn(i) + fields[i], fields[j] = fields[j], fields[i] + } + for i := 0; i < len(fields); i++ { + bw.WriteString(fields[i]) + } + // Step 25. send CRLF. + bw.WriteString("\r\n") + + // Step 26. genearte 8 bytes random key. + key3 := generateKey3() + // Step 27. send it out. + bw.Write(key3) + if err = bw.Flush(); err != nil { + return + } + + // Step 28-29, 32-40. read response from server. + resp, err := http.ReadResponse(br, "GET") + if err != nil { + return err + } + // Step 30. check response code is 101. + if resp.StatusCode != 101 { + return ErrBadStatus + } + + // Step 41. check websocket headers. + upgrade, found := resp.Header["Upgrade"] + if !found { + return ErrNoUpgrade + } + if upgrade != "WebSocket" { + return ErrBadUpgrade + } + connection, found := resp.Header["Connection"] + if !found || strings.ToLower(connection) != "upgrade" { + return ErrBadUpgrade + } + + s, found := resp.Header["Sec-Websocket-Origin"] + if !found { + return ErrNoWebSocketOrigin + } + if s != origin { + return ErrBadWebSocketOrigin + } + s, found = resp.Header["Sec-Websocket-Location"] + if !found { + return ErrNoWebSocketLocation + } + if s != location { + return ErrBadWebSocketLocation + } + if protocol != "" { + s, found = resp.Header["Sec-Websocket-Protocol"] + if !found { + return ErrNoWebSocketProtocol + } + if s != protocol { + return ErrBadWebSocketProtocol + } + } + + // Step 42-43. get expected data from challange data. + expected, err := getExpectedForChallenge(number1, number2, key3) + if err != nil { + return err + } + + // Step 44. read 16 bytes from server. + reply := make([]byte, 16) + if _, err = io.ReadFull(br, reply); err != nil { + return err + } + + // Step 45. check the reply equals to expected data. + if !bytes.Equal(expected, reply) { + return ErrChallengeResponse + } + // WebSocket connection is established. + return +} + +/* + Handhake described in (soon obsolete) + draft-hixie-thewebsocket-protocol-75. +*/ +func draft75handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) { bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n") bw.WriteString("Upgrade: WebSocket\r\n") bw.WriteString("Connection: Upgrade\r\n") diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 93d8b7afd2..78c42990a0 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -5,8 +5,12 @@ package websocket import ( + "bytes" + "crypto/md5" + "encoding/binary" "http" "io" + "strings" ) /* @@ -36,25 +40,159 @@ import ( */ type Handler func(*Conn) -// ServeHTTP implements the http.Handler interface for a Web Socket. +/* + Gets key number from Sec-WebSocket-Key: field as described + in 5.2 Sending the server's opening handshake, 4. +*/ +func getKeyNumber(s string) (r uint32) { + // 4. Let /key-number_n/ be the digits (characters in the range + // U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9)) in /key_1/, + // interpreted as a base ten integer, ignoring all other characters + // in /key_n/. + r = 0 + for i := 0; i < len(s); i++ { + if s[i] >= '0' && s[i] <= '9' { + r = r*10 + uint32(s[i]) - '0' + } + } + return +} + +// ServeHTTP implements the http.Handler interface for a Web Socket func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { + rwc, buf, err := c.Hijack() + if err != nil { + panic("Hijack failed: ", err.String()) + return + } + // The server should abort the WebSocket connection if it finds + // the client did not send a handshake that matches with protocol + // specification. + defer rwc.Close() + + if req.Method != "GET" { + return + } + // HTTP version can be safely ignored. + + if v, found := req.Header["Upgrade"]; !found || + strings.ToLower(v) != "websocket" { + return + } + if v, found := req.Header["Connection"]; !found || + strings.ToLower(v) != "upgrade" { + return + } + // TODO(ukai): check Host + origin, found := req.Header["Origin"] + if !found { + return + } + + key1, found := req.Header["Sec-Websocket-Key1"] + if !found { + return + } + key2, found := req.Header["Sec-Websocket-Key2"] + if !found { + return + } + key3 := make([]byte, 8) + if _, err := io.ReadFull(buf, key3); err != nil { + return + } + + location := "ws://" + req.Host + req.URL.RawPath + + // Step 4. get key number in Sec-WebSocket-Key fields. + keyNumber1 := getKeyNumber(key1) + keyNumber2 := getKeyNumber(key2) + + // Step 5. get number of spaces in Sec-WebSocket-Key fields. + space1 := uint32(strings.Count(key1, " ")) + space2 := uint32(strings.Count(key2, " ")) + if space1 == 0 || space2 == 0 { + return + } + + // Step 6. key number must be an integral multiple of spaces. + if keyNumber1%space1 != 0 || keyNumber2%space2 != 0 { + return + } + + // Step 7. let part be key number divided by spaces. + part1 := keyNumber1 / space1 + part2 := keyNumber2 / space2 + + // Step 8. let challenge to be concatination of part1, part2 and key3. + challenge := make([]byte, 16) + challengeBuf := bytes.NewBuffer(challenge) + err = binary.Write(challengeBuf, binary.BigEndian, part1) + if err != nil { + return + } + err = binary.Write(challengeBuf, binary.BigEndian, part2) + if err != nil { + return + } + if n := copy(challenge[8:], key3); n != 8 { + return + } + // Step 9. get MD5 fingerprint of challenge. + h := md5.New() + if _, err = h.Write(challenge); err != nil { + return + } + response := h.Sum() + + // Step 10. send response status line. + buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n") + // Step 11. send response headers. + buf.WriteString("Upgrade: WebSocket\r\n") + buf.WriteString("Connection: Upgrade\r\n") + buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n") + buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n") + protocol, found := req.Header["Sec-WebSocket-Protocol"] + if found { + buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n") + } + // Step 12. send CRLF. + buf.WriteString("\r\n") + // Step 13. send response data. + buf.Write(response) + if err := buf.Flush(); err != nil { + return + } + ws := newConn(origin, location, protocol, buf, rwc) + f(ws) +} + + +/* + Draft75Handler is an interface to a WebSocket based on + (soon obsolete) draft-hixie-thewebsocketprotocol-75. +*/ +type Draft75Handler func(*Conn) + +// ServeHTTP implements the http.Handler interface for a Web Socket. +func (f Draft75Handler) ServeHTTP(c *http.Conn, req *http.Request) { if req.Method != "GET" || req.Proto != "HTTP/1.1" { c.WriteHeader(http.StatusBadRequest) io.WriteString(c, "Unexpected request") return } - if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" { + if v, found := req.Header["Upgrade"]; !found || v != "WebSocket" { c.WriteHeader(http.StatusBadRequest) io.WriteString(c, "missing Upgrade: WebSocket header") return } - if v, present := req.Header["Connection"]; !present || v != "Upgrade" { + if v, found := req.Header["Connection"]; !found || v != "Upgrade" { c.WriteHeader(http.StatusBadRequest) io.WriteString(c, "missing Connection: Upgrade header") return } - origin, present := req.Header["Origin"] - if !present { + origin, found := req.Header["Origin"] + if !found { c.WriteHeader(http.StatusBadRequest) io.WriteString(c, "missing Origin header") return @@ -75,9 +213,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { buf.WriteString("Connection: Upgrade\r\n") buf.WriteString("WebSocket-Origin: " + origin + "\r\n") buf.WriteString("WebSocket-Location: " + location + "\r\n") - protocol, present := req.Header["Websocket-Protocol"] + protocol, found := req.Header["Websocket-Protocol"] // canonical header key of WebSocket-Protocol. - if present { + if found { buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n") } buf.WriteString("\r\n") diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go index 92582b1ef2..58065580e7 100644 --- a/src/pkg/websocket/websocket_test.go +++ b/src/pkg/websocket/websocket_test.go @@ -27,23 +27,56 @@ func startServer() { serverAddr = l.Addr().String() log.Stderr("Test WebSocket server listening on ", serverAddr) http.Handle("/echo", Handler(echoServer)) + http.Handle("/echoDraft75", Draft75Handler(echoServer)) go http.Serve(l, nil) } func TestEcho(t *testing.T) { once.Do(startServer) + // websocket.Dial() client, err := net.Dial("tcp", "", serverAddr) if err != nil { t.Fatal("dialing", err) } - ws, err := newClient("/echo", "localhost", "http://localhost", - "ws://localhost/echo", "", client) + "ws://localhost/echo", "", client, handshake) + if err != nil { + t.Errorf("WebSocket handshake error", err) + return + } + + msg := []byte("hello, world\n") + if _, err := ws.Write(msg); err != nil { + t.Errorf("Write: error %v", err) + } + var actual_msg = make([]byte, 512) + n, err := ws.Read(actual_msg) + if err != nil { + t.Errorf("Read: error %v", err) + } + actual_msg = actual_msg[0:n] + if !bytes.Equal(msg, actual_msg) { + t.Errorf("Echo: expected %q got %q", msg, actual_msg) + } + ws.Close() +} + +func TestEchoDraft75(t *testing.T) { + once.Do(startServer) + + // websocket.Dial() + client, err := net.Dial("tcp", "", serverAddr) + if err != nil { + t.Fatal("dialing", err) + } + ws, err := newClient("/echoDraft75", "localhost", "http://localhost", + "ws://localhost/echoDraft75", "", client, draft75handshake) if err != nil { t.Errorf("WebSocket handshake error", err) return } + msg := []byte("hello, world\n") if _, err := ws.Write(msg); err != nil { t.Errorf("Write: error %v", err) @@ -69,7 +102,7 @@ func TestWithQuery(t *testing.T) { } ws, err := newClient("/echo?q=v", "localhost", "http://localhost", - "ws://localhost/echo?q=v", "", client) + "ws://localhost/echo?q=v", "", client, handshake) if err != nil { t.Errorf("WebSocket handshake error", err) return @@ -80,13 +113,33 @@ func TestWithQuery(t *testing.T) { func TestHTTP(t *testing.T) { once.Do(startServer) - r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) + // If the client did not send a handshake that matches the protocol + // specification, the server should abort the WebSocket connection. + _, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) + if err == nil { + t.Errorf("Get: unexpected success") + return + } + urlerr, ok := err.(*http.URLError) + if !ok { + t.Errorf("Get: not URLError %#v", err) + return + } + if urlerr.Error != io.ErrUnexpectedEOF { + t.Errorf("Get: error %#v", err) + return + } +} + +func TestHTTPDraft75(t *testing.T) { + once.Do(startServer) + + r, _, err := http.Get(fmt.Sprintf("http://%s/echoDraft75", serverAddr)) if err != nil { - t.Errorf("Get: error %v", err) + t.Errorf("Get: error %#v", err) return } if r.StatusCode != http.StatusBadRequest { t.Errorf("Get: got status %d", r.StatusCode) - return } } -- 2.48.1