type ServerConnection struct {
Server *Server
- in, out *halfConnection
+ *transport
channels map[uint32]*channel
nextChanId uint32
// kexDH performs Diffie-Hellman key agreement on a ServerConnection. The
// returned values are given the same names as in RFC 4253, section 8.
func (s *ServerConnection) kexDH(group *dhGroup, hashFunc crypto.Hash, magics *handshakeMagics, hostKeyAlgo string) (H, K []byte, err os.Error) {
- packet, err := s.in.readPacket()
+ packet, err := s.readPacket()
if err != nil {
return
}
}
packet = marshal(msgKexDHReply, kexDHReply)
- err = s.out.writePacket(packet)
+ err = s.writePacket(packet)
return
}
// Handshake performs an SSH transport and client authentication on the given ServerConnection.
func (s *ServerConnection) Handshake(conn net.Conn) os.Error {
var magics handshakeMagics
- inBuf := bufio.NewReader(conn)
-
- _, err := conn.Write(serverVersion)
- if err != nil {
+ s.transport = &transport{
+ reader: reader{
+ Reader: bufio.NewReader(conn),
+ },
+ writer: writer{
+ Writer: bufio.NewWriter(conn),
+ rand: rand.Reader,
+ },
+ Close: func() os.Error {
+ return conn.Close()
+ },
+ }
+
+ if _, err := conn.Write(serverVersion); err != nil {
return err
}
-
magics.serverVersion = serverVersion[:len(serverVersion)-2]
+
+ version, ok := readVersion(s.transport)
+ if !ok {
+ return os.NewError("failed to read version string from client")
+ }
+ magics.clientVersion = version
+
serverKexInit := kexInitMsg{
KexAlgos: supportedKexAlgos,
ServerHostKeyAlgos: supportedHostKeyAlgos,
kexInitPacket := marshal(msgKexInit, serverKexInit)
magics.serverKexInit = kexInitPacket
- var out halfConnection
- out.out = conn
- out.rand = rand.Reader
- s.out = &out
- err = out.writePacket(kexInitPacket)
- if err != nil {
+ if err := s.writePacket(kexInitPacket); err != nil {
return err
}
- version, ok := readVersion(inBuf)
- if !ok {
- return os.NewError("failed to read version string from client")
- }
- magics.clientVersion = version
-
- var in halfConnection
- in.in = inBuf
- s.in = &in
- packet, err := in.readPacket()
+ packet, err := s.readPacket()
if err != nil {
return err
}
+
magics.clientKexInit = packet
var clientKexInit kexInitMsg
return err
}
- kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(&in, &out, &clientKexInit, &serverKexInit)
+ kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, s.transport, &clientKexInit, &serverKexInit)
if !ok {
return os.NewError("ssh: no common algorithms")
}
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
- _, err := in.readPacket()
+ _, err := s.readPacket()
if err != nil {
return err
}
}
packet = []byte{msgNewKeys}
- if err = out.writePacket(packet); err != nil {
+ if err = s.writePacket(packet); err != nil {
return err
}
- if err = out.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
+ if err = s.transport.writer.setupKeys(serverKeys, K, H, H, hashFunc); err != nil {
return err
}
- if packet, err = in.readPacket(); err != nil {
+ if packet, err = s.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
- in.setupKeys(clientKeys, K, H, H, hashFunc)
+ s.transport.reader.setupKeys(clientKeys, K, H, H, hashFunc)
- packet, err = in.readPacket()
+ packet, err = s.readPacket()
if err != nil {
return err
}
Service: serviceUserAuth,
}
packet = marshal(msgServiceAccept, serviceAccept)
- if err = out.writePacket(packet); err != nil {
+ if err = s.writePacket(packet); err != nil {
return err
}
userAuthLoop:
for {
- if packet, err = s.in.readPacket(); err != nil {
+ if packet, err = s.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
Algo: algo,
PubKey: string(pubKey),
}
- if err = s.out.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
+ if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
return os.NewError("ssh: no authentication methods configured but NoClientAuth is also false")
}
- if err = s.out.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
+ if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
- if err = s.out.writePacket(packet); err != nil {
+ if err = s.writePacket(packet); err != nil {
return err
}
}
for {
- packet, err := s.in.readPacket()
+ packet, err := s.readPacket()
if err != nil {
s.lock.Lock()
}
if request.WantReply {
- if err := s.out.writePacket([]byte{msgRequestFailure}); err != nil {
+ if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
"crypto/subtle"
"hash"
"io"
- "net"
"os"
)
-// halfConnection represents one direction of an SSH connection. It maintains
-// the cipher state needed to process messages.
-type halfConnection struct {
- // Only one of these two will be non-nil
- in *bufio.Reader
- out net.Conn
+const (
+ paddingMultiple = 16 // TODO(dfc) does this need to be configurable?
+)
+
+// transport represents the SSH connection to the remote peer.
+type transport struct {
+ reader
+ writer
- rand io.Reader
cipherAlgo string
macAlgo string
compressionAlgo string
+
+ Close func() os.Error
+}
+
+// reader represents the incoming connection state.
+type reader struct {
+ io.Reader
+ common
+}
+
+// writer represnts the outgoing connection state.
+type writer struct {
+ *bufio.Writer
paddingMultiple int
+ rand io.Reader
+ common
+}
+// common represents the cipher state needed to process messages in a single
+// direction.
+type common struct {
seqNum uint32
-
mac hash.Hash
cipher cipher.Stream
}
-func (hc *halfConnection) readOnePacket() (packet []byte, err os.Error) {
- var lengthBytes [5]byte
+// Read and decrypt a single packet from the remote peer.
+func (r *reader) readOnePacket() ([]byte, os.Error) {
+ var lengthBytes = make([]byte, 5)
+ var macSize uint32
- _, err = io.ReadFull(hc.in, lengthBytes[:])
- if err != nil {
- return
+ if _, err := io.ReadFull(r, lengthBytes); err != nil {
+ return nil, err
}
- if hc.cipher != nil {
- hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
+ if r.cipher != nil {
+ r.cipher.XORKeyStream(lengthBytes, lengthBytes)
}
- macSize := 0
- if hc.mac != nil {
- hc.mac.Reset()
- var seqNumBytes [4]byte
- seqNumBytes[0] = byte(hc.seqNum >> 24)
- seqNumBytes[1] = byte(hc.seqNum >> 16)
- seqNumBytes[2] = byte(hc.seqNum >> 8)
- seqNumBytes[3] = byte(hc.seqNum)
- hc.mac.Write(seqNumBytes[:])
- hc.mac.Write(lengthBytes[:])
- macSize = hc.mac.Size()
+ if r.mac != nil {
+ r.mac.Reset()
+ seqNumBytes := []byte{
+ byte(r.seqNum >> 24),
+ byte(r.seqNum >> 16),
+ byte(r.seqNum >> 8),
+ byte(r.seqNum),
+ }
+ r.mac.Write(seqNumBytes)
+ r.mac.Write(lengthBytes)
+ macSize = uint32(r.mac.Size())
}
length := uint32(lengthBytes[0])<<24 | uint32(lengthBytes[1])<<16 | uint32(lengthBytes[2])<<8 | uint32(lengthBytes[3])
-
paddingLength := uint32(lengthBytes[4])
if length <= paddingLength+1 {
return nil, os.NewError("packet too large")
}
- packet = make([]byte, length-1+uint32(macSize))
- _, err = io.ReadFull(hc.in, packet)
- if err != nil {
+ packet := make([]byte, length-1+macSize)
+ if _, err := io.ReadFull(r, packet); err != nil {
return nil, err
}
mac := packet[length-1:]
- if hc.cipher != nil {
- hc.cipher.XORKeyStream(packet, packet[:length-1])
+ if r.cipher != nil {
+ r.cipher.XORKeyStream(packet, packet[:length-1])
}
- if hc.mac != nil {
- hc.mac.Write(packet[:length-1])
- if subtle.ConstantTimeCompare(hc.mac.Sum(), mac) != 1 {
+ if r.mac != nil {
+ r.mac.Write(packet[:length-1])
+ if subtle.ConstantTimeCompare(r.mac.Sum(), mac) != 1 {
return nil, os.NewError("ssh: MAC failure")
}
}
- hc.seqNum++
- packet = packet[:length-paddingLength-1]
- return
+ r.seqNum++
+ return packet[:length-paddingLength-1], nil
}
-func (hc *halfConnection) readPacket() (packet []byte, err os.Error) {
+// Read and decrypt next packet discarding debug and noop messages.
+func (t *transport) readPacket() ([]byte, os.Error) {
for {
- packet, err := hc.readOnePacket()
+ packet, err := t.readOnePacket()
if err != nil {
return nil, err
}
panic("unreachable")
}
-func (hc *halfConnection) writePacket(packet []byte) os.Error {
- paddingMultiple := hc.paddingMultiple
- if paddingMultiple == 0 {
- paddingMultiple = 8
- }
-
- paddingLength := paddingMultiple - (4+1+len(packet))%paddingMultiple
+// Encrypt and send a packet of data to the remote peer.
+func (w *writer) writePacket(packet []byte) os.Error {
+ paddingLength := paddingMultiple - (5+len(packet))%paddingMultiple
if paddingLength < 4 {
paddingLength += paddingMultiple
}
- var lengthBytes [5]byte
length := len(packet) + 1 + paddingLength
- lengthBytes[0] = byte(length >> 24)
- lengthBytes[1] = byte(length >> 16)
- lengthBytes[2] = byte(length >> 8)
- lengthBytes[3] = byte(length)
- lengthBytes[4] = byte(paddingLength)
-
- var padding [32]byte
- _, err := io.ReadFull(hc.rand, padding[:paddingLength])
+ lengthBytes := []byte{
+ byte(length >> 24),
+ byte(length >> 16),
+ byte(length >> 8),
+ byte(length),
+ byte(paddingLength),
+ }
+ padding := make([]byte, paddingLength)
+ _, err := io.ReadFull(w.rand, padding)
if err != nil {
return err
}
- if hc.mac != nil {
- hc.mac.Reset()
- var seqNumBytes [4]byte
- seqNumBytes[0] = byte(hc.seqNum >> 24)
- seqNumBytes[1] = byte(hc.seqNum >> 16)
- seqNumBytes[2] = byte(hc.seqNum >> 8)
- seqNumBytes[3] = byte(hc.seqNum)
- hc.mac.Write(seqNumBytes[:])
- hc.mac.Write(lengthBytes[:])
- hc.mac.Write(packet)
- hc.mac.Write(padding[:paddingLength])
+ if w.mac != nil {
+ w.mac.Reset()
+ seqNumBytes := []byte{
+ byte(w.seqNum >> 24),
+ byte(w.seqNum >> 16),
+ byte(w.seqNum >> 8),
+ byte(w.seqNum),
+ }
+ w.mac.Write(seqNumBytes)
+ w.mac.Write(lengthBytes)
+ w.mac.Write(packet)
+ w.mac.Write(padding)
}
- if hc.cipher != nil {
- hc.cipher.XORKeyStream(lengthBytes[:], lengthBytes[:])
- hc.cipher.XORKeyStream(packet, packet)
- hc.cipher.XORKeyStream(padding[:], padding[:paddingLength])
+ // TODO(dfc) lengthBytes, packet and padding should be
+ // subslices of a single buffer
+ if w.cipher != nil {
+ w.cipher.XORKeyStream(lengthBytes, lengthBytes)
+ w.cipher.XORKeyStream(packet, packet)
+ w.cipher.XORKeyStream(padding, padding)
}
- _, err = hc.out.Write(lengthBytes[:])
- if err != nil {
+ if _, err := w.Write(lengthBytes); err != nil {
return err
}
- _, err = hc.out.Write(packet)
- if err != nil {
+ if _, err := w.Write(packet); err != nil {
return err
}
- _, err = hc.out.Write(padding[:paddingLength])
- if err != nil {
+ if _, err := w.Write(padding); err != nil {
return err
}
- if hc.mac != nil {
- _, err = hc.out.Write(hc.mac.Sum())
+ if w.mac != nil {
+ if _, err := w.Write(w.mac.Sum()); err != nil {
+ return err
+ }
}
- hc.seqNum++
-
+ if err := w.Flush(); err != nil {
+ return err
+ }
+ w.seqNum++
return err
}
-const (
- serverKeys = iota
- clientKeys
+type direction struct {
+ ivTag []byte
+ keyTag []byte
+ macKeyTag []byte
+}
+
+// TODO(dfc) can this be made a constant ?
+var (
+ serverKeys = direction{[]byte{'B'}, []byte{'D'}, []byte{'F'}}
+ clientKeys = direction{[]byte{'A'}, []byte{'C'}, []byte{'E'}}
)
-// setupServerKeys sets the cipher and MAC keys from K, H and sessionId, as
+// setupKeys sets the cipher and MAC keys from K, H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
-func (hc *halfConnection) setupKeys(direction int, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
+func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) os.Error {
h := hashFunc.New()
- // We only support these algorithms for now.
- if hc.cipherAlgo != cipherAES128CTR || hc.macAlgo != macSHA196 {
- return os.NewError("ssh: setupServerKeys internal error")
- }
-
blockSize := 16
keySize := 16
macKeySize := 20
- var ivTag, keyTag, macKeyTag byte
- if direction == serverKeys {
- ivTag, keyTag, macKeyTag = 'B', 'D', 'F'
- } else {
- ivTag, keyTag, macKeyTag = 'A', 'C', 'E'
- }
-
iv := make([]byte, blockSize)
key := make([]byte, keySize)
macKey := make([]byte, macKeySize)
- generateKeyMaterial(iv, ivTag, K, H, sessionId, h)
- generateKeyMaterial(key, keyTag, K, H, sessionId, h)
- generateKeyMaterial(macKey, macKeyTag, K, H, sessionId, h)
+ generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
+ generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
+ generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
- hc.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
+ c.mac = truncatingMAC{12, hmac.NewSHA1(macKey)}
aes, err := aes.NewCipher(key)
if err != nil {
return err
}
- hc.cipher = cipher.NewCTR(aes, iv)
- hc.paddingMultiple = 16
+ c.cipher = cipher.NewCTR(aes, iv)
return nil
}
// generateKeyMaterial fills out with key material generated from tag, K, H
// and sessionId, as specified in RFC 4253, section 7.2.
-func generateKeyMaterial(out []byte, tag byte, K, H, sessionId []byte, h hash.Hash) {
+func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
var digestsSoFar []byte
for len(out) > 0 {
h.Write(H)
if len(digestsSoFar) == 0 {
- h.Write([]byte{tag})
+ h.Write(tag)
h.Write(sessionId)
} else {
h.Write(digestsSoFar)
// while searching for the end of the version handshake.
const maxVersionStringBytes = 1024
-func readVersion(r *bufio.Reader) (versionString []byte, ok bool) {
+// Read version string as specified by RFC 4253, section 4.2.
+func readVersion(r io.Reader) (versionString []byte, ok bool) {
versionString = make([]byte, 0, 64)
seenCR := false
+ var buf [1]byte
forEachByte:
for len(versionString) < maxVersionStringBytes {
- b, err := r.ReadByte()
+ _, err := io.ReadFull(r, buf[:])
if err != nil {
return
}
+ b := buf[0]
if !seenCR {
if b == '\r' {