]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: simpler implementation of record layer
authorRuss Cox <rsc@golang.org>
Tue, 27 Apr 2010 05:19:04 +0000 (22:19 -0700)
committerRuss Cox <rsc@golang.org>
Tue, 27 Apr 2010 05:19:04 +0000 (22:19 -0700)
Depends on CL 957045, 980043, 1004043.
Fixes #715.

R=agl1, agl
CC=golang-dev
https://golang.org/cl/943043

15 files changed:
src/pkg/crypto/tls/Makefile
src/pkg/crypto/tls/alert.go
src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/conn.go [new file with mode: 0644]
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_messages.go
src/pkg/crypto/tls/handshake_messages_test.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/handshake_server_test.go
src/pkg/crypto/tls/record_process.go [deleted file]
src/pkg/crypto/tls/record_process_test.go [deleted file]
src/pkg/crypto/tls/record_read.go [deleted file]
src/pkg/crypto/tls/record_read_test.go [deleted file]
src/pkg/crypto/tls/record_write.go [deleted file]
src/pkg/crypto/tls/tls.go

index 55c9d87cf9e3b4ced377f0702e4889c9d914928f..5e25bd43adc846eff484940a301e3dad7d64e18b 100644 (file)
@@ -7,15 +7,13 @@ include ../../../Make.$(GOARCH)
 TARG=crypto/tls
 GOFILES=\
        alert.go\
+       ca_set.go\
        common.go\
+       conn.go\
        handshake_client.go\
        handshake_messages.go\
        handshake_server.go\
        prf.go\
-       record_process.go\
-       record_read.go\
-       record_write.go\
-       ca_set.go\
        tls.go\
 
 include ../../../Make.pkg
index 2f740b39ea2015b8c8468816257d495539e19631..3b9e0e2415b2afd22e9b52c2186fc4d51d6605f9 100644 (file)
@@ -4,40 +4,70 @@
 
 package tls
 
-type alertLevel int
-type alertType int
+import "strconv"
+
+type alert uint8
 
 const (
-       alertLevelWarning alertLevel = 1
-       alertLevelError   alertLevel = 2
+       // alert level
+       alertLevelWarning = 1
+       alertLevelError   = 2
 )
 
 const (
-       alertCloseNotify            alertType = 0
-       alertUnexpectedMessage      alertType = 10
-       alertBadRecordMAC           alertType = 20
-       alertDecryptionFailed       alertType = 21
-       alertRecordOverflow         alertType = 22
-       alertDecompressionFailure   alertType = 30
-       alertHandshakeFailure       alertType = 40
-       alertBadCertificate         alertType = 42
-       alertUnsupportedCertificate alertType = 43
-       alertCertificateRevoked     alertType = 44
-       alertCertificateExpired     alertType = 45
-       alertCertificateUnknown     alertType = 46
-       alertIllegalParameter       alertType = 47
-       alertUnknownCA              alertType = 48
-       alertAccessDenied           alertType = 49
-       alertDecodeError            alertType = 50
-       alertDecryptError           alertType = 51
-       alertProtocolVersion        alertType = 70
-       alertInsufficientSecurity   alertType = 71
-       alertInternalError          alertType = 80
-       alertUserCanceled           alertType = 90
-       alertNoRenegotiation        alertType = 100
+       alertCloseNotify            alert = 0
+       alertUnexpectedMessage      alert = 10
+       alertBadRecordMAC           alert = 20
+       alertDecryptionFailed       alert = 21
+       alertRecordOverflow         alert = 22
+       alertDecompressionFailure   alert = 30
+       alertHandshakeFailure       alert = 40
+       alertBadCertificate         alert = 42
+       alertUnsupportedCertificate alert = 43
+       alertCertificateRevoked     alert = 44
+       alertCertificateExpired     alert = 45
+       alertCertificateUnknown     alert = 46
+       alertIllegalParameter       alert = 47
+       alertUnknownCA              alert = 48
+       alertAccessDenied           alert = 49
+       alertDecodeError            alert = 50
+       alertDecryptError           alert = 51
+       alertProtocolVersion        alert = 70
+       alertInsufficientSecurity   alert = 71
+       alertInternalError          alert = 80
+       alertUserCanceled           alert = 90
+       alertNoRenegotiation        alert = 100
 )
 
-type alert struct {
-       level alertLevel
-       error alertType
+var alertText = map[alert]string{
+       alertCloseNotify:            "close notify",
+       alertUnexpectedMessage:      "unexpected message",
+       alertBadRecordMAC:           "bad record MAC",
+       alertDecryptionFailed:       "decryption failed",
+       alertRecordOverflow:         "record overflow",
+       alertDecompressionFailure:   "decompression failure",
+       alertHandshakeFailure:       "handshake failure",
+       alertBadCertificate:         "bad certificate",
+       alertUnsupportedCertificate: "unsupported certificate",
+       alertCertificateRevoked:     "revoked certificate",
+       alertCertificateExpired:     "expired certificate",
+       alertCertificateUnknown:     "unknown certificate",
+       alertIllegalParameter:       "illegal parameter",
+       alertUnknownCA:              "unknown certificate authority",
+       alertAccessDenied:           "access denied",
+       alertDecodeError:            "error decoding message",
+       alertDecryptError:           "error decrypting message",
+       alertProtocolVersion:        "protocol version not supported",
+       alertInsufficientSecurity:   "insufficient security level",
+       alertInternalError:          "internal error",
+       alertUserCanceled:           "user canceled",
+       alertNoRenegotiation:        "no renegotiation",
+}
+
+func (e alert) String() string {
+       s, ok := alertText[e]
+       if ok {
+               return s
+       }
+       return "alert(" + strconv.Itoa(int(e)) + ")"
 }
index ef54a1db76363e29f07761915c20086dc417e068..56c22cf7d8ea76082293d3d395661859ef7e880a 100644 (file)
@@ -10,22 +10,18 @@ import (
        "io"
        "io/ioutil"
        "once"
-       "os"
        "time"
 )
 
 const (
-       // maxTLSCiphertext is the maximum length of a plaintext payload.
-       maxTLSPlaintext = 16384
-       // maxTLSCiphertext is the maximum length payload after compression and encryption.
-       maxTLSCiphertext = 16384 + 2048
-       // maxHandshakeMsg is the largest single handshake message that we'll buffer.
-       maxHandshakeMsg = 65536
-       // defaultMajor and defaultMinor are the maximum TLS version that we support.
-       defaultMajor = 3
-       defaultMinor = 2
-)
+       maxPlaintext    = 16384        // maximum plaintext payload length
+       maxCiphertext   = 16384 + 2048 // maximum ciphertext payload length
+       recordHeaderLen = 5            // record header length
+       maxHandshake    = 65536        // maximum handshake we support (protocol max is 16 MB)
 
+       minVersion = 0x0301 // minimum supported version - TLS 1.0
+       maxVersion = 0x0302 // maximum supported version - TLS 1.1
+)
 
 // TLS record types.
 type recordType uint8
@@ -67,7 +63,7 @@ var (
 type ConnectionState struct {
        HandshakeComplete  bool
        CipherSuite        string
-       Error              alertType
+       Error              alert
        NegotiatedProtocol string
 }
 
@@ -99,6 +95,7 @@ type record struct {
 
 type handshakeMessage interface {
        marshal() []byte
+       unmarshal([]byte) bool
 }
 
 type encryptor interface {
@@ -108,34 +105,16 @@ type encryptor interface {
 
 // mutualVersion returns the protocol version to use given the advertised
 // version of the peer.
-func mutualVersion(theirMajor, theirMinor uint8) (major, minor uint8, ok bool) {
-       // We don't deal with peers < TLS 1.0 (aka version 3.1).
-       if theirMajor < 3 || theirMajor == 3 && theirMinor < 1 {
-               return 0, 0, false
+func mutualVersion(vers uint16) (uint16, bool) {
+       if vers < minVersion {
+               return 0, false
        }
-       major = 3
-       minor = 2
-       if theirMinor < minor {
-               minor = theirMinor
+       if vers > maxVersion {
+               vers = maxVersion
        }
-       ok = true
-       return
+       return vers, true
 }
 
-// A nop implements the NULL encryption and MAC algorithms.
-type nop struct{}
-
-func (nop) XORKeyStream(buf []byte) {}
-
-func (nop) Write(buf []byte) (int, os.Error) { return len(buf), nil }
-
-func (nop) Sum() []byte { return nil }
-
-func (nop) Reset() {}
-
-func (nop) Size() int { return 0 }
-
-
 // The defaultConfig is used in place of a nil *Config in the TLS server and client.
 var varDefaultConfig *Config
 
diff --git a/src/pkg/crypto/tls/conn.go b/src/pkg/crypto/tls/conn.go
new file mode 100644 (file)
index 0000000..d0e8464
--- /dev/null
@@ -0,0 +1,635 @@
+// TLS low level connection and record layer
+
+package tls
+
+import (
+       "bytes"
+       "crypto/subtle"
+       "hash"
+       "io"
+       "net"
+       "os"
+       "sync"
+)
+
+// A Conn represents a secured connection.
+// It implements the net.Conn interface.
+type Conn struct {
+       // constant
+       conn     net.Conn
+       isClient bool
+
+       // constant after handshake; protected by handshakeMutex
+       handshakeMutex    sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex
+       vers              uint16     // TLS version
+       haveVers          bool       // version has been negotiated
+       config            *Config    // configuration passed to constructor
+       handshakeComplete bool
+       cipherSuite       uint16
+
+       clientProtocol string
+
+       // first permanent error
+       errMutex sync.Mutex
+       err      os.Error
+
+       // input/output
+       in, out  halfConn     // in.Mutex < out.Mutex
+       rawInput *block       // raw input, right off the wire
+       input    *block       // application data waiting to be read
+       hand     bytes.Buffer // handshake data waiting to be read
+
+       tmp [16]byte
+}
+
+func (c *Conn) setError(err os.Error) os.Error {
+       c.errMutex.Lock()
+       defer c.errMutex.Unlock()
+
+       if c.err == nil {
+               c.err = err
+       }
+       return err
+}
+
+func (c *Conn) error() os.Error {
+       c.errMutex.Lock()
+       defer c.errMutex.Unlock()
+
+       return c.err
+}
+
+// Access to net.Conn methods.
+// Cannot just embed net.Conn because that would
+// export the struct field too.
+
+// LocalAddr returns the local network address.
+func (c *Conn) LocalAddr() net.Addr {
+       return c.conn.LocalAddr()
+}
+
+// RemoteAddr returns the remote network address.
+func (c *Conn) RemoteAddr() net.Addr {
+       return c.conn.RemoteAddr()
+}
+
+// SetTimeout sets the read deadline associated with the connection.
+// There is no write deadline.
+func (c *Conn) SetTimeout(nsec int64) os.Error {
+       return c.conn.SetTimeout(nsec)
+}
+
+// SetReadTimeout sets the time (in nanoseconds) that
+// Read will wait for data before returning os.EAGAIN.
+// Setting nsec == 0 (the default) disables the deadline.
+func (c *Conn) SetReadTimeout(nsec int64) os.Error {
+       return c.conn.SetReadTimeout(nsec)
+}
+
+// SetWriteTimeout exists to satisfy the net.Conn interface
+// but is not implemented by TLS.  It always returns an error.
+func (c *Conn) SetWriteTimeout(nsec int64) os.Error {
+       return os.NewError("TLS does not support SetWriteTimeout")
+}
+
+// A halfConn represents one direction of the record layer
+// connection, either sending or receiving.
+type halfConn struct {
+       sync.Mutex
+       crypt encryptor // encryption state
+       mac   hash.Hash // MAC algorithm
+       seq   [8]byte   // 64-bit sequence number
+       bfree *block    // list of free blocks
+
+       nextCrypt encryptor // next encryption state
+       nextMac   hash.Hash // next MAC algorithm
+}
+
+// prepareCipherSpec sets the encryption and MAC states
+// that a subsequent changeCipherSpec will use.
+func (hc *halfConn) prepareCipherSpec(crypt encryptor, mac hash.Hash) {
+       hc.nextCrypt = crypt
+       hc.nextMac = mac
+}
+
+// changeCipherSpec changes the encryption and MAC states
+// to the ones previously passed to prepareCipherSpec.
+func (hc *halfConn) changeCipherSpec() os.Error {
+       if hc.nextCrypt == nil {
+               return alertInternalError
+       }
+       hc.crypt = hc.nextCrypt
+       hc.mac = hc.nextMac
+       hc.nextCrypt = nil
+       hc.nextMac = nil
+       return nil
+}
+
+// incSeq increments the sequence number.
+func (hc *halfConn) incSeq() {
+       for i := 7; i >= 0; i-- {
+               hc.seq[i]++
+               if hc.seq[i] != 0 {
+                       return
+               }
+       }
+
+       // Not allowed to let sequence number wrap.
+       // Instead, must renegotiate before it does.
+       // Not likely enough to bother.
+       panic("TLS: sequence number wraparound")
+}
+
+// resetSeq resets the sequence number to zero.
+func (hc *halfConn) resetSeq() {
+       for i := range hc.seq {
+               hc.seq[i] = 0
+       }
+}
+
+// decrypt checks and strips the mac and decrypts the data in b.
+func (hc *halfConn) decrypt(b *block) (bool, alert) {
+       // pull out payload
+       payload := b.data[recordHeaderLen:]
+
+       // decrypt
+       if hc.crypt != nil {
+               hc.crypt.XORKeyStream(payload)
+       }
+
+       // check, strip mac
+       if hc.mac != nil {
+               if len(payload) < hc.mac.Size() {
+                       return false, alertBadRecordMAC
+               }
+
+               // strip mac off payload, b.data
+               n := len(payload) - hc.mac.Size()
+               b.data[3] = byte(n >> 8)
+               b.data[4] = byte(n)
+               b.data = b.data[0 : recordHeaderLen+n]
+               remoteMAC := payload[n:]
+
+               hc.mac.Reset()
+               hc.mac.Write(&hc.seq)
+               hc.incSeq()
+               hc.mac.Write(b.data)
+
+               if subtle.ConstantTimeCompare(hc.mac.Sum(), remoteMAC) != 1 {
+                       return false, alertBadRecordMAC
+               }
+       }
+
+       return true, 0
+}
+
+// encrypt encrypts and macs the data in b.
+func (hc *halfConn) encrypt(b *block) (bool, alert) {
+       // mac
+       if hc.mac != nil {
+               hc.mac.Reset()
+               hc.mac.Write(&hc.seq)
+               hc.incSeq()
+               hc.mac.Write(b.data)
+               mac := hc.mac.Sum()
+               n := len(b.data)
+               b.resize(n + len(mac))
+               copy(b.data[n:], mac)
+
+               // update length to include mac
+               n = len(b.data) - recordHeaderLen
+               b.data[3] = byte(n >> 8)
+               b.data[4] = byte(n)
+       }
+
+       // encrypt
+       if hc.crypt != nil {
+               hc.crypt.XORKeyStream(b.data[recordHeaderLen:])
+       }
+
+       return true, 0
+}
+
+// A block is a simple data buffer.
+type block struct {
+       data []byte
+       off  int // index for Read
+       link *block
+}
+
+// resize resizes block to be n bytes, growing if necessary.
+func (b *block) resize(n int) {
+       if n > cap(b.data) {
+               b.reserve(n)
+       }
+       b.data = b.data[0:n]
+}
+
+// reserve makes sure that block contains a capacity of at least n bytes.
+func (b *block) reserve(n int) {
+       if cap(b.data) >= n {
+               return
+       }
+       m := cap(b.data)
+       if m == 0 {
+               m = 1024
+       }
+       for m < n {
+               m *= 2
+       }
+       data := make([]byte, len(b.data), m)
+       copy(data, b.data)
+       b.data = data
+}
+
+// readFromUntil reads from r into b until b contains at least n bytes
+// or else returns an error.
+func (b *block) readFromUntil(r io.Reader, n int) os.Error {
+       // quick case
+       if len(b.data) >= n {
+               return nil
+       }
+
+       // read until have enough.
+       b.reserve(n)
+       for {
+               m, err := r.Read(b.data[len(b.data):cap(b.data)])
+               b.data = b.data[0 : len(b.data)+m]
+               if len(b.data) >= n {
+                       break
+               }
+               if err != nil {
+                       return err
+               }
+       }
+       return nil
+}
+
+func (b *block) Read(p []byte) (n int, err os.Error) {
+       n = copy(p, b.data[b.off:])
+       b.off += n
+       return
+}
+
+// newBlock allocates a new block, from hc's free list if possible.
+func (hc *halfConn) newBlock() *block {
+       b := hc.bfree
+       if b == nil {
+               return new(block)
+       }
+       hc.bfree = b.link
+       b.link = nil
+       b.resize(0)
+       return b
+}
+
+// freeBlock returns a block to hc's free list.
+// The protocol is such that each side only has a block or two on
+// its free list at a time, so there's no need to worry about
+// trimming the list, etc.
+func (hc *halfConn) freeBlock(b *block) {
+       b.link = hc.bfree
+       hc.bfree = b
+}
+
+// splitBlock splits a block after the first n bytes,
+// returning a block with those n bytes and a
+// block with the remaindec.  the latter may be nil.
+func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
+       if len(b.data) <= n {
+               return b, nil
+       }
+       bb := hc.newBlock()
+       bb.resize(len(b.data) - n)
+       copy(bb.data, b.data[n:])
+       b.data = b.data[0:n]
+       return b, bb
+}
+
+// readRecord reads the next TLS record from the connection
+// and updates the record layer state.
+// c.in.Mutex <= L; c.input == nil.
+func (c *Conn) readRecord(want recordType) os.Error {
+       // Caller must be in sync with connection:
+       // handshake data if handshake not yet completed,
+       // else application data.  (We don't support renegotiation.)
+       switch want {
+       default:
+               return c.sendAlert(alertInternalError)
+       case recordTypeHandshake, recordTypeChangeCipherSpec:
+               if c.handshakeComplete {
+                       return c.sendAlert(alertInternalError)
+               }
+       case recordTypeApplicationData:
+               if !c.handshakeComplete {
+                       return c.sendAlert(alertInternalError)
+               }
+       }
+
+Again:
+       if c.rawInput == nil {
+               c.rawInput = c.in.newBlock()
+       }
+       b := c.rawInput
+
+       // Read header, payload.
+       if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
+               // RFC suggests that EOF without an alertCloseNotify is
+               // an error, but popular web sites seem to do this,
+               // so we can't make it an error.
+               // if err == os.EOF {
+               //      err = io.ErrUnexpectedEOF
+               // }
+               if e, ok := err.(net.Error); !ok || !e.Temporary() {
+                       c.setError(err)
+               }
+               return err
+       }
+       typ := recordType(b.data[0])
+       vers := uint16(b.data[1])<<8 | uint16(b.data[2])
+       n := int(b.data[3])<<8 | int(b.data[4])
+       if c.haveVers && vers != c.vers {
+               return c.sendAlert(alertProtocolVersion)
+       }
+       if n > maxCiphertext {
+               return c.sendAlert(alertRecordOverflow)
+       }
+       if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
+               if err == os.EOF {
+                       err = io.ErrUnexpectedEOF
+               }
+               if e, ok := err.(net.Error); !ok || !e.Temporary() {
+                       c.setError(err)
+               }
+               return err
+       }
+
+       // Process message.
+       b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
+       b.off = recordHeaderLen
+       if ok, err := c.in.decrypt(b); !ok {
+               return c.sendAlert(err)
+       }
+       data := b.data[b.off:]
+       if len(data) > maxPlaintext {
+               c.sendAlert(alertRecordOverflow)
+               c.in.freeBlock(b)
+               return c.error()
+       }
+
+       switch typ {
+       default:
+               c.sendAlert(alertUnexpectedMessage)
+
+       case recordTypeAlert:
+               if len(data) != 2 {
+                       c.sendAlert(alertUnexpectedMessage)
+                       break
+               }
+               if alert(data[1]) == alertCloseNotify {
+                       c.setError(os.EOF)
+                       break
+               }
+               switch data[0] {
+               case alertLevelWarning:
+                       // drop on the floor
+                       c.in.freeBlock(b)
+                       goto Again
+               case alertLevelError:
+                       c.setError(&net.OpError{Op: "remote error", Error: alert(data[1])})
+               default:
+                       c.sendAlert(alertUnexpectedMessage)
+               }
+
+       case recordTypeChangeCipherSpec:
+               if typ != want || len(data) != 1 || data[0] != 1 {
+                       c.sendAlert(alertUnexpectedMessage)
+                       break
+               }
+               err := c.in.changeCipherSpec()
+               if err != nil {
+                       c.sendAlert(err.(alert))
+               }
+
+       case recordTypeApplicationData:
+               if typ != want {
+                       c.sendAlert(alertUnexpectedMessage)
+                       break
+               }
+               c.input = b
+               b = nil
+
+       case recordTypeHandshake:
+               // TODO(rsc): Should at least pick off connection close.
+               if typ != want {
+                       return c.sendAlert(alertNoRenegotiation)
+               }
+               c.hand.Write(data)
+       }
+
+       if b != nil {
+               c.in.freeBlock(b)
+       }
+       return c.error()
+}
+
+// sendAlert sends a TLS alert message.
+// c.out.Mutex <= L.
+func (c *Conn) sendAlertLocked(err alert) os.Error {
+       c.tmp[0] = alertLevelError
+       if err == alertNoRenegotiation {
+               c.tmp[0] = alertLevelWarning
+       }
+       c.tmp[1] = byte(err)
+       c.writeRecord(recordTypeAlert, c.tmp[0:2])
+       return c.setError(&net.OpError{Op: "local error", Error: err})
+}
+
+// sendAlert sends a TLS alert message.
+// L < c.out.Mutex.
+func (c *Conn) sendAlert(err alert) os.Error {
+       c.out.Lock()
+       defer c.out.Unlock()
+       return c.sendAlertLocked(err)
+}
+
+// writeRecord writes a TLS record with the given type and payload
+// to the connection and updates the record layer state.
+// c.out.Mutex <= L.
+func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err os.Error) {
+       b := c.out.newBlock()
+       for len(data) > 0 {
+               m := len(data)
+               if m > maxPlaintext {
+                       m = maxPlaintext
+               }
+               b.resize(recordHeaderLen + m)
+               b.data[0] = byte(typ)
+               vers := c.vers
+               if vers == 0 {
+                       vers = maxVersion
+               }
+               b.data[1] = byte(vers >> 8)
+               b.data[2] = byte(vers)
+               b.data[3] = byte(m >> 8)
+               b.data[4] = byte(m)
+               copy(b.data[recordHeaderLen:], data)
+               c.out.encrypt(b)
+               _, err = c.conn.Write(b.data)
+               if err != nil {
+                       break
+               }
+               n += m
+               data = data[m:]
+       }
+       c.out.freeBlock(b)
+
+       if typ == recordTypeChangeCipherSpec {
+               err = c.out.changeCipherSpec()
+               if err != nil {
+                       // Cannot call sendAlert directly,
+                       // because we already hold c.out.Mutex.
+                       c.tmp[0] = alertLevelError
+                       c.tmp[1] = byte(err.(alert))
+                       c.writeRecord(recordTypeAlert, c.tmp[0:2])
+                       c.err = &net.OpError{Op: "local error", Error: err}
+                       return n, c.err
+               }
+       }
+       return
+}
+
+// readHandshake reads the next handshake message from
+// the record layer.
+// c.in.Mutex < L; c.out.Mutex < L.
+func (c *Conn) readHandshake() (interface{}, os.Error) {
+       for c.hand.Len() < 4 {
+               if c.err != nil {
+                       return nil, c.err
+               }
+               c.readRecord(recordTypeHandshake)
+       }
+
+       data := c.hand.Bytes()
+       n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
+       if n > maxHandshake {
+               c.sendAlert(alertInternalError)
+               return nil, c.err
+       }
+       for c.hand.Len() < 4+n {
+               if c.err != nil {
+                       return nil, c.err
+               }
+               c.readRecord(recordTypeHandshake)
+       }
+       data = c.hand.Next(4 + n)
+       var m handshakeMessage
+       switch data[0] {
+       case typeClientHello:
+               m = new(clientHelloMsg)
+       case typeServerHello:
+               m = new(serverHelloMsg)
+       case typeCertificate:
+               m = new(certificateMsg)
+       case typeServerHelloDone:
+               m = new(serverHelloDoneMsg)
+       case typeClientKeyExchange:
+               m = new(clientKeyExchangeMsg)
+       case typeNextProtocol:
+               m = new(nextProtoMsg)
+       case typeFinished:
+               m = new(finishedMsg)
+       default:
+               c.sendAlert(alertUnexpectedMessage)
+               return nil, alertUnexpectedMessage
+       }
+
+       // The handshake message unmarshallers
+       // expect to be able to keep references to data,
+       // so pass in a fresh copy that won't be overwritten.
+       data = bytes.Add(nil, data)
+
+       if !m.unmarshal(data) {
+               c.sendAlert(alertUnexpectedMessage)
+               return nil, alertUnexpectedMessage
+       }
+       return m, nil
+}
+
+// Write writes data to the connection.
+func (c *Conn) Write(b []byte) (n int, err os.Error) {
+       if err = c.Handshake(); err != nil {
+               return
+       }
+
+       c.out.Lock()
+       defer c.out.Unlock()
+
+       if !c.handshakeComplete {
+               return 0, alertInternalError
+       }
+       if c.err != nil {
+               return 0, c.err
+       }
+       return c.writeRecord(recordTypeApplicationData, b)
+}
+
+// Read can be made to time out and return err == os.EAGAIN
+// after a fixed time limit; see SetTimeout and SetReadTimeout.
+func (c *Conn) Read(b []byte) (n int, err os.Error) {
+       if err = c.Handshake(); err != nil {
+               return
+       }
+
+       c.in.Lock()
+       defer c.in.Unlock()
+
+       for c.input == nil && c.err == nil {
+               c.readRecord(recordTypeApplicationData)
+       }
+       if c.err != nil {
+               return 0, c.err
+       }
+       n, err = c.input.Read(b)
+       if c.input.off >= len(c.input.data) {
+               c.in.freeBlock(c.input)
+               c.input = nil
+       }
+       return n, nil
+}
+
+// Close closes the connection.
+func (c *Conn) Close() os.Error {
+       if err := c.Handshake(); err != nil {
+               return err
+       }
+       return c.sendAlert(alertCloseNotify)
+}
+
+// Handshake runs the client or server handshake
+// protocol if it has not yet been run.
+// Most uses of this packge need not call Handshake
+// explicitly: the first Read or Write will call it automatically.
+func (c *Conn) Handshake() os.Error {
+       c.handshakeMutex.Lock()
+       defer c.handshakeMutex.Unlock()
+       if err := c.error(); err != nil {
+               return err
+       }
+       if c.handshakeComplete {
+               return nil
+       }
+       if c.isClient {
+               return c.clientHandshake()
+       }
+       return c.serverHandshake()
+}
+
+// If c is a TLS server, ClientConnection returns the protocol
+// requested by the client during the TLS handshake.
+// Handshake must have been called already.
+func (c *Conn) ClientConnection() string {
+       c.handshakeMutex.Lock()
+       defer c.handshakeMutex.Unlock()
+       return c.clientProtocol
+}
index 8cc6b7409c4fa2559579c7ed1086774274944e99..dd3009802db447074d6c65a24e492dd721d01fc8 100644 (file)
@@ -12,74 +12,63 @@ import (
        "crypto/subtle"
        "crypto/x509"
        "io"
+       "os"
 )
 
-// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
-type clientHandshake struct {
-       writeChan   chan<- interface{}
-       controlChan chan<- interface{}
-       msgChan     <-chan interface{}
-       config      *Config
-}
-
-func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
-       h.writeChan = writeChan
-       h.controlChan = controlChan
-       h.msgChan = msgChan
-       h.config = config
-
-       defer close(writeChan)
-       defer close(controlChan)
-
+func (c *Conn) clientHandshake() os.Error {
        finishedHash := newFinishedHash()
 
+       config := defaultConfig()
+
        hello := &clientHelloMsg{
-               major:              defaultMajor,
-               minor:              defaultMinor,
+               vers:               maxVersion,
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
                random:             make([]byte, 32),
        }
 
-       currentTime := uint32(config.Time())
-       hello.random[0] = byte(currentTime >> 24)
-       hello.random[1] = byte(currentTime >> 16)
-       hello.random[2] = byte(currentTime >> 8)
-       hello.random[3] = byte(currentTime)
+       t := uint32(config.Time())
+       hello.random[0] = byte(t >> 24)
+       hello.random[1] = byte(t >> 16)
+       hello.random[2] = byte(t >> 8)
+       hello.random[3] = byte(t)
        _, err := io.ReadFull(config.Rand, hello.random[4:])
        if err != nil {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
 
        finishedHash.Write(hello.marshal())
-       writeChan <- writerSetVersion{defaultMajor, defaultMinor}
-       writeChan <- hello
+       c.writeRecord(recordTypeHandshake, hello.marshal())
 
-       serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg)
+       msg, err := c.readHandshake()
+       if err != nil {
+               return err
+       }
+       serverHello, ok := msg.(*serverHelloMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
        finishedHash.Write(serverHello.marshal())
-       major, minor, ok := mutualVersion(serverHello.major, serverHello.minor)
+
+       vers, ok := mutualVersion(serverHello.vers)
        if !ok {
-               h.error(alertProtocolVersion)
-               return
+               c.sendAlert(alertProtocolVersion)
        }
-
-       writeChan <- writerSetVersion{major, minor}
+       c.vers = vers
+       c.haveVers = true
 
        if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA ||
                serverHello.compressionMethod != compressionNone {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
 
-       certMsg, ok := h.readHandshakeMsg().(*certificateMsg)
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+       certMsg, ok := msg.(*certificateMsg)
        if !ok || len(certMsg.certificates) == 0 {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
        finishedHash.Write(certMsg.marshal())
 
@@ -87,139 +76,98 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
        for i, asn1Data := range certMsg.certificates {
                cert, err := x509.ParseCertificate(asn1Data)
                if err != nil {
-                       h.error(alertBadCertificate)
-                       return
+                       return c.sendAlert(alertBadCertificate)
                }
                certs[i] = cert
        }
 
        // TODO(agl): do better validation of certs: max path length, name restrictions etc.
        for i := 1; i < len(certs); i++ {
-               if certs[i-1].CheckSignatureFrom(certs[i]) != nil {
-                       h.error(alertBadCertificate)
-                       return
+               if err := certs[i-1].CheckSignatureFrom(certs[i]); err != nil {
+                       return c.sendAlert(alertBadCertificate)
                }
        }
 
-       if config.RootCAs != nil {
+       // TODO(rsc): Find certificates for OS X 10.6.
+       if false && config.RootCAs != nil {
                root := config.RootCAs.FindParent(certs[len(certs)-1])
                if root == nil {
-                       h.error(alertBadCertificate)
-                       return
+                       return c.sendAlert(alertBadCertificate)
                }
                if certs[len(certs)-1].CheckSignatureFrom(root) != nil {
-                       h.error(alertBadCertificate)
-                       return
+                       return c.sendAlert(alertBadCertificate)
                }
        }
 
        pub, ok := certs[0].PublicKey.(*rsa.PublicKey)
        if !ok {
-               h.error(alertUnsupportedCertificate)
-               return
+               return c.sendAlert(alertUnsupportedCertificate)
        }
 
-       shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg)
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+       shd, ok := msg.(*serverHelloDoneMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
        finishedHash.Write(shd.marshal())
 
        ckx := new(clientKeyExchangeMsg)
        preMasterSecret := make([]byte, 48)
-       // Note that the version number in the preMasterSecret must be the
-       // version offered in the ClientHello.
-       preMasterSecret[0] = defaultMajor
-       preMasterSecret[1] = defaultMinor
+       preMasterSecret[0] = byte(hello.vers >> 8)
+       preMasterSecret[1] = byte(hello.vers)
        _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
        if err != nil {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
 
        ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret)
        if err != nil {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
 
        finishedHash.Write(ckx.marshal())
-       writeChan <- ckx
+       c.writeRecord(recordTypeHandshake, ckx.marshal())
 
        suite := cipherSuites[0]
        masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
                keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength)
 
        cipher, _ := rc4.NewCipher(clientKey)
-       writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
+
+       c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+       c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 
        finished := new(finishedMsg)
        finished.verifyData = finishedHash.clientSum(masterSecret)
        finishedHash.Write(finished.marshal())
-       writeChan <- finished
-
-       // TODO(agl): this is cut-through mode which should probably be an option.
-       writeChan <- writerEnableApplicationData{}
-
-       _, ok = h.readHandshakeMsg().(changeCipherSpec)
-       if !ok {
-               h.error(alertUnexpectedMessage)
-               return
-       }
+       c.writeRecord(recordTypeHandshake, finished.marshal())
 
        cipher2, _ := rc4.NewCipher(serverKey)
-       controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}
+       c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+       c.readRecord(recordTypeChangeCipherSpec)
+       if c.err != nil {
+               return c.err
+       }
 
-       serverFinished, ok := h.readHandshakeMsg().(*finishedMsg)
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+       serverFinished, ok := msg.(*finishedMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
 
        verify := finishedHash.serverSum(masterSecret)
        if len(verify) != len(serverFinished.verifyData) ||
                subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
-               h.error(alertHandshakeFailure)
-               return
+               return c.sendAlert(alertHandshakeFailure)
        }
 
-       controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"}
-
-       // This should just block forever.
-       _ = h.readHandshakeMsg()
-       h.error(alertUnexpectedMessage)
-       return
-}
-
-func (h *clientHandshake) readHandshakeMsg() interface{} {
-       v := <-h.msgChan
-       if closed(h.msgChan) {
-               // If the channel closed then the processor received an error
-               // from the peer and we don't want to echo it back to them.
-               h.msgChan = nil
-               return 0
-       }
-       if _, ok := v.(alert); ok {
-               // We got an alert from the processor. We forward to the writer
-               // and shutdown.
-               h.writeChan <- v
-               h.msgChan = nil
-               return 0
-       }
-       return v
-}
-
-func (h *clientHandshake) error(e alertType) {
-       if h.msgChan != nil {
-               // If we didn't get an error from the processor, then we need
-               // to tell it about the error.
-               go func() {
-                       for _ = range h.msgChan {
-                       }
-               }()
-               h.controlChan <- ConnectionState{Error: e}
-               close(h.controlChan)
-               h.writeChan <- alert{alertLevelError, e}
-       }
+       c.handshakeComplete = true
+       c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
+       return nil
 }
index 966314857f35b4be8f3ae01c2b4dc6e0d44eb093..f0a48c8630ab7ceb696fa7d07dd69f54f4087a94 100644 (file)
@@ -6,7 +6,7 @@ package tls
 
 type clientHelloMsg struct {
        raw                []byte
-       major, minor       uint8
+       vers               uint16
        random             []byte
        sessionId          []byte
        cipherSuites       []uint16
@@ -40,8 +40,8 @@ func (m *clientHelloMsg) marshal() []byte {
        x[1] = uint8(length >> 16)
        x[2] = uint8(length >> 8)
        x[3] = uint8(length)
-       x[4] = m.major
-       x[5] = m.minor
+       x[4] = uint8(m.vers >> 8)
+       x[5] = uint8(m.vers)
        copy(x[6:38], m.random)
        x[38] = uint8(len(m.sessionId))
        copy(x[39:39+len(m.sessionId)], m.sessionId)
@@ -108,12 +108,11 @@ func (m *clientHelloMsg) marshal() []byte {
 }
 
 func (m *clientHelloMsg) unmarshal(data []byte) bool {
-       if len(data) < 43 {
+       if len(data) < 42 {
                return false
        }
        m.raw = data
-       m.major = data[4]
-       m.minor = data[5]
+       m.vers = uint16(data[4])<<8 | uint16(data[5])
        m.random = data[6:38]
        sessionIdLen := int(data[38])
        if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
@@ -136,7 +135,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
                m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
        }
        data = data[2+cipherSuiteLen:]
-       if len(data) < 2 {
+       if len(data) < 1 {
                return false
        }
        compressionMethodsLen := int(data[0])
@@ -212,7 +211,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
 
 type serverHelloMsg struct {
        raw               []byte
-       major, minor      uint8
+       vers              uint16
        random            []byte
        sessionId         []byte
        cipherSuite       uint16
@@ -249,8 +248,8 @@ func (m *serverHelloMsg) marshal() []byte {
        x[1] = uint8(length >> 16)
        x[2] = uint8(length >> 8)
        x[3] = uint8(length)
-       x[4] = m.major
-       x[5] = m.minor
+       x[4] = uint8(m.vers >> 8)
+       x[5] = uint8(m.vers)
        copy(x[6:38], m.random)
        x[38] = uint8(len(m.sessionId))
        copy(x[39:39+len(m.sessionId)], m.sessionId)
@@ -306,8 +305,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
                return false
        }
        m.raw = data
-       m.major = data[4]
-       m.minor = data[5]
+       m.vers = uint16(data[4])<<8 | uint16(data[5])
        m.random = data[6:38]
        sessionIdLen := int(data[38])
        if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
index 3c5902e2458a65594c4207584cdbe5081a05acf9..2e422cc6a007cc70852e7a867211d780c5723e34 100644 (file)
@@ -97,8 +97,7 @@ func randomString(n int, rand *rand.Rand) string {
 
 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &clientHelloMsg{}
-       m.major = uint8(rand.Intn(256))
-       m.minor = uint8(rand.Intn(256))
+       m.vers = uint16(rand.Intn(65536))
        m.random = randomBytes(32, rand)
        m.sessionId = randomBytes(rand.Intn(32), rand)
        m.cipherSuites = make([]uint16, rand.Intn(63)+1)
@@ -118,8 +117,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
 
 func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &serverHelloMsg{}
-       m.major = uint8(rand.Intn(256))
-       m.minor = uint8(rand.Intn(256))
+       m.vers = uint16(rand.Intn(65536))
        m.random = randomBytes(32, rand)
        m.sessionId = randomBytes(rand.Intn(32), rand)
        m.cipherSuite = uint16(rand.Int31())
index 50854d1543bb0602067e8c787a769db1c8501f0b..ebf956763ac637318b18aa1754c69006e93ec555 100644 (file)
@@ -19,6 +19,7 @@ import (
        "crypto/sha1"
        "crypto/subtle"
        "io"
+       "os"
 )
 
 type cipherSuite struct {
@@ -31,33 +32,22 @@ var cipherSuites = []cipherSuite{
        cipherSuite{TLS_RSA_WITH_RC4_128_SHA, 20, 16},
 }
 
-// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
-type serverHandshake struct {
-       writeChan   chan<- interface{}
-       controlChan chan<- interface{}
-       msgChan     <-chan interface{}
-       config      *Config
-}
-
-func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
-       h.writeChan = writeChan
-       h.controlChan = controlChan
-       h.msgChan = msgChan
-       h.config = config
-
-       defer close(writeChan)
-       defer close(controlChan)
-
-       clientHello, ok := h.readHandshakeMsg().(*clientHelloMsg)
+func (c *Conn) serverHandshake() os.Error {
+       config := c.config
+       msg, err := c.readHandshake()
+       if err != nil {
+               return err
+       }
+       clientHello, ok := msg.(*clientHelloMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
-       major, minor, ok := mutualVersion(clientHello.major, clientHello.minor)
+       vers, ok := mutualVersion(clientHello.vers)
        if !ok {
-               h.error(alertProtocolVersion)
-               return
+               return c.sendAlert(alertProtocolVersion)
        }
+       c.vers = vers
+       c.haveVers = true
 
        finishedHash := newFinishedHash()
        finishedHash.Write(clientHello.marshal())
@@ -89,23 +79,20 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
        }
 
        if suite == nil || !foundCompression {
-               h.error(alertHandshakeFailure)
-               return
+               return c.sendAlert(alertHandshakeFailure)
        }
 
-       hello.major = major
-       hello.minor = minor
+       hello.vers = vers
        hello.cipherSuite = suite.id
-       currentTime := uint32(config.Time())
+       t := uint32(config.Time())
        hello.random = make([]byte, 32)
-       hello.random[0] = byte(currentTime >> 24)
-       hello.random[1] = byte(currentTime >> 16)
-       hello.random[2] = byte(currentTime >> 8)
-       hello.random[3] = byte(currentTime)
-       _, err := io.ReadFull(config.Rand, hello.random[4:])
+       hello.random[0] = byte(t >> 24)
+       hello.random[1] = byte(t >> 16)
+       hello.random[2] = byte(t >> 8)
+       hello.random[3] = byte(t)
+       _, err = io.ReadFull(config.Rand, hello.random[4:])
        if err != nil {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
        hello.compressionMethod = compressionNone
        if clientHello.nextProtoNeg {
@@ -114,41 +101,40 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
        }
 
        finishedHash.Write(hello.marshal())
-       writeChan <- writerSetVersion{major, minor}
-       writeChan <- hello
+       c.writeRecord(recordTypeHandshake, hello.marshal())
 
        if len(config.Certificates) == 0 {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
 
        certMsg := new(certificateMsg)
        certMsg.certificates = config.Certificates[0].Certificate
        finishedHash.Write(certMsg.marshal())
-       writeChan <- certMsg
+       c.writeRecord(recordTypeHandshake, certMsg.marshal())
 
        helloDone := new(serverHelloDoneMsg)
        finishedHash.Write(helloDone.marshal())
-       writeChan <- helloDone
+       c.writeRecord(recordTypeHandshake, helloDone.marshal())
 
-       ckx, ok := h.readHandshakeMsg().(*clientKeyExchangeMsg)
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+       ckx, ok := msg.(*clientKeyExchangeMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
        finishedHash.Write(ckx.marshal())
 
        preMasterSecret := make([]byte, 48)
        _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
        if err != nil {
-               h.error(alertInternalError)
-               return
+               return c.sendAlert(alertInternalError)
        }
 
        err = rsa.DecryptPKCS1v15SessionKey(config.Rand, config.Certificates[0].PrivateKey, ckx.ciphertext, preMasterSecret)
        if err != nil {
-               h.error(alertHandshakeFailure)
-               return
+               return c.sendAlert(alertHandshakeFailure)
        }
        // We don't check the version number in the premaster secret. For one,
        // by checking it, we would leak information about the validity of the
@@ -160,91 +146,53 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
        masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
                keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
 
-       _, ok = h.readHandshakeMsg().(changeCipherSpec)
-       if !ok {
-               h.error(alertUnexpectedMessage)
-               return
-       }
-
        cipher, _ := rc4.NewCipher(clientKey)
-       controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
+       c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+       c.readRecord(recordTypeChangeCipherSpec)
+       if err := c.error(); err != nil {
+               return err
+       }
 
-       clientProtocol := ""
        if hello.nextProtoNeg {
-               nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg)
+               msg, err = c.readHandshake()
+               if err != nil {
+                       return err
+               }
+               nextProto, ok := msg.(*nextProtoMsg)
                if !ok {
-                       h.error(alertUnexpectedMessage)
-                       return
+                       return c.sendAlert(alertUnexpectedMessage)
                }
                finishedHash.Write(nextProto.marshal())
-               clientProtocol = nextProto.proto
+               c.clientProtocol = nextProto.proto
        }
 
-       clientFinished, ok := h.readHandshakeMsg().(*finishedMsg)
+       msg, err = c.readHandshake()
+       if err != nil {
+               return err
+       }
+       clientFinished, ok := msg.(*finishedMsg)
        if !ok {
-               h.error(alertUnexpectedMessage)
-               return
+               return c.sendAlert(alertUnexpectedMessage)
        }
 
        verify := finishedHash.clientSum(masterSecret)
        if len(verify) != len(clientFinished.verifyData) ||
                subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
-               h.error(alertHandshakeFailure)
-               return
+               return c.sendAlert(alertHandshakeFailure)
        }
 
-       controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol}
-
        finishedHash.Write(clientFinished.marshal())
 
        cipher2, _ := rc4.NewCipher(serverKey)
-       writeChan <- writerChangeCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)}
+       c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+       c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 
        finished := new(finishedMsg)
        finished.verifyData = finishedHash.serverSum(masterSecret)
-       writeChan <- finished
-
-       writeChan <- writerEnableApplicationData{}
-
-       for {
-               _, ok := h.readHandshakeMsg().(*clientHelloMsg)
-               if !ok {
-                       h.error(alertUnexpectedMessage)
-                       return
-               }
-               // We reject all renegotication requests.
-               writeChan <- alert{alertLevelWarning, alertNoRenegotiation}
-       }
-}
+       c.writeRecord(recordTypeHandshake, finished.marshal())
 
-func (h *serverHandshake) readHandshakeMsg() interface{} {
-       v := <-h.msgChan
-       if closed(h.msgChan) {
-               // If the channel closed then the processor received an error
-               // from the peer and we don't want to echo it back to them.
-               h.msgChan = nil
-               return 0
-       }
-       if _, ok := v.(alert); ok {
-               // We got an alert from the processor. We forward to the writer
-               // and shutdown.
-               h.writeChan <- v
-               h.msgChan = nil
-               return 0
-       }
-       return v
-}
+       c.handshakeComplete = true
+       c.cipherSuite = TLS_RSA_WITH_RC4_128_SHA
 
-func (h *serverHandshake) error(e alertType) {
-       if h.msgChan != nil {
-               // If we didn't get an error from the processor, then we need
-               // to tell it about the error.
-               go func() {
-                       for _ = range h.msgChan {
-                       }
-               }()
-               h.controlChan <- ConnectionState{false, "", e, ""}
-               close(h.controlChan)
-               h.writeChan <- alert{alertLevelError, e}
-       }
+       return nil
 }
index a580b14e3c5a8cbf9762e2f08160beac573bbb56..d31dc497e37b50fda460f4d5b6843f524c4eac91 100644 (file)
@@ -5,12 +5,16 @@
 package tls
 
 import (
-       "bytes"
+       //      "bytes"
        "big"
        "crypto/rsa"
+       "encoding/hex"
+       "flag"
+       "io"
+       "net"
        "os"
        "testing"
-       "testing/script"
+       //      "testing/script"
 )
 
 type zeroSource struct{}
@@ -34,29 +38,23 @@ func init() {
        testConfig.Certificates[0].PrivateKey = testPrivateKey
 }
 
-func setupServerHandshake() (writeChan chan interface{}, controlChan chan interface{}, msgChan chan interface{}) {
-       sh := new(serverHandshake)
-       writeChan = make(chan interface{})
-       controlChan = make(chan interface{})
-       msgChan = make(chan interface{})
-
-       go sh.loop(writeChan, controlChan, msgChan, testConfig)
-       return
-}
-
-func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert alertType) {
-       writeChan, controlChan, msgChan := setupServerHandshake()
-       defer close(msgChan)
-
-       send := script.NewEvent("send", nil, script.Send{msgChan, clientHello})
-       recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}})
-       close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan})
-       recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert, ""}})
-       close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan})
-
-       err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
+func testClientHelloFailure(t *testing.T, m handshakeMessage, expected os.Error) {
+       // Create in-memory network connection,
+       // send message to server.  Should return
+       // expected error.
+       c, s := net.Pipe()
+       go func() {
+               cli := Client(c, testConfig)
+               if ch, ok := m.(*clientHelloMsg); ok {
+                       cli.vers = ch.vers
+               }
+               cli.writeRecord(recordTypeHandshake, m.marshal())
+               c.Close()
+       }()
+       err := Server(s, testConfig).Handshake()
+       s.Close()
+       if e, ok := err.(*net.OpError); !ok || e.Error != expected {
+               t.Errorf("Got error: %s; expected: %s", err, expected)
        }
 }
 
@@ -64,134 +62,100 @@ func TestSimpleError(t *testing.T) {
        testClientHelloFailure(t, &serverHelloDoneMsg{}, alertUnexpectedMessage)
 }
 
-var badProtocolVersions = []uint8{0, 0, 0, 5, 1, 0, 1, 5, 2, 0, 2, 5, 3, 0}
+var badProtocolVersions = []uint16{0x0000, 0x0005, 0x0100, 0x0105, 0x0200, 0x0205, 0x0300}
 
 func TestRejectBadProtocolVersion(t *testing.T) {
-       clientHello := new(clientHelloMsg)
-
-       for i := 0; i < len(badProtocolVersions); i += 2 {
-               clientHello.major = badProtocolVersions[i]
-               clientHello.minor = badProtocolVersions[i+1]
-
-               testClientHelloFailure(t, clientHello, alertProtocolVersion)
+       for _, v := range badProtocolVersions {
+               testClientHelloFailure(t, &clientHelloMsg{vers: v}, alertProtocolVersion)
        }
 }
 
 func TestNoSuiteOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
+       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 
 }
 
 func TestNoCompressionOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
+       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 }
 
-func matchServerHello(v interface{}) bool {
-       serverHello, ok := v.(*serverHelloMsg)
-       if !ok {
-               return false
-       }
-       return serverHello.major == 3 &&
-               serverHello.minor == 2 &&
-               serverHello.cipherSuite == TLS_RSA_WITH_RC4_128_SHA &&
-               serverHello.compressionMethod == compressionNone
-}
-
 func TestAlertForwarding(t *testing.T) {
-       writeChan, controlChan, msgChan := setupServerHandshake()
-       defer close(msgChan)
-
-       a := alert{alertLevelError, alertNoRenegotiation}
-       sendAlert := script.NewEvent("send alert", nil, script.Send{msgChan, a})
-       recvAlert := script.NewEvent("recv alert", []*script.Event{sendAlert}, script.Recv{writeChan, a})
-       closeWriter := script.NewEvent("close writer", []*script.Event{recvAlert}, script.Closed{writeChan})
-       closeControl := script.NewEvent("close control", []*script.Event{recvAlert}, script.Closed{controlChan})
-
-       err := script.Perform(0, []*script.Event{sendAlert, recvAlert, closeWriter, closeControl})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
+       c, s := net.Pipe()
+       go func() {
+               Client(c, testConfig).sendAlert(alertUnknownCA)
+               c.Close()
+       }()
+
+       err := Server(s, testConfig).Handshake()
+       s.Close()
+       if e, ok := err.(*net.OpError); !ok || e.Error != os.Error(alertUnknownCA) {
+               t.Errorf("Got error: %s; expected: %s", err, alertUnknownCA)
        }
 }
 
 func TestClose(t *testing.T) {
-       writeChan, controlChan, msgChan := setupServerHandshake()
-
-       close := script.NewEvent("close", nil, script.Close{msgChan})
-       closed1 := script.NewEvent("closed1", []*script.Event{close}, script.Closed{writeChan})
-       closed2 := script.NewEvent("closed2", []*script.Event{close}, script.Closed{controlChan})
-
-       err := script.Perform(0, []*script.Event{close, closed1, closed2})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
+       c, s := net.Pipe()
+       go c.Close()
 
-func matchCertificate(v interface{}) bool {
-       cert, ok := v.(*certificateMsg)
-       if !ok {
-               return false
+       err := Server(s, testConfig).Handshake()
+       s.Close()
+       if err != os.EOF {
+               t.Errorf("Got error: %s; expected: %s", err, os.EOF)
        }
-       return len(cert.certificates) == 1 &&
-               bytes.Compare(cert.certificates[0], testCertificate) == 0
 }
 
-func matchSetCipher(v interface{}) bool {
-       _, ok := v.(writerChangeCipherSpec)
-       return ok
-}
 
-func matchDone(v interface{}) bool {
-       _, ok := v.(*serverHelloDoneMsg)
-       return ok
-}
+func TestHandshakeServer(t *testing.T) {
+       c, s := net.Pipe()
+       srv := Server(s, testConfig)
+       go func() {
+               srv.Write([]byte("hello, world\n"))
+               srv.Close()
+       }()
+
+       defer c.Close()
+       for i, b := range serverScript {
+               if i%2 == 0 {
+                       c.Write(b)
+                       continue
+               }
+               bb := make([]byte, len(b))
+               _, err := io.ReadFull(c, bb)
+               if err != nil {
+                       t.Fatalf("#%d: %s", i, err)
+               }
+       }
 
-func matchFinished(v interface{}) bool {
-       finished, ok := v.(*finishedMsg)
-       if !ok {
-               return false
+       if !srv.haveVers || srv.vers != 0x0302 {
+               t.Errorf("server version incorrect: %v %v", srv.haveVers, srv.vers)
        }
-       return bytes.Compare(finished.verifyData, fromHex("29122ae11453e631487b02ed")) == 0
-}
 
-func matchNewCipherSpec(v interface{}) bool {
-       _, ok := v.(*newCipherSpec)
-       return ok
+       // TODO: check protocol
 }
 
-func TestFullHandshake(t *testing.T) {
-       writeChan, controlChan, msgChan := setupServerHandshake()
-       defer close(msgChan)
-
-       // The values for this test were obtained from running `gnutls-cli --insecure --debug 9`
-       clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}, false, ""}
+var serve = flag.Bool("serve", false, "run a TLS server on :10443")
 
-       sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello})
-       setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}})
-       recvHello := script.NewEvent("recv hello", []*script.Event{setVersion}, script.RecvMatch{writeChan, matchServerHello})
-       recvCert := script.NewEvent("recv cert", []*script.Event{recvHello}, script.RecvMatch{writeChan, matchCertificate})
-       recvDone := script.NewEvent("recv done", []*script.Event{recvCert}, script.RecvMatch{writeChan, matchDone})
-
-       ckx := &clientKeyExchangeMsg{nil, fromHex("872e1fee5f37dd86f3215938ac8de20b302b90074e9fb93097e6b7d1286d0f45abf2daf179deb618bb3c70ed0afee6ee24476ee4649e5a23358143c0f1d9c251")}
-       sendCKX := script.NewEvent("send ckx", []*script.Event{recvDone}, script.Send{msgChan, ckx})
-
-       sendCCS := script.NewEvent("send ccs", []*script.Event{sendCKX}, script.Send{msgChan, changeCipherSpec{}})
-       recvNCS := script.NewEvent("recv done", []*script.Event{sendCCS}, script.RecvMatch{controlChan, matchNewCipherSpec})
-
-       finished := &finishedMsg{nil, fromHex("c8faca5d242f4423325c5b1a")}
-       sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished})
-       recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished})
-       setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher})
-       recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, ""}})
+func TestRunServer(t *testing.T) {
+       if !*serve {
+               return
+       }
 
-       err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished})
+       l, err := Listen("tcp", ":10443", testConfig)
        if err != nil {
-               t.Errorf("Got error: %s", err)
+               t.Fatal(err)
        }
-}
 
-var testCertificate = fromHex("3082025930820203a003020102020900c2ec326b95228959300d06092a864886f70d01010505003054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374301e170d3039313032303232323434355a170d3130313032303232323434355a3054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374305c300d06092a864886f70d0101010500034b003048024100b2990f49c47dfa8cd400ae6a4d1b8a3b6a13642b23f28b003bfb97790ade9a4cc82b8b2a81747ddec08b6296e53a08c331687ef25c4bf4936ba1c0e6041e9d150203010001a381b73081b4301d0603551d0e0416041478a06086837c9293a8c9b70c0bdabdb9d77eeedf3081840603551d23047d307b801478a06086837c9293a8c9b70c0bdabdb9d77eeedfa158a4563054310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464310d300b0603550403130474657374820900c2ec326b95228959300c0603551d13040530030101ff300d06092a864886f70d0101050500034100ac23761ae1349d85a439caad4d0b932b09ea96de1917c3e0507c446f4838cb3076fb4d431db8c1987e96f1d7a8a2054dea3a64ec99a3f0eda4d47a163bf1f6ac")
+       for {
+               c, err := l.Accept()
+               if err != nil {
+                       break
+               }
+               c.Write([]byte("hello, world\n"))
+               c.Close()
+       }
+}
 
 func bigFromString(s string) *big.Int {
        ret := new(big.Int)
@@ -199,12 +163,131 @@ func bigFromString(s string) *big.Int {
        return ret
 }
 
+func fromHex(s string) []byte {
+       b, _ := hex.DecodeString(s)
+       return b
+}
+
+var testCertificate = fromHex("308202b030820219a00302010202090085b0bba48a7fb8ca300d06092a864886f70d01010505003045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c7464301e170d3130303432343039303933385a170d3131303432343039303933385a3045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746430819f300d06092a864886f70d010101050003818d0030818902818100bb79d6f517b5e5bf4610d0dc69bee62b07435ad0032d8a7a4385b71452e7a5654c2c78b8238cb5b482e5de1f953b7e62a52ca533d6fe125c7a56fcf506bffa587b263fb5cd04d3d0c921964ac7f4549f5abfef427100fe1899077f7e887d7df10439c4a22edb51c97ce3c04c3b326601cfafb11db8719a1ddbdb896baeda2d790203010001a381a73081a4301d0603551d0e04160414b1ade2855acfcb28db69ce2369ded3268e18883930750603551d23046e306c8014b1ade2855acfcb28db69ce2369ded3268e188839a149a4473045310b3009060355040613024155311330110603550408130a536f6d652d53746174653121301f060355040a1318496e7465726e6574205769646769747320507479204c746482090085b0bba48a7fb8ca300c0603551d13040530030101ff300d06092a864886f70d010105050003818100086c4524c76bb159ab0c52ccf2b014d7879d7a6475b55a9566e4c52b8eae12661feb4f38b36e60d392fdf74108b52513b1187a24fb301dbaed98b917ece7d73159db95d31d78ea50565cd5825a2d5a5f33c4b6d8c97590968c0f5298b5cd981f89205ff2a01ca31b9694dda9fd57e970e8266d71999b266e3850296c90a7bdd9")
+
 var testPrivateKey = &rsa.PrivateKey{
        PublicKey: rsa.PublicKey{
-               N: bigFromString("9353930466774385905609975137998169297361893554149986716853295022578535724979677252958524466350471210367835187480748268864277464700638583474144061408845077"),
+               N: bigFromString("131650079503776001033793877885499001334664249354723305978524647182322416328664556247316495448366990052837680518067798333412266673813370895702118944398081598789828837447552603077848001020611640547221687072142537202428102790818451901395596882588063427854225330436740647715202971973145151161964464812406232198521"),
                E: 65537,
        },
-       D: bigFromString("7266398431328116344057699379749222532279343923819063639497049039389899328538543087657733766554155839834519529439851673014800261285757759040931985506583861"),
-       P: bigFromString("98920366548084643601728869055592650835572950932266967461790948584315647051443"),
-       Q: bigFromString("94560208308847015747498523884063394671606671904944666360068158221458669711639"),
+       D: bigFromString("29354450337804273969007277378287027274721892607543397931919078829901848876371746653677097639302788129485893852488285045793268732234230875671682624082413996177431586734171663258657462237320300610850244186316880055243099640544518318093544057213190320837094958164973959123058337475052510833916491060913053867729"),
+       P: bigFromString("11969277782311800166562047708379380720136961987713178380670422671426759650127150688426177829077494755200794297055316163155755835813760102405344560929062149"),
+       Q: bigFromString("10998999429884441391899182616418192492905073053684657075974935218461686523870125521822756579792315215543092255516093840728890783887287417039645833477273829"),
+}
+
+// Script of interaction with gnutls implementation.
+// The values for this test are obtained by building a test binary (gotest)
+// and then running 6.out -serve to start a server and then
+// gnutls-cli --insecure --debug 100 -p 10443 localhost
+// to dump a session.
+var serverScript = [][]byte{
+       // Alternate write and read.
+       []byte{
+               0x16, 0x03, 0x02, 0x00, 0x71, 0x01, 0x00, 0x00, 0x6d, 0x03, 0x02, 0x4b, 0xd4, 0xee, 0x6e, 0xab,
+               0x0b, 0xc3, 0x01, 0xd6, 0x8d, 0xe0, 0x72, 0x7e, 0x6c, 0x04, 0xbe, 0x9a, 0x3c, 0xa3, 0xd8, 0x95,
+               0x28, 0x00, 0xb2, 0xe8, 0x1f, 0xdd, 0xb0, 0xec, 0xca, 0x46, 0x1f, 0x00, 0x00, 0x28, 0x00, 0x33,
+               0x00, 0x39, 0x00, 0x16, 0x00, 0x32, 0x00, 0x38, 0x00, 0x13, 0x00, 0x66, 0x00, 0x90, 0x00, 0x91,
+               0x00, 0x8f, 0x00, 0x8e, 0x00, 0x2f, 0x00, 0x35, 0x00, 0x0a, 0x00, 0x05, 0x00, 0x04, 0x00, 0x8c,
+               0x00, 0x8d, 0x00, 0x8b, 0x00, 0x8a, 0x01, 0x00, 0x00, 0x1c, 0x00, 0x09, 0x00, 0x03, 0x02, 0x00,
+               0x01, 0x00, 0x00, 0x00, 0x11, 0x00, 0x0f, 0x00, 0x00, 0x0c, 0x31, 0x39, 0x32, 0x2e, 0x31, 0x36,
+               0x38, 0x2e, 0x30, 0x2e, 0x31, 0x30,
+       },
+
+       []byte{
+               0x16, 0x03, 0x02, 0x00, 0x2a,
+               0x02, 0x00, 0x00, 0x26, 0x03, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
+               0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x00,
+
+               0x16, 0x03, 0x02, 0x02, 0xbe,
+               0x0b, 0x00, 0x02, 0xba, 0x00, 0x02, 0xb7, 0x00, 0x02, 0xb4, 0x30, 0x82, 0x02, 0xb0, 0x30, 0x82,
+               0x02, 0x19, 0xa0, 0x03, 0x02, 0x01, 0x02, 0x02, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f,
+               0xb8, 0xca, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05,
+               0x00, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55,
+               0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d,
+               0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18,
+               0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73,
+               0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30, 0x1e, 0x17, 0x0d, 0x31, 0x30, 0x30, 0x34,
+               0x32, 0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x17, 0x0d, 0x31, 0x31, 0x30, 0x34, 0x32,
+               0x34, 0x30, 0x39, 0x30, 0x39, 0x33, 0x38, 0x5a, 0x30, 0x45, 0x31, 0x0b, 0x30, 0x09, 0x06, 0x03,
+               0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03, 0x55, 0x04, 0x08,
+               0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31, 0x21, 0x30, 0x1f,
+               0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e, 0x65, 0x74, 0x20,
+               0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c, 0x74, 0x64, 0x30,
+               0x81, 0x9f, 0x30, 0x0d, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x01, 0x05,
+               0x00, 0x03, 0x81, 0x8d, 0x00, 0x30, 0x81, 0x89, 0x02, 0x81, 0x81, 0x00, 0xbb, 0x79, 0xd6, 0xf5,
+               0x17, 0xb5, 0xe5, 0xbf, 0x46, 0x10, 0xd0, 0xdc, 0x69, 0xbe, 0xe6, 0x2b, 0x07, 0x43, 0x5a, 0xd0,
+               0x03, 0x2d, 0x8a, 0x7a, 0x43, 0x85, 0xb7, 0x14, 0x52, 0xe7, 0xa5, 0x65, 0x4c, 0x2c, 0x78, 0xb8,
+               0x23, 0x8c, 0xb5, 0xb4, 0x82, 0xe5, 0xde, 0x1f, 0x95, 0x3b, 0x7e, 0x62, 0xa5, 0x2c, 0xa5, 0x33,
+               0xd6, 0xfe, 0x12, 0x5c, 0x7a, 0x56, 0xfc, 0xf5, 0x06, 0xbf, 0xfa, 0x58, 0x7b, 0x26, 0x3f, 0xb5,
+               0xcd, 0x04, 0xd3, 0xd0, 0xc9, 0x21, 0x96, 0x4a, 0xc7, 0xf4, 0x54, 0x9f, 0x5a, 0xbf, 0xef, 0x42,
+               0x71, 0x00, 0xfe, 0x18, 0x99, 0x07, 0x7f, 0x7e, 0x88, 0x7d, 0x7d, 0xf1, 0x04, 0x39, 0xc4, 0xa2,
+               0x2e, 0xdb, 0x51, 0xc9, 0x7c, 0xe3, 0xc0, 0x4c, 0x3b, 0x32, 0x66, 0x01, 0xcf, 0xaf, 0xb1, 0x1d,
+               0xb8, 0x71, 0x9a, 0x1d, 0xdb, 0xdb, 0x89, 0x6b, 0xae, 0xda, 0x2d, 0x79, 0x02, 0x03, 0x01, 0x00,
+               0x01, 0xa3, 0x81, 0xa7, 0x30, 0x81, 0xa4, 0x30, 0x1d, 0x06, 0x03, 0x55, 0x1d, 0x0e, 0x04, 0x16,
+               0x04, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69, 0xde,
+               0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0x30, 0x75, 0x06, 0x03, 0x55, 0x1d, 0x23, 0x04, 0x6e, 0x30,
+               0x6c, 0x80, 0x14, 0xb1, 0xad, 0xe2, 0x85, 0x5a, 0xcf, 0xcb, 0x28, 0xdb, 0x69, 0xce, 0x23, 0x69,
+               0xde, 0xd3, 0x26, 0x8e, 0x18, 0x88, 0x39, 0xa1, 0x49, 0xa4, 0x47, 0x30, 0x45, 0x31, 0x0b, 0x30,
+               0x09, 0x06, 0x03, 0x55, 0x04, 0x06, 0x13, 0x02, 0x41, 0x55, 0x31, 0x13, 0x30, 0x11, 0x06, 0x03,
+               0x55, 0x04, 0x08, 0x13, 0x0a, 0x53, 0x6f, 0x6d, 0x65, 0x2d, 0x53, 0x74, 0x61, 0x74, 0x65, 0x31,
+               0x21, 0x30, 0x1f, 0x06, 0x03, 0x55, 0x04, 0x0a, 0x13, 0x18, 0x49, 0x6e, 0x74, 0x65, 0x72, 0x6e,
+               0x65, 0x74, 0x20, 0x57, 0x69, 0x64, 0x67, 0x69, 0x74, 0x73, 0x20, 0x50, 0x74, 0x79, 0x20, 0x4c,
+               0x74, 0x64, 0x82, 0x09, 0x00, 0x85, 0xb0, 0xbb, 0xa4, 0x8a, 0x7f, 0xb8, 0xca, 0x30, 0x0c, 0x06,
+               0x03, 0x55, 0x1d, 0x13, 0x04, 0x05, 0x30, 0x03, 0x01, 0x01, 0xff, 0x30, 0x0d, 0x06, 0x09, 0x2a,
+               0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01, 0x05, 0x05, 0x00, 0x03, 0x81, 0x81, 0x00, 0x08, 0x6c,
+               0x45, 0x24, 0xc7, 0x6b, 0xb1, 0x59, 0xab, 0x0c, 0x52, 0xcc, 0xf2, 0xb0, 0x14, 0xd7, 0x87, 0x9d,
+               0x7a, 0x64, 0x75, 0xb5, 0x5a, 0x95, 0x66, 0xe4, 0xc5, 0x2b, 0x8e, 0xae, 0x12, 0x66, 0x1f, 0xeb,
+               0x4f, 0x38, 0xb3, 0x6e, 0x60, 0xd3, 0x92, 0xfd, 0xf7, 0x41, 0x08, 0xb5, 0x25, 0x13, 0xb1, 0x18,
+               0x7a, 0x24, 0xfb, 0x30, 0x1d, 0xba, 0xed, 0x98, 0xb9, 0x17, 0xec, 0xe7, 0xd7, 0x31, 0x59, 0xdb,
+               0x95, 0xd3, 0x1d, 0x78, 0xea, 0x50, 0x56, 0x5c, 0xd5, 0x82, 0x5a, 0x2d, 0x5a, 0x5f, 0x33, 0xc4,
+               0xb6, 0xd8, 0xc9, 0x75, 0x90, 0x96, 0x8c, 0x0f, 0x52, 0x98, 0xb5, 0xcd, 0x98, 0x1f, 0x89, 0x20,
+               0x5f, 0xf2, 0xa0, 0x1c, 0xa3, 0x1b, 0x96, 0x94, 0xdd, 0xa9, 0xfd, 0x57, 0xe9, 0x70, 0xe8, 0x26,
+               0x6d, 0x71, 0x99, 0x9b, 0x26, 0x6e, 0x38, 0x50, 0x29, 0x6c, 0x90, 0xa7, 0xbd, 0xd9,
+               0x16, 0x03, 0x02, 0x00, 0x04,
+               0x0e, 0x00, 0x00, 0x00,
+       },
+
+       []byte{
+               0x16, 0x03, 0x02, 0x00, 0x86, 0x10, 0x00, 0x00, 0x82, 0x00, 0x80, 0x3b, 0x7a, 0x9b, 0x05, 0xfd,
+               0x1b, 0x0d, 0x81, 0xf0, 0xac, 0x59, 0x57, 0x4e, 0xb6, 0xf5, 0x81, 0xed, 0x52, 0x78, 0xc5, 0xff,
+               0x36, 0x33, 0x9c, 0x94, 0x31, 0xc3, 0x14, 0x98, 0x5d, 0xa0, 0x49, 0x23, 0x11, 0x67, 0xdf, 0x73,
+               0x1b, 0x81, 0x0b, 0xdd, 0x10, 0xda, 0xee, 0xb5, 0x68, 0x61, 0xa9, 0xb6, 0x15, 0xae, 0x1a, 0x11,
+               0x31, 0x42, 0x2e, 0xde, 0x01, 0x4b, 0x81, 0x70, 0x03, 0xc8, 0x5b, 0xca, 0x21, 0x88, 0x25, 0xef,
+               0x89, 0xf0, 0xb7, 0xff, 0x24, 0x32, 0xd3, 0x14, 0x76, 0xe2, 0x50, 0x5c, 0x2e, 0x75, 0x9d, 0x5c,
+               0xa9, 0x80, 0x3d, 0x6f, 0xd5, 0x46, 0xd3, 0xdb, 0x42, 0x6e, 0x55, 0x81, 0x88, 0x42, 0x0e, 0x45,
+               0xfe, 0x9e, 0xe4, 0x41, 0x79, 0xcf, 0x71, 0x0e, 0xed, 0x27, 0xa8, 0x20, 0x05, 0xe9, 0x7a, 0x42,
+               0x4f, 0x05, 0x10, 0x2e, 0x52, 0x5d, 0x8c, 0x3c, 0x40, 0x49, 0x4c,
+
+               0x14, 0x03, 0x02, 0x00, 0x01, 0x01,
+
+               0x16, 0x03, 0x02, 0x00, 0x24, 0x8b, 0x12, 0x24, 0x06, 0xaa, 0x92, 0x74, 0xa1, 0x46, 0x6f, 0xc1,
+               0x4e, 0x4a, 0xf7, 0x16, 0xdd, 0xd6, 0xe1, 0x2d, 0x37, 0x0b, 0x44, 0xba, 0xeb, 0xc4, 0x6c, 0xc7,
+               0xa0, 0xb7, 0x8c, 0x9d, 0x24, 0xbd, 0x99, 0x33, 0x1e,
+       },
+
+       []byte{
+               0x14, 0x03, 0x02, 0x00, 0x01,
+               0x01,
+
+               0x16, 0x03, 0x02, 0x00, 0x24,
+               0x6e, 0xd1, 0x3e, 0x49, 0x68, 0xc1, 0xa0, 0xa5, 0xb7, 0xaf, 0xb0, 0x7c, 0x52, 0x1f, 0xf7, 0x2d,
+               0x51, 0xf3, 0xa5, 0xb6, 0xf6, 0xd4, 0x18, 0x4b, 0x7a, 0xd5, 0x24, 0x1d, 0x09, 0xb6, 0x41, 0x1c,
+               0x1c, 0x98, 0xf6, 0x90,
+
+               0x17, 0x03, 0x02, 0x00, 0x21,
+               0x50, 0xb7, 0x92, 0x4f, 0xd8, 0x78, 0x29, 0xa2, 0xe7, 0xa5, 0xa6, 0xbd, 0x1a, 0x0c, 0xf1, 0x5a,
+               0x6e, 0x6c, 0xeb, 0x38, 0x99, 0x9b, 0x3c, 0xfd, 0xee, 0x53, 0xe8, 0x4d, 0x7b, 0xa5, 0x5b, 0x00,
+
+               0xb9,
+
+               0x15, 0x03, 0x02, 0x00, 0x16,
+               0xc7, 0xc9, 0x5a, 0x72, 0xfb, 0x02, 0xa5, 0x93, 0xdd, 0x69, 0xeb, 0x30, 0x68, 0x5e, 0xbc, 0xe0,
+               0x44, 0xb9, 0x59, 0x33, 0x68, 0xa9,
+       },
 }
diff --git a/src/pkg/crypto/tls/record_process.go b/src/pkg/crypto/tls/record_process.go
deleted file mode 100644 (file)
index 77470f0..0000000
+++ /dev/null
@@ -1,302 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-// A recordProcessor accepts reassembled records, decrypts and verifies them
-// and routes them either to the handshake processor, to up to the application.
-// It also accepts requests from the application for the current connection
-// state, or for a notification when the state changes.
-
-import (
-       "container/list"
-       "crypto/subtle"
-       "hash"
-)
-
-// getConnectionState is a request from the application to get the current
-// ConnectionState.
-type getConnectionState struct {
-       reply chan<- ConnectionState
-}
-
-// waitConnectionState is a request from the application to be notified when
-// the connection state changes.
-type waitConnectionState struct {
-       reply chan<- ConnectionState
-}
-
-// connectionStateChange is a message from the handshake processor that the
-// connection state has changed.
-type connectionStateChange struct {
-       connState ConnectionState
-}
-
-// changeCipherSpec is a message send to the handshake processor to signal that
-// the peer is switching ciphers.
-type changeCipherSpec struct{}
-
-// newCipherSpec is a message from the handshake processor that future
-// records should be processed with a new cipher and MAC function.
-type newCipherSpec struct {
-       encrypt encryptor
-       mac     hash.Hash
-}
-
-type recordProcessor struct {
-       decrypt       encryptor
-       mac           hash.Hash
-       seqNum        uint64
-       handshakeBuf  []byte
-       appDataChan   chan<- []byte
-       requestChan   <-chan interface{}
-       controlChan   <-chan interface{}
-       recordChan    <-chan *record
-       handshakeChan chan<- interface{}
-
-       // recordRead is nil when we don't wish to read any more.
-       recordRead <-chan *record
-       // appDataSend is nil when len(appData) == 0.
-       appDataSend chan<- []byte
-       // appData contains any application data queued for upstream.
-       appData []byte
-       // A list of channels waiting for connState to change.
-       waitQueue *list.List
-       connState ConnectionState
-       shutdown  bool
-       header    [13]byte
-}
-
-// drainRequestChannel processes messages from the request channel until it's closed.
-func drainRequestChannel(requestChan <-chan interface{}, c ConnectionState) {
-       for v := range requestChan {
-               if closed(requestChan) {
-                       return
-               }
-               switch r := v.(type) {
-               case getConnectionState:
-                       r.reply <- c
-               case waitConnectionState:
-                       r.reply <- c
-               }
-       }
-}
-
-func (p *recordProcessor) loop(appDataChan chan<- []byte, requestChan <-chan interface{}, controlChan <-chan interface{}, recordChan <-chan *record, handshakeChan chan<- interface{}) {
-       noop := nop{}
-       p.decrypt = noop
-       p.mac = noop
-       p.waitQueue = list.New()
-
-       p.appDataChan = appDataChan
-       p.requestChan = requestChan
-       p.controlChan = controlChan
-       p.recordChan = recordChan
-       p.handshakeChan = handshakeChan
-       p.recordRead = recordChan
-
-       for !p.shutdown {
-               select {
-               case p.appDataSend <- p.appData:
-                       p.appData = nil
-                       p.appDataSend = nil
-                       p.recordRead = p.recordChan
-               case c := <-controlChan:
-                       p.processControlMsg(c)
-               case r := <-requestChan:
-                       p.processRequestMsg(r)
-               case r := <-p.recordRead:
-                       p.processRecord(r)
-               }
-       }
-
-       p.wakeWaiters()
-       go drainRequestChannel(p.requestChan, p.connState)
-       go func() {
-               for _ = range controlChan {
-               }
-       }()
-
-       close(handshakeChan)
-       if len(p.appData) > 0 {
-               appDataChan <- p.appData
-       }
-       close(appDataChan)
-}
-
-func (p *recordProcessor) processRequestMsg(requestMsg interface{}) {
-       if closed(p.requestChan) {
-               p.shutdown = true
-               return
-       }
-
-       switch r := requestMsg.(type) {
-       case getConnectionState:
-               r.reply <- p.connState
-       case waitConnectionState:
-               if p.connState.HandshakeComplete {
-                       r.reply <- p.connState
-               }
-               p.waitQueue.PushBack(r.reply)
-       }
-}
-
-func (p *recordProcessor) processControlMsg(msg interface{}) {
-       connState, ok := msg.(ConnectionState)
-       if !ok || closed(p.controlChan) {
-               p.shutdown = true
-               return
-       }
-
-       p.connState = connState
-       p.wakeWaiters()
-}
-
-func (p *recordProcessor) wakeWaiters() {
-       for i := p.waitQueue.Front(); i != nil; i = i.Next() {
-               i.Value.(chan<- ConnectionState) <- p.connState
-       }
-       p.waitQueue.Init()
-}
-
-func (p *recordProcessor) processRecord(r *record) {
-       if closed(p.recordChan) {
-               p.shutdown = true
-               return
-       }
-
-       p.decrypt.XORKeyStream(r.payload)
-       if len(r.payload) < p.mac.Size() {
-               p.error(alertBadRecordMAC)
-               return
-       }
-
-       fillMACHeader(&p.header, p.seqNum, len(r.payload)-p.mac.Size(), r)
-       p.seqNum++
-
-       p.mac.Reset()
-       p.mac.Write(p.header[0:13])
-       p.mac.Write(r.payload[0 : len(r.payload)-p.mac.Size()])
-       macBytes := p.mac.Sum()
-
-       if subtle.ConstantTimeCompare(macBytes, r.payload[len(r.payload)-p.mac.Size():]) != 1 {
-               p.error(alertBadRecordMAC)
-               return
-       }
-
-       switch r.contentType {
-       case recordTypeHandshake:
-               p.processHandshakeRecord(r.payload[0 : len(r.payload)-p.mac.Size()])
-       case recordTypeChangeCipherSpec:
-               if len(r.payload) != 1 || r.payload[0] != 1 {
-                       p.error(alertUnexpectedMessage)
-                       return
-               }
-
-               p.handshakeChan <- changeCipherSpec{}
-               newSpec, ok := (<-p.controlChan).(*newCipherSpec)
-               if !ok {
-                       p.connState.Error = alertUnexpectedMessage
-                       p.shutdown = true
-                       return
-               }
-               p.decrypt = newSpec.encrypt
-               p.mac = newSpec.mac
-               p.seqNum = 0
-       case recordTypeApplicationData:
-               if p.connState.HandshakeComplete == false {
-                       p.error(alertUnexpectedMessage)
-                       return
-               }
-               p.recordRead = nil
-               p.appData = r.payload[0 : len(r.payload)-p.mac.Size()]
-               p.appDataSend = p.appDataChan
-       default:
-               p.error(alertUnexpectedMessage)
-               return
-       }
-}
-
-func (p *recordProcessor) processHandshakeRecord(data []byte) {
-       if p.handshakeBuf == nil {
-               p.handshakeBuf = data
-       } else {
-               if len(p.handshakeBuf) > maxHandshakeMsg {
-                       p.error(alertInternalError)
-                       return
-               }
-               newBuf := make([]byte, len(p.handshakeBuf)+len(data))
-               copy(newBuf, p.handshakeBuf)
-               copy(newBuf[len(p.handshakeBuf):], data)
-               p.handshakeBuf = newBuf
-       }
-
-       for len(p.handshakeBuf) >= 4 {
-               handshakeLen := int(p.handshakeBuf[1])<<16 |
-                       int(p.handshakeBuf[2])<<8 |
-                       int(p.handshakeBuf[3])
-               if handshakeLen+4 > len(p.handshakeBuf) {
-                       break
-               }
-
-               bytes := p.handshakeBuf[0 : handshakeLen+4]
-               p.handshakeBuf = p.handshakeBuf[handshakeLen+4:]
-               if bytes[0] == typeFinished {
-                       // Special case because Finished is synchronous: the
-                       // handshake handler has to tell us if it's ok to start
-                       // forwarding application data.
-                       m := new(finishedMsg)
-                       if !m.unmarshal(bytes) {
-                               p.error(alertUnexpectedMessage)
-                       }
-                       p.handshakeChan <- m
-                       var ok bool
-                       p.connState, ok = (<-p.controlChan).(ConnectionState)
-                       if !ok || p.connState.Error != 0 {
-                               p.shutdown = true
-                               return
-                       }
-               } else {
-                       msg, ok := parseHandshakeMsg(bytes)
-                       if !ok {
-                               p.error(alertUnexpectedMessage)
-                               return
-                       }
-                       p.handshakeChan <- msg
-               }
-       }
-}
-
-func (p *recordProcessor) error(err alertType) {
-       close(p.handshakeChan)
-       p.connState.Error = err
-       p.wakeWaiters()
-       p.shutdown = true
-}
-
-func parseHandshakeMsg(data []byte) (interface{}, bool) {
-       var m interface {
-               unmarshal([]byte) bool
-       }
-
-       switch data[0] {
-       case typeClientHello:
-               m = new(clientHelloMsg)
-       case typeServerHello:
-               m = new(serverHelloMsg)
-       case typeCertificate:
-               m = new(certificateMsg)
-       case typeServerHelloDone:
-               m = new(serverHelloDoneMsg)
-       case typeClientKeyExchange:
-               m = new(clientKeyExchangeMsg)
-       case typeNextProtocol:
-               m = new(nextProtoMsg)
-       default:
-               return nil, false
-       }
-
-       ok := m.unmarshal(data)
-       return m, ok
-}
diff --git a/src/pkg/crypto/tls/record_process_test.go b/src/pkg/crypto/tls/record_process_test.go
deleted file mode 100644 (file)
index fe001a2..0000000
+++ /dev/null
@@ -1,137 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-import (
-       "encoding/hex"
-       "testing"
-       "testing/script"
-)
-
-func setup() (appDataChan chan []byte, requestChan chan interface{}, controlChan chan interface{}, recordChan chan *record, handshakeChan chan interface{}) {
-       rp := new(recordProcessor)
-       appDataChan = make(chan []byte)
-       requestChan = make(chan interface{})
-       controlChan = make(chan interface{})
-       recordChan = make(chan *record)
-       handshakeChan = make(chan interface{})
-
-       go rp.loop(appDataChan, requestChan, controlChan, recordChan, handshakeChan)
-       return
-}
-
-func fromHex(s string) []byte {
-       b, _ := hex.DecodeString(s)
-       return b
-}
-
-func TestNullConnectionState(t *testing.T) {
-       _, requestChan, controlChan, recordChan, _ := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       // Test a simple request for the connection state.
-       replyChan := make(chan ConnectionState)
-       sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
-       getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}})
-
-       err := script.Perform(0, []*script.Event{sendReq, getReply})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
-
-func TestWaitConnectionState(t *testing.T) {
-       _, requestChan, controlChan, recordChan, _ := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       // Test that waitConnectionState doesn't get a reply until the connection state changes.
-       replyChan := make(chan ConnectionState)
-       sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
-       replyChan2 := make(chan ConnectionState)
-       sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
-       getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}})
-       sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}})
-       getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}})
-
-       err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
-
-func TestHandshakeAssembly(t *testing.T) {
-       _, requestChan, controlChan, recordChan, handshakeChan := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       // Test the reassembly of a fragmented handshake message.
-       send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("10000003")}})
-       send2 := script.NewEvent("send 2", []*script.Event{send1}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("0001")}})
-       send3 := script.NewEvent("send 3", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("42")}})
-       recvMsg := script.NewEvent("recv", []*script.Event{send3}, script.Recv{handshakeChan, &clientKeyExchangeMsg{fromHex("10000003000142"), fromHex("42")}})
-
-       err := script.Perform(0, []*script.Event{send1, send2, send3, recvMsg})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
-
-func TestEarlyApplicationData(t *testing.T) {
-       _, requestChan, controlChan, recordChan, handshakeChan := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       // Test that applicaton data received before the handshake has completed results in an error.
-       send := script.NewEvent("send", nil, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("")}})
-       recv := script.NewEvent("recv", []*script.Event{send}, script.Closed{handshakeChan})
-
-       err := script.Perform(0, []*script.Event{send, recv})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
-
-func TestApplicationData(t *testing.T) {
-       appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       // Test that the application data is forwarded after a successful Finished message.
-       send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
-       recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
-       send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}})
-       send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
-       recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
-
-       err := script.Perform(0, []*script.Event{send1, recv1, send2, send3, recv2})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
-
-func TestInvalidChangeCipherSpec(t *testing.T) {
-       appDataChan, requestChan, controlChan, recordChan, handshakeChan := setup()
-       defer close(requestChan)
-       defer close(controlChan)
-       defer close(recordChan)
-
-       send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
-       recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
-       send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}})
-       close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
-       close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
-
-       err := script.Perform(0, []*script.Event{send1, recv1, send2, close, close2})
-       if err != nil {
-               t.Errorf("Got error: %s", err)
-       }
-}
diff --git a/src/pkg/crypto/tls/record_read.go b/src/pkg/crypto/tls/record_read.go
deleted file mode 100644 (file)
index 682fde8..0000000
+++ /dev/null
@@ -1,42 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-// The record reader handles reading from the connection and reassembling TLS
-// record structures. It loops forever doing this and writes the TLS records to
-// it's outbound channel. On error, it closes its outbound channel.
-
-import (
-       "io"
-       "bufio"
-)
-
-// recordReader loops, reading TLS records from source and writing them to the
-// given channel. The channel is closed on EOF or on error.
-func recordReader(c chan<- *record, source io.Reader) {
-       defer close(c)
-       buf := bufio.NewReader(source)
-
-       for {
-               var header [5]byte
-               n, _ := buf.Read(&header)
-               if n != 5 {
-                       return
-               }
-
-               recordLength := int(header[3])<<8 | int(header[4])
-               if recordLength > maxTLSCiphertext {
-                       return
-               }
-
-               payload := make([]byte, recordLength)
-               n, _ = buf.Read(payload)
-               if n != recordLength {
-                       return
-               }
-
-               c <- &record{recordType(header[0]), header[1], header[2], payload}
-       }
-}
diff --git a/src/pkg/crypto/tls/record_read_test.go b/src/pkg/crypto/tls/record_read_test.go
deleted file mode 100644 (file)
index f897599..0000000
+++ /dev/null
@@ -1,73 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-import (
-       "bytes"
-       "testing"
-       "testing/iotest"
-)
-
-func matchRecord(r1, r2 *record) bool {
-       if (r1 == nil) != (r2 == nil) {
-               return false
-       }
-       if r1 == nil {
-               return true
-       }
-       return r1.contentType == r2.contentType &&
-               r1.major == r2.major &&
-               r1.minor == r2.minor &&
-               bytes.Compare(r1.payload, r2.payload) == 0
-}
-
-type recordReaderTest struct {
-       in  []byte
-       out []*record
-}
-
-var recordReaderTests = []recordReaderTest{
-       recordReaderTest{nil, nil},
-       recordReaderTest{fromHex("01"), nil},
-       recordReaderTest{fromHex("0102"), nil},
-       recordReaderTest{fromHex("010203"), nil},
-       recordReaderTest{fromHex("01020300"), nil},
-       recordReaderTest{fromHex("0102030000"), []*record{&record{1, 2, 3, nil}}},
-       recordReaderTest{fromHex("01020300000102030000"), []*record{&record{1, 2, 3, nil}, &record{1, 2, 3, nil}}},
-       recordReaderTest{fromHex("0102030001fe0102030002feff"), []*record{&record{1, 2, 3, []byte{0xfe}}, &record{1, 2, 3, []byte{0xfe, 0xff}}}},
-       recordReaderTest{fromHex("010203000001020300"), []*record{&record{1, 2, 3, nil}}},
-}
-
-func TestRecordReader(t *testing.T) {
-       for i, test := range recordReaderTests {
-               buf := bytes.NewBuffer(test.in)
-               c := make(chan *record)
-               go recordReader(c, buf)
-               matchRecordReaderOutput(t, i, test, c)
-
-               buf = bytes.NewBuffer(test.in)
-               buf2 := iotest.OneByteReader(buf)
-               c = make(chan *record)
-               go recordReader(c, buf2)
-               matchRecordReaderOutput(t, i*2, test, c)
-       }
-}
-
-func matchRecordReaderOutput(t *testing.T, i int, test recordReaderTest, c <-chan *record) {
-       for j, r1 := range test.out {
-               r2 := <-c
-               if r2 == nil {
-                       t.Errorf("#%d truncated after %d values", i, j)
-                       break
-               }
-               if !matchRecord(r1, r2) {
-                       t.Errorf("#%d (%d) got:%#v want:%#v", i, j, r2, r1)
-               }
-       }
-       <-c
-       if !closed(c) {
-               t.Errorf("#%d: channel didn't close", i)
-       }
-}
diff --git a/src/pkg/crypto/tls/record_write.go b/src/pkg/crypto/tls/record_write.go
deleted file mode 100644 (file)
index 5f3fb5b..0000000
+++ /dev/null
@@ -1,170 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package tls
-
-import (
-       "fmt"
-       "hash"
-       "io"
-)
-
-// writerEnableApplicationData is a message which instructs recordWriter to
-// start reading and transmitting data from the application data channel.
-type writerEnableApplicationData struct{}
-
-// writerChangeCipherSpec updates the encryption and MAC functions and resets
-// the sequence count.
-type writerChangeCipherSpec struct {
-       encryptor encryptor
-       mac       hash.Hash
-}
-
-// writerSetVersion sets the version number bytes that we included in the
-// record header for future records.
-type writerSetVersion struct {
-       major, minor uint8
-}
-
-// A recordWriter accepts messages from the handshake processor and
-// application data. It writes them to the outgoing connection and blocks on
-// writing. It doesn't read from the application data channel until the
-// handshake processor has signaled that the handshake is complete.
-type recordWriter struct {
-       writer       io.Writer
-       encryptor    encryptor
-       mac          hash.Hash
-       seqNum       uint64
-       major, minor uint8
-       shutdown     bool
-       appChan      <-chan []byte
-       controlChan  <-chan interface{}
-       header       [13]byte
-}
-
-func (w *recordWriter) loop(writer io.Writer, appChan <-chan []byte, controlChan <-chan interface{}) {
-       w.writer = writer
-       w.encryptor = nop{}
-       w.mac = nop{}
-       w.appChan = appChan
-       w.controlChan = controlChan
-
-       for !w.shutdown {
-               msg := <-controlChan
-               if _, ok := msg.(writerEnableApplicationData); ok {
-                       break
-               }
-               w.processControlMessage(msg)
-       }
-
-       for !w.shutdown {
-               // Always process control messages first.
-               if controlMsg, ok := <-controlChan; ok {
-                       w.processControlMessage(controlMsg)
-                       continue
-               }
-
-               select {
-               case controlMsg := <-controlChan:
-                       w.processControlMessage(controlMsg)
-               case appMsg := <-appChan:
-                       w.processAppMessage(appMsg)
-               }
-       }
-
-       if !closed(appChan) {
-               go func() {
-                       for _ = range appChan {
-                       }
-               }()
-       }
-       if !closed(controlChan) {
-               go func() {
-                       for _ = range controlChan {
-                       }
-               }()
-       }
-}
-
-// fillMACHeader generates a MAC header. See RFC 4346, section 6.2.3.1.
-func fillMACHeader(header *[13]byte, seqNum uint64, length int, r *record) {
-       header[0] = uint8(seqNum >> 56)
-       header[1] = uint8(seqNum >> 48)
-       header[2] = uint8(seqNum >> 40)
-       header[3] = uint8(seqNum >> 32)
-       header[4] = uint8(seqNum >> 24)
-       header[5] = uint8(seqNum >> 16)
-       header[6] = uint8(seqNum >> 8)
-       header[7] = uint8(seqNum)
-       header[8] = uint8(r.contentType)
-       header[9] = r.major
-       header[10] = r.minor
-       header[11] = uint8(length >> 8)
-       header[12] = uint8(length)
-}
-
-func (w *recordWriter) writeRecord(r *record) {
-       w.mac.Reset()
-
-       fillMACHeader(&w.header, w.seqNum, len(r.payload), r)
-
-       w.mac.Write(w.header[0:13])
-       w.mac.Write(r.payload)
-       macBytes := w.mac.Sum()
-
-       w.encryptor.XORKeyStream(r.payload)
-       w.encryptor.XORKeyStream(macBytes)
-
-       length := len(r.payload) + len(macBytes)
-       w.header[11] = uint8(length >> 8)
-       w.header[12] = uint8(length)
-       w.writer.Write(w.header[8:13])
-       w.writer.Write(r.payload)
-       w.writer.Write(macBytes)
-
-       w.seqNum++
-}
-
-func (w *recordWriter) processControlMessage(controlMsg interface{}) {
-       if controlMsg == nil {
-               w.shutdown = true
-               return
-       }
-
-       switch msg := controlMsg.(type) {
-       case writerChangeCipherSpec:
-               w.writeRecord(&record{recordTypeChangeCipherSpec, w.major, w.minor, []byte{0x01}})
-               w.encryptor = msg.encryptor
-               w.mac = msg.mac
-               w.seqNum = 0
-       case writerSetVersion:
-               w.major = msg.major
-               w.minor = msg.minor
-       case alert:
-               w.writeRecord(&record{recordTypeAlert, w.major, w.minor, []byte{byte(msg.level), byte(msg.error)}})
-       case handshakeMessage:
-               // TODO(agl): marshal may return a slice too large for a single record.
-               w.writeRecord(&record{recordTypeHandshake, w.major, w.minor, msg.marshal()})
-       default:
-               fmt.Printf("processControlMessage: unknown %#v\n", msg)
-       }
-}
-
-func (w *recordWriter) processAppMessage(appMsg []byte) {
-       if closed(w.appChan) {
-               w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, []byte{byte(alertCloseNotify)}})
-               w.shutdown = true
-               return
-       }
-
-       var done int
-       for done < len(appMsg) {
-               todo := len(appMsg)
-               if todo > maxTLSPlaintext {
-                       todo = maxTLSPlaintext
-               }
-               w.writeRecord(&record{recordTypeApplicationData, w.major, w.minor, appMsg[done : done+todo]})
-               done += todo
-       }
-}
index 5fbf850daa0dda523da693501ee9f9f88057ef3b..1a5da3ac43d929755d1140c776d442a1868c8fc9 100644 (file)
 package tls
 
 import (
-       "io"
        "os"
        "net"
-       "time"
 )
 
-// A Conn represents a secure connection.
-type Conn struct {
-       net.Conn
-       writeChan                 chan<- []byte
-       readChan                  <-chan []byte
-       requestChan               chan<- interface{}
-       readBuf                   []byte
-       eof                       bool
-       readTimeout, writeTimeout int64
-}
-
-func timeout(c chan<- bool, nsecs int64) {
-       time.Sleep(nsecs)
-       c <- true
-}
-
-func (tls *Conn) Read(p []byte) (int, os.Error) {
-       if len(tls.readBuf) == 0 {
-               if tls.eof {
-                       return 0, os.EOF
-               }
-
-               var timeoutChan chan bool
-               if tls.readTimeout > 0 {
-                       timeoutChan = make(chan bool)
-                       go timeout(timeoutChan, tls.readTimeout)
-               }
-
-               select {
-               case b := <-tls.readChan:
-                       tls.readBuf = b
-               case <-timeoutChan:
-                       return 0, os.EAGAIN
-               }
-
-               // TLS distinguishes between orderly closes and truncations. An
-               // orderly close is represented by a zero length slice.
-               if closed(tls.readChan) {
-                       return 0, io.ErrUnexpectedEOF
-               }
-               if len(tls.readBuf) == 0 {
-                       tls.eof = true
-                       return 0, os.EOF
-               }
-       }
-
-       n := copy(p, tls.readBuf)
-       tls.readBuf = tls.readBuf[n:]
-       return n, nil
-}
-
-func (tls *Conn) Write(p []byte) (int, os.Error) {
-       if tls.eof || closed(tls.readChan) {
-               return 0, os.EOF
-       }
-
-       var timeoutChan chan bool
-       if tls.writeTimeout > 0 {
-               timeoutChan = make(chan bool)
-               go timeout(timeoutChan, tls.writeTimeout)
-       }
-
-       select {
-       case tls.writeChan <- p:
-       case <-timeoutChan:
-               return 0, os.EAGAIN
-       }
-
-       return len(p), nil
-}
-
-func (tls *Conn) Close() os.Error {
-       close(tls.writeChan)
-       close(tls.requestChan)
-       tls.eof = true
-       return nil
-}
-
-func (tls *Conn) SetTimeout(nsec int64) os.Error {
-       tls.readTimeout = nsec
-       tls.writeTimeout = nsec
-       return nil
-}
-
-func (tls *Conn) SetReadTimeout(nsec int64) os.Error {
-       tls.readTimeout = nsec
-       return nil
-}
-
-func (tls *Conn) SetWriteTimeout(nsec int64) os.Error {
-       tls.writeTimeout = nsec
-       return nil
-}
-
-func (tls *Conn) GetConnectionState() ConnectionState {
-       replyChan := make(chan ConnectionState)
-       tls.requestChan <- getConnectionState{replyChan}
-       return <-replyChan
-}
-
-func (tls *Conn) WaitConnectionState() ConnectionState {
-       replyChan := make(chan ConnectionState)
-       tls.requestChan <- waitConnectionState{replyChan}
-       return <-replyChan
-}
-
-type handshaker interface {
-       loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config)
-}
-
-// Server establishes a secure connection over the given connection and acts
-// as a TLS server.
-func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn {
-       if config == nil {
-               config = defaultConfig()
-       }
-       tls := new(Conn)
-       tls.Conn = conn
-
-       writeChan := make(chan []byte)
-       readChan := make(chan []byte)
-       requestChan := make(chan interface{})
-
-       tls.writeChan = writeChan
-       tls.readChan = readChan
-       tls.requestChan = requestChan
-
-       handshakeWriterChan := make(chan interface{})
-       processorHandshakeChan := make(chan interface{})
-       handshakeProcessorChan := make(chan interface{})
-       readerProcessorChan := make(chan *record)
-
-       go new(recordWriter).loop(conn, writeChan, handshakeWriterChan)
-       go recordReader(readerProcessorChan, conn)
-       go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan)
-       go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config)
-
-       return tls
-}
-
 func Server(conn net.Conn, config *Config) *Conn {
-       return startTLSGoroutines(conn, new(serverHandshake), config)
+       return &Conn{conn: conn, config: config}
 }
 
 func Client(conn net.Conn, config *Config) *Conn {
-       return startTLSGoroutines(conn, new(clientHandshake), config)
+       return &Conn{conn: conn, config: config, isClient: true}
 }
 
 type Listener struct {
@@ -180,22 +38,24 @@ func (l *Listener) Addr() net.Addr { return l.listener.Addr() }
 
 // NewListener creates a Listener which accepts connections from an inner
 // Listener and wraps each connection with Server.
+// The configuration config must be non-nil and must have
+// at least one certificate.
 func NewListener(listener net.Listener, config *Config) (l *Listener) {
-       if config == nil {
-               config = defaultConfig()
-       }
        l = new(Listener)
        l.listener = listener
        l.config = config
        return
 }
 
-func Listen(network, laddr string) (net.Listener, os.Error) {
+func Listen(network, laddr string, config *Config) (net.Listener, os.Error) {
+       if config == nil || len(config.Certificates) == 0 {
+               return nil, os.NewError("tls.Listen: no certificates in configuration")
+       }
        l, err := net.Listen(network, laddr)
        if err != nil {
                return nil, err
        }
-       return NewListener(l, nil), nil
+       return NewListener(l, config), nil
 }
 
 func Dial(network, laddr, raddr string) (net.Conn, os.Error) {