]> Cypherpunks repositories - gostls13.git/commitdiff
exp/ssh: new package.
authorAdam Langley <agl@golang.org>
Sat, 17 Sep 2011 19:57:24 +0000 (15:57 -0400)
committerAdam Langley <agl@golang.org>
Sat, 17 Sep 2011 19:57:24 +0000 (15:57 -0400)
The typical UNIX method for controlling long running process is to
send the process signals. Since this doesn't get you very far, various
ad-hoc, remote-control protocols have been used over time by programs
like Apache and BIND.

Implementing an SSH server means that Go code will have a standard,
secure way to do this in the future.

R=bradfitz, borman, dave, gustavo, dsymonds, r, adg, rsc, rogpeppe, lvd, kevlar, raul.san
CC=golang-dev
https://golang.org/cl/4962064

src/pkg/exp/ssh/Makefile [new file with mode: 0644]
src/pkg/exp/ssh/channel.go [new file with mode: 0644]
src/pkg/exp/ssh/common.go [new file with mode: 0644]
src/pkg/exp/ssh/doc.go [new file with mode: 0644]
src/pkg/exp/ssh/messages.go [new file with mode: 0644]
src/pkg/exp/ssh/messages_test.go [new file with mode: 0644]
src/pkg/exp/ssh/server.go [new file with mode: 0644]
src/pkg/exp/ssh/server_shell.go [new file with mode: 0644]
src/pkg/exp/ssh/server_shell_test.go [new file with mode: 0644]
src/pkg/exp/ssh/transport.go [new file with mode: 0644]

diff --git a/src/pkg/exp/ssh/Makefile b/src/pkg/exp/ssh/Makefile
new file mode 100644 (file)
index 0000000..e8f33b7
--- /dev/null
@@ -0,0 +1,16 @@
+# Copyright 2011 The Go Authors. All rights reserved.
+# Use of this source code is governed by a BSD-style
+# license that can be found in the LICENSE file.
+
+include ../../../Make.inc
+
+TARG=exp/ssh
+GOFILES=\
+       common.go\
+       messages.go\
+       server.go\
+       transport.go\
+        channel.go\
+        server_shell.go\
+
+include ../../../Make.pkg
diff --git a/src/pkg/exp/ssh/channel.go b/src/pkg/exp/ssh/channel.go
new file mode 100644 (file)
index 0000000..10f6235
--- /dev/null
@@ -0,0 +1,317 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "os"
+       "sync"
+)
+
+// A Channel is an ordered, reliable, duplex stream that is multiplexed over an
+// SSH connection.
+type Channel interface {
+       // Accept accepts the channel creation request.
+       Accept() os.Error
+       // Reject rejects the channel creation request. After calling this, no
+       // other methods on the Channel may be called. If they are then the
+       // peer is likely to signal a protocol error and drop the connection.
+       Reject(reason RejectionReason, message string) os.Error
+
+       // Read may return a ChannelRequest as an os.Error.
+       Read(data []byte) (int, os.Error)
+       Write(data []byte) (int, os.Error)
+       Close() os.Error
+
+       // AckRequest either sends an ack or nack to the channel request.
+       AckRequest(ok bool) os.Error
+
+       // ChannelType returns the type of the channel, as supplied by the
+       // client.
+       ChannelType() string
+       // ExtraData returns the arbitary payload for this channel, as supplied
+       // by the client. This data is specific to the channel type.
+       ExtraData() []byte
+}
+
+// ChannelRequest represents a request sent on a channel, outside of the normal
+// stream of bytes. It may result from calling Read on a Channel.
+type ChannelRequest struct {
+       Request   string
+       WantReply bool
+       Payload   []byte
+}
+
+func (c ChannelRequest) String() string {
+       return "channel request received"
+}
+
+// RejectionReason is an enumeration used when rejecting channel creation
+// requests. See RFC 4254, section 5.1.
+type RejectionReason int
+
+const (
+       Prohibited RejectionReason = iota + 1
+       ConnectionFailed
+       UnknownChannelType
+       ResourceShortage
+)
+
+type channel struct {
+       // immutable once created
+       chanType  string
+       extraData []byte
+
+       theyClosed  bool
+       theySentEOF bool
+       weClosed    bool
+       dead        bool
+
+       serverConn            *ServerConnection
+       myId, theirId         uint32
+       myWindow, theirWindow uint32
+       maxPacketSize         uint32
+       err                   os.Error
+
+       pendingRequests []ChannelRequest
+       pendingData     []byte
+       head, length    int
+
+       // This lock is inferior to serverConn.lock
+       lock sync.Mutex
+       cond *sync.Cond
+}
+
+func (c *channel) Accept() os.Error {
+       c.serverConn.lock.Lock()
+       defer c.serverConn.lock.Unlock()
+
+       if c.serverConn.err != nil {
+               return c.serverConn.err
+       }
+
+       confirm := channelOpenConfirmMsg{
+               PeersId:       c.theirId,
+               MyId:          c.myId,
+               MyWindow:      c.myWindow,
+               MaxPacketSize: c.maxPacketSize,
+       }
+       return c.serverConn.out.writePacket(marshal(msgChannelOpenConfirm, confirm))
+}
+
+func (c *channel) Reject(reason RejectionReason, message string) os.Error {
+       c.serverConn.lock.Lock()
+       defer c.serverConn.lock.Unlock()
+
+       if c.serverConn.err != nil {
+               return c.serverConn.err
+       }
+
+       reject := channelOpenFailureMsg{
+               PeersId:  c.theirId,
+               Reason:   uint32(reason),
+               Message:  message,
+               Language: "en",
+       }
+       return c.serverConn.out.writePacket(marshal(msgChannelOpenFailure, reject))
+}
+
+func (c *channel) handlePacket(packet interface{}) {
+       c.lock.Lock()
+       defer c.lock.Unlock()
+
+       switch packet := packet.(type) {
+       case *channelRequestMsg:
+               req := ChannelRequest{
+                       Request:   packet.Request,
+                       WantReply: packet.WantReply,
+                       Payload:   packet.RequestSpecificData,
+               }
+
+               c.pendingRequests = append(c.pendingRequests, req)
+               c.cond.Signal()
+       case *channelCloseMsg:
+               c.theyClosed = true
+               c.cond.Signal()
+       case *channelEOFMsg:
+               c.theySentEOF = true
+               c.cond.Signal()
+       default:
+               panic("unknown packet type")
+       }
+}
+
+func (c *channel) handleData(data []byte) {
+       c.lock.Lock()
+       defer c.lock.Unlock()
+
+       // The other side should never send us more than our window.
+       if len(data)+c.length > len(c.pendingData) {
+               // TODO(agl): we should tear down the channel with a protocol
+               // error.
+               return
+       }
+
+       c.myWindow -= uint32(len(data))
+       for i := 0; i < 2; i++ {
+               tail := c.head + c.length
+               if tail > len(c.pendingData) {
+                       tail -= len(c.pendingData)
+               }
+               n := copy(c.pendingData[tail:], data)
+               data = data[n:]
+               c.length += n
+       }
+
+       c.cond.Signal()
+}
+
+func (c *channel) Read(data []byte) (n int, err os.Error) {
+       c.lock.Lock()
+       defer c.lock.Unlock()
+
+       if c.err != nil {
+               return 0, c.err
+       }
+
+       if c.myWindow <= uint32(len(c.pendingData))/2 {
+               packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{
+                       PeersId:         c.theirId,
+                       AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow,
+               })
+               if err := c.serverConn.out.writePacket(packet); err != nil {
+                       return 0, err
+               }
+       }
+
+       for {
+               if c.theySentEOF || c.theyClosed || c.dead {
+                       return 0, os.EOF
+               }
+
+               if len(c.pendingRequests) > 0 {
+                       req := c.pendingRequests[0]
+                       if len(c.pendingRequests) == 1 {
+                               c.pendingRequests = nil
+                       } else {
+                               oldPendingRequests := c.pendingRequests
+                               c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1)
+                               copy(c.pendingRequests, oldPendingRequests[1:])
+                       }
+
+                       return 0, req
+               }
+
+               if c.length > 0 {
+                       tail := c.head + c.length
+                       if tail > len(c.pendingData) {
+                               tail -= len(c.pendingData)
+                       }
+                       n = copy(data, c.pendingData[c.head:tail])
+                       c.head += n
+                       c.length -= n
+                       if c.head == len(c.pendingData) {
+                               c.head = 0
+                       }
+                       return
+               }
+
+               c.cond.Wait()
+       }
+
+       panic("unreachable")
+}
+
+func (c *channel) Write(data []byte) (n int, err os.Error) {
+       for len(data) > 0 {
+               c.lock.Lock()
+               if c.dead || c.weClosed {
+                       return 0, os.EOF
+               }
+
+               if c.theirWindow == 0 {
+                       c.cond.Wait()
+                       continue
+               }
+               c.lock.Unlock()
+
+               todo := data
+               if uint32(len(todo)) > c.theirWindow {
+                       todo = todo[:c.theirWindow]
+               }
+
+               packet := make([]byte, 1+4+4+len(todo))
+               packet[0] = msgChannelData
+               packet[1] = byte(c.theirId) >> 24
+               packet[2] = byte(c.theirId) >> 16
+               packet[3] = byte(c.theirId) >> 8
+               packet[4] = byte(c.theirId)
+               packet[5] = byte(len(todo)) >> 24
+               packet[6] = byte(len(todo)) >> 16
+               packet[7] = byte(len(todo)) >> 8
+               packet[8] = byte(len(todo))
+               copy(packet[9:], todo)
+
+               c.serverConn.lock.Lock()
+               if err = c.serverConn.out.writePacket(packet); err != nil {
+                       c.serverConn.lock.Unlock()
+                       return
+               }
+               c.serverConn.lock.Unlock()
+
+               n += len(todo)
+               data = data[len(todo):]
+       }
+
+       return
+}
+
+func (c *channel) Close() os.Error {
+       c.serverConn.lock.Lock()
+       defer c.serverConn.lock.Unlock()
+
+       if c.serverConn.err != nil {
+               return c.serverConn.err
+       }
+
+       if c.weClosed {
+               return os.NewError("ssh: channel already closed")
+       }
+       c.weClosed = true
+
+       closeMsg := channelCloseMsg{
+               PeersId: c.theirId,
+       }
+       return c.serverConn.out.writePacket(marshal(msgChannelClose, closeMsg))
+}
+
+func (c *channel) AckRequest(ok bool) os.Error {
+       c.serverConn.lock.Lock()
+       defer c.serverConn.lock.Unlock()
+
+       if c.serverConn.err != nil {
+               return c.serverConn.err
+       }
+
+       if ok {
+               ack := channelRequestSuccessMsg{
+                       PeersId: c.theirId,
+               }
+               return c.serverConn.out.writePacket(marshal(msgChannelSuccess, ack))
+       } else {
+               ack := channelRequestFailureMsg{
+                       PeersId: c.theirId,
+               }
+               return c.serverConn.out.writePacket(marshal(msgChannelFailure, ack))
+       }
+       panic("unreachable")
+}
+
+func (c *channel) ChannelType() string {
+       return c.chanType
+}
+
+func (c *channel) ExtraData() []byte {
+       return c.extraData
+}
diff --git a/src/pkg/exp/ssh/common.go b/src/pkg/exp/ssh/common.go
new file mode 100644 (file)
index 0000000..c951d1a
--- /dev/null
@@ -0,0 +1,96 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "strconv"
+)
+
+// These are string constants in the SSH protocol.
+const (
+       kexAlgoDH14SHA1 = "diffie-hellman-group14-sha1"
+       hostAlgoRSA     = "ssh-rsa"
+       cipherAES128CTR = "aes128-ctr"
+       macSHA196       = "hmac-sha1-96"
+       compressionNone = "none"
+       serviceUserAuth = "ssh-userauth"
+       serviceSSH      = "ssh-connection"
+)
+
+// UnexpectedMessageError results when the SSH message that we received didn't
+// match what we wanted.
+type UnexpectedMessageError struct {
+       expected, got uint8
+}
+
+func (u UnexpectedMessageError) String() string {
+       return "ssh: unexpected message type " + strconv.Itoa(int(u.got)) + " (expected " + strconv.Itoa(int(u.expected)) + ")"
+}
+
+// ParseError results from a malformed SSH message.
+type ParseError struct {
+       msgType uint8
+}
+
+func (p ParseError) String() string {
+       return "ssh: parse error in message type " + strconv.Itoa(int(p.msgType))
+}
+
+func findCommonAlgorithm(clientAlgos []string, serverAlgos []string) (commonAlgo string, ok bool) {
+       for _, clientAlgo := range clientAlgos {
+               for _, serverAlgo := range serverAlgos {
+                       if clientAlgo == serverAlgo {
+                               return clientAlgo, true
+                       }
+               }
+       }
+
+       return
+}
+
+func findAgreedAlgorithms(clientToServer, serverToClient *halfConnection, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
+       kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
+       if !ok {
+               return
+       }
+
+       hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
+       if !ok {
+               return
+       }
+
+       clientToServer.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
+       if !ok {
+               return
+       }
+
+       serverToClient.cipherAlgo, ok = findCommonAlgorithm(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
+       if !ok {
+               return
+       }
+
+       clientToServer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
+       if !ok {
+               return
+       }
+
+       serverToClient.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
+       if !ok {
+               return
+       }
+
+       clientToServer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
+       if !ok {
+               return
+       }
+
+       serverToClient.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
+       if !ok {
+               return
+       }
+
+       ok = true
+       return
+}
diff --git a/src/pkg/exp/ssh/doc.go b/src/pkg/exp/ssh/doc.go
new file mode 100644 (file)
index 0000000..8dbdb07
--- /dev/null
@@ -0,0 +1,79 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+/*
+Package ssh implements an SSH server.
+
+SSH is a transport security protocol, an authentication protocol and a
+family of application protocols. The most typical application level
+protocol is a remote shell and this is specifically implemented.  However,
+the multiplexed nature of SSH is exposed to users that wish to support
+others.
+
+An SSH server is represented by a Server, which manages a number of
+ServerConnections and handles authentication.
+
+       var s Server
+       s.PubKeyCallback = pubKeyAuth
+       s.PasswordCallback = passwordAuth
+
+       pemBytes, err := ioutil.ReadFile("id_rsa")
+       if err != nil {
+               panic("Failed to load private key")
+       }
+       err = s.SetRSAPrivateKey(pemBytes)
+       if err != nil {
+               panic("Failed to parse private key")
+       }
+
+Once a Server has been set up, connections can be attached.
+
+       var sConn ServerConnection
+       sConn.Server = &s
+       err = sConn.Handshake(conn)
+       if err != nil {
+               panic("failed to handshake")
+       }
+
+An SSH connection multiplexes several channels, which must be accepted themselves:
+
+
+       for {
+               channel, err := sConn.Accept()
+               if err != nil {
+                       panic("error from Accept")
+               }
+
+               ...
+       }
+
+Accept reads from the connection, demultiplexes packets to their corresponding
+channels and returns when a new channel request is seen. Some goroutine must
+always be calling Accept; otherwise no messages will be forwarded to the
+channels.
+
+Channels have a type, depending on the application level protocol intended. In
+the case of a shell, the type is "session" and ServerShell may be used to
+present a simple terminal interface.
+
+       if channel.ChannelType() != "session" {
+               c.Reject(RejectUnknownChannelType, "unknown channel type")
+               return
+       }
+       channel.Accept()
+
+       shell := NewServerShell(channel, "> ")
+       go func() {
+               defer channel.Close()
+               for {
+                       line, err := shell.ReadLine()
+                       if err != nil {
+                               break
+                       }
+                       println(line)
+               }
+               return
+       }()
+*/
+package ssh
diff --git a/src/pkg/exp/ssh/messages.go b/src/pkg/exp/ssh/messages.go
new file mode 100644 (file)
index 0000000..d375eaf
--- /dev/null
@@ -0,0 +1,557 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "big"
+       "bytes"
+       "io"
+       "os"
+       "reflect"
+)
+
+// These are SSH message type numbers. They are scattered around several
+// documents but many were taken from
+// http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1
+const (
+       msgDisconnect     = 1
+       msgIgnore         = 2
+       msgUnimplemented  = 3
+       msgDebug          = 4
+       msgServiceRequest = 5
+       msgServiceAccept  = 6
+
+       msgKexInit = 20
+       msgNewKeys = 21
+
+       msgKexDHInit  = 30
+       msgKexDHReply = 31
+
+       msgUserAuthRequest  = 50
+       msgUserAuthFailure  = 51
+       msgUserAuthSuccess  = 52
+       msgUserAuthBanner   = 53
+       msgUserAuthPubKeyOk = 60
+
+       msgGlobalRequest  = 80
+       msgRequestSuccess = 81
+       msgRequestFailure = 82
+
+       msgChannelOpen         = 90
+       msgChannelOpenConfirm  = 91
+       msgChannelOpenFailure  = 92
+       msgChannelWindowAdjust = 93
+       msgChannelData         = 94
+       msgChannelExtendedData = 95
+       msgChannelEOF          = 96
+       msgChannelClose        = 97
+       msgChannelRequest      = 98
+       msgChannelSuccess      = 99
+       msgChannelFailure      = 100
+)
+
+// SSH messages:
+//
+// These structures mirror the wire format of the corresponding SSH messages.
+// They are marshaled using reflection with the marshal and unmarshal functions
+// in this file. The only wrinkle is that a final member of type []byte with a
+// tag of "rest" receives the remainder of a packet when unmarshaling.
+
+// See RFC 4253, section 7.1.
+type kexInitMsg struct {
+       Cookie                  [16]byte
+       KexAlgos                []string
+       ServerHostKeyAlgos      []string
+       CiphersClientServer     []string
+       CiphersServerClient     []string
+       MACsClientServer        []string
+       MACsServerClient        []string
+       CompressionClientServer []string
+       CompressionServerClient []string
+       LanguagesClientServer   []string
+       LanguagesServerClient   []string
+       FirstKexFollows         bool
+       Reserved                uint32
+}
+
+// See RFC 4253, section 8.
+type kexDHInitMsg struct {
+       X *big.Int
+}
+
+type kexDHReplyMsg struct {
+       HostKey   []byte
+       Y         *big.Int
+       Signature []byte
+}
+
+// See RFC 4253, section 10.
+type serviceRequestMsg struct {
+       Service string
+}
+
+// See RFC 4253, section 10.
+type serviceAcceptMsg struct {
+       Service string
+}
+
+// See RFC 4252, section 5.
+type userAuthRequestMsg struct {
+       User    string
+       Service string
+       Method  string
+       Payload []byte "rest"
+}
+
+// See RFC 4252, section 5.1
+type userAuthFailureMsg struct {
+       Methods        []string
+       PartialSuccess bool
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenMsg struct {
+       ChanType         string
+       PeersId          uint32
+       PeersWindow      uint32
+       MaxPacketSize    uint32
+       TypeSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenConfirmMsg struct {
+       PeersId          uint32
+       MyId             uint32
+       MyWindow         uint32
+       MaxPacketSize    uint32
+       TypeSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.1.
+type channelOpenFailureMsg struct {
+       PeersId  uint32
+       Reason   uint32
+       Message  string
+       Language string
+}
+
+type channelRequestMsg struct {
+       PeersId             uint32
+       Request             string
+       WantReply           bool
+       RequestSpecificData []byte "rest"
+}
+
+// See RFC 4254, section 5.4.
+type channelRequestSuccessMsg struct {
+       PeersId uint32
+}
+
+// See RFC 4254, section 5.4.
+type channelRequestFailureMsg struct {
+       PeersId uint32
+}
+
+// See RFC 4254, section 5.3
+type channelCloseMsg struct {
+       PeersId uint32
+}
+
+// See RFC 4254, section 5.3
+type channelEOFMsg struct {
+       PeersId uint32
+}
+
+// See RFC 4254, section 4
+type globalRequestMsg struct {
+       Type      string
+       WantReply bool
+}
+
+// See RFC 4254, section 5.2
+type windowAdjustMsg struct {
+       PeersId         uint32
+       AdditionalBytes uint32
+}
+
+// See RFC 4252, section 7
+type userAuthPubKeyOkMsg struct {
+       Algo   string
+       PubKey string
+}
+
+// unmarshal parses the SSH wire data in packet into out using reflection.
+// expectedType is the expected SSH message type. It either returns nil on
+// success, or a ParseError or UnexpectedMessageError on error.
+func unmarshal(out interface{}, packet []byte, expectedType uint8) os.Error {
+       if len(packet) == 0 {
+               return ParseError{expectedType}
+       }
+       if packet[0] != expectedType {
+               return UnexpectedMessageError{expectedType, packet[0]}
+       }
+       packet = packet[1:]
+
+       v := reflect.ValueOf(out).Elem()
+       structType := v.Type()
+       var ok bool
+       for i := 0; i < v.NumField(); i++ {
+               field := v.Field(i)
+               t := field.Type()
+               switch t.Kind() {
+               case reflect.Bool:
+                       if len(packet) < 1 {
+                               return ParseError{expectedType}
+                       }
+                       field.SetBool(packet[0] != 0)
+                       packet = packet[1:]
+               case reflect.Array:
+                       if t.Elem().Kind() != reflect.Uint8 {
+                               panic("array of non-uint8")
+                       }
+                       if len(packet) < t.Len() {
+                               return ParseError{expectedType}
+                       }
+                       for j := 0; j < t.Len(); j++ {
+                               field.Index(j).Set(reflect.ValueOf(packet[j]))
+                       }
+                       packet = packet[t.Len():]
+               case reflect.Uint32:
+                       var u32 uint32
+                       if u32, packet, ok = parseUint32(packet); !ok {
+                               return ParseError{expectedType}
+                       }
+                       field.SetUint(uint64(u32))
+               case reflect.String:
+                       var s []byte
+                       if s, packet, ok = parseString(packet); !ok {
+                               return ParseError{expectedType}
+                       }
+                       field.SetString(string(s))
+               case reflect.Slice:
+                       switch t.Elem().Kind() {
+                       case reflect.Uint8:
+                               if structType.Field(i).Tag == "rest" {
+                                       field.Set(reflect.ValueOf(packet))
+                                       packet = nil
+                               } else {
+                                       var s []byte
+                                       if s, packet, ok = parseString(packet); !ok {
+                                               return ParseError{expectedType}
+                                       }
+                                       field.Set(reflect.ValueOf(s))
+                               }
+                       case reflect.String:
+                               var nl []string
+                               if nl, packet, ok = parseNameList(packet); !ok {
+                                       return ParseError{expectedType}
+                               }
+                               field.Set(reflect.ValueOf(nl))
+                       default:
+                               panic("slice of unknown type")
+                       }
+               case reflect.Ptr:
+                       if t == bigIntType {
+                               var n *big.Int
+                               if n, packet, ok = parseInt(packet); !ok {
+                                       return ParseError{expectedType}
+                               }
+                               field.Set(reflect.ValueOf(n))
+                       } else {
+                               panic("pointer to unknown type")
+                       }
+               default:
+                       panic("unknown type")
+               }
+       }
+
+       if len(packet) != 0 {
+               return ParseError{expectedType}
+       }
+
+       return nil
+}
+
+// marshal serializes the message in msg, using the given message type.
+func marshal(msgType uint8, msg interface{}) []byte {
+       var out []byte
+       out = append(out, msgType)
+
+       v := reflect.ValueOf(msg)
+       structType := v.Type()
+       for i := 0; i < v.NumField(); i++ {
+               field := v.Field(i)
+               t := field.Type()
+               switch t.Kind() {
+               case reflect.Bool:
+                       var v uint8
+                       if field.Bool() {
+                               v = 1
+                       }
+                       out = append(out, v)
+               case reflect.Array:
+                       if t.Elem().Kind() != reflect.Uint8 {
+                               panic("array of non-uint8")
+                       }
+                       for j := 0; j < t.Len(); j++ {
+                               out = append(out, byte(field.Index(j).Uint()))
+                       }
+               case reflect.Uint32:
+                       u32 := uint32(field.Uint())
+                       out = append(out, byte(u32>>24))
+                       out = append(out, byte(u32>>16))
+                       out = append(out, byte(u32>>8))
+                       out = append(out, byte(u32))
+               case reflect.String:
+                       s := field.String()
+                       out = append(out, byte(len(s)>>24))
+                       out = append(out, byte(len(s)>>16))
+                       out = append(out, byte(len(s)>>8))
+                       out = append(out, byte(len(s)))
+                       out = append(out, []byte(s)...)
+               case reflect.Slice:
+                       switch t.Elem().Kind() {
+                       case reflect.Uint8:
+                               length := field.Len()
+                               if structType.Field(i).Tag != "rest" {
+                                       out = append(out, byte(length>>24))
+                                       out = append(out, byte(length>>16))
+                                       out = append(out, byte(length>>8))
+                                       out = append(out, byte(length))
+                               }
+                               for j := 0; j < length; j++ {
+                                       out = append(out, byte(field.Index(j).Uint()))
+                               }
+                       case reflect.String:
+                               var length int
+                               for j := 0; j < field.Len(); j++ {
+                                       if j != 0 {
+                                               length++ /* comma */
+                                       }
+                                       length += len(field.Index(j).String())
+                               }
+
+                               out = append(out, byte(length>>24))
+                               out = append(out, byte(length>>16))
+                               out = append(out, byte(length>>8))
+                               out = append(out, byte(length))
+                               for j := 0; j < field.Len(); j++ {
+                                       if j != 0 {
+                                               out = append(out, ',')
+                                       }
+                                       out = append(out, []byte(field.Index(j).String())...)
+                               }
+                       default:
+                               panic("slice of unknown type")
+                       }
+               case reflect.Ptr:
+                       if t == bigIntType {
+                               var n *big.Int
+                               nValue := reflect.ValueOf(&n)
+                               nValue.Elem().Set(field)
+                               needed := intLength(n)
+                               oldLength := len(out)
+
+                               if cap(out)-len(out) < needed {
+                                       newOut := make([]byte, len(out), 2*(len(out)+needed))
+                                       copy(newOut, out)
+                                       out = newOut
+                               }
+                               out = out[:oldLength+needed]
+                               marshalInt(out[oldLength:], n)
+                       } else {
+                               panic("pointer to unknown type")
+                       }
+               }
+       }
+
+       return out
+}
+
+var bigOne = big.NewInt(1)
+
+func parseString(in []byte) (out, rest []byte, ok bool) {
+       if len(in) < 4 {
+               return
+       }
+       length := uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
+       if uint32(len(in)) < 4+length {
+               return
+       }
+       out = in[4 : 4+length]
+       rest = in[4+length:]
+       ok = true
+       return
+}
+
+var comma = []byte{','}
+
+func parseNameList(in []byte) (out []string, rest []byte, ok bool) {
+       contents, rest, ok := parseString(in)
+       if !ok {
+               return
+       }
+       if len(contents) == 0 {
+               return
+       }
+       parts := bytes.Split(contents, comma)
+       out = make([]string, len(parts))
+       for i, part := range parts {
+               out[i] = string(part)
+       }
+       return
+}
+
+func parseInt(in []byte) (out *big.Int, rest []byte, ok bool) {
+       contents, rest, ok := parseString(in)
+       if !ok {
+               return
+       }
+       out = new(big.Int)
+
+       if len(contents) > 0 && contents[0]&0x80 == 0x80 {
+               // This is a negative number
+               notBytes := make([]byte, len(contents))
+               for i := range notBytes {
+                       notBytes[i] = ^contents[i]
+               }
+               out.SetBytes(notBytes)
+               out.Add(out, bigOne)
+               out.Neg(out)
+       } else {
+               // Positive number
+               out.SetBytes(contents)
+       }
+       ok = true
+       return
+}
+
+func parseUint32(in []byte) (out uint32, rest []byte, ok bool) {
+       if len(in) < 4 {
+               return
+       }
+       out = uint32(in[0])<<24 | uint32(in[1])<<16 | uint32(in[2])<<8 | uint32(in[3])
+       rest = in[4:]
+       ok = true
+       return
+}
+
+const maxPacketSize = 36000
+
+func nameListLength(namelist []string) int {
+       length := 4 /* uint32 length prefix */
+       for i, name := range namelist {
+               if i != 0 {
+                       length++ /* comma */
+               }
+               length += len(name)
+       }
+       return length
+}
+
+func intLength(n *big.Int) int {
+       length := 4 /* length bytes */
+       if n.Sign() < 0 {
+               nMinus1 := new(big.Int).Neg(n)
+               nMinus1.Sub(nMinus1, bigOne)
+               bitLen := nMinus1.BitLen()
+               if bitLen%8 == 0 {
+                       // The number will need 0xff padding
+                       length++
+               }
+               length += (bitLen + 7) / 8
+       } else if n.Sign() == 0 {
+               // A zero is the zero length string
+       } else {
+               bitLen := n.BitLen()
+               if bitLen%8 == 0 {
+                       // The number will need 0x00 padding
+                       length++
+               }
+               length += (bitLen + 7) / 8
+       }
+
+       return length
+}
+
+func marshalInt(to []byte, n *big.Int) []byte {
+       lengthBytes := to
+       to = to[4:]
+       length := 0
+
+       if n.Sign() < 0 {
+               // A negative number has to be converted to two's-complement
+               // form. So we'll subtract 1 and invert. If the
+               // most-significant-bit isn't set then we'll need to pad the
+               // beginning with 0xff in order to keep the number negative.
+               nMinus1 := new(big.Int).Neg(n)
+               nMinus1.Sub(nMinus1, bigOne)
+               bytes := nMinus1.Bytes()
+               for i := range bytes {
+                       bytes[i] ^= 0xff
+               }
+               if len(bytes) == 0 || bytes[0]&0x80 == 0 {
+                       to[0] = 0xff
+                       to = to[1:]
+                       length++
+               }
+               nBytes := copy(to, bytes)
+               to = to[nBytes:]
+               length += nBytes
+       } else if n.Sign() == 0 {
+               // A zero is the zero length string
+       } else {
+               bytes := n.Bytes()
+               if len(bytes) > 0 && bytes[0]&0x80 != 0 {
+                       // We'll have to pad this with a 0x00 in order to
+                       // stop it looking like a negative number.
+                       to[0] = 0
+                       to = to[1:]
+                       length++
+               }
+               nBytes := copy(to, bytes)
+               to = to[nBytes:]
+               length += nBytes
+       }
+
+       lengthBytes[0] = byte(length >> 24)
+       lengthBytes[1] = byte(length >> 16)
+       lengthBytes[2] = byte(length >> 8)
+       lengthBytes[3] = byte(length)
+       return to
+}
+
+func writeInt(w io.Writer, n *big.Int) {
+       length := intLength(n)
+       buf := make([]byte, length)
+       marshalInt(buf, n)
+       w.Write(buf)
+}
+
+func writeString(w io.Writer, s []byte) {
+       var lengthBytes [4]byte
+       lengthBytes[0] = byte(len(s) >> 24)
+       lengthBytes[1] = byte(len(s) >> 16)
+       lengthBytes[2] = byte(len(s) >> 8)
+       lengthBytes[3] = byte(len(s))
+       w.Write(lengthBytes[:])
+       w.Write(s)
+}
+
+func stringLength(s []byte) int {
+       return 4 + len(s)
+}
+
+func marshalString(to []byte, s []byte) []byte {
+       to[0] = byte(len(s) >> 24)
+       to[1] = byte(len(s) >> 16)
+       to[2] = byte(len(s) >> 8)
+       to[3] = byte(len(s))
+       to = to[4:]
+       copy(to, s)
+       return to[len(s):]
+}
+
+var bigIntType = reflect.TypeOf((*big.Int)(nil))
diff --git a/src/pkg/exp/ssh/messages_test.go b/src/pkg/exp/ssh/messages_test.go
new file mode 100644 (file)
index 0000000..629f3d3
--- /dev/null
@@ -0,0 +1,125 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "big"
+       "rand"
+       "reflect"
+       "testing"
+       "testing/quick"
+)
+
+var intLengthTests = []struct {
+       val, length int
+}{
+       {0, 4 + 0},
+       {1, 4 + 1},
+       {127, 4 + 1},
+       {128, 4 + 2},
+       {-1, 4 + 1},
+}
+
+func TestIntLength(t *testing.T) {
+       for _, test := range intLengthTests {
+               v := new(big.Int).SetInt64(int64(test.val))
+               length := intLength(v)
+               if length != test.length {
+                       t.Errorf("For %d, got length %d but expected %d", test.val, length, test.length)
+               }
+       }
+}
+
+var messageTypes = []interface{}{
+       &kexInitMsg{},
+       &kexDHInitMsg{},
+       &serviceRequestMsg{},
+       &serviceAcceptMsg{},
+       &userAuthRequestMsg{},
+       &channelOpenMsg{},
+       &channelOpenConfirmMsg{},
+       &channelRequestMsg{},
+       &channelRequestSuccessMsg{},
+}
+
+func TestMarshalUnmarshal(t *testing.T) {
+       rand := rand.New(rand.NewSource(0))
+       for i, iface := range messageTypes {
+               ty := reflect.ValueOf(iface).Type()
+
+               n := 100
+               if testing.Short() {
+                       n = 5
+               }
+               for j := 0; j < n; j++ {
+                       v, ok := quick.Value(ty, rand)
+                       if !ok {
+                               t.Errorf("#%d: failed to create value", i)
+                               break
+                       }
+
+                       m1 := v.Elem().Interface()
+                       m2 := iface
+
+                       marshaled := marshal(msgIgnore, m1)
+                       if err := unmarshal(m2, marshaled, msgIgnore); err != nil {
+                               t.Errorf("#%d failed to unmarshal %#v: %s", i, m1, err)
+                               break
+                       }
+
+                       if !reflect.DeepEqual(v.Interface(), m2) {
+                               t.Errorf("#%d\ngot: %#v\nwant:%#v\n%x", i, m2, m1, marshaled)
+                               break
+                       }
+               }
+       }
+}
+
+func randomBytes(out []byte, rand *rand.Rand) {
+       for i := 0; i < len(out); i++ {
+               out[i] = byte(rand.Int31())
+       }
+}
+
+func randomNameList(rand *rand.Rand) []string {
+       ret := make([]string, rand.Int31()&15)
+       for i := range ret {
+               s := make([]byte, 1+(rand.Int31()&15))
+               for j := range s {
+                       s[j] = 'a' + uint8(rand.Int31()&15)
+               }
+               ret[i] = string(s)
+       }
+       return ret
+}
+
+func randomInt(rand *rand.Rand) *big.Int {
+       return new(big.Int).SetInt64(int64(int32(rand.Uint32())))
+}
+
+func (*kexInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       ki := &kexInitMsg{}
+       randomBytes(ki.Cookie[:], rand)
+       ki.KexAlgos = randomNameList(rand)
+       ki.ServerHostKeyAlgos = randomNameList(rand)
+       ki.CiphersClientServer = randomNameList(rand)
+       ki.CiphersServerClient = randomNameList(rand)
+       ki.MACsClientServer = randomNameList(rand)
+       ki.MACsServerClient = randomNameList(rand)
+       ki.CompressionClientServer = randomNameList(rand)
+       ki.CompressionServerClient = randomNameList(rand)
+       ki.LanguagesClientServer = randomNameList(rand)
+       ki.LanguagesServerClient = randomNameList(rand)
+       if rand.Int31()&1 == 1 {
+               ki.FirstKexFollows = true
+       }
+       return reflect.ValueOf(ki)
+}
+
+func (*kexDHInitMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       dhi := &kexDHInitMsg{}
+       dhi.X = randomInt(rand)
+       return reflect.ValueOf(dhi)
+}
diff --git a/src/pkg/exp/ssh/server.go b/src/pkg/exp/ssh/server.go
new file mode 100644 (file)
index 0000000..57cd597
--- /dev/null
@@ -0,0 +1,711 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "big"
+       "bufio"
+       "bytes"
+       "crypto"
+       "crypto/rand"
+       "crypto/rsa"
+       _ "crypto/sha1"
+       "crypto/x509"
+       "encoding/pem"
+       "net"
+       "os"
+       "sync"
+)
+
+var supportedKexAlgos = []string{kexAlgoDH14SHA1}
+var supportedHostKeyAlgos = []string{hostAlgoRSA}
+var supportedCiphers = []string{cipherAES128CTR}
+var supportedMACs = []string{macSHA196}
+var supportedCompressions = []string{compressionNone}
+
+// Server represents an SSH server. A Server may have several ServerConnections.
+type Server struct {
+       rsa           *rsa.PrivateKey
+       rsaSerialized []byte
+
+       // NoClientAuth is true if clients are allowed to connect without
+       // authenticating.
+       NoClientAuth bool
+
+       // PasswordCallback, if non-nil, is called when a user attempts to
+       // authenticate using a password. It may be called concurrently from
+       // several goroutines.
+       PasswordCallback func(user, password string) bool
+
+       // PubKeyCallback, if non-nil, is called when a client attempts public
+       // key authentication. It must return true iff the given public key is
+       // valid for the given user.
+       PubKeyCallback func(user, algo string, pubkey []byte) bool
+}
+
+// SetRSAPrivateKey sets the private key for a Server. A Server must have a
+// private key configured in order to accept connections. The private key must
+// be in the form of a PEM encoded, PKCS#1, RSA private key. The file "id_rsa"
+// typically contains such a key.
+func (s *Server) SetRSAPrivateKey(pemBytes []byte) os.Error {
+       block, _ := pem.Decode(pemBytes)
+       if block == nil {
+               return os.NewError("ssh: no key found")
+       }
+       var err os.Error
+       s.rsa, err = x509.ParsePKCS1PrivateKey(block.Bytes)
+       if err != nil {
+               return err
+       }
+
+       s.rsaSerialized = marshalRSA(s.rsa)
+       return nil
+}
+
+// marshalRSA serializes an RSA private key according to RFC 4256, section 6.6.
+func marshalRSA(priv *rsa.PrivateKey) []byte {
+       e := new(big.Int).SetInt64(int64(priv.E))
+       length := stringLength([]byte(hostAlgoRSA))
+       length += intLength(e)
+       length += intLength(priv.N)
+
+       ret := make([]byte, length)
+       r := marshalString(ret, []byte(hostAlgoRSA))
+       r = marshalInt(r, e)
+       r = marshalInt(r, priv.N)
+
+       return ret
+}
+
+// parseRSA parses an RSA key according to RFC 4256, section 6.6.
+func parseRSA(in []byte) (pubKey *rsa.PublicKey, ok bool) {
+       algo, in, ok := parseString(in)
+       if !ok || string(algo) != hostAlgoRSA {
+               return nil, false
+       }
+       bigE, in, ok := parseInt(in)
+       if !ok || bigE.BitLen() > 24 {
+               return nil, false
+       }
+       e := bigE.Int64()
+       if e < 3 || e&1 == 0 {
+               return nil, false
+       }
+       N, in, ok := parseInt(in)
+       if !ok || len(in) > 0 {
+               return nil, false
+       }
+       return &rsa.PublicKey{
+               N: N,
+               E: int(e),
+       }, true
+}
+
+func parseRSASig(in []byte) (sig []byte, ok bool) {
+       algo, in, ok := parseString(in)
+       if !ok || string(algo) != hostAlgoRSA {
+               return nil, false
+       }
+       sig, in, ok = parseString(in)
+       if len(in) > 0 {
+               ok = false
+       }
+       return
+}
+
+// cachedPubKey contains the results of querying whether a public key is
+// acceptable for a user. The cache only applies to a single ServerConnection.
+type cachedPubKey struct {
+       user, algo string
+       pubKey     []byte
+       result     bool
+}
+
+const maxCachedPubKeys = 16
+
+// ServerConnection represents an incomming connection to a Server.
+type ServerConnection struct {
+       Server *Server
+
+       in, out *halfConnection
+
+       channels   map[uint32]*channel
+       nextChanId uint32
+
+       // lock protects err and also allows Channels to serialise their writes
+       // to out.
+       lock sync.RWMutex
+       err  os.Error
+
+       // cachedPubKeys contains the cache results of tests for public keys.
+       // Since SSH clients will query whether a public key is acceptable
+       // before attempting to authenticate with it, we end up with duplicate
+       // queries for public key validity.
+       cachedPubKeys []cachedPubKey
+}
+
+// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
+type dhGroup struct {
+       g, p *big.Int
+}
+
+// dhGroup14 is the group called diffie-hellman-group14-sha1 in RFC 4253 and
+// Oakley Group 14 in RFC 3526.
+var dhGroup14 *dhGroup
+
+var dhGroup14Once sync.Once
+
+func initDHGroup14() {
+       p, _ := new(big.Int).SetString("FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE45B3DC2007CB8A163BF0598DA48361C55D39A69163FA8FD24CF5F83655D23DCA3AD961C62F356208552BB9ED529077096966D670C354E4ABC9804F1746C08CA18217C32905E462E36CE3BE39E772C180E86039B2783A2EC07A28FB5C55DF06F4C52C9DE2BCBF6955817183995497CEA956AE515D2261898FA051015728E5A8AACAA68FFFFFFFFFFFFFFFF", 16)
+
+       dhGroup14 = &dhGroup{
+               g: new(big.Int).SetInt64(2),
+               p: p,
+       }
+}
+
+type handshakeMagics struct {
+       clientVersion, serverVersion []byte
+       clientKexInit, serverKexInit []byte
+}
+
+// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
+// returned values are given the same names as in RFC 4253, section 8.
+func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
+       packet, err := s.in.readPacket()
+       if err != nil {
+               return
+       }
+       var kexDHInit kexDHInitMsg
+       if err = unmarshal(&kexDHInit, packet, msgKexDHInit); err != nil {
+               return
+       }
+
+       if kexDHInit.X.Sign() == 0 || kexDHInit.X.Cmp(group.p) >= 0 {
+               return nil, nil, os.NewError("client DH parameter out of bounds")
+       }
+
+       y, err := rand.Int(rand.Reader, group.p)
+       if err != nil {
+               return
+       }
+
+       Y := new(big.Int).Exp(group.g, y, group.p)
+       kInt := new(big.Int).Exp(kexDHInit.X, y, group.p)
+
+       var serializedHostKey []byte
+       switch hostKeyAlgo {
+       case hostAlgoRSA:
+               serializedHostKey = s.Server.rsaSerialized
+       default:
+               return nil, nil, os.NewError("internal error")
+       }
+
+       h := hashFunc.New()
+       writeString(h, magics.clientVersion)
+       writeString(h, magics.serverVersion)
+       writeString(h, magics.clientKexInit)
+       writeString(h, magics.serverKexInit)
+       writeString(h, serializedHostKey)
+       writeInt(h, kexDHInit.X)
+       writeInt(h, Y)
+       K = make([]byte, intLength(kInt))
+       marshalInt(K, kInt)
+       h.Write(K)
+
+       H = h.Sum()
+
+       h.Reset()
+       h.Write(H)
+       hh := h.Sum()
+
+       var sig []byte
+       switch hostKeyAlgo {
+       case hostAlgoRSA:
+               sig, err = rsa.SignPKCS1v15(rand.Reader, s.Server.rsa, hashFunc, hh)
+               if err != nil {
+                       return
+               }
+       default:
+               return nil, nil, os.NewError("internal error")
+       }
+
+       serializedSig := serializeRSASignature(sig)
+
+       kexDHReply := kexDHReplyMsg{
+               HostKey:   serializedHostKey,
+               Y:         Y,
+               Signature: serializedSig,
+       }
+       packet = marshal(msgKexDHReply, kexDHReply)
+
+       err = s.out.writePacket(packet)
+       return
+}
+
+func serializeRSASignature(sig []byte) []byte {
+       length := stringLength([]byte(hostAlgoRSA))
+       length += stringLength(sig)
+
+       ret := make([]byte, length)
+       r := marshalString(ret, []byte(hostAlgoRSA))
+       r = marshalString(r, sig)
+
+       return ret
+}
+
+// serverVersion is the fixed identification string that Server will use.
+var serverVersion = []byte("SSH-2.0-Go\r\n")
+
+// buildDataSignedForAuth returns the data that is signed in order to prove
+// posession of a private key. See RFC 4252, section 7.
+func buildDataSignedForAuth(sessionId []byte, req userAuthRequestMsg, algo, pubKey []byte) []byte {
+       user := []byte(req.User)
+       service := []byte(req.Service)
+       method := []byte(req.Method)
+
+       length := stringLength(sessionId)
+       length += 1
+       length += stringLength(user)
+       length += stringLength(service)
+       length += stringLength(method)
+       length += 1
+       length += stringLength(algo)
+       length += stringLength(pubKey)
+
+       ret := make([]byte, length)
+       r := marshalString(ret, sessionId)
+       r[0] = msgUserAuthRequest
+       r = r[1:]
+       r = marshalString(r, user)
+       r = marshalString(r, service)
+       r = marshalString(r, method)
+       r[0] = 1
+       r = r[1:]
+       r = marshalString(r, algo)
+       r = marshalString(r, pubKey)
+       return ret
+}
+
+// Handshake performs an SSH transport and client authentication on the given ServerConnection.
+func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
+       var magics handshakeMagics
+       inBuf := bufio.NewReader(conn)
+
+       _, err := conn.Write(serverVersion)
+       if err != nil {
+               return err
+       }
+
+       magics.serverVersion = serverVersion[:len(serverVersion)-2]
+       serverKexInit := kexInitMsg{
+               KexAlgos:                supportedKexAlgos,
+               ServerHostKeyAlgos:      supportedHostKeyAlgos,
+               CiphersClientServer:     supportedCiphers,
+               CiphersServerClient:     supportedCiphers,
+               MACsClientServer:        supportedMACs,
+               MACsServerClient:        supportedMACs,
+               CompressionClientServer: supportedCompressions,
+               CompressionServerClient: supportedCompressions,
+       }
+       kexInitPacket := marshal(msgKexInit, serverKexInit)
+       magics.serverKexInit = kexInitPacket
+
+       var out halfConnection
+       out.out = conn
+       out.rand = rand.Reader
+       s.out = &out
+       err = out.writePacket(kexInitPacket)
+       if err != nil {
+               return err
+       }
+
+       version, ok := readVersion(inBuf)
+       if !ok {
+               return os.NewError("failed to read version string from client")
+       }
+       magics.clientVersion = version
+
+       var in halfConnection
+       in.in = inBuf
+       s.in = &in
+       packet, err := in.readPacket()
+       if err != nil {
+               return err
+       }
+       magics.clientKexInit = packet
+
+       var clientKexInit kexInitMsg
+       if err = unmarshal(&clientKexInit, packet, msgKexInit); err != nil {
+               return err
+       }
+
+       kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit)
+       if !ok {
+               return os.NewError("ssh: no common algorithms")
+       }
+
+       if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
+               // The client sent a Kex message for the wrong algorithm,
+               // which we have to ignore.
+               _, err := in.readPacket()
+               if err != nil {
+                       return err
+               }
+       }
+
+       var H, K []byte
+       var hashFunc crypto.Hash
+       switch kexAlgo {
+       case kexAlgoDH14SHA1:
+               hashFunc = crypto.SHA1
+               dhGroup14Once.Do(initDHGroup14)
+               H, K, err = s.kexDH(dhGroup14, hashFunc, &magics, hostKeyAlgo)
+       default:
+               err = os.NewError("ssh: internal error")
+       }
+
+       if err != nil {
+               return err
+       }
+
+       packet = []byte{msgNewKeys}
+       if err = out.writePacket(packet); err != nil {
+               return err
+       }
+       if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+               return err
+       }
+
+       if packet, err = in.readPacket(); err != nil {
+               return err
+       }
+       if packet[0] != msgNewKeys {
+               return UnexpectedMessageError{msgNewKeys, packet[0]}
+       }
+
+       in.setupKeys(clientKeys, K, H, H, hashFunc)
+
+       packet, err = in.readPacket()
+       if err != nil {
+               return err
+       }
+
+       var serviceRequest serviceRequestMsg
+       if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
+               return err
+       }
+       if serviceRequest.Service != serviceUserAuth {
+               return os.NewError("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
+       }
+
+       serviceAccept := serviceAcceptMsg{
+               Service: serviceUserAuth,
+       }
+       packet = marshal(msgServiceAccept, serviceAccept)
+       if err = out.writePacket(packet); err != nil {
+               return err
+       }
+
+       if err = s.authenticate(H); err != nil {
+               return err
+       }
+
+       s.channels = make(map[uint32]*channel)
+       return nil
+}
+
+func isAcceptableAlgo(algo string) bool {
+       return algo == hostAlgoRSA
+}
+
+// testPubKey returns true if the given public key is acceptable for the user.
+func (s *ServerConnection) testPubKey(user, algo string, pubKey []byte) bool {
+       if s.Server.PubKeyCallback == nil || !isAcceptableAlgo(algo) {
+               return false
+       }
+
+       for _, c := range s.cachedPubKeys {
+               if c.user == user && c.algo == algo && bytes.Equal(c.pubKey, pubKey) {
+                       return c.result
+               }
+       }
+
+       result := s.Server.PubKeyCallback(user, algo, pubKey)
+       if len(s.cachedPubKeys) < maxCachedPubKeys {
+               c := cachedPubKey{
+                       user:   user,
+                       algo:   algo,
+                       pubKey: make([]byte, len(pubKey)),
+                       result: result,
+               }
+               copy(c.pubKey, pubKey)
+               s.cachedPubKeys = append(s.cachedPubKeys, c)
+       }
+
+       return result
+}
+
+func (s *ServerConnection) authenticate(H []byte) os.Error {
+       var userAuthReq userAuthRequestMsg
+       var err os.Error
+       var packet []byte
+
+userAuthLoop:
+       for {
+               if packet, err = s.in.readPacket(); err != nil {
+                       return err
+               }
+               if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
+                       return err
+               }
+
+               if userAuthReq.Service != serviceSSH {
+                       return os.NewError("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
+               }
+
+               switch userAuthReq.Method {
+               case "none":
+                       if s.Server.NoClientAuth {
+                               break userAuthLoop
+                       }
+               case "password":
+                       if s.Server.PasswordCallback == nil {
+                               break
+                       }
+                       payload := userAuthReq.Payload
+                       if len(payload) < 1 || payload[0] != 0 {
+                               return ParseError{msgUserAuthRequest}
+                       }
+                       payload = payload[1:]
+                       password, payload, ok := parseString(payload)
+                       if !ok || len(payload) > 0 {
+                               return ParseError{msgUserAuthRequest}
+                       }
+
+                       if s.Server.PasswordCallback(userAuthReq.User, string(password)) {
+                               break userAuthLoop
+                       }
+               case "publickey":
+                       if s.Server.PubKeyCallback == nil {
+                               break
+                       }
+                       payload := userAuthReq.Payload
+                       if len(payload) < 1 {
+                               return ParseError{msgUserAuthRequest}
+                       }
+                       isQuery := payload[0] == 0
+                       payload = payload[1:]
+                       algoBytes, payload, ok := parseString(payload)
+                       if !ok {
+                               return ParseError{msgUserAuthRequest}
+                       }
+                       algo := string(algoBytes)
+
+                       pubKey, payload, ok := parseString(payload)
+                       if !ok {
+                               return ParseError{msgUserAuthRequest}
+                       }
+                       if isQuery {
+                               // The client can query if the given public key
+                               // would be ok.
+                               if len(payload) > 0 {
+                                       return ParseError{msgUserAuthRequest}
+                               }
+                               if s.testPubKey(userAuthReq.User, algo, pubKey) {
+                                       okMsg := userAuthPubKeyOkMsg{
+                                               Algo:   algo,
+                                               PubKey: string(pubKey),
+                                       }
+                                       if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+                                               return err
+                                       }
+                                       continue userAuthLoop
+                               }
+                       } else {
+                               sig, payload, ok := parseString(payload)
+                               if !ok || len(payload) > 0 {
+                                       return ParseError{msgUserAuthRequest}
+                               }
+                               if !isAcceptableAlgo(algo) {
+                                       break
+                               }
+                               rsaSig, ok := parseRSASig(sig)
+                               if !ok {
+                                       return ParseError{msgUserAuthRequest}
+                               }
+                               signedData := buildDataSignedForAuth(H, userAuthReq, algoBytes, pubKey)
+                               switch algo {
+                               case hostAlgoRSA:
+                                       hashFunc := crypto.SHA1
+                                       h := hashFunc.New()
+                                       h.Write(signedData)
+                                       digest := h.Sum()
+                                       rsaKey, ok := parseRSA(pubKey)
+                                       if !ok {
+                                               return ParseError{msgUserAuthRequest}
+                                       }
+                                       if rsa.VerifyPKCS1v15(rsaKey, hashFunc, digest, rsaSig) != nil {
+                                               return ParseError{msgUserAuthRequest}
+                                       }
+                               default:
+                                       return os.NewError("ssh: isAcceptableAlgo incorrect")
+                               }
+                               if s.testPubKey(userAuthReq.User, algo, pubKey) {
+                                       break userAuthLoop
+                               }
+                       }
+               }
+
+               var failureMsg userAuthFailureMsg
+               if s.Server.PasswordCallback != nil {
+                       failureMsg.Methods = append(failureMsg.Methods, "password")
+               }
+               if s.Server.PubKeyCallback != nil {
+                       failureMsg.Methods = append(failureMsg.Methods, "publickey")
+               }
+
+               if len(failureMsg.Methods) == 0 {
+                       return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
+               }
+
+               if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+                       return err
+               }
+       }
+
+       packet = []byte{msgUserAuthSuccess}
+       if err = s.out.writePacket(packet); err != nil {
+               return err
+       }
+
+       return nil
+}
+
+const defaultWindowSize = 32768
+
+// Accept reads and processes messages on a ServerConnection. It must be called
+// in order to demultiplex messages to any resulting Channels.
+func (s *ServerConnection) Accept() (Channel, os.Error) {
+       if s.err != nil {
+               return nil, s.err
+       }
+
+       for {
+               packet, err := s.in.readPacket()
+               if err != nil {
+
+                       s.lock.Lock()
+                       s.err = err
+                       s.lock.Unlock()
+
+                       for _, c := range s.channels {
+                               c.dead = true
+                               c.handleData(nil)
+                       }
+
+                       return nil, err
+               }
+
+               switch packet[0] {
+               case msgChannelOpen:
+                       var chanOpen channelOpenMsg
+                       if err := unmarshal(&chanOpen, packet, msgChannelOpen); err != nil {
+                               return nil, err
+                       }
+
+                       c := new(channel)
+                       c.chanType = chanOpen.ChanType
+                       c.theirId = chanOpen.PeersId
+                       c.theirWindow = chanOpen.PeersWindow
+                       c.maxPacketSize = chanOpen.MaxPacketSize
+                       c.extraData = chanOpen.TypeSpecificData
+                       c.myWindow = defaultWindowSize
+                       c.serverConn = s
+                       c.cond = sync.NewCond(&c.lock)
+                       c.pendingData = make([]byte, c.myWindow)
+
+                       s.lock.Lock()
+                       c.myId = s.nextChanId
+                       s.nextChanId++
+                       s.channels[c.myId] = c
+                       s.lock.Unlock()
+                       return c, nil
+
+               case msgChannelRequest:
+                       var chanRequest channelRequestMsg
+                       if err := unmarshal(&chanRequest, packet, msgChannelRequest); err != nil {
+                               return nil, err
+                       }
+
+                       s.lock.Lock()
+                       c, ok := s.channels[chanRequest.PeersId]
+                       if !ok {
+                               continue
+                       }
+                       c.handlePacket(&chanRequest)
+                       s.lock.Unlock()
+
+               case msgChannelData:
+                       if len(packet) < 5 {
+                               return nil, ParseError{msgChannelData}
+                       }
+                       chanId := uint32(packet[1])<<24 | uint32(packet[2])<<16 | uint32(packet[3])<<8 | uint32(packet[4])
+
+                       s.lock.Lock()
+                       c, ok := s.channels[chanId]
+                       if !ok {
+                               continue
+                       }
+                       c.handleData(packet[9:])
+                       s.lock.Unlock()
+
+               case msgChannelEOF:
+                       var eofMsg channelEOFMsg
+                       if err := unmarshal(&eofMsg, packet, msgChannelEOF); err != nil {
+                               return nil, err
+                       }
+
+                       s.lock.Lock()
+                       c, ok := s.channels[eofMsg.PeersId]
+                       if !ok {
+                               continue
+                       }
+                       c.handlePacket(&eofMsg)
+                       s.lock.Unlock()
+
+               case msgChannelClose:
+                       var closeMsg channelCloseMsg
+                       if err := unmarshal(&closeMsg, packet, msgChannelClose); err != nil {
+                               return nil, err
+                       }
+
+                       s.lock.Lock()
+                       c, ok := s.channels[closeMsg.PeersId]
+                       if !ok {
+                               continue
+                       }
+                       c.handlePacket(&closeMsg)
+                       s.lock.Unlock()
+
+               case msgGlobalRequest:
+                       var request globalRequestMsg
+                       if err := unmarshal(&request, packet, msgGlobalRequest); err != nil {
+                               return nil, err
+                       }
+
+                       if request.WantReply {
+                               if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil {
+                                       return nil, err
+                               }
+                       }
+
+               default:
+                       // Unknown message. Ignore.
+               }
+       }
+
+       panic("unreachable")
+}
diff --git a/src/pkg/exp/ssh/server_shell.go b/src/pkg/exp/ssh/server_shell.go
new file mode 100644 (file)
index 0000000..53a3241
--- /dev/null
@@ -0,0 +1,399 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "os"
+)
+
+// ServerShell contains the state for running a VT100 terminal that is capable
+// of reading lines of input.
+type ServerShell struct {
+       c      Channel
+       prompt string
+
+       // line is the current line being entered.
+       line []byte
+       // pos is the logical position of the cursor in line
+       pos int
+
+       // cursorX contains the current X value of the cursor where the left
+       // edge is 0. cursorY contains the row number where the first row of
+       // the current line is 0.
+       cursorX, cursorY int
+       // maxLine is the greatest value of cursorY so far.
+       maxLine int
+
+       termWidth, termHeight int
+
+       // outBuf contains the terminal data to be sent.
+       outBuf []byte
+       // remainder contains the remainder of any partial key sequences after
+       // a read. It aliases into inBuf.
+       remainder []byte
+       inBuf     [256]byte
+}
+
+// NewServerShell runs a VT100 terminal on the given channel. prompt is a
+// string that is written at the start of each input line. For example: "> ".
+func NewServerShell(c Channel, prompt string) *ServerShell {
+       return &ServerShell{
+               c:          c,
+               prompt:     prompt,
+               termWidth:  80,
+               termHeight: 24,
+       }
+}
+
+const (
+       keyCtrlD     = 4
+       keyEnter     = '\r'
+       keyEscape    = 27
+       keyBackspace = 127
+       keyUnknown   = 256 + iota
+       keyUp
+       keyDown
+       keyLeft
+       keyRight
+       keyAltLeft
+       keyAltRight
+)
+
+// bytesToKey tries to parse a key sequence from b. If successful, it returns
+// the key and the remainder of the input. Otherwise it returns -1.
+func bytesToKey(b []byte) (int, []byte) {
+       if len(b) == 0 {
+               return -1, nil
+       }
+
+       if b[0] != keyEscape {
+               return int(b[0]), b[1:]
+       }
+
+       if len(b) >= 3 && b[0] == keyEscape && b[1] == '[' {
+               switch b[2] {
+               case 'A':
+                       return keyUp, b[3:]
+               case 'B':
+                       return keyDown, b[3:]
+               case 'C':
+                       return keyRight, b[3:]
+               case 'D':
+                       return keyLeft, b[3:]
+               }
+       }
+
+       if len(b) >= 6 && b[0] == keyEscape && b[1] == '[' && b[2] == '1' && b[3] == ';' && b[4] == '3' {
+               switch b[5] {
+               case 'C':
+                       return keyAltRight, b[6:]
+               case 'D':
+                       return keyAltLeft, b[6:]
+               }
+       }
+
+       // If we get here then we have a key that we don't recognise, or a
+       // partial sequence. It's not clear how one should find the end of a
+       // sequence without knowing them all, but it seems that [a-zA-Z] only
+       // appears at the end of a sequence.
+       for i, c := range b[0:] {
+               if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
+                       return keyUnknown, b[i+1:]
+               }
+       }
+
+       return -1, b
+}
+
+// queue appends data to the end of ss.outBuf
+func (ss *ServerShell) queue(data []byte) {
+       if len(ss.outBuf)+len(data) > cap(ss.outBuf) {
+               newOutBuf := make([]byte, len(ss.outBuf), 2*(len(ss.outBuf)+len(data)))
+               copy(newOutBuf, ss.outBuf)
+               ss.outBuf = newOutBuf
+       }
+
+       oldLen := len(ss.outBuf)
+       ss.outBuf = ss.outBuf[:len(ss.outBuf)+len(data)]
+       copy(ss.outBuf[oldLen:], data)
+}
+
+var eraseUnderCursor = []byte{' ', keyEscape, '[', 'D'}
+
+func isPrintable(key int) bool {
+       return key >= 32 && key < 127
+}
+
+// moveCursorToPos appends data to ss.outBuf which will move the cursor to the
+// given, logical position in the text.
+func (ss *ServerShell) moveCursorToPos(pos int) {
+       x := len(ss.prompt) + pos
+       y := x / ss.termWidth
+       x = x % ss.termWidth
+
+       up := 0
+       if y < ss.cursorY {
+               up = ss.cursorY - y
+       }
+
+       down := 0
+       if y > ss.cursorY {
+               down = y - ss.cursorY
+       }
+
+       left := 0
+       if x < ss.cursorX {
+               left = ss.cursorX - x
+       }
+
+       right := 0
+       if x > ss.cursorX {
+               right = x - ss.cursorX
+       }
+
+       movement := make([]byte, 3*(up+down+left+right))
+       m := movement
+       for i := 0; i < up; i++ {
+               m[0] = keyEscape
+               m[1] = '['
+               m[2] = 'A'
+               m = m[3:]
+       }
+       for i := 0; i < down; i++ {
+               m[0] = keyEscape
+               m[1] = '['
+               m[2] = 'B'
+               m = m[3:]
+       }
+       for i := 0; i < left; i++ {
+               m[0] = keyEscape
+               m[1] = '['
+               m[2] = 'D'
+               m = m[3:]
+       }
+       for i := 0; i < right; i++ {
+               m[0] = keyEscape
+               m[1] = '['
+               m[2] = 'C'
+               m = m[3:]
+       }
+
+       ss.cursorX = x
+       ss.cursorY = y
+       ss.queue(movement)
+}
+
+const maxLineLength = 4096
+
+// handleKey processes the given key and, optionally, returns a line of text
+// that the user has entered.
+func (ss *ServerShell) handleKey(key int) (line string, ok bool) {
+       switch key {
+       case keyBackspace:
+               if ss.pos == 0 {
+                       return
+               }
+               ss.pos--
+
+               copy(ss.line[ss.pos:], ss.line[1+ss.pos:])
+               ss.line = ss.line[:len(ss.line)-1]
+               ss.writeLine(ss.line[ss.pos:])
+               ss.moveCursorToPos(ss.pos)
+               ss.queue(eraseUnderCursor)
+       case keyAltLeft:
+               // move left by a word.
+               if ss.pos == 0 {
+                       return
+               }
+               ss.pos--
+               for ss.pos > 0 {
+                       if ss.line[ss.pos] != ' ' {
+                               break
+                       }
+                       ss.pos--
+               }
+               for ss.pos > 0 {
+                       if ss.line[ss.pos] == ' ' {
+                               ss.pos++
+                               break
+                       }
+                       ss.pos--
+               }
+               ss.moveCursorToPos(ss.pos)
+       case keyAltRight:
+               // move right by a word.
+               for ss.pos < len(ss.line) {
+                       if ss.line[ss.pos] == ' ' {
+                               break
+                       }
+                       ss.pos++
+               }
+               for ss.pos < len(ss.line) {
+                       if ss.line[ss.pos] != ' ' {
+                               break
+                       }
+                       ss.pos++
+               }
+               ss.moveCursorToPos(ss.pos)
+       case keyLeft:
+               if ss.pos == 0 {
+                       return
+               }
+               ss.pos--
+               ss.moveCursorToPos(ss.pos)
+       case keyRight:
+               if ss.pos == len(ss.line) {
+                       return
+               }
+               ss.pos++
+               ss.moveCursorToPos(ss.pos)
+       case keyEnter:
+               ss.moveCursorToPos(len(ss.line))
+               ss.queue([]byte("\r\n"))
+               line = string(ss.line)
+               ok = true
+               ss.line = ss.line[:0]
+               ss.pos = 0
+               ss.cursorX = 0
+               ss.cursorY = 0
+               ss.maxLine = 0
+       default:
+               if !isPrintable(key) {
+                       return
+               }
+               if len(ss.line) == maxLineLength {
+                       return
+               }
+               if len(ss.line) == cap(ss.line) {
+                       newLine := make([]byte, len(ss.line), 2*(1+len(ss.line)))
+                       copy(newLine, ss.line)
+                       ss.line = newLine
+               }
+               ss.line = ss.line[:len(ss.line)+1]
+               copy(ss.line[ss.pos+1:], ss.line[ss.pos:])
+               ss.line[ss.pos] = byte(key)
+               ss.writeLine(ss.line[ss.pos:])
+               ss.pos++
+               ss.moveCursorToPos(ss.pos)
+       }
+       return
+}
+
+func (ss *ServerShell) writeLine(line []byte) {
+       for len(line) != 0 {
+               if ss.cursorX == ss.termWidth {
+                       ss.queue([]byte("\r\n"))
+                       ss.cursorX = 0
+                       ss.cursorY++
+                       if ss.cursorY > ss.maxLine {
+                               ss.maxLine = ss.cursorY
+                       }
+               }
+
+               remainingOnLine := ss.termWidth - ss.cursorX
+               todo := len(line)
+               if todo > remainingOnLine {
+                       todo = remainingOnLine
+               }
+               ss.queue(line[:todo])
+               ss.cursorX += todo
+               line = line[todo:]
+       }
+}
+
+// parsePtyRequest parses the payload of the pty-req message and extracts the
+// dimensions of the terminal. See RFC 4254, section 6.2.
+func parsePtyRequest(s []byte) (width, height int, ok bool) {
+       _, s, ok = parseString(s)
+       if !ok {
+               return
+       }
+       width32, s, ok := parseUint32(s)
+       if !ok {
+               return
+       }
+       height32, _, ok := parseUint32(s)
+       width = int(width32)
+       height = int(height32)
+       if width < 1 {
+               ok = false
+       }
+       if height < 1 {
+               ok = false
+       }
+       return
+}
+
+func (ss *ServerShell) Write(buf []byte) (n int, err os.Error) {
+       return ss.c.Write(buf)
+}
+
+// ReadLine returns a line of input from the terminal.
+func (ss *ServerShell) ReadLine() (line string, err os.Error) {
+       ss.writeLine([]byte(ss.prompt))
+       ss.c.Write(ss.outBuf)
+       ss.outBuf = ss.outBuf[:0]
+
+       for {
+               // ss.remainder is a slice at the beginning of ss.inBuf
+               // containing a partial key sequence
+               readBuf := ss.inBuf[len(ss.remainder):]
+               n, err := ss.c.Read(readBuf)
+               if err == nil {
+                       ss.remainder = ss.inBuf[:n+len(ss.remainder)]
+                       rest := ss.remainder
+                       lineOk := false
+                       for !lineOk {
+                               var key int
+                               key, rest = bytesToKey(rest)
+                               if key < 0 {
+                                       break
+                               }
+                               if key == keyCtrlD {
+                                       return "", os.EOF
+                               }
+                               line, lineOk = ss.handleKey(key)
+                       }
+                       if len(rest) > 0 {
+                               n := copy(ss.inBuf[:], rest)
+                               ss.remainder = ss.inBuf[:n]
+                       } else {
+                               ss.remainder = nil
+                       }
+                       ss.c.Write(ss.outBuf)
+                       ss.outBuf = ss.outBuf[:0]
+                       if lineOk {
+                               return
+                       }
+                       continue
+               }
+
+               if req, ok := err.(ChannelRequest); ok {
+                       ok := false
+                       switch req.Request {
+                       case "pty-req":
+                               ss.termWidth, ss.termHeight, ok = parsePtyRequest(req.Payload)
+                               if !ok {
+                                       ss.termWidth = 80
+                                       ss.termHeight = 24
+                               }
+                       case "shell":
+                               ok = true
+                               if len(req.Payload) > 0 {
+                                       // We don't accept any commands, only the default shell.
+                                       ok = false
+                               }
+                       case "env":
+                               ok = true
+                       }
+                       if req.WantReply {
+                               ss.c.AckRequest(ok)
+                       }
+               } else {
+                       return "", err
+               }
+       }
+       panic("unreachable")
+}
diff --git a/src/pkg/exp/ssh/server_shell_test.go b/src/pkg/exp/ssh/server_shell_test.go
new file mode 100644 (file)
index 0000000..622cf7c
--- /dev/null
@@ -0,0 +1,134 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "testing"
+       "os"
+)
+
+type MockChannel struct {
+       toSend       []byte
+       bytesPerRead int
+       received     []byte
+}
+
+func (c *MockChannel) Accept() os.Error {
+       return nil
+}
+
+func (c *MockChannel) Reject(RejectionReason, string) os.Error {
+       return nil
+}
+
+func (c *MockChannel) Read(data []byte) (n int, err os.Error) {
+       n = len(data)
+       if n == 0 {
+               return
+       }
+       if n > len(c.toSend) {
+               n = len(c.toSend)
+       }
+       if n == 0 {
+               return 0, os.EOF
+       }
+       if c.bytesPerRead > 0 && n > c.bytesPerRead {
+               n = c.bytesPerRead
+       }
+       copy(data, c.toSend[:n])
+       c.toSend = c.toSend[n:]
+       return
+}
+
+func (c *MockChannel) Write(data []byte) (n int, err os.Error) {
+       c.received = append(c.received, data...)
+       return len(data), nil
+}
+
+func (c *MockChannel) Close() os.Error {
+       return nil
+}
+
+func (c *MockChannel) AckRequest(ok bool) os.Error {
+       return nil
+}
+
+func (c *MockChannel) ChannelType() string {
+       return ""
+}
+
+func (c *MockChannel) ExtraData() []byte {
+       return nil
+}
+
+func TestClose(t *testing.T) {
+       c := &MockChannel{}
+       ss := NewServerShell(c, "> ")
+       line, err := ss.ReadLine()
+       if line != "" {
+               t.Errorf("Expected empty line but got: %s", line)
+       }
+       if err != os.EOF {
+               t.Errorf("Error should have been EOF but got: %s", err)
+       }
+}
+
+var keyPressTests = []struct {
+       in   string
+       line string
+       err  os.Error
+}{
+       {
+               "",
+               "",
+               os.EOF,
+       },
+       {
+               "\r",
+               "",
+               nil,
+       },
+       {
+               "foo\r",
+               "foo",
+               nil,
+       },
+       {
+               "a\x1b[Cb\r", // right
+               "ab",
+               nil,
+       },
+       {
+               "a\x1b[Db\r", // left
+               "ba",
+               nil,
+       },
+       {
+               "a\177b\r", // backspace
+               "b",
+               nil,
+       },
+}
+
+func TestKeyPresses(t *testing.T) {
+       for i, test := range keyPressTests {
+               for j := 0; j < len(test.in); j++ {
+                       c := &MockChannel{
+                               toSend:       []byte(test.in),
+                               bytesPerRead: j,
+                       }
+                       ss := NewServerShell(c, "> ")
+                       line, err := ss.ReadLine()
+                       if line != test.line {
+                               t.Errorf("Line resulting from test %d (%d bytes per read) was '%s', expected '%s'", i, j, line, test.line)
+                               break
+                       }
+                       if err != test.err {
+                               t.Errorf("Error resulting from test %d (%d bytes per read) was '%v', expected '%v'", i, j, err, test.err)
+                               break
+                       }
+               }
+       }
+}
diff --git a/src/pkg/exp/ssh/transport.go b/src/pkg/exp/ssh/transport.go
new file mode 100644 (file)
index 0000000..919759f
--- /dev/null
@@ -0,0 +1,308 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ssh
+
+import (
+       "bufio"
+       "crypto"
+       "crypto/aes"
+       "crypto/cipher"
+       "crypto/hmac"
+       "crypto/subtle"
+       "hash"
+       "io"
+       "net"
+       "os"
+)
+
+// halfConnection represents one direction of an SSH connection. It maintains
+// the cipher state needed to process messages.
+type halfConnection struct {
+       // Only one of these two will be non-nil
+       in  *bufio.Reader
+       out net.Conn
+
+       rand            io.Reader
+       cipherAlgo      string
+       macAlgo         string
+       compressionAlgo string
+       paddingMultiple int
+
+       seqNum uint32
+
+       mac    hash.Hash
+       cipher cipher.Stream
+}
+
+func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
+       var lengthBytes [5]byte
+
+       _, err = io.ReadFull(hc.in, lengthBytes[:])
+       if err != nil {
+               return
+       }
+
+       if hc.cipher != nil {
+               hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
+       }
+
+       macSize := 0
+       if hc.mac != nil {
+               hc.mac.Reset()
+               var seqNumBytes [4]byte
+               seqNumBytes[0] = byte(hc.seqNum >> 24)
+               seqNumBytes[1] = byte(hc.seqNum >> 16)
+               seqNumBytes[2] = byte(hc.seqNum >> 8)
+               seqNumBytes[3] = byte(hc.seqNum)
+               hc.mac.Write(seqNumBytes[:])
+               hc.mac.Write(lengthBytes[:])
+               macSize = hc.mac.Size()
+       }
+
+       length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
+
+       paddingLength := uint32(lengthBytes[4])
+
+       if length <= paddingLength+1 {
+               return nil, os.NewError("invalid packet length")
+       }
+       if length > maxPacketSize {
+               return nil, os.NewError("packet too large")
+       }
+
+       packet = make([]byte, length-1+uint32(macSize))
+       _, err = io.ReadFull(hc.in, packet)
+       if err != nil {
+               return nil, err
+       }
+       mac := packet[length-1:]
+       if hc.cipher != nil {
+               hc.cipher.XORKeyStream(packet, packet[:length-1])
+       }
+
+       if hc.mac != nil {
+               hc.mac.Write(packet[:length-1])
+               if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
+                       return nil, os.NewError("ssh: MAC failure")
+               }
+       }
+
+       hc.seqNum++
+       packet = packet[:length-paddingLength-1]
+       return
+}
+
+func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
+       for {
+               packet, err := hc.readOnePacket()
+               if err != nil {
+                       return nil, err
+               }
+               if packet[0] != msgIgnore && packet[0] != msgDebug {
+                       return packet, nil
+               }
+       }
+       panic("unreachable")
+}
+
+func (hc *halfConnection) writePacket(packet []byte) os.Error {
+       paddingMultiple := hc.paddingMultiple
+       if paddingMultiple == 0 {
+               paddingMultiple = 8
+       }
+
+       paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
+       if paddingLength < 4 {
+               paddingLength += paddingMultiple
+       }
+
+       var lengthBytes [5]byte
+       length := len(packet) + 1 + paddingLength
+       lengthBytes[0] = byte(length >> 24)
+       lengthBytes[1] = byte(length >> 16)
+       lengthBytes[2] = byte(length >> 8)
+       lengthBytes[3] = byte(length)
+       lengthBytes[4] = byte(paddingLength)
+
+       var padding [32]byte
+       _, err := io.ReadFull(hc.rand, padding[:paddingLength])
+       if err != nil {
+               return err
+       }
+
+       if hc.mac != nil {
+               hc.mac.Reset()
+               var seqNumBytes [4]byte
+               seqNumBytes[0] = byte(hc.seqNum >> 24)
+               seqNumBytes[1] = byte(hc.seqNum >> 16)
+               seqNumBytes[2] = byte(hc.seqNum >> 8)
+               seqNumBytes[3] = byte(hc.seqNum)
+               hc.mac.Write(seqNumBytes[:])
+               hc.mac.Write(lengthBytes[:])
+               hc.mac.Write(packet)
+               hc.mac.Write(padding[:paddingLength])
+       }
+
+       if hc.cipher != nil {
+               hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
+               hc.cipher.XORKeyStream(packet, packet)
+               hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
+       }
+
+       _, err = hc.out.Write(lengthBytes[:])
+       if err != nil {
+               return err
+       }
+       _, err = hc.out.Write(packet)
+       if err != nil {
+               return err
+       }
+       _, err = hc.out.Write(padding[:paddingLength])
+       if err != nil {
+               return err
+       }
+
+       if hc.mac != nil {
+               _, err = hc.out.Write(hc.mac.Sum())
+       }
+
+       hc.seqNum++
+
+       return err
+}
+
+const (
+       serverKeys = iota
+       clientKeys
+)
+
+// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
+// described in RFC 4253, section 6.4. direction should either be serverKeys
+// (to setup server->client keys) or clientKeys (for client->server keys).
+func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
+       h := hashFunc.New()
+
+       // We only support these algorithms for now.
+       if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
+               return os.NewError("ssh: setupServerKeys internal error")
+       }
+
+       blockSize := 16
+       keySize := 16
+       macKeySize := 20
+
+       var ivTag, keyTag, macKeyTag byte
+       if direction == serverKeys {
+               ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
+       } else {
+               ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
+       }
+
+       iv := make([]byte, blockSize)
+       key := make([]byte, keySize)
+       macKey := make([]byte, macKeySize)
+       generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
+       generateKeyMaterial(key, keyTag, K, H, sessionId, h)
+       generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
+
+       hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
+       aes, err := aes.NewCipher(key)
+       if err != nil {
+               return err
+       }
+       hc.cipher = cipher.NewCTR(aes, iv)
+       hc.paddingMultiple = 16
+       return nil
+}
+
+// generateKeyMaterial fills out with key material generated from tag, K, H
+// and sessionId, as specified in RFC 4253, section 7.2.
+func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
+       var digestsSoFar []byte
+
+       for len(out) > 0 {
+               h.Reset()
+               h.Write(K)
+               h.Write(H)
+
+               if len(digestsSoFar) == 0 {
+                       h.Write([]byte{tag})
+                       h.Write(sessionId)
+               } else {
+                       h.Write(digestsSoFar)
+               }
+
+               digest := h.Sum()
+               n := copy(out, digest)
+               out = out[n:]
+               if len(out) > 0 {
+                       digestsSoFar = append(digestsSoFar, digest...)
+               }
+       }
+}
+
+// truncatingMAC wraps around a hash.Hash and truncates the output digest to
+// a given size.
+type truncatingMAC struct {
+       length int
+       hmac   hash.Hash
+}
+
+func (t truncatingMAC) Write(data []byte) (int, os.Error) {
+       return t.hmac.Write(data)
+}
+
+func (t truncatingMAC) Sum() []byte {
+       digest := t.hmac.Sum()
+       return digest[:t.length]
+}
+
+func (t truncatingMAC) Reset() {
+       t.hmac.Reset()
+}
+
+func (t truncatingMAC) Size() int {
+       return t.length
+}
+
+// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
+// version string. In the event that the client is talking a different protocol
+// we need to set a limit otherwise we will keep using more and more memory
+// while searching for the end of the version handshake.
+const maxVersionStringBytes = 1024
+
+func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
+       versionString = make([]byte, 0, 64)
+       seenCR := false
+
+forEachByte:
+       for len(versionString) < maxVersionStringBytes {
+               b, err := r.ReadByte()
+               if err != nil {
+                       return
+               }
+
+               if !seenCR {
+                       if b == '\r' {
+                               seenCR = true
+                       }
+               } else {
+                       if b == '\n' {
+                               ok = true
+                               break forEachByte
+                       } else {
+                               seenCR = false
+                       }
+               }
+               versionString = append(versionString, b)
+       }
+
+       if ok {
+               // We need to remove the CR from versionString
+               versionString = versionString[:len(versionString)-1]
+       }
+
+       return
+}