]> Cypherpunks repositories - gostls13.git/commitdiff
Add WebSocket server framework hooked into http.
authorFumitoshi Ukai <ukai@google.com>
Sun, 29 Nov 2009 22:22:44 +0000 (14:22 -0800)
committerRuss Cox <rsc@golang.org>
Sun, 29 Nov 2009 22:22:44 +0000 (14:22 -0800)
R=r, rsc
https://golang.org/cl/156071

src/pkg/Makefile
src/pkg/websocket/Makefile [new file with mode: 0644]
src/pkg/websocket/client.go [new file with mode: 0644]
src/pkg/websocket/server.go [new file with mode: 0644]
src/pkg/websocket/websocket.go [new file with mode: 0644]
src/pkg/websocket/websocket_test.go [new file with mode: 0644]

index 2b5e76c40af032dc469fd6177b0285576ca87c75..9386a8f5aa6dbccd52293199f83463481d404280 100644 (file)
@@ -98,6 +98,7 @@ DIRS=\
        time\
        unicode\
        utf8\
+       websocket\
        xml\
 
 NOTEST=\
diff --git a/src/pkg/websocket/Makefile b/src/pkg/websocket/Makefile
new file mode 100644 (file)
index 0000000..ba1b726
--- /dev/null
@@ -0,0 +1,9 @@
+include $(GOROOT)/src/Make.$(GOARCH)
+
+TARG=websocket
+GOFILES=\
+       client.go\
+       server.go\
+       websocket.go\
+
+include $(GOROOT)/src/Make.pkg
diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go
new file mode 100644 (file)
index 0000000..bedaec0
--- /dev/null
@@ -0,0 +1,136 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+       "bufio";
+       "http";
+       "io";
+       "net";
+       "os";
+)
+
+type ProtocolError struct {
+       os.ErrorString;
+}
+
+var (
+       ErrBadStatus            = &ProtocolError{"bad status"};
+       ErrNoUpgrade            = &ProtocolError{"no upgrade"};
+       ErrBadUpgrade           = &ProtocolError{"bad upgrade"};
+       ErrNoWebSocketOrigin    = &ProtocolError{"no WebSocket-Origin"};
+       ErrBadWebSocketOrigin   = &ProtocolError{"bad WebSocket-Origin"};
+       ErrNoWebSocketLocation  = &ProtocolError{"no WebSocket-Location"};
+       ErrBadWebSocketLocation = &ProtocolError{"bad WebSocket-Location"};
+       ErrNoWebSocketProtocol  = &ProtocolError{"no WebSocket-Protocol"};
+       ErrBadWebSocketProtocol = &ProtocolError{"bad WebSocket-Protocol"};
+)
+
+// newClient creates a new Web Socket client connection.
+func newClient(resourceName, host, origin, location, protocol string, rwc io.ReadWriteCloser) (ws *Conn, err os.Error) {
+       br := bufio.NewReader(rwc);
+       bw := bufio.NewWriter(rwc);
+       err = handshake(resourceName, host, origin, location, protocol, br, bw);
+       if err != nil {
+               return
+       }
+       buf := bufio.NewReadWriter(br, bw);
+       ws = newConn(origin, location, protocol, buf, rwc);
+       return;
+}
+
+// Dial opens new Web Socket client connection.
+//
+// A trivial example client is:
+//
+// package main
+//
+// import (
+//  "websocket"
+//  "strings"
+// )
+//
+// func main() {
+//    ws, err := websocket.Dial("ws://localhost/ws", "", "http://localhost/");
+//    if err != nil {
+//        panic("Dial: ", err.String())
+//    }
+//    if _, err := ws.Write(strings.Bytes("hello, world!\n")); err != nil {
+//        panic("Write: ", err.String())
+//    }
+//    var msg = make([]byte, 512);
+//    if n, err := ws.Read(msg); err != nil {
+//        panic("Read: ", err.String())
+//    }
+//    // msg[0:n]
+// }
+
+func Dial(url, protocol, origin string) (ws *Conn, err os.Error) {
+       parsedUrl, err := http.ParseURL(url);
+       if err != nil {
+               return
+       }
+       client, err := net.Dial("tcp", "", parsedUrl.Host);
+       if err != nil {
+               return
+       }
+       return newClient(parsedUrl.Path, parsedUrl.Host, origin, url, protocol, client);
+}
+
+func handshake(resourceName, host, origin, location, protocol string, br *bufio.Reader, bw *bufio.Writer) (err os.Error) {
+       bw.WriteString("GET " + resourceName + " HTTP/1.1\r\n");
+       bw.WriteString("Upgrade: WebSocket\r\n");
+       bw.WriteString("Connection: Upgrade\r\n");
+       bw.WriteString("Host: " + host + "\r\n");
+       bw.WriteString("Origin: " + origin + "\r\n");
+       if protocol != "" {
+               bw.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
+       }
+       bw.WriteString("\r\n");
+       bw.Flush();
+       resp, err := http.ReadResponse(br);
+       if err != nil {
+               return
+       }
+       if resp.Status != "101 Web Socket Protocol Handshake" {
+               return ErrBadStatus
+       }
+       upgrade, found := resp.Header["Upgrade"];
+       if !found {
+               return ErrNoUpgrade
+       }
+       if upgrade != "WebSocket" {
+               return ErrBadUpgrade
+       }
+       connection, found := resp.Header["Connection"];
+       if !found || connection != "Upgrade" {
+               return ErrBadUpgrade
+       }
+
+       ws_origin, found := resp.Header["Websocket-Origin"];
+       if !found {
+               return ErrNoWebSocketOrigin
+       }
+       if ws_origin != origin {
+               return ErrBadWebSocketOrigin
+       }
+       ws_location, found := resp.Header["Websocket-Location"];
+       if !found {
+               return ErrNoWebSocketLocation
+       }
+       if ws_location != location {
+               return ErrBadWebSocketLocation
+       }
+       if protocol != "" {
+               ws_protocol, found := resp.Header["Websocket-Protocol"];
+               if !found {
+                       return ErrNoWebSocketProtocol
+               }
+               if ws_protocol != protocol {
+                       return ErrBadWebSocketProtocol
+               }
+       }
+       return;
+}
diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go
new file mode 100644 (file)
index 0000000..e2d8cec
--- /dev/null
@@ -0,0 +1,73 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+       "http";
+       "io";
+)
+
+// Handler is a interface that use a WebSocket.
+//
+// A trivial example server is:
+//
+//  package main
+//
+//  import (
+//     "http"
+//     "io"
+//     "websocket"
+//  )
+//
+//  // echo back the websocket.
+//  func EchoServer(ws *websocket.Conn) {
+//       io.Copy(ws, ws);
+//  }
+//
+//  func main() {
+//    http.Handle("/echo", websocket.Handler(EchoServer));
+//    err := http.ListenAndServe(":12345", nil);
+//    if err != nil {
+//        panic("ListenAndServe: ", err.String())
+//    }
+//  }
+type Handler func(*Conn)
+
+func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
+       if req.Method != "GET" || req.Proto != "HTTP/1.1" ||
+               req.Header["Upgrade"] != "WebSocket" ||
+               req.Header["Connection"] != "Upgrade" {
+               c.WriteHeader(http.StatusNotFound);
+               io.WriteString(c, "must use websocket to connect here");
+               return;
+       }
+       rwc, buf, err := c.Hijack();
+       if err != nil {
+               panic("Hijack failed: ", err.String());
+               return;
+       }
+       defer rwc.Close();
+       origin := req.Header["Origin"];
+       location := "ws://" + req.Host + req.URL.Path;
+
+       // TODO(ukai): verify origin,location,protocol.
+
+       buf.WriteString("HTTP/1.1 101 Web Socket Protocol Handshake\r\n");
+       buf.WriteString("Upgrade: WebSocket\r\n");
+       buf.WriteString("Connection: Upgrade\r\n");
+       buf.WriteString("WebSocket-Origin: " + origin + "\r\n");
+       buf.WriteString("WebSocket-Location: " + location + "\r\n");
+       protocol := "";
+       // canonical header key of WebSocket-Protocol.
+       if protocol, found := req.Header["Websocket-Protocol"]; found {
+               buf.WriteString("WebSocket-Protocol: " + protocol + "\r\n")
+       }
+       buf.WriteString("\r\n");
+       if err := buf.Flush(); err != nil {
+               return
+       }
+       ws := newConn(origin, location, protocol, buf, rwc);
+       f(ws);
+}
diff --git a/src/pkg/websocket/websocket.go b/src/pkg/websocket/websocket.go
new file mode 100644 (file)
index 0000000..0fd32cd
--- /dev/null
@@ -0,0 +1,138 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// The websocket package implements Web Socket protocol server.
+package websocket
+
+// References:
+//   The Web Socket protocol: http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol
+
+// TODO(ukai):
+//   better logging.
+
+import (
+       "bufio";
+       "io";
+       "net";
+       "os";
+)
+
+type WebSocketAddr string
+
+func (addr WebSocketAddr) Network() string     { return "websocket" }
+
+func (addr WebSocketAddr) String() string      { return string(addr) }
+
+// Conn is an channels to communicate over Web Socket.
+type Conn struct {
+       // An origin URI of the Web Socket.
+       Origin  string;
+       // A location URI of the Web Socket.
+       Location        string;
+       // A subprotocol of the Web Socket.
+       Protocol        string;
+
+       buf     *bufio.ReadWriter;
+       rwc     io.ReadWriteCloser;
+}
+
+// newConn creates a new Web Socket.
+func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
+       if buf == nil {
+               br := bufio.NewReader(rwc);
+               bw := bufio.NewWriter(rwc);
+               buf = bufio.NewReadWriter(br, bw);
+       }
+       ws := &Conn{origin, location, protocol, buf, rwc};
+       return ws;
+}
+
+func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
+       for {
+               frameByte, err := ws.buf.ReadByte();
+               if err != nil {
+                       return
+               }
+               if (frameByte & 0x80) == 0x80 {
+                       length := 0;
+                       for {
+                               c, err := ws.buf.ReadByte();
+                               if err != nil {
+                                       return
+                               }
+                               if (c & 0x80) == 0x80 {
+                                       length = length*128 + int(c&0x7f)
+                               } else {
+                                       break
+                               }
+                       }
+                       for length > 0 {
+                               _, err := ws.buf.ReadByte();
+                               if err != nil {
+                                       return
+                               }
+                               length--;
+                       }
+               } else {
+                       for {
+                               c, err := ws.buf.ReadByte();
+                               if err != nil {
+                                       return
+                               }
+                               if c == '\xff' {
+                                       return
+                               }
+                               if frameByte == 0 {
+                                       if n+1 <= cap(msg) {
+                                               msg = msg[0 : n+1]
+                                       }
+                                       msg[n] = c;
+                                       n++;
+                               }
+                               if n >= cap(msg) {
+                                       err = os.E2BIG;
+                                       return;
+                               }
+                       }
+               }
+       }
+       return;
+}
+
+func (ws *Conn) Write(msg []byte) (n int, err os.Error) {
+       ws.buf.WriteByte(0);
+       ws.buf.Write(msg);
+       ws.buf.WriteByte(0xff);
+       err = ws.buf.Flush();
+       return len(msg), err;
+}
+
+func (ws *Conn) Close() os.Error       { return ws.rwc.Close() }
+
+func (ws *Conn) LocalAddr() net.Addr   { return WebSocketAddr(ws.Origin) }
+
+func (ws *Conn) RemoteAddr() net.Addr  { return WebSocketAddr(ws.Location) }
+
+func (ws *Conn) SetTimeout(nsec int64) os.Error {
+       if conn, ok := ws.rwc.(net.Conn); ok {
+               return conn.SetTimeout(nsec)
+       }
+       return os.EINVAL;
+}
+
+func (ws *Conn) SetReadTimeout(nsec int64) os.Error {
+       if conn, ok := ws.rwc.(net.Conn); ok {
+               return conn.SetReadTimeout(nsec)
+       }
+       return os.EINVAL;
+}
+
+func (ws *Conn) SetWriteTimeout(nsec int64) os.Error {
+       if conn, ok := ws.rwc.(net.Conn); ok {
+               return conn.SetWriteTimeout(nsec)
+       }
+       return os.EINVAL;
+}
+
+var _ net.Conn = (*Conn)(nil)  // compile-time check that *Conn implements net.Conn.
diff --git a/src/pkg/websocket/websocket_test.go b/src/pkg/websocket/websocket_test.go
new file mode 100644 (file)
index 0000000..ed25053
--- /dev/null
@@ -0,0 +1,61 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package websocket
+
+import (
+       "bytes";
+       "http";
+       "io";
+       "log";
+       "net";
+       "once";
+       "strings";
+       "testing";
+)
+
+var serverAddr string
+
+func echoServer(ws *Conn)      { io.Copy(ws, ws) }
+
+func startServer() {
+       l, e := net.Listen("tcp", ":0");        // any available address
+       if e != nil {
+               log.Exitf("net.Listen tcp :0 %v", e)
+       }
+       serverAddr = l.Addr().String();
+       log.Stderr("Test WebSocket server listening on ", serverAddr);
+       http.Handle("/echo", Handler(echoServer));
+       go http.Serve(l, nil);
+}
+
+func TestEcho(t *testing.T) {
+       once.Do(startServer);
+
+       client, err := net.Dial("tcp", "", serverAddr);
+       if err != nil {
+               t.Fatal("dialing", err)
+       }
+
+       ws, err := newClient("/echo", "localhost", "http://localhost",
+               "ws://localhost/echo", "", client);
+       if err != nil {
+               t.Errorf("WebSocket handshake error", err);
+               return;
+       }
+       msg := strings.Bytes("hello, world\n");
+       if _, err := ws.Write(msg); err != nil {
+               t.Errorf("Write: error %v", err)
+       }
+       var actual_msg = make([]byte, 512);
+       n, err := ws.Read(actual_msg);
+       if err != nil {
+               t.Errorf("Read: error %v", err)
+       }
+       actual_msg = actual_msg[0:n];
+       if !bytes.Equal(msg, actual_msg) {
+               t.Errorf("Echo: expected %q got %q", msg, actual_msg)
+       }
+       ws.Close();
+}