]> Cypherpunks repositories - gostls13.git/commitdiff
websocket: Add support for secure WebSockets
authorJukka-Pekka Kekkonen <karatepekka@gmail.com>
Sat, 11 Sep 2010 04:27:16 +0000 (00:27 -0400)
committerRuss Cox <rsc@golang.org>
Sat, 11 Sep 2010 04:27:16 +0000 (00:27 -0400)
Fixes #842.
Fixes #1061.

R=rsc
CC=golang-dev
https://golang.org/cl/2119042

src/pkg/http/server.go
src/pkg/websocket/client.go
src/pkg/websocket/server.go

index 2de0748677670b421196be5a0ecd7da57b0b8bd4..c7fd942134e5dce92d209b5f5bdc01bf411c410c 100644 (file)
@@ -63,6 +63,7 @@ type Conn struct {
        header          map[string]string // reply header parameters
        written         int64             // number of bytes written in body
        status          int               // status code passed to WriteHeader
+       usingTLS        bool              // a flag indicating connection over TLS
 }
 
 // Create new connection from rwc.
@@ -73,6 +74,7 @@ func newConn(rwc net.Conn, handler Handler) (c *Conn, err os.Error) {
        }
        c.handler = handler
        c.rwc = rwc
+       _, c.usingTLS = rwc.(*tls.Conn)
        br := bufio.NewReader(rwc)
        bw := bufio.NewWriter(rwc)
        c.buf = bufio.NewReadWriter(br, bw)
@@ -151,6 +153,11 @@ func (c *Conn) readRequest() (req *Request, err os.Error) {
        return req, nil
 }
 
+// UsingTLS returns true if the connection uses transport layer security (TLS).
+func (c *Conn) UsingTLS() bool {
+       return c.usingTLS
+}
+
 // SetHeader sets a header line in the eventual reply.
 // For example, SetHeader("Content-Type", "text/html; charset=utf-8")
 // will result in the header line
index a82a8804d373e7e2a39b110b4df011c2b90c0186..caf63f16f657d845a07dcc6a683c9d5cc5f53c85 100644 (file)
@@ -8,6 +8,7 @@ import (
        "bufio"
        "bytes"
        "container/vector"
+       "crypto/tls"
        "fmt"
        "http"
        "io"
@@ -22,6 +23,7 @@ type ProtocolError struct {
 }
 
 var (
+       ErrBadScheme            = os.ErrorString("bad scheme")
        ErrBadStatus            = &ProtocolError{"bad status"}
        ErrBadUpgrade           = &ProtocolError{"missing or bad upgrade"}
        ErrBadWebSocketOrigin   = &ProtocolError{"missing or bad WebSocket-Origin"}
@@ -31,6 +33,17 @@ var (
        secKeyRandomChars       [0x30 - 0x21 + 0x7F - 0x3A]byte
 )
 
+type DialError struct {
+       URL      string
+       Protocol string
+       Origin   string
+       Error    os.Error
+}
+
+func (e *DialError) String() string {
+       return "websocket.Dial " + e.URL + ": " + e.Error.String()
+}
+
 func init() {
        i := 0
        for ch := byte(0x21); ch < 0x30; ch++ {
@@ -86,15 +99,35 @@ A trivial example client:
        }
 */
 func Dial(url, protocol, origin string) (ws *Conn, err os.Error) {
+       var client net.Conn
+
        parsedUrl, err := http.ParseURL(url)
        if err != nil {
-               return
+               goto Error
+       }
+
+       switch parsedUrl.Scheme {
+       case "ws":
+               client, err = net.Dial("tcp", "", parsedUrl.Host)
+
+       case "wss":
+               client, err = tls.Dial("tcp", "", parsedUrl.Host)
+
+       default:
+               err = ErrBadScheme
        }
-       client, err := net.Dial("tcp", "", parsedUrl.Host)
        if err != nil {
-               return
+               goto Error
+       }
+
+       ws, err = newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake)
+       if err != nil {
+               goto Error
        }
-       return newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake)
+       return
+
+Error:
+       return nil, &DialError{url, protocol, origin, err}
 }
 
 /*
index 6f33a9abed540962325c661d584d4840390b65f3..b884884fa53b194e536478d135e7760e65be2be4 100644 (file)
@@ -97,7 +97,12 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) {
                return
        }
 
-       location := "ws://" + req.Host + req.URL.RawPath
+       var location string
+       if c.UsingTLS() {
+               location = "wss://" + req.Host + req.URL.RawPath
+       } else {
+               location = "ws://" + req.Host + req.URL.RawPath
+       }
 
        // Step 4. get key number in Sec-WebSocket-Key<n> fields.
        keyNumber1 := getKeyNumber(key1)
@@ -185,7 +190,13 @@ func (f Draft75Handler) ServeHTTP(c *http.Conn, req *http.Request) {
                return
        }
        defer rwc.Close()
-       location := "ws://" + req.Host + req.URL.RawPath
+
+       var location string
+       if c.UsingTLS() {
+               location = "wss://" + req.Host + req.URL.RawPath
+       } else {
+               location = "ws://" + req.Host + req.URL.RawPath
+       }
 
        // TODO(ukai): verify origin,location,protocol.