]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: add experimental ssh client
authorDave Cheney <dave@cheney.net>
Thu, 20 Oct 2011 19:44:45 +0000 (15:44 -0400)
committerAdam Langley <agl@golang.org>
Thu, 20 Oct 2011 19:44:45 +0000 (15:44 -0400)
Requires CL 5285044

client.go:
* add Dial, ClientConn, ClientChan, ClientConfig and Cmd.

doc.go:
* add Client documentation.

server.go:
* adjust for readVersion change.

transport.go:
* return an os.Error not a bool from readVersion.

R=rsc, agl, n13m3y3r
CC=golang-dev
https://golang.org/cl/5162047

src/pkg/exp/ssh/Makefile
src/pkg/exp/ssh/client.go [new file with mode: 0644]
src/pkg/exp/ssh/doc.go
src/pkg/exp/ssh/server.go
src/pkg/exp/ssh/transport.go
src/pkg/exp/ssh/transport_test.go

index 1a100e9b69bb830f701be0d8ef662c287010894f..1084d029db403fa671481a284ee2d09a6052f353 100644 (file)
@@ -7,6 +7,7 @@ include ../../../Make.inc
 TARG=exp/ssh
 GOFILES=\
        channel.go\
+       client.go\
        common.go\
        messages.go\
        transport.go\
diff --git a/src/pkg/exp/ssh/client.go b/src/pkg/exp/ssh/client.go
new file mode 100644 (file)
index 0000000..edb95ec
--- /dev/null
@@ -0,0 +1,628 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "big"
+       "crypto"
+       "crypto/rand"
+       "encoding/binary"
+       "fmt"
+       "io"
+       "os"
+       "net"
+       "sync"
+)
+
+// clientVersion is the fixed identification string that the client will use.
+var clientVersion = []byte("SSH-2.0-Go\r\n")
+
+// ClientConn represents the client side of an SSH connection.
+type ClientConn struct {
+       *transport
+       config *ClientConfig
+       chanlist
+}
+
+// Client returns a new SSH client connection using c as the underlying transport.
+func Client(c net.Conn, config *ClientConfig) (*ClientConn, os.Error) {
+       conn := &ClientConn{
+               transport: newTransport(c, config.rand()),
+               config:    config,
+               chanlist: chanlist{
+                       Mutex: new(sync.Mutex),
+                       chans: make(map[uint32]*ClientChan),
+               },
+       }
+       if err := conn.handshake(); err != nil {
+               conn.Close()
+               return nil, err
+       }
+       if err := conn.authenticate(); err != nil {
+               conn.Close()
+               return nil, err
+       }
+       go conn.mainLoop()
+       return conn, nil
+}
+
+// handshake performs the client side key exchange. See RFC 4253 Section 7.
+func (c *ClientConn) handshake() os.Error {
+       var magics handshakeMagics
+
+       if _, err := c.Write(clientVersion); err != nil {
+               return err
+       }
+       if err := c.Flush(); err != nil {
+               return err
+       }
+       magics.clientVersion = clientVersion[:len(clientVersion)-2]
+
+       // read remote server version
+       version, err := readVersion(c)
+       if err != nil {
+               return err
+       }
+       magics.serverVersion = version
+       clientKexInit := kexInitMsg{
+               KexAlgos:                supportedKexAlgos,
+               ServerHostKeyAlgos:      supportedHostKeyAlgos,
+               CiphersClientServer:     supportedCiphers,
+               CiphersServerClient:     supportedCiphers,
+               MACsClientServer:        supportedMACs,
+               MACsServerClient:        supportedMACs,
+               CompressionClientServer: supportedCompressions,
+               CompressionServerClient: supportedCompressions,
+       }
+       kexInitPacket := marshal(msgKexInit, clientKexInit)
+       magics.clientKexInit = kexInitPacket
+
+       if err := c.writePacket(kexInitPacket); err != nil {
+               return err
+       }
+       packet, err := c.readPacket()
+       if err != nil {
+               return err
+       }
+
+       magics.serverKexInit = packet
+
+       var serverKexInit kexInitMsg
+       if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil {
+               return err
+       }
+
+       kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
+       if !ok {
+               return os.NewError("ssh: no common algorithms")
+       }
+
+       if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
+               // The server sent a Kex message for the wrong algorithm,
+               // which we have to ignore.
+               if _, err := c.readPacket(); err != nil {
+                       return err
+               }
+       }
+
+       var H, K []byte
+       var hashFunc crypto.Hash
+       switch kexAlgo {
+       case kexAlgoDH14SHA1:
+               hashFunc = crypto.SHA1
+               dhGroup14Once.Do(initDHGroup14)
+               H, K, err = c.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
+       default:
+               fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
+       }
+       if err != nil {
+               return err
+       }
+
+       if err = c.writePacket([]byte{msgNewKeys}); err != nil {
+               return err
+       }
+       if err = c.transport.writer.setupKeys(clientKeys, K, H, H, hashFunc); err != nil {
+               return err
+       }
+       if packet, err = c.readPacket(); err != nil {
+               return err
+       }
+       if packet[0] != msgNewKeys {
+               return UnexpectedMessageError{msgNewKeys, packet[0]}
+       }
+       return c.transport.reader.setupKeys(serverKeys, K, H, H, hashFunc)
+}
+
+// authenticate authenticates with the remote server. See RFC 4252. 
+// Only "password" authentication is supported.
+func (c *ClientConn) authenticate() os.Error {
+       if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
+               return err
+       }
+       packet, err := c.readPacket()
+       if err != nil {
+               return err
+       }
+
+       var serviceAccept serviceAcceptMsg
+       if err = unmarshal(&serviceAccept, packet, msgServiceAccept); err != nil {
+               return err
+       }
+
+       // TODO(dfc) support proper authentication method negotation
+       method := "none"
+       if c.config.Password != "" {
+               method = "password"
+       }
+       if err := c.sendUserAuthReq(method); err != nil {
+               return err
+       }
+
+       if packet, err = c.readPacket(); err != nil {
+               return err
+       }
+
+       if packet[0] != msgUserAuthSuccess {
+               return UnexpectedMessageError{msgUserAuthSuccess, packet[0]}
+       }
+       return nil
+}
+
+func (c *ClientConn) sendUserAuthReq(method string) os.Error {
+       length := stringLength([]byte(c.config.Password)) + 1
+       payload := make([]byte, length)
+       // always false for password auth, see RFC 4252 Section 8.
+       payload[0] = 0
+       marshalString(payload[1:], []byte(c.config.Password))
+
+       return c.writePacket(marshal(msgUserAuthRequest, userAuthRequestMsg{
+               User:    c.config.User,
+               Service: serviceSSH,
+               Method:  method,
+               Payload: payload,
+       }))
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ClientConn. The
+// returned values are given the same names as in RFC 4253, section 8.
+func (c *ClientConn) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) ([]byte, []byte, os.Error) {
+       x, err := rand.Int(c.config.rand(), group.p)
+       if err != nil {
+               return nil, nil, err
+       }
+       X := new(big.Int).Exp(group.g, x, group.p)
+       kexDHInit := kexDHInitMsg{
+               X: X,
+       }
+       if err := c.writePacket(marshal(msgKexDHInit, kexDHInit)); err != nil {
+               return nil, nil, err
+       }
+
+       packet, err := c.readPacket()
+       if err != nil {
+               return nil, nil, err
+       }
+
+       var kexDHReply = new(kexDHReplyMsg)
+       if err = unmarshal(kexDHReply, packet, msgKexDHReply); err != nil {
+               return nil, nil, err
+       }
+
+       if kexDHReply.Y.Sign() == 0 || kexDHReply.Y.Cmp(group.p) >= 0 {
+               return nil, nil, os.NewError("server DH parameter out of bounds")
+       }
+
+       kInt := new(big.Int).Exp(kexDHReply.Y, x, group.p)
+       h := hashFunc.New()
+       writeString(h, magics.clientVersion)
+       writeString(h, magics.serverVersion)
+       writeString(h, magics.clientKexInit)
+       writeString(h, magics.serverKexInit)
+       writeString(h, kexDHReply.HostKey)
+       writeInt(h, X)
+       writeInt(h, kexDHReply.Y)
+       K := make([]byte, intLength(kInt))
+       marshalInt(K, kInt)
+       h.Write(K)
+
+       H := h.Sum()
+
+       return H, K, nil
+}
+
+// OpenChan opens a new client channel. The most common session type is "session". 
+// The full set of valid session types are listed in RFC 4250 4.9.1.
+func (c *ClientConn) OpenChan(typ string) (*ClientChan, os.Error) {
+       ch, id := c.newChan(c.transport)
+       if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
+               ChanType:      typ,
+               PeersId:       id,
+               PeersWindow:   8192,
+               MaxPacketSize: 16384,
+       })); err != nil {
+               // remove channel reference
+               c.chanlist.remove(id)
+               return nil, err
+       }
+       // wait for response
+       switch msg := (<-ch.msg).(type) {
+       case *channelOpenConfirmMsg:
+               ch.peersId = msg.MyId
+       case *channelOpenFailureMsg:
+               c.chanlist.remove(id)
+               return nil, os.NewError(msg.Message)
+       default:
+               c.chanlist.remove(id)
+               return nil, os.NewError("Unexpected packet")
+       }
+       return ch, nil
+}
+
+// mainloop reads incoming messages and routes channel messages
+// to their respective ClientChans.
+func (c *ClientConn) mainLoop() {
+       for {
+               packet, err := c.readPacket()
+               if err != nil {
+                       // TODO(dfc) signal the underlying close to all channels
+                       c.Close()
+                       return
+               }
+               switch msg := decode(packet).(type) {
+               case *channelOpenMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelOpenConfirmMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelOpenFailureMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelCloseMsg:
+                       ch := c.getChan(msg.PeersId)
+                       close(ch.stdinWriter.win)
+                       close(ch.stdoutReader.data)
+                       close(ch.stderrReader.dataExt)
+                       c.chanlist.remove(msg.PeersId)
+               case *channelEOFMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelRequestSuccessMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelRequestFailureMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *channelRequestMsg:
+                       c.getChan(msg.PeersId).msg <- msg
+               case *windowAdjustMsg:
+                       c.getChan(msg.PeersId).stdinWriter.win <- int(msg.AdditionalBytes)
+               case *channelData:
+                       c.getChan(msg.PeersId).stdoutReader.data <- msg.Payload
+               case *channelExtendedData:
+                       // TODO(dfc) should this send be non blocking. RFC 4254 5.2 suggests
+                       // ext data consumes window size, does that need to be handled as well ?
+                       c.getChan(msg.PeersId).stderrReader.dataExt <- msg.Data
+               default:
+                       fmt.Printf("mainLoop: unhandled %#v\n", msg)
+               }
+       }
+}
+
+// Dial connects to the given network address using net.Dial and 
+// then initiates a SSH handshake, returning the resulting client connection.
+func Dial(network, addr string, config *ClientConfig) (*ClientConn, os.Error) {
+       conn, err := net.Dial(network, addr)
+       if err != nil {
+               return nil, err
+       }
+       return Client(conn, config)
+}
+
+// A ClientConfig structure is used to configure a ClientConn. After one has 
+// been passed to an SSH function it must not be modified.
+type ClientConfig struct {
+       // Rand provides the source of entropy for key exchange. If Rand is 
+       // nil, the cryptographic random reader in package crypto/rand will 
+       // be used.
+       Rand io.Reader
+
+       // The username to authenticate.
+       User string
+
+       // Used for "password" method authentication.
+       Password string
+}
+
+func (c *ClientConfig) rand() io.Reader {
+       if c.Rand == nil {
+               return rand.Reader
+       }
+       return c.Rand
+}
+
+// A ClientChan represents a single RFC 4254 channel that is multiplexed 
+// over a single SSH connection.
+type ClientChan struct {
+       packetWriter
+       *stdinWriter  // used by Exec and Shell
+       *stdoutReader // used by Exec and Shell
+       *stderrReader // used by Exec and Shell
+       id, peersId   uint32
+       msg           chan interface{} // incoming messages 
+}
+
+func newClientChan(t *transport, id uint32) *ClientChan {
+       // TODO(DFC) allocating stdin/out/err on ClientChan creation is
+       // wasteful, but ClientConn.mainLoop() needs a way of finding 
+       // those channels before Exec/Shell is called because the remote 
+       // may send window adjustments at any time.
+       return &ClientChan{
+               packetWriter: t,
+               stdinWriter: &stdinWriter{
+                       packetWriter: t,
+                       id:           id,
+                       win:          make(chan int, 16),
+               },
+               stdoutReader: &stdoutReader{
+                       packetWriter: t,
+                       id:           id,
+                       win:          8192,
+                       data:         make(chan []byte, 16),
+               },
+               stderrReader: &stderrReader{
+                       dataExt: make(chan string, 16),
+               },
+               id:  id,
+               msg: make(chan interface{}, 16),
+       }
+}
+
+// Close closes the channel. This does not close the underlying connection.
+func (c *ClientChan) Close() os.Error {
+       return c.writePacket(marshal(msgChannelClose, channelCloseMsg{
+               PeersId: c.id,
+       }))
+}
+
+// Setenv sets an environment variable that will be applied to any
+// command executed by Shell or Exec.
+func (c *ClientChan) Setenv(name, value string) os.Error {
+       namLen := stringLength([]byte(name))
+       valLen := stringLength([]byte(value))
+       payload := make([]byte, namLen+valLen)
+       marshalString(payload[:namLen], []byte(name))
+       marshalString(payload[namLen:], []byte(value))
+
+       return c.sendChanReq(channelRequestMsg{
+               PeersId:             c.id,
+               Request:             "env",
+               WantReply:           true,
+               RequestSpecificData: payload,
+       })
+}
+
+func (c *ClientChan) sendChanReq(req channelRequestMsg) os.Error {
+       if err := c.writePacket(marshal(msgChannelRequest, req)); err != nil {
+               return err
+       }
+       for {
+               switch msg := (<-c.msg).(type) {
+               case *channelRequestSuccessMsg:
+                       return nil
+               case *channelRequestFailureMsg:
+                       return os.NewError(req.Request)
+               default:
+                       return fmt.Errorf("%#v", msg)
+               }
+       }
+       panic("unreachable")
+}
+
+// An empty mode list (a string of 1 character, opcode 0), see RFC 4254 Section 8.
+var emptyModeList = []byte{0, 0, 0, 1, 0}
+
+// RequstPty requests a pty to be allocated on the remote side of this channel.
+func (c *ClientChan) RequestPty(term string, h, w int) os.Error {
+       buf := make([]byte, 4+len(term)+16+len(emptyModeList))
+       b := marshalString(buf, []byte(term))
+       binary.BigEndian.PutUint32(b, uint32(h))
+       binary.BigEndian.PutUint32(b[4:], uint32(w))
+       binary.BigEndian.PutUint32(b[8:], uint32(h*8))
+       binary.BigEndian.PutUint32(b[12:], uint32(w*8))
+       copy(b[16:], emptyModeList)
+
+       return c.sendChanReq(channelRequestMsg{
+               PeersId:             c.id,
+               Request:             "pty-req",
+               WantReply:           true,
+               RequestSpecificData: buf,
+       })
+}
+
+// Exec runs cmd on the remote host.
+// Typically, the remote server passes cmd to the shell for interpretation.
+func (c *ClientChan) Exec(cmd string) (*Cmd, os.Error) {
+       cmdLen := stringLength([]byte(cmd))
+       payload := make([]byte, cmdLen)
+       marshalString(payload, []byte(cmd))
+       err := c.sendChanReq(channelRequestMsg{
+               PeersId:             c.id,
+               Request:             "exec",
+               WantReply:           true,
+               RequestSpecificData: payload,
+       })
+       return &Cmd{
+               c.stdinWriter,
+               c.stdoutReader,
+               c.stderrReader,
+       }, err
+}
+
+// Shell starts a login shell on the remote host.
+func (c *ClientChan) Shell() (*Cmd, os.Error) {
+       err := c.sendChanReq(channelRequestMsg{
+               PeersId:   c.id,
+               Request:   "shell",
+               WantReply: true,
+       })
+       return &Cmd{
+               c.stdinWriter,
+               c.stdoutReader,
+               c.stderrReader,
+       }, err
+
+}
+
+// Thread safe channel list.
+type chanlist struct {
+       *sync.Mutex
+       // TODO(dfc) should could be converted to a slice
+       chans map[uint32]*ClientChan
+}
+
+// Allocate a new ClientChan with the next avail local id.
+func (c *chanlist) newChan(t *transport) (*ClientChan, uint32) {
+       c.Lock()
+       defer c.Unlock()
+
+       for i := uint32(0); i < 1<<31; i++ {
+               if _, ok := c.chans[i]; !ok {
+                       ch := newClientChan(t, i)
+                       c.chans[i] = ch
+                       return ch, uint32(i)
+               }
+       }
+       panic("unable to find free channel")
+}
+
+func (c *chanlist) getChan(id uint32) *ClientChan {
+       c.Lock()
+       defer c.Unlock()
+       return c.chans[id]
+}
+
+func (c *chanlist) remove(id uint32) {
+       c.Lock()
+       defer c.Unlock()
+       delete(c.chans, id)
+}
+
+// A Cmd represents a connection to a remote command or shell
+// Closing Cmd.Stdin will be observed by the remote process.
+type Cmd struct {
+       // Writes to Stdin are made available to the command's standard input.
+       // Closing Stdin causes the command to observe an EOF on its standard input.
+       Stdin io.WriteCloser
+
+       // Reads from Stdout consume the command's standard output.
+       // There is a fixed amount of buffering of the command's standard output.
+       // Failing to read from Stdout will eventually cause the command to block
+       // when writing to its standard output.  Closing Stdout unblocks any
+       // such writes and makes them return errors.
+       Stdout io.ReadCloser
+
+       // Reads from Stderr consume the command's standard error.
+       // The SSH protocol assumes it can always send standard error;
+       // the command will never block writing to its standard error.
+       // However, failure to read from Stderr will eventually cause the
+       // SSH protocol to jam, so it is important to arrange for reading
+       // from Stderr, even if by
+       //    go io.Copy(ioutil.Discard, cmd.Stderr)
+       Stderr io.Reader
+}
+
+// A stdinWriter represents the stdin of a remote process.
+type stdinWriter struct {
+       win          chan int // receives window adjustments
+       id           uint32
+       rwin         int // current rwin size
+       packetWriter     // for sending channelDataMsg
+}
+
+// Write writes data to the remote process's standard input.
+func (w *stdinWriter) Write(data []byte) (n int, err os.Error) {
+       for {
+               if w.rwin == 0 {
+                       win, ok := <-w.win
+                       if !ok {
+                               return 0, os.EOF
+                       }
+                       w.rwin += win
+                       continue
+               }
+               n = len(data)
+               packet := make([]byte, 0, 9+n)
+               packet = append(packet, msgChannelData,
+                       byte(w.id)>>24, byte(w.id)>>16, byte(w.id)>>8, byte(w.id),
+                       byte(n)>>24, byte(n)>>16, byte(n)>>8, byte(n))
+               err = w.writePacket(append(packet, data...))
+               w.rwin -= n
+               return
+       }
+       panic("unreachable")
+}
+
+func (w *stdinWriter) Close() os.Error {
+       return w.writePacket(marshal(msgChannelEOF, channelEOFMsg{w.id}))
+}
+
+// A stdoutReader represents the stdout of a remote process.
+type stdoutReader struct {
+       // TODO(dfc) a fixed size channel may not be the right data structure.
+       // If writes to this channel block, they will block mainLoop, making
+       // it unable to receive new messages from the remote side.
+       data         chan []byte // receives data from remote
+       id           uint32
+       win          int // current win size
+       packetWriter     // for sending windowAdjustMsg
+       buf          []byte
+}
+
+// Read reads data from the remote process's standard output.
+func (r *stdoutReader) Read(data []byte) (int, os.Error) {
+       var ok bool
+       for {
+               if len(r.buf) > 0 {
+                       n := copy(data, r.buf)
+                       r.buf = r.buf[n:]
+                       r.win += n
+                       msg := windowAdjustMsg{
+                               PeersId:         r.id,
+                               AdditionalBytes: uint32(n),
+                       }
+                       err := r.writePacket(marshal(msgChannelWindowAdjust, msg))
+                       return n, err
+               }
+               r.buf, ok = <-r.data
+               if !ok {
+                       return 0, os.EOF
+               }
+               r.win -= len(r.buf)
+       }
+       panic("unreachable")
+}
+
+func (r *stdoutReader) Close() os.Error {
+       return r.writePacket(marshal(msgChannelEOF, channelEOFMsg{r.id}))
+}
+
+// A stderrReader represents the stderr of a remote process.
+type stderrReader struct {
+       dataExt chan string // receives dataExt from remote
+       buf     []byte      // buffer current dataExt
+}
+
+// Read reads a line of data from the remote process's stderr.
+func (r *stderrReader) Read(data []byte) (int, os.Error) {
+       for {
+               if len(r.buf) > 0 {
+                       n := copy(data, r.buf)
+                       r.buf = r.buf[n:]
+                       return n, nil
+               }
+               buf, ok := <-r.dataExt
+               if !ok {
+                       return 0, os.EOF
+               }
+               r.buf = []byte(buf)
+       }
+       panic("unreachable")
+}
index 54a7ba9fdae3d7f97470fce250e5060fd39327bc..a2ec3faca74014a56b054ebde8c9ff936cc27ec1 100644 (file)
@@ -3,7 +3,7 @@
 // license that can be found in the LICENSE file.
 
 /*
-Package ssh implements an SSH server.
+Package ssh implements an SSH client and server.
 
 SSH is a transport security protocol, an authentication protocol and a
 family of application protocols. The most typical application level
@@ -75,5 +75,27 @@ present a simple terminal interface.
                }
                return
        }()
+
+An SSH client is represented with a ClientConn. Currently only the "password"
+authentication method is supported. 
+
+       config := &ClientConfig{
+               User: "username",
+               Password: "123456",
+       }
+       client, err := Dial("yourserver.com:22", config)
+
+Each ClientConn can support multiple channels, represented by ClientChan. Each
+channel should be of a type specified in rfc4250, 4.9.1.
+
+       ch, err := client.OpenChan("session")
+
+Once the ClientChan is opened, you can execute a single command on the remote side 
+using the Exec method.
+
+       cmd, err := ch.Exec("/usr/bin/whoami")
+       reader := bufio.NewReader(cmd.Stdin)
+       line, _, _ := reader.ReadLine()
+       fmt.Println(line)
 */
 package ssh
index 410cafc44c29ef95a3f7825e98aacb35f5b7d613..b5a5e017d39d6f15f7c29193d8b437df3aa3d1a3 100644 (file)
@@ -267,9 +267,9 @@ func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
        }
        magics.serverVersion = serverVersion[:len(serverVersion)-2]
 
-       version, ok := readVersion(s.transport)
-       if !ok {
-               return os.NewError("failed to read version string from client")
+       version, err := readVersion(s.transport)
+       if err != nil {
+               return err
        }
        magics.clientVersion = version
 
index 5994004d86656ba854a8e8c010bd36575751a983..97eaf975d10f306bd6ca26f2eade366edcd1ea3a 100644 (file)
@@ -332,16 +332,15 @@ func (t truncatingMAC) Size() int {
 const maxVersionStringBytes = 1024
 
 // Read version string as specified by RFC 4253, section 4.2.
-func readVersion(r io.Reader) (versionString []byte, ok bool) {
-       versionString = make([]byte, 0, 64)
-       seenCR := false
-
+func readVersion(r io.Reader) ([]byte, os.Error) {
+       versionString := make([]byte, 0, 64)
+       var ok, seenCR bool
        var buf [1]byte
 forEachByte:
        for len(versionString) < maxVersionStringBytes {
                _, err := io.ReadFull(r, buf[:])
                if err != nil {
-                       return
+                       return nil, err
                }
                b := buf[0]
 
@@ -360,10 +359,10 @@ forEachByte:
                versionString = append(versionString, b)
        }
 
-       if ok {
-               // We need to remove the CR from versionString
-               versionString = versionString[:len(versionString)-1]
+       if !ok {
+               return nil, os.NewError("failed to read version string")
        }
 
-       return
+       // We need to remove the CR from versionString
+       return versionString[:len(versionString)-1], nil
 }
index 9a610a7803c30b834718967b42a6b08c6dfce326..b2e2a7fc92aa6587a978fe637c1566492cbb788b 100644 (file)
@@ -12,9 +12,9 @@ import (
 
 func TestReadVersion(t *testing.T) {
        buf := []byte(serverVersion)
-       result, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
-       if !ok {
-               t.Error("readVersion didn't read version correctly")
+       result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
+       if err != nil {
+               t.Errorf("readVersion didn't read version correctly: %s", err)
        }
        if !bytes.Equal(buf[:len(buf)-2], result) {
                t.Error("version read did not match expected")
@@ -23,7 +23,7 @@ func TestReadVersion(t *testing.T) {
 
 func TestReadVersionTooLong(t *testing.T) {
        buf := make([]byte, maxVersionStringBytes+1)
-       if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok {
+       if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
                t.Errorf("readVersion consumed %d bytes without error", len(buf))
        }
 }
@@ -31,7 +31,7 @@ func TestReadVersionTooLong(t *testing.T) {
 func TestReadVersionWithoutCRLF(t *testing.T) {
        buf := []byte(serverVersion)
        buf = buf[:len(buf)-1]
-       if _, ok := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); ok {
+       if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
                t.Error("readVersion did not notice \\n was missing")
        }
 }