]> Cypherpunks repositories - gostls13.git/commitdiff
web socket: fix short Read
authorFumitoshi Ukai <ukai@google.com>
Thu, 21 Oct 2010 02:36:06 +0000 (22:36 -0400)
committerRuss Cox <rsc@golang.org>
Thu, 21 Oct 2010 02:36:06 +0000 (22:36 -0400)
Fixes #1145.

R=rsc
CC=golang-dev
https://golang.org/cl/2302042

src/pkg/websocket/websocket.go
src/pkg/websocket/websocket_test.go

index 99e1d144854c875ca643a04f810790eac1406b23..d5996abe1a5836461b9503efbece00a31fbab28b 100644 (file)
@@ -27,6 +27,13 @@ func (addr WebSocketAddr) Network() string { return "websocket" }
 // String returns the network address for a Web Socket.
 func (addr WebSocketAddr) String() string { return string(addr) }
 
+const (
+       stateFrameByte = iota
+       stateFrameLength
+       stateFrameData
+       stateFrameTextData
+)
+
 // Conn is a channel to communicate to a Web Socket.
 // It implements the net.Conn interface.
 type Conn struct {
@@ -39,6 +46,10 @@ type Conn struct {
 
        buf *bufio.ReadWriter
        rwc io.ReadWriteCloser
+
+       // It holds text data in previous Read() that failed with small buffer.
+       data    []byte
+       reading bool
 }
 
 // newConn creates a new Web Socket.
@@ -48,60 +59,66 @@ func newConn(origin, location, protocol string, buf *bufio.ReadWriter, rwc io.Re
                bw := bufio.NewWriter(rwc)
                buf = bufio.NewReadWriter(br, bw)
        }
-       ws := &Conn{origin, location, protocol, buf, rwc}
+       ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
        return ws
 }
 
 // Read implements the io.Reader interface for a Conn.
 func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
-       for {
-               frameByte, err := ws.buf.ReadByte()
+Frame:
+       for !ws.reading && len(ws.data) == 0 {
+               // Beginning of frame, possibly.
+               b, err := ws.buf.ReadByte()
                if err != nil {
-                       return n, err
+                       return 0, err
                }
-               if (frameByte & 0x80) == 0x80 {
+               if b&0x80 == 0x80 {
+                       // Skip length frame.
                        length := 0
                        for {
                                c, err := ws.buf.ReadByte()
                                if err != nil {
-                                       return n, err
+                                       return 0, err
                                }
                                length = length*128 + int(c&0x7f)
-                               if (c & 0x80) == 0 {
+                               if c&0x80 == 0 {
                                        break
                                }
                        }
                        for length > 0 {
                                _, err := ws.buf.ReadByte()
                                if err != nil {
-                                       return n, err
+                                       return 0, err
                                }
-                               length--
                        }
-               } else {
+                       continue Frame
+               }
+               // In text mode
+               if b != 0 {
+                       // Skip this frame
                        for {
                                c, err := ws.buf.ReadByte()
                                if err != nil {
-                                       return n, err
+                                       return 0, err
                                }
                                if c == '\xff' {
-                                       return n, err
-                               }
-                               if frameByte == 0 {
-                                       if n+1 <= cap(msg) {
-                                               msg = msg[0 : n+1]
-                                       }
-                                       msg[n] = c
-                                       n++
-                               }
-                               if n >= cap(msg) {
-                                       return n, os.E2BIG
+                                       break
                                }
                        }
+                       continue Frame
                }
+               ws.reading = true
        }
-
-       panic("unreachable")
+       if len(ws.data) == 0 {
+               ws.data, err = ws.buf.ReadSlice('\xff')
+               if err == nil {
+                       ws.reading = false
+                       ws.data = ws.data[:len(ws.data)-1] // trim \xff
+               }
+       }
+       n = copy(msg, ws.data)
+       ws.data = ws.data[n:]
+       return n, err
 }
 
 // Write implements the io.Writer interface for a Conn.
index 9639d8f88b1b4ff829f150eb7336dd5e2f2654e7..c66c114589d95135ff3ac000a178fe518b4a71d0 100644 (file)
@@ -5,6 +5,7 @@
 package websocket
 
 import (
+       "bufio"
        "bytes"
        "fmt"
        "http"
@@ -195,3 +196,77 @@ func TestTrailingSpaces(t *testing.T) {
                }
        }
 }
+
+func TestSmallBuffer(t *testing.T) {
+       // http://code.google.com/p/go/issues/detail?id=1145
+       // Read should be able to handle reading a fragment of a frame.
+       once.Do(startServer)
+
+       // websocket.Dial()
+       client, err := net.Dial("tcp", "", serverAddr)
+       if err != nil {
+               t.Fatal("dialing", err)
+       }
+       ws, err := newClient("/echo", "localhost", "http://localhost",
+               "ws://localhost/echo", "", client, handshake)
+       if err != nil {
+               t.Errorf("WebSocket handshake error: %v", err)
+               return
+       }
+
+       msg := []byte("hello, world\n")
+       if _, err := ws.Write(msg); err != nil {
+               t.Errorf("Write: %v", err)
+       }
+       var small_msg = make([]byte, 8)
+       n, err := ws.Read(small_msg)
+       if err != nil {
+               t.Errorf("Read: %v", err)
+       }
+       if !bytes.Equal(msg[:len(small_msg)], small_msg) {
+               t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
+       }
+       var second_msg = make([]byte, len(msg))
+       n, err = ws.Read(second_msg)
+       if err != nil {
+               t.Errorf("Read: %v", err)
+       }
+       second_msg = second_msg[0:n]
+       if !bytes.Equal(msg[len(small_msg):], second_msg) {
+               t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
+       }
+       ws.Close()
+
+}
+
+func testSkipLengthFrame(t *testing.T) {
+       b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
+       buf := bytes.NewBuffer(b)
+       br := bufio.NewReader(buf)
+       bw := bufio.NewWriter(buf)
+       ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
+       msg := make([]byte, 5)
+       n, err := ws.Read(msg)
+       if err != nil {
+               t.Errorf("Read: %v", err)
+       }
+       if !bytes.Equal(b[4:8], msg[0:n]) {
+               t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
+       }
+}
+
+func testSkipNoUTF8Frame(t *testing.T) {
+       b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
+       buf := bytes.NewBuffer(b)
+       br := bufio.NewReader(buf)
+       bw := bufio.NewWriter(buf)
+       ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
+       msg := make([]byte, 5)
+       n, err := ws.Read(msg)
+       if err != nil {
+               t.Errorf("Read: %v", err)
+       }
+       if !bytes.Equal(b[4:8], msg[0:n]) {
+               t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
+       }
+}