]> Cypherpunks repositories - gostls13.git/commitdiff
websocket: implement new protocol
authorFumitoshi Ukai <ukai@google.com>
Wed, 24 Mar 2010 01:09:24 +0000 (18:09 -0700)
committerRuss Cox <rsc@golang.org>
Wed, 24 Mar 2010 01:09:24 +0000 (18:09 -0700)
http://www.whatwg.org/specs/web-socket-protocol/
(draft of draft-hixie-thewebsocketprotocol-76)

draft-hixie-thewebsocketprotocol-76 will introduce new handshake
incompatible draft 75 or prior.
http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol

R=rsc
CC=golang-dev
https://golang.org/cl/583041

src/pkg/websocket/client.go
src/pkg/websocket/server.go
src/pkg/websocket/websocket_test.go

index 7bf53d840c341c59bf810a897d9c05549d001095..52870800cc4e696dad3e76fff6c4b133d8d48b01 100644 (file)
@@ -5,11 +5,18 @@
 package websocket
 
 import (
+       "encoding/binary"
        "bufio"
+       "bytes"
+       "container/vector"
+       "crypto/md5"
+       "fmt"
        "http"
        "io"
        "net"
        "os"
+       "rand"
+       "strings"
 )
 
 type ProtocolError struct {
@@ -26,10 +33,26 @@ var (
        ErrBadWebSocketLocation = &ProtocolError{"bad WebSocket-Location"}
        ErrNoWebSocketProtocol  = &ProtocolError{"no WebSocket-Protocol"}
        ErrBadWebSocketProtocol = &ProtocolError{"bad WebSocket-Protocol"}
+       ErrChallengeResponse    = &ProtocolError{"mismatch challange/response"}
+       secKeyRandomChars       [0x30 - 0x21 + 0x7F - 0x3A]byte
 )
 
+func init() {
+       i := 0
+       for ch := byte(0x21); ch < 0x30; ch++ {
+               secKeyRandomChars[i] = ch
+               i++
+       }
+       for ch := byte(0x3a); ch < 0x7F; ch++ {
+               secKeyRandomChars[i] = ch
+               i++
+       }
+}
+
+type handshaker func(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) os.Error
+
 // newClient creates a new Web Socket client connection.
-func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) {
+func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser, handshake handshaker) (ws *Conn, err os.Error) {
        br := bufio.NewReader(rwc)
        bw := bufio.NewWriter(rwc)
        err = handshake(resourceName, host, origin, location, protocol, br, bw)
@@ -76,10 +99,212 @@ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) {
        if err != nil {
                return
        }
-       return newClient(parsedUrl.Path, parsedUrl.Host, origin, url, protocol, client)
+       return newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake)
+}
+
+/*
+       Generates handshake key as described in 4.1 Opening handshake
+        step 16 to 22.
+       cf. http://www.whatwg.org/specs/web-socket-protocol/
+*/
+func generateKeyNumber() (key string, number uint32) {
+       // 16.  Let /spaces_n/ be a random integer from 1 to 12 inclusive.
+       spaces := rand.Intn(12) + 1
+
+       // 17. Let /max_n/ be the largest integer not greater than
+       //     4,294,967,295 divided by /spaces_n/
+       max := int(4294967295 / uint32(spaces))
+
+       // 18. Let /number_n/ be a random integer from 0 to /max_n/ inclusive.
+       number = uint32(rand.Intn(max + 1))
+
+       // 19. Let /product_n/ be the result of multiplying /number_n/ and
+       //     /spaces_n/ together.
+       product := number * uint32(spaces)
+
+       // 20. Let /key_n/ be a string consisting of /product_n/, expressed
+       // in base ten using the numerals in the range U+0030 DIGIT ZERO (0)
+       // to U+0039 DIGIT NINE (9).
+       key = fmt.Sprintf("%d", product)
+
+       // 21. Insert /spaces_n/ U+0020 SPACE characters into /key_n/ at random
+       //     posisions.
+       for i := 0; i < spaces; i++ {
+               pos := rand.Intn(len(key)-1) + 1
+               key = key[0:pos] + " " + key[pos:]
+       }
+
+       // 22. Insert between one and twelve random characters from the ranges
+       //     U+0021 to U+002F and U+003A to U+007E into /key_n/ at random
+       //     positions.
+       n := rand.Intn(12) + 1
+       for i := 0; i < n; i++ {
+               pos := rand.Intn(len(key)) + 1
+               ch := secKeyRandomChars[rand.Intn(len(secKeyRandomChars))]
+               key = key[0:pos] + string(ch) + key[pos:]
+       }
+       return
+}
+
+/*
+       Generates handshake key_3 as described in 4.1 Opening handshake
+        step 26.
+       cf. http://www.whatwg.org/specs/web-socket-protocol/
+*/
+func generateKey3() (key []byte) {
+       // 26. Let /key3/ be a string consisting of eight random bytes (or
+       //  equivalently, a random 64 bit integer encoded in big-endian order).
+       key = make([]byte, 8)
+       for i := 0; i < 8; i++ {
+               key[i] = byte(rand.Intn(256))
+       }
+       return
 }
 
+/*
+       Gets expected from challenge as described in 4.1 Opening handshake
+        Step 42 to 43.
+       cf. http://www.whatwg.org/specs/web-socket-protocol/
+*/
+func getExpectedForChallenge(number1, number2 uint32, key3 []byte) (expected []byte, err os.Error) {
+       // 41. Let /challenge/ be the concatenation of /number_1/, expressed
+       // a big-endian 32 bit integer, /number_2/, expressed in a big-
+       // endian 32 bit integer, and the eight bytes of /key_3/ in the
+       // order they were sent to the wire.
+       challenge := make([]byte, 16)
+       challengeBuf := bytes.NewBuffer(challenge)
+       binary.Write(challengeBuf, binary.BigEndian, number1)
+       binary.Write(challengeBuf, binary.BigEndian, number2)
+       copy(challenge[8:], key3)
+
+       // 42. Let /expected/ be the MD5 fingerprint of /challenge/ as a big-
+       // endian 128 bit string.
+       h := md5.New()
+       if _, err = h.Write(challenge); err != nil {
+               return
+       }
+       expected = h.Sum()
+       return
+}
+
+/*
+       Web Socket protocol handshake based on
+       http://www.whatwg.org/specs/web-socket-protocol/
+        (draft of http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol)
+*/
 func handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) {
+       // 4.1. Opening handshake.
+       // Step 5.  send a request line.
+       bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n")
+
+       // Step 6-14. push request headers in fields.
+       var fields vector.StringVector
+       fields.Push("Upgrade: WebSocket\r\n")
+       fields.Push("Connection: Upgrade\r\n")
+       fields.Push("Host: " + host + "\r\n")
+       fields.Push("Origin: " + origin + "\r\n")
+       if protocol != "" {
+               fields.Push("Sec-WebSocket-Protocol: " + protocol + "\r\n")
+       }
+       // TODO(ukai): Step 15. send cookie if any.
+
+       // Step 16-23. generate keys and push Sec-WebSocket-Key<n> in fields.
+       key1, number1 := generateKeyNumber()
+       key2, number2 := generateKeyNumber()
+       fields.Push("Sec-WebSocket-Key1: " + key1 + "\r\n")
+       fields.Push("Sec-WebSocket-Key2: " + key2 + "\r\n")
+
+       // Step 24. shuffle fields and send them out.
+       for i := 1; i < len(fields); i++ {
+               j := rand.Intn(i)
+               fields[i], fields[j] = fields[j], fields[i]
+       }
+       for i := 0; i < len(fields); i++ {
+               bw.WriteString(fields[i])
+       }
+       // Step 25. send CRLF.
+       bw.WriteString("\r\n")
+
+       // Step 26. genearte 8 bytes random key.
+       key3 := generateKey3()
+       // Step 27. send it out.
+       bw.Write(key3)
+       if err = bw.Flush(); err != nil {
+               return
+       }
+
+       // Step 28-29, 32-40. read response from server.
+       resp, err := http.ReadResponse(br, "GET")
+       if err != nil {
+               return err
+       }
+       // Step 30. check response code is 101.
+       if resp.StatusCode != 101 {
+               return ErrBadStatus
+       }
+
+       // Step 41. check websocket headers.
+       upgrade, found := resp.Header["Upgrade"]
+       if !found {
+               return ErrNoUpgrade
+       }
+       if upgrade != "WebSocket" {
+               return ErrBadUpgrade
+       }
+       connection, found := resp.Header["Connection"]
+       if !found || strings.ToLower(connection) != "upgrade" {
+               return ErrBadUpgrade
+       }
+
+       s, found := resp.Header["Sec-Websocket-Origin"]
+       if !found {
+               return ErrNoWebSocketOrigin
+       }
+       if s != origin {
+               return ErrBadWebSocketOrigin
+       }
+       s, found = resp.Header["Sec-Websocket-Location"]
+       if !found {
+               return ErrNoWebSocketLocation
+       }
+       if s != location {
+               return ErrBadWebSocketLocation
+       }
+       if protocol != "" {
+               s, found = resp.Header["Sec-Websocket-Protocol"]
+               if !found {
+                       return ErrNoWebSocketProtocol
+               }
+               if s != protocol {
+                       return ErrBadWebSocketProtocol
+               }
+       }
+
+       // Step 42-43. get expected data from challange data.
+       expected, err := getExpectedForChallenge(number1, number2, key3)
+       if err != nil {
+               return err
+       }
+
+       // Step 44. read 16 bytes from server.
+       reply := make([]byte, 16)
+       if _, err = io.ReadFull(br, reply); err != nil {
+               return err
+       }
+
+       // Step 45. check the reply equals to expected data.
+       if !bytes.Equal(expected, reply) {
+               return ErrChallengeResponse
+       }
+       // WebSocket connection is established.
+       return
+}
+
+/*
+       Handhake described in (soon obsolete)
+       draft-hixie-thewebsocket-protocol-75.
+*/
+func draft75handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) {
        bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n")
        bw.WriteString("Upgrade: WebSocket\r\n")
        bw.WriteString("Connection: Upgrade\r\n")
index 93d8b7afd287ac2da02fa2153b9be738862a4b04..78c42990a0e2089476240bc8e9d3f75ff397dd0d 100644 (file)
@@ -5,8 +5,12 @@
 package websocket
 
 import (
+       "bytes"
+       "crypto/md5"
+       "encoding/binary"
        "http"
        "io"
+       "strings"
 )
 
 /*
@@ -36,25 +40,159 @@ import (
 */
 type Handler func(*Conn)
 
-// ServeHTTP implements the http.Handler interface for a Web Socket.
+/*
+       Gets key number from Sec-WebSocket-Key<n>: field as described
+       in 5.2 Sending the server's opening handshake, 4.
+*/
+func getKeyNumber(s string) (r uint32) {
+       // 4. Let /key-number_n/ be the digits (characters in the range
+       // U+0030 DIGIT ZERO (0) to U+0039 DIGIT NINE (9)) in /key_1/,
+       // interpreted as a base ten integer, ignoring all other characters
+       // in /key_n/.
+       r = 0
+       for i := 0; i < len(s); i++ {
+               if s[i] >= '0' && s[i] <= '9' {
+                       r = r*10 + uint32(s[i]) - '0'
+               }
+       }
+       return
+}
+
+// ServeHTTP implements the http.Handler interface for a Web Socket
 func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
+       rwc, buf, err := c.Hijack()
+       if err != nil {
+               panic("Hijack failed: ", err.String())
+               return
+       }
+       // The server should abort the WebSocket connection if it finds
+       // the client did not send a handshake that matches with protocol
+       // specification.
+       defer rwc.Close()
+
+       if req.Method != "GET" {
+               return
+       }
+       // HTTP version can be safely ignored.
+
+       if v, found := req.Header["Upgrade"]; !found ||
+               strings.ToLower(v) != "websocket" {
+               return
+       }
+       if v, found := req.Header["Connection"]; !found ||
+               strings.ToLower(v) != "upgrade" {
+               return
+       }
+       // TODO(ukai): check Host
+       origin, found := req.Header["Origin"]
+       if !found {
+               return
+       }
+
+       key1, found := req.Header["Sec-Websocket-Key1"]
+       if !found {
+               return
+       }
+       key2, found := req.Header["Sec-Websocket-Key2"]
+       if !found {
+               return
+       }
+       key3 := make([]byte, 8)
+       if _, err := io.ReadFull(buf, key3); err != nil {
+               return
+       }
+
+       location := "ws://" + req.Host + req.URL.RawPath
+
+       // Step 4. get key number in Sec-WebSocket-Key<n> fields.
+       keyNumber1 := getKeyNumber(key1)
+       keyNumber2 := getKeyNumber(key2)
+
+       // Step 5. get number of spaces in Sec-WebSocket-Key<n> fields.
+       space1 := uint32(strings.Count(key1, " "))
+       space2 := uint32(strings.Count(key2, " "))
+       if space1 == 0 || space2 == 0 {
+               return
+       }
+
+       // Step 6. key number must be an integral multiple of spaces.
+       if keyNumber1%space1 != 0 || keyNumber2%space2 != 0 {
+               return
+       }
+
+       // Step 7. let part be key number divided by spaces.
+       part1 := keyNumber1 / space1
+       part2 := keyNumber2 / space2
+
+       // Step 8. let challenge to be concatination of part1, part2 and key3.
+       challenge := make([]byte, 16)
+       challengeBuf := bytes.NewBuffer(challenge)
+       err = binary.Write(challengeBuf, binary.BigEndian, part1)
+       if err != nil {
+               return
+       }
+       err = binary.Write(challengeBuf, binary.BigEndian, part2)
+       if err != nil {
+               return
+       }
+       if n := copy(challenge[8:], key3); n != 8 {
+               return
+       }
+       // Step 9. get MD5 fingerprint of challenge.
+       h := md5.New()
+       if _, err = h.Write(challenge); err != nil {
+               return
+       }
+       response := h.Sum()
+
+       // Step 10. send response status line.
+       buf.WriteString("HTTP/1.1 101 WebSocket Protocol Handshake\r\n")
+       // Step 11. send response headers.
+       buf.WriteString("Upgrade: WebSocket\r\n")
+       buf.WriteString("Connection: Upgrade\r\n")
+       buf.WriteString("Sec-WebSocket-Location: " + location + "\r\n")
+       buf.WriteString("Sec-WebSocket-Origin: " + origin + "\r\n")
+       protocol, found := req.Header["Sec-WebSocket-Protocol"]
+       if found {
+               buf.WriteString("Sec-WebSocket-Protocol: " + protocol + "\r\n")
+       }
+       // Step 12. send CRLF.
+       buf.WriteString("\r\n")
+       // Step 13. send response data.
+       buf.Write(response)
+       if err := buf.Flush(); err != nil {
+               return
+       }
+       ws := newConn(origin, location, protocol, buf, rwc)
+       f(ws)
+}
+
+
+/*
+       Draft75Handler is an interface to a WebSocket based on
+        (soon obsolete) draft-hixie-thewebsocketprotocol-75.
+*/
+type Draft75Handler func(*Conn)
+
+// ServeHTTP implements the http.Handler interface for a Web Socket.
+func (f Draft75Handler) ServeHTTP(c *http.Conn, req *http.Request) {
        if req.Method != "GET" || req.Proto != "HTTP/1.1" {
                c.WriteHeader(http.StatusBadRequest)
                io.WriteString(c, "Unexpected request")
                return
        }
-       if v, present := req.Header["Upgrade"]; !present || v != "WebSocket" {
+       if v, found := req.Header["Upgrade"]; !found || v != "WebSocket" {
                c.WriteHeader(http.StatusBadRequest)
                io.WriteString(c, "missing Upgrade: WebSocket header")
                return
        }
-       if v, present := req.Header["Connection"]; !present || v != "Upgrade" {
+       if v, found := req.Header["Connection"]; !found || v != "Upgrade" {
                c.WriteHeader(http.StatusBadRequest)
                io.WriteString(c, "missing Connection: Upgrade header")
                return
        }
-       origin, present := req.Header["Origin"]
-       if !present {
+       origin, found := req.Header["Origin"]
+       if !found {
                c.WriteHeader(http.StatusBadRequest)
                io.WriteString(c, "missing Origin header")
                return
@@ -75,9 +213,9 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
        buf.WriteString("Connection: Upgrade\r\n")
        buf.WriteString("WebSocket-Origin: " + origin + "\r\n")
        buf.WriteString("WebSocket-Location: " + location + "\r\n")
-       protocol, present := req.Header["Websocket-Protocol"]
+       protocol, found := req.Header["Websocket-Protocol"]
        // canonical header key of WebSocket-Protocol.
-       if present {
+       if found {
                buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
        }
        buf.WriteString("\r\n")
index 92582b1ef2ea54c3a01242df955fed9aaf3aada4..58065580e7c1e548c0887e81c4b356582cc25260 100644 (file)
@@ -27,23 +27,56 @@ func startServer() {
        serverAddr = l.Addr().String()
        log.Stderr("Test WebSocket server listening on ", serverAddr)
        http.Handle("/echo", Handler(echoServer))
+       http.Handle("/echoDraft75", Draft75Handler(echoServer))
        go http.Serve(l, nil)
 }
 
 func TestEcho(t *testing.T) {
        once.Do(startServer)
 
+       // websocket.Dial()
        client, err := net.Dial("tcp", "", serverAddr)
        if err != nil {
                t.Fatal("dialing", err)
        }
-
        ws, err := newClient("/echo", "localhost", "http://localhost",
-               "ws://localhost/echo", "", client)
+               "ws://localhost/echo", "", client, handshake)
+       if err != nil {
+               t.Errorf("WebSocket handshake error", err)
+               return
+       }
+
+       msg := []byte("hello, world\n")
+       if _, err := ws.Write(msg); err != nil {
+               t.Errorf("Write: error %v", err)
+       }
+       var actual_msg = make([]byte, 512)
+       n, err := ws.Read(actual_msg)
+       if err != nil {
+               t.Errorf("Read: error %v", err)
+       }
+       actual_msg = actual_msg[0:n]
+       if !bytes.Equal(msg, actual_msg) {
+               t.Errorf("Echo: expected %q got %q", msg, actual_msg)
+       }
+       ws.Close()
+}
+
+func TestEchoDraft75(t *testing.T) {
+       once.Do(startServer)
+
+       // websocket.Dial()
+       client, err := net.Dial("tcp", "", serverAddr)
+       if err != nil {
+               t.Fatal("dialing", err)
+       }
+       ws, err := newClient("/echoDraft75", "localhost", "http://localhost",
+               "ws://localhost/echoDraft75", "", client, draft75handshake)
        if err != nil {
                t.Errorf("WebSocket handshake error", err)
                return
        }
+
        msg := []byte("hello, world\n")
        if _, err := ws.Write(msg); err != nil {
                t.Errorf("Write: error %v", err)
@@ -69,7 +102,7 @@ func TestWithQuery(t *testing.T) {
        }
 
        ws, err := newClient("/echo?q=v", "localhost", "http://localhost",
-               "ws://localhost/echo?q=v", "", client)
+               "ws://localhost/echo?q=v", "", client, handshake)
        if err != nil {
                t.Errorf("WebSocket handshake error", err)
                return
@@ -80,13 +113,33 @@ func TestWithQuery(t *testing.T) {
 func TestHTTP(t *testing.T) {
        once.Do(startServer)
 
-       r, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
+       // If the client did not send a handshake that matches the protocol
+       // specification, the server should abort the WebSocket connection.
+       _, _, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
+       if err == nil {
+               t.Errorf("Get: unexpected success")
+               return
+       }
+       urlerr, ok := err.(*http.URLError)
+       if !ok {
+               t.Errorf("Get: not URLError %#v", err)
+               return
+       }
+       if urlerr.Error != io.ErrUnexpectedEOF {
+               t.Errorf("Get: error %#v", err)
+               return
+       }
+}
+
+func TestHTTPDraft75(t *testing.T) {
+       once.Do(startServer)
+
+       r, _, err := http.Get(fmt.Sprintf("http://%s/echoDraft75", serverAddr))
        if err != nil {
-               t.Errorf("Get: error %v", err)
+               t.Errorf("Get: error %#v", err)
                return
        }
        if r.StatusCode != http.StatusBadRequest {
                t.Errorf("Get: got status %d", r.StatusCode)
-               return
        }
 }