// String returns the network address for a Web Socket.
func (addr WebSocketAddr) String() string { return string(addr) }
+const (
+ stateFrameByte = iota
+ stateFrameLength
+ stateFrameData
+ stateFrameTextData
+)
+
// Conn is a channel to communicate to a Web Socket.
// It implements the net.Conn interface.
type Conn struct {
buf *bufio.ReadWriter
rwc io.ReadWriteCloser
+
+ // It holds text data in previous Read() that failed with small buffer.
+ data []byte
+ reading bool
}
// newConn creates a new Web Socket.
bw := bufio.NewWriter(rwc)
buf = bufio.NewReadWriter(br, bw)
}
- ws := &Conn{origin, location, protocol, buf, rwc}
+ ws := &Conn{Origin: origin, Location: location, Protocol: protocol, buf: buf, rwc: rwc}
return ws
}
// Read implements the io.Reader interface for a Conn.
func (ws *Conn) Read(msg []byte) (n int, err os.Error) {
- for {
- frameByte, err := ws.buf.ReadByte()
+Frame:
+ for !ws.reading && len(ws.data) == 0 {
+ // Beginning of frame, possibly.
+ b, err := ws.buf.ReadByte()
if err != nil {
- return n, err
+ return 0, err
}
- if (frameByte & 0x80) == 0x80 {
+ if b&0x80 == 0x80 {
+ // Skip length frame.
length := 0
for {
c, err := ws.buf.ReadByte()
if err != nil {
- return n, err
+ return 0, err
}
length = length*128 + int(c&0x7f)
- if (c & 0x80) == 0 {
+ if c&0x80 == 0 {
break
}
}
for length > 0 {
_, err := ws.buf.ReadByte()
if err != nil {
- return n, err
+ return 0, err
}
- length--
}
- } else {
+ continue Frame
+ }
+ // In text mode
+ if b != 0 {
+ // Skip this frame
for {
c, err := ws.buf.ReadByte()
if err != nil {
- return n, err
+ return 0, err
}
if c == '\xff' {
- return n, err
- }
- if frameByte == 0 {
- if n+1 <= cap(msg) {
- msg = msg[0 : n+1]
- }
- msg[n] = c
- n++
- }
- if n >= cap(msg) {
- return n, os.E2BIG
+ break
}
}
+ continue Frame
}
+ ws.reading = true
}
-
- panic("unreachable")
+ if len(ws.data) == 0 {
+ ws.data, err = ws.buf.ReadSlice('\xff')
+ if err == nil {
+ ws.reading = false
+ ws.data = ws.data[:len(ws.data)-1] // trim \xff
+ }
+ }
+ n = copy(msg, ws.data)
+ ws.data = ws.data[n:]
+ return n, err
}
// Write implements the io.Writer interface for a Conn.
package websocket
import (
+ "bufio"
"bytes"
"fmt"
"http"
}
}
}
+
+func TestSmallBuffer(t *testing.T) {
+ // http://code.google.com/p/go/issues/detail?id=1145
+ // Read should be able to handle reading a fragment of a frame.
+ once.Do(startServer)
+
+ // websocket.Dial()
+ client, err := net.Dial("tcp", "", serverAddr)
+ if err != nil {
+ t.Fatal("dialing", err)
+ }
+ ws, err := newClient("/echo", "localhost", "http://localhost",
+ "ws://localhost/echo", "", client, handshake)
+ if err != nil {
+ t.Errorf("WebSocket handshake error: %v", err)
+ return
+ }
+
+ msg := []byte("hello, world\n")
+ if _, err := ws.Write(msg); err != nil {
+ t.Errorf("Write: %v", err)
+ }
+ var small_msg = make([]byte, 8)
+ n, err := ws.Read(small_msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if !bytes.Equal(msg[:len(small_msg)], small_msg) {
+ t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
+ }
+ var second_msg = make([]byte, len(msg))
+ n, err = ws.Read(second_msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ second_msg = second_msg[0:n]
+ if !bytes.Equal(msg[len(small_msg):], second_msg) {
+ t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
+ }
+ ws.Close()
+
+}
+
+func testSkipLengthFrame(t *testing.T) {
+ b := []byte{'\x80', '\x01', 'x', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
+ buf := bytes.NewBuffer(b)
+ br := bufio.NewReader(buf)
+ bw := bufio.NewWriter(buf)
+ ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
+ msg := make([]byte, 5)
+ n, err := ws.Read(msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if !bytes.Equal(b[4:8], msg[0:n]) {
+ t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
+ }
+}
+
+func testSkipNoUTF8Frame(t *testing.T) {
+ b := []byte{'\x01', 'n', '\xff', 0, 'h', 'e', 'l', 'l', 'o', '\xff'}
+ buf := bytes.NewBuffer(b)
+ br := bufio.NewReader(buf)
+ bw := bufio.NewWriter(buf)
+ ws := newConn("http://127.0.0.1/", "ws://127.0.0.1/", "", bufio.NewReadWriter(br, bw), nil)
+ msg := make([]byte, 5)
+ n, err := ws.Read(msg)
+ if err != nil {
+ t.Errorf("Read: %v", err)
+ }
+ if !bytes.Equal(b[4:8], msg[0:n]) {
+ t.Errorf("Read: expected %q got %q", msg[4:8], msg[0:n])
+ }
+}