From 50e0fb4c791af190a93419519395db3c97fa7c65 Mon Sep 17 00:00:00 2001 From: Jukka-Pekka Kekkonen Date: Sat, 11 Sep 2010 00:27:16 -0400 Subject: [PATCH] websocket: Add support for secure WebSockets Fixes #842. Fixes #1061. R=rsc CC=golang-dev https://golang.org/cl/2119042 --- src/pkg/http/server.go | 7 +++++++ src/pkg/websocket/client.go | 41 +++++++++++++++++++++++++++++++++---- src/pkg/websocket/server.go | 15 ++++++++++++-- 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/src/pkg/http/server.go b/src/pkg/http/server.go index 2de0748677..c7fd942134 100644 --- a/src/pkg/http/server.go +++ b/src/pkg/http/server.go @@ -63,6 +63,7 @@ type Conn struct { header map[string]string // reply header parameters written int64 // number of bytes written in body status int // status code passed to WriteHeader + usingTLS bool // a flag indicating connection over TLS } // Create new connection from rwc. @@ -73,6 +74,7 @@ func newConn(rwc net.Conn, handler Handler) (c *Conn, err os.Error) { } c.handler = handler c.rwc = rwc + _, c.usingTLS = rwc.(*tls.Conn) br := bufio.NewReader(rwc) bw := bufio.NewWriter(rwc) c.buf = bufio.NewReadWriter(br, bw) @@ -151,6 +153,11 @@ func (c *Conn) readRequest() (req *Request, err os.Error) { return req, nil } +// UsingTLS returns true if the connection uses transport layer security (TLS). +func (c *Conn) UsingTLS() bool { + return c.usingTLS +} + // SetHeader sets a header line in the eventual reply. // For example, SetHeader("Content-Type", "text/html; charset=utf-8") // will result in the header line diff --git a/src/pkg/websocket/client.go b/src/pkg/websocket/client.go index a82a8804d3..caf63f16f6 100644 --- a/src/pkg/websocket/client.go +++ b/src/pkg/websocket/client.go @@ -8,6 +8,7 @@ import ( "bufio" "bytes" "container/vector" + "crypto/tls" "fmt" "http" "io" @@ -22,6 +23,7 @@ type ProtocolError struct { } var ( + ErrBadScheme = os.ErrorString("bad scheme") ErrBadStatus = &ProtocolError{"bad status"} ErrBadUpgrade = &ProtocolError{"missing or bad upgrade"} ErrBadWebSocketOrigin = &ProtocolError{"missing or bad WebSocket-Origin"} @@ -31,6 +33,17 @@ var ( secKeyRandomChars [0x30 - 0x21 + 0x7F - 0x3A]byte ) +type DialError struct { + URL string + Protocol string + Origin string + Error os.Error +} + +func (e *DialError) String() string { + return "websocket.Dial " + e.URL + ": " + e.Error.String() +} + func init() { i := 0 for ch := byte(0x21); ch < 0x30; ch++ { @@ -86,15 +99,35 @@ A trivial example client: } */ func Dial(url, protocol, origin string) (ws *Conn, err os.Error) { + var client net.Conn + parsedUrl, err := http.ParseURL(url) if err != nil { - return + goto Error + } + + switch parsedUrl.Scheme { + case "ws": + client, err = net.Dial("tcp", "", parsedUrl.Host) + + case "wss": + client, err = tls.Dial("tcp", "", parsedUrl.Host) + + default: + err = ErrBadScheme } - client, err := net.Dial("tcp", "", parsedUrl.Host) if err != nil { - return + goto Error + } + + ws, err = newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake) + if err != nil { + goto Error } - return newClient(parsedUrl.RawPath, parsedUrl.Host, origin, url, protocol, client, handshake) + return + +Error: + return nil, &DialError{url, protocol, origin, err} } /* diff --git a/src/pkg/websocket/server.go b/src/pkg/websocket/server.go index 6f33a9abed..b884884fa5 100644 --- a/src/pkg/websocket/server.go +++ b/src/pkg/websocket/server.go @@ -97,7 +97,12 @@ func (f Handler) ServeHTTP(c *http.Conn, req *http.Request) { return } - location := "ws://" + req.Host + req.URL.RawPath + var location string + if c.UsingTLS() { + location = "wss://" + req.Host + req.URL.RawPath + } else { + location = "ws://" + req.Host + req.URL.RawPath + } // Step 4. get key number in Sec-WebSocket-Key fields. keyNumber1 := getKeyNumber(key1) @@ -185,7 +190,13 @@ func (f Draft75Handler) ServeHTTP(c *http.Conn, req *http.Request) { return } defer rwc.Close() - location := "ws://" + req.Host + req.URL.RawPath + + var location string + if c.UsingTLS() { + location = "wss://" + req.Host + req.URL.RawPath + } else { + location = "ws://" + req.Host + req.URL.RawPath + } // TODO(ukai): verify origin,location,protocol. -- 2.50.0