]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add initial client implementation.
authorAdam Langley <agl@golang.org>
Sat, 21 Nov 2009 23:53:03 +0000 (15:53 -0800)
committerAdam Langley <agl@golang.org>
Sat, 21 Nov 2009 23:53:03 +0000 (15:53 -0800)
R=rsc, agl
CC=golang-dev
https://golang.org/cl/157076

src/pkg/crypto/tls/Makefile
src/pkg/crypto/tls/ca_set.go [new file with mode: 0644]
src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/handshake_client.go [new file with mode: 0644]
src/pkg/crypto/tls/handshake_messages.go
src/pkg/crypto/tls/handshake_messages_test.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/record_process.go
src/pkg/crypto/tls/tls.go

index dd3df29573eebce62f0fcd8cf0f4ac92eea020d0..0c2d339988a6d4e23e05793451300b819d428685 100644 (file)
@@ -8,12 +8,14 @@ TARG=crypto/tls
 GOFILES=\
        alert.go\
        common.go\
+       handshake_client.go\
        handshake_messages.go\
        handshake_server.go\
        prf.go\
        record_process.go\
        record_read.go\
        record_write.go\
+       ca_set.go\
        tls.go\
 
 include $(GOROOT)/src/Make.pkg
diff --git a/src/pkg/crypto/tls/ca_set.go b/src/pkg/crypto/tls/ca_set.go
new file mode 100644 (file)
index 0000000..e8cddd6
--- /dev/null
@@ -0,0 +1,75 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+       "crypto/x509";
+       "encoding/pem";
+)
+
+// A CASet is a set of certificates.
+type CASet struct {
+       bySubjectKeyId  map[string]*x509.Certificate;
+       byName          map[string]*x509.Certificate;
+}
+
+func NewCASet() *CASet {
+       return &CASet{
+               make(map[string]*x509.Certificate),
+               make(map[string]*x509.Certificate),
+       }
+}
+
+func nameToKey(name *x509.Name) string {
+       return name.Country + "/" + name.OrganizationalUnit + "/" + name.OrganizationalUnit + "/" + name.CommonName
+}
+
+// FindParent attempts to find the certificate in s which signs the given
+// certificate. If no such certificate can be found, it returns nil.
+func (s *CASet) FindParent(cert *x509.Certificate) (parent *x509.Certificate) {
+       var ok bool;
+
+       if len(cert.AuthorityKeyId) > 0 {
+               parent, ok = s.bySubjectKeyId[string(cert.AuthorityKeyId)]
+       } else {
+               parent, ok = s.byName[nameToKey(&cert.Issuer)]
+       }
+
+       if !ok {
+               return nil
+       }
+       return parent;
+}
+
+// SetFromPEM attempts to parse a series of PEM encoded root certificates. It
+// appends any certificates found to s and returns true if any certificates
+// were successfully parsed. On many Linux systems, /etc/ssl/cert.pem will
+// contains the system wide set of root CAs in a format suitable for this
+// function.
+func (s *CASet) SetFromPEM(pemCerts []byte) (ok bool) {
+       for len(pemCerts) > 0 {
+               var block *pem.Block;
+               block, pemCerts = pem.Decode(pemCerts);
+               if block == nil {
+                       break
+               }
+               if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
+                       continue
+               }
+
+               cert, err := x509.ParseCertificate(block.Bytes);
+               if err != nil {
+                       continue
+               }
+
+               if len(cert.SubjectKeyId) > 0 {
+                       s.bySubjectKeyId[string(cert.SubjectKeyId)] = cert
+               }
+               s.byName[nameToKey(&cert.Subject)] = cert;
+               ok = true;
+       }
+
+       return;
+}
index e295cd472804817b3f8d20b8b8a54f30c08d3a45..e1318a8930d43461ba371643be263a4069704b0a 100644 (file)
@@ -17,6 +17,9 @@ const (
        maxTLSCiphertext        = 16384 + 2048;
        // maxHandshakeMsg is the largest single handshake message that we'll buffer.
        maxHandshakeMsg = 65536;
+       // defaultMajor and defaultMinor are the maximum TLS version that we support.
+       defaultMajor    = 3;
+       defaultMinor    = 2;
 )
 
 
@@ -64,6 +67,7 @@ type Config struct {
        // Time returns the current time as the number of seconds since the epoch.
        Time            func() int64;
        Certificates    []Certificate;
+       RootCAs         *CASet;
 }
 
 type Certificate struct {
diff --git a/src/pkg/crypto/tls/handshake_client.go b/src/pkg/crypto/tls/handshake_client.go
new file mode 100644 (file)
index 0000000..db9bb8c
--- /dev/null
@@ -0,0 +1,225 @@
+// Copyright 2009 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tls
+
+import (
+       "crypto/hmac";
+       "crypto/rc4";
+       "crypto/rsa";
+       "crypto/sha1";
+       "crypto/subtle";
+       "crypto/x509";
+       "io";
+)
+
+// A serverHandshake performs the server side of the TLS 1.1 handshake protocol.
+type clientHandshake struct {
+       writeChan       chan<- interface{};
+       controlChan     chan<- interface{};
+       msgChan         <-chan interface{};
+       config          *Config;
+}
+
+func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config) {
+       h.writeChan = writeChan;
+       h.controlChan = controlChan;
+       h.msgChan = msgChan;
+       h.config = config;
+
+       defer close(writeChan);
+       defer close(controlChan);
+
+       finishedHash := newFinishedHash();
+
+       hello := &clientHelloMsg{
+               major: defaultMajor,
+               minor: defaultMinor,
+               cipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA},
+               compressionMethods: []uint8{compressionNone},
+               random: make([]byte, 32),
+       };
+
+       currentTime := uint32(config.Time());
+       hello.random[0] = byte(currentTime >> 24);
+       hello.random[1] = byte(currentTime >> 16);
+       hello.random[2] = byte(currentTime >> 8);
+       hello.random[3] = byte(currentTime);
+       _, err := io.ReadFull(config.Rand, hello.random[4:len(hello.random)]);
+       if err != nil {
+               h.error(alertInternalError);
+               return;
+       }
+
+       finishedHash.Write(hello.marshal());
+       writeChan <- writerSetVersion{defaultMajor, defaultMinor};
+       writeChan <- hello;
+
+       serverHello, ok := h.readHandshakeMsg().(*serverHelloMsg);
+       if !ok {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+       finishedHash.Write(serverHello.marshal());
+       major, minor, ok := mutualVersion(serverHello.major, serverHello.minor);
+       if !ok {
+               h.error(alertProtocolVersion);
+               return;
+       }
+
+       writeChan <- writerSetVersion{major, minor};
+
+       if serverHello.cipherSuite != TLS_RSA_WITH_RC4_128_SHA ||
+               serverHello.compressionMethod != compressionNone {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+
+       certMsg, ok := h.readHandshakeMsg().(*certificateMsg);
+       if !ok || len(certMsg.certificates) == 0 {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+       finishedHash.Write(certMsg.marshal());
+
+       certs := make([]*x509.Certificate, len(certMsg.certificates));
+       for i, asn1Data := range certMsg.certificates {
+               cert, err := x509.ParseCertificate(asn1Data);
+               if err != nil {
+                       h.error(alertBadCertificate);
+                       return;
+               }
+               certs[i] = cert;
+       }
+
+       // TODO(agl): do better validation of certs: max path length, name restrictions etc.
+       for i := 1; i < len(certs); i++ {
+               if certs[i-1].CheckSignatureFrom(certs[i]) != nil {
+                       h.error(alertBadCertificate);
+                       return;
+               }
+       }
+
+       if config.RootCAs != nil {
+               root := config.RootCAs.FindParent(certs[len(certs)-1]);
+               if root == nil {
+                       h.error(alertBadCertificate);
+                       return;
+               }
+               if certs[len(certs)-1].CheckSignatureFrom(root) != nil {
+                       h.error(alertBadCertificate);
+                       return;
+               }
+       }
+
+       pub, ok := certs[0].PublicKey.(*rsa.PublicKey);
+       if !ok {
+               h.error(alertUnsupportedCertificate);
+               return;
+       }
+
+       shd, ok := h.readHandshakeMsg().(*serverHelloDoneMsg);
+       if !ok {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+       finishedHash.Write(shd.marshal());
+
+       ckx := new(clientKeyExchangeMsg);
+       preMasterSecret := make([]byte, 48);
+       // Note that the version number in the preMasterSecret must be the
+       // version offered in the ClientHello.
+       preMasterSecret[0] = defaultMajor;
+       preMasterSecret[1] = defaultMinor;
+       _, err = io.ReadFull(config.Rand, preMasterSecret[2:len(preMasterSecret)]);
+       if err != nil {
+               h.error(alertInternalError);
+               return;
+       }
+
+       ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret);
+       if err != nil {
+               h.error(alertInternalError);
+               return;
+       }
+
+       finishedHash.Write(ckx.marshal());
+       writeChan <- ckx;
+
+       suite := cipherSuites[0];
+       masterSecret, clientMAC, serverMAC, clientKey, serverKey :=
+               keysFromPreMasterSecret11(preMasterSecret, hello.random, serverHello.random, suite.hashLength, suite.cipherKeyLength);
+
+       cipher, _ := rc4.NewCipher(clientKey);
+       writeChan <- writerChangeCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)};
+
+       finished := new(finishedMsg);
+       finished.verifyData = finishedHash.clientSum(masterSecret);
+       finishedHash.Write(finished.marshal());
+       writeChan <- finished;
+
+       // TODO(agl): this is cut-through mode which should probably be an option.
+       writeChan <- writerEnableApplicationData{};
+
+       _, ok = h.readHandshakeMsg().(changeCipherSpec);
+       if !ok {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+
+       cipher2, _ := rc4.NewCipher(serverKey);
+       controlChan <- &newCipherSpec{cipher2, hmac.New(sha1.New(), serverMAC)};
+
+       serverFinished, ok := h.readHandshakeMsg().(*finishedMsg);
+       if !ok {
+               h.error(alertUnexpectedMessage);
+               return;
+       }
+
+       verify := finishedHash.serverSum(masterSecret);
+       if len(verify) != len(serverFinished.verifyData) ||
+               subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
+               h.error(alertHandshakeFailure);
+               return;
+       }
+
+       controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0};
+
+       // This should just block forever.
+       _ = h.readHandshakeMsg();
+       h.error(alertUnexpectedMessage);
+       return;
+}
+
+func (h *clientHandshake) readHandshakeMsg() interface{} {
+       v := <-h.msgChan;
+       if closed(h.msgChan) {
+               // If the channel closed then the processor received an error
+               // from the peer and we don't want to echo it back to them.
+               h.msgChan = nil;
+               return 0;
+       }
+       if _, ok := v.(alert); ok {
+               // We got an alert from the processor. We forward to the writer
+               // and shutdown.
+               h.writeChan <- v;
+               h.msgChan = nil;
+               return 0;
+       }
+       return v;
+}
+
+func (h *clientHandshake) error(e alertType) {
+       if h.msgChan != nil {
+               // If we didn't get an error from the processor, then we need
+               // to tell it about the error.
+               go func() {
+                       for _ = range h.msgChan {
+                       }
+               }();
+               h.controlChan <- ConnectionState{false, "", e};
+               close(h.controlChan);
+               h.writeChan <- alert{alertLevelError, e};
+       }
+}
index 87e2e779e39f2c412cfef36e8e73d04a0bbe2668..65dae8762594b851de68adc67bda56d7b8f71ca8 100644 (file)
@@ -45,7 +45,7 @@ func (m *clientHelloMsg) marshal() []byte {
 }
 
 func (m *clientHelloMsg) unmarshal(data []byte) bool {
-       if len(data) < 39 {
+       if len(data) < 43 {
                return false
        }
        m.raw = data;
@@ -120,6 +120,30 @@ func (m *serverHelloMsg) marshal() []byte {
        return x;
 }
 
+func (m *serverHelloMsg) unmarshal(data []byte) bool {
+       if len(data) < 42 {
+               return false
+       }
+       m.raw = data;
+       m.major = data[4];
+       m.minor = data[5];
+       m.random = data[6:38];
+       sessionIdLen := int(data[38]);
+       if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
+               return false
+       }
+       m.sessionId = data[39 : 39+sessionIdLen];
+       data = data[39+sessionIdLen : len(data)];
+       if len(data) < 3 {
+               return false
+       }
+       m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]);
+       m.compressionMethod = data[2];
+
+       // Trailing data is allowed because extensions may be present.
+       return true;
+}
+
 type certificateMsg struct {
        raw             []byte;
        certificates    [][]byte;
@@ -160,6 +184,43 @@ func (m *certificateMsg) marshal() (x []byte) {
        return;
 }
 
+func (m *certificateMsg) unmarshal(data []byte) bool {
+       if len(data) < 7 {
+               return false
+       }
+
+       m.raw = data;
+       certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]);
+       if uint32(len(data)) != certsLen+7 {
+               return false
+       }
+
+       numCerts := 0;
+       d := data[7:len(data)];
+       for certsLen > 0 {
+               if len(d) < 4 {
+                       return false
+               }
+               certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]);
+               if uint32(len(d)) < 3+certLen {
+                       return false
+               }
+               d = d[3+certLen : len(d)];
+               certsLen -= 3 + certLen;
+               numCerts++;
+       }
+
+       m.certificates = make([][]byte, numCerts);
+       d = data[7:len(data)];
+       for i := 0; i < numCerts; i++ {
+               certLen := uint32(d[0])<<24 | uint32(d[1])<<8 | uint32(d[2]);
+               m.certificates[i] = d[3 : 3+certLen];
+               d = d[3+certLen : len(d)];
+       }
+
+       return true;
+}
+
 type serverHelloDoneMsg struct{}
 
 func (m *serverHelloDoneMsg) marshal() []byte {
@@ -168,6 +229,10 @@ func (m *serverHelloDoneMsg) marshal() []byte {
        return x;
 }
 
+func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
+       return len(data) == 4
+}
+
 type clientKeyExchangeMsg struct {
        raw             []byte;
        ciphertext      []byte;
index 5dafc388bcad221bf812c7aef11368dc68331c37..c580f65c6854e01778b7824b07b1557f2e15dcd3 100644 (file)
@@ -13,6 +13,8 @@ import (
 
 var tests = []interface{}{
        &clientHelloMsg{},
+       &serverHelloMsg{},
+       &certificateMsg{},
        &clientKeyExchangeMsg{},
        &finishedMsg{},
 }
@@ -59,6 +61,20 @@ func TestMarshalUnmarshal(t *testing.T) {
        }
 }
 
+func TestFuzz(t *testing.T) {
+       rand := rand.New(rand.NewSource(0));
+       for _, iface := range tests {
+               m := iface.(testMessage);
+
+               for j := 0; j < 1000; j++ {
+                       len := rand.Intn(100);
+                       bytes := randomBytes(len, rand);
+                       // This just looks for crashes due to bounds errors etc.
+                       m.unmarshal(bytes);
+               }
+       }
+}
+
 func randomBytes(n int, rand *rand.Rand) []byte {
        r := make([]byte, n);
        for i := 0; i < n; i++ {
@@ -82,9 +98,30 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.NewValue(m);
 }
 
+func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &serverHelloMsg{};
+       m.major = uint8(rand.Intn(256));
+       m.minor = uint8(rand.Intn(256));
+       m.random = randomBytes(32, rand);
+       m.sessionId = randomBytes(rand.Intn(32), rand);
+       m.cipherSuite = uint16(rand.Int31());
+       m.compressionMethod = uint8(rand.Intn(256));
+       return reflect.NewValue(m);
+}
+
+func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &certificateMsg{};
+       numCerts := rand.Intn(20);
+       m.certificates = make([][]byte, numCerts);
+       for i := 0; i < numCerts; i++ {
+               m.certificates[i] = randomBytes(rand.Intn(10)+1, rand)
+       }
+       return reflect.NewValue(m);
+}
+
 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &clientKeyExchangeMsg{};
-       m.ciphertext = randomBytes(rand.Intn(1000), rand);
+       m.ciphertext = randomBytes(rand.Intn(1000)+1, rand);
        return reflect.NewValue(m);
 }
 
index 0e04c42af2d292bf3fc126897537e76f47728408..2e7760365c42469d93bd7422f7179e1ebe6b3f56 100644 (file)
@@ -224,12 +224,12 @@ func (h *serverHandshake) error(e alertType) {
        if h.msgChan != nil {
                // If we didn't get an error from the processor, then we need
                // to tell it about the error.
-               h.controlChan <- ConnectionState{false, "", e};
-               close(h.controlChan);
                go func() {
                        for _ = range h.msgChan {
                        }
                }();
+               h.controlChan <- ConnectionState{false, "", e};
+               close(h.controlChan);
                h.writeChan <- alert{alertLevelError, e};
        }
 }
index b7edd9fd16c1352543a22b8f97659d612f8ad831..e356d67ca29fce0716785578e1d00e5d474596b5 100644 (file)
@@ -210,7 +210,7 @@ func (p *recordProcessor) processRecord(r *record) {
                        return;
                }
                p.recordRead = nil;
-               p.appData = r.payload;
+               p.appData = r.payload[0 : len(r.payload)-p.mac.Size()];
                p.appDataSend = p.appDataChan;
        default:
                p.error(alertUnexpectedMessage);
@@ -283,6 +283,12 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) {
        switch data[0] {
        case typeClientHello:
                m = new(clientHelloMsg)
+       case typeServerHello:
+               m = new(serverHelloMsg)
+       case typeCertificate:
+               m = new(certificateMsg)
+       case typeServerHelloDone:
+               m = new(serverHelloDoneMsg)
        case typeClientKeyExchange:
                m = new(clientKeyExchangeMsg)
        default:
index a162487de2c15e3753abe080c16f5a18ad77aea7..c5a0f69d5d4cfe97759198184982fc9c4dd6ff5d 100644 (file)
@@ -112,9 +112,19 @@ func (tls *Conn) GetConnectionState() ConnectionState {
        return <-replyChan;
 }
 
+func (tls *Conn) WaitConnectionState() ConnectionState {
+       replyChan := make(chan ConnectionState);
+       tls.requestChan <- waitConnectionState{replyChan};
+       return <-replyChan;
+}
+
+type handshaker interface {
+       loop(writeChan chan<- interface{}, controlChan chan<- interface{}, msgChan <-chan interface{}, config *Config);
+}
+
 // Server establishes a secure connection over the given connection and acts
 // as a TLS server.
-func Server(conn net.Conn, config *Config) *Conn {
+func startTLSGoroutines(conn net.Conn, h handshaker, config *Config) *Conn {
        tls := new(Conn);
        tls.Conn = conn;
 
@@ -134,11 +144,19 @@ func Server(conn net.Conn, config *Config) *Conn {
        go new(recordWriter).loop(conn, writeChan, handshakeWriterChan);
        go recordReader(readerProcessorChan, conn);
        go new(recordProcessor).loop(readChan, requestChan, handshakeProcessorChan, readerProcessorChan, processorHandshakeChan);
-       go new(serverHandshake).loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config);
+       go h.loop(handshakeWriterChan, handshakeProcessorChan, processorHandshakeChan, config);
 
        return tls;
 }
 
+func Server(conn net.Conn, config *Config) *Conn {
+       return startTLSGoroutines(conn, new(serverHandshake), config)
+}
+
+func Client(conn net.Conn, config *Config) *Conn {
+       return startTLSGoroutines(conn, new(clientHandshake), config)
+}
+
 type Listener struct {
        listener        net.Listener;
        config          *Config;