]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: add client OCSP stapling support.
authorAdam Langley <agl@golang.org>
Wed, 14 Jul 2010 14:40:15 +0000 (10:40 -0400)
committerAdam Langley <agl@golang.org>
Wed, 14 Jul 2010 14:40:15 +0000 (10:40 -0400)
R=r, rsc
CC=golang-dev
https://golang.org/cl/1750042

src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/conn.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_messages.go
src/pkg/crypto/tls/handshake_messages_test.go
src/pkg/crypto/tls/handshake_server_test.go

index 56c22cf7d8ea76082293d3d395661859ef7e880a..7c6940aa32be3ede8a5e2581f8be2e97804f80e4 100644 (file)
@@ -38,6 +38,7 @@ const (
        typeClientHello       uint8 = 1
        typeServerHello       uint8 = 2
        typeCertificate       uint8 = 11
+       typeCertificateStatus uint8 = 22
        typeServerHelloDone   uint8 = 14
        typeClientKeyExchange uint8 = 16
        typeFinished          uint8 = 20
@@ -45,25 +46,30 @@ const (
 )
 
 // TLS cipher suites.
-var (
+const (
        TLS_RSA_WITH_RC4_128_SHA uint16 = 5
 )
 
 // TLS compression types.
-var (
+const (
        compressionNone uint8 = 0
 )
 
 // TLS extension numbers
 var (
-       extensionServerName   uint16 = 0
-       extensionNextProtoNeg uint16 = 13172 // not IANA assigned
+       extensionServerName    uint16 = 0
+       extensionStatusRequest uint16 = 5
+       extensionNextProtoNeg  uint16 = 13172 // not IANA assigned
+)
+
+// TLS CertificateStatusType (RFC 3546)
+const (
+       statusTypeOCSP uint8 = 1
 )
 
 type ConnectionState struct {
        HandshakeComplete  bool
-       CipherSuite        string
-       Error              alert
+       CipherSuite        uint16
        NegotiatedProtocol string
 }
 
index 0798e26f65b5039ec079800102c06f826affb5e2..aa224e49d2275b66fcf978723828f271f69f6990 100644 (file)
@@ -26,6 +26,7 @@ type Conn struct {
        config            *Config    // configuration passed to constructor
        handshakeComplete bool
        cipherSuite       uint16
+       ocspResponse      []byte // stapled OCSP response
 
        clientProtocol string
 
@@ -531,6 +532,8 @@ func (c *Conn) readHandshake() (interface{}, os.Error) {
                m = new(serverHelloMsg)
        case typeCertificate:
                m = new(certificateMsg)
+       case typeCertificateStatus:
+               m = new(certificateStatusMsg)
        case typeServerHelloDone:
                m = new(serverHelloDoneMsg)
        case typeClientKeyExchange:
@@ -625,11 +628,26 @@ func (c *Conn) Handshake() os.Error {
        return c.serverHandshake()
 }
 
-// If c is a TLS server, ClientConnection returns the protocol
-// requested by the client during the TLS handshake.
-// Handshake must have been called already.
-func (c *Conn) ClientConnection() string {
+// ConnectionState returns basic TLS details about the connection.
+func (c *Conn) ConnectionState() ConnectionState {
        c.handshakeMutex.Lock()
        defer c.handshakeMutex.Unlock()
-       return c.clientProtocol
+
+       var state ConnectionState
+       state.HandshakeComplete = c.handshakeComplete
+       if c.handshakeComplete {
+               state.NegotiatedProtocol = c.clientProtocol
+               state.CipherSuite = c.cipherSuite
+       }
+
+       return state
+}
+
+// OCSPResponse returns the stapled OCSP response from the TLS server, if
+// any. (Only valid for client connections.)
+func (c *Conn) OCSPResponse() []byte {
+       c.handshakeMutex.Lock()
+       defer c.handshakeMutex.Unlock()
+
+       return c.ocspResponse
 }
index dd3009802db447074d6c65a24e492dd721d01fc8..b3b597327fc90a5c6bbe4ec6571f8c91a8aaa478 100644 (file)
@@ -18,21 +18,24 @@ import (
 func (c *Conn) clientHandshake() os.Error {
        finishedHash := newFinishedHash()
 
-       config := defaultConfig()
+       if c.config == nil {
+               c.config = defaultConfig()
+       }
 
        hello := &clientHelloMsg{
                vers:               maxVersion,
                cipherSuites:       []uint16{TLS_RSA_WITH_RC4_128_SHA},
                compressionMethods: []uint8{compressionNone},
                random:             make([]byte, 32),
+               ocspStapling:       true,
        }
 
-       t := uint32(config.Time())
+       t := uint32(c.config.Time())
        hello.random[0] = byte(t >> 24)
        hello.random[1] = byte(t >> 16)
        hello.random[2] = byte(t >> 8)
        hello.random[3] = byte(t)
-       _, err := io.ReadFull(config.Rand, hello.random[4:])
+       _, err := io.ReadFull(c.config.Rand, hello.random[4:])
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
@@ -89,8 +92,8 @@ func (c *Conn) clientHandshake() os.Error {
        }
 
        // TODO(rsc): Find certificates for OS X 10.6.
-       if false && config.RootCAs != nil {
-               root := config.RootCAs.FindParent(certs[len(certs)-1])
+       if false && c.config.RootCAs != nil {
+               root := c.config.RootCAs.FindParent(certs[len(certs)-1])
                if root == nil {
                        return c.sendAlert(alertBadCertificate)
                }
@@ -104,6 +107,22 @@ func (c *Conn) clientHandshake() os.Error {
                return c.sendAlert(alertUnsupportedCertificate)
        }
 
+       if serverHello.certStatus {
+               msg, err = c.readHandshake()
+               if err != nil {
+                       return err
+               }
+               cs, ok := msg.(*certificateStatusMsg)
+               if !ok {
+                       return c.sendAlert(alertUnexpectedMessage)
+               }
+               finishedHash.Write(cs.marshal())
+
+               if cs.statusType == statusTypeOCSP {
+                       c.ocspResponse = cs.response
+               }
+       }
+
        msg, err = c.readHandshake()
        if err != nil {
                return err
@@ -118,12 +137,12 @@ func (c *Conn) clientHandshake() os.Error {
        preMasterSecret := make([]byte, 48)
        preMasterSecret[0] = byte(hello.vers >> 8)
        preMasterSecret[1] = byte(hello.vers)
-       _, err = io.ReadFull(config.Rand, preMasterSecret[2:])
+       _, err = io.ReadFull(c.config.Rand, preMasterSecret[2:])
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
 
-       ckx.ciphertext, err = rsa.EncryptPKCS1v15(config.Rand, pub, preMasterSecret)
+       ckx.ciphertext, err = rsa.EncryptPKCS1v15(c.config.Rand, pub, preMasterSecret)
        if err != nil {
                return c.sendAlert(alertInternalError)
        }
index f0a48c8630ab7ceb696fa7d07dd69f54f4087a94..13c05fe574d67ed20ad66604974d4649cd0dad8a 100644 (file)
@@ -13,6 +13,7 @@ type clientHelloMsg struct {
        compressionMethods []uint8
        nextProtoNeg       bool
        serverName         string
+       ocspStapling       bool
 }
 
 func (m *clientHelloMsg) marshal() []byte {
@@ -26,6 +27,10 @@ func (m *clientHelloMsg) marshal() []byte {
        if m.nextProtoNeg {
                numExtensions++
        }
+       if m.ocspStapling {
+               extensionsLength += 1 + 2 + 2
+               numExtensions++
+       }
        if len(m.serverName) > 0 {
                extensionsLength += 5 + len(m.serverName)
                numExtensions++
@@ -101,6 +106,16 @@ func (m *clientHelloMsg) marshal() []byte {
                copy(z[5:], []byte(m.serverName))
                z = z[l:]
        }
+       if m.ocspStapling {
+               // RFC 4366, section 3.6
+               z[0] = byte(extensionStatusRequest >> 8)
+               z[1] = byte(extensionStatusRequest)
+               z[2] = 0
+               z[3] = 5
+               z[4] = 1 // OCSP type
+               // Two zero valued uint16s for the two lengths.
+               z = z[9:]
+       }
 
        m.raw = x
 
@@ -148,6 +163,7 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
 
        m.nextProtoNeg = false
        m.serverName = ""
+       m.ocspStapling = false
 
        if len(data) == 0 {
                // ClientHello is optionally followed by extension data
@@ -202,6 +218,8 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
                                return false
                        }
                        m.nextProtoNeg = true
+               case extensionStatusRequest:
+                       m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
                }
                data = data[length:]
        }
@@ -218,6 +236,7 @@ type serverHelloMsg struct {
        compressionMethod uint8
        nextProtoNeg      bool
        nextProtos        []string
+       certStatus        bool
 }
 
 func (m *serverHelloMsg) marshal() []byte {
@@ -238,6 +257,9 @@ func (m *serverHelloMsg) marshal() []byte {
                nextProtoLen += len(m.nextProtos)
                extensionsLength += nextProtoLen
        }
+       if m.certStatus {
+               numExtensions++
+       }
        if numExtensions > 0 {
                extensionsLength += 4 * numExtensions
                length += 2 + extensionsLength
@@ -281,6 +303,11 @@ func (m *serverHelloMsg) marshal() []byte {
                        z = z[1+l:]
                }
        }
+       if m.certStatus {
+               z[0] = byte(extensionStatusRequest >> 8)
+               z[1] = byte(extensionStatusRequest)
+               z = z[4:]
+       }
 
        m.raw = x
 
@@ -322,6 +349,7 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
 
        m.nextProtoNeg = false
        m.nextProtos = nil
+       m.certStatus = false
 
        if len(data) == 0 {
                // ServerHello is optionally followed by extension data
@@ -361,6 +389,11 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
                                m.nextProtos = append(m.nextProtos, string(d[0:l]))
                                d = d[l:]
                        }
+               case extensionStatusRequest:
+                       if length > 0 {
+                               return false
+                       }
+                       m.certStatus = true
                }
                data = data[length:]
        }
@@ -445,6 +478,61 @@ func (m *certificateMsg) unmarshal(data []byte) bool {
        return true
 }
 
+type certificateStatusMsg struct {
+       raw        []byte
+       statusType uint8
+       response   []byte
+}
+
+func (m *certificateStatusMsg) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+
+       var x []byte
+       if m.statusType == statusTypeOCSP {
+               x = make([]byte, 4+4+len(m.response))
+               x[0] = typeCertificateStatus
+               l := len(m.response) + 4
+               x[1] = byte(l >> 16)
+               x[2] = byte(l >> 8)
+               x[3] = byte(l)
+               x[4] = statusTypeOCSP
+
+               l -= 4
+               x[5] = byte(l >> 16)
+               x[6] = byte(l >> 8)
+               x[7] = byte(l)
+               copy(x[8:], m.response)
+       } else {
+               x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType}
+       }
+
+       m.raw = x
+       return x
+}
+
+func (m *certificateStatusMsg) unmarshal(data []byte) bool {
+       m.raw = data
+       if len(data) < 5 {
+               return false
+       }
+       m.statusType = data[4]
+
+       m.response = nil
+       if m.statusType == statusTypeOCSP {
+               if len(data) < 8 {
+                       return false
+               }
+               respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7])
+               if uint32(len(data)) != 4+4+respLen {
+                       return false
+               }
+               m.response = data[8:]
+       }
+       return true
+}
+
 type serverHelloDoneMsg struct{}
 
 func (m *serverHelloDoneMsg) marshal() []byte {
index 2e422cc6a007cc70852e7a867211d780c5723e34..274e16f9b5dbf8bf5581a30b89b0ac96c337eca4 100644 (file)
@@ -16,6 +16,7 @@ var tests = []interface{}{
        &serverHelloMsg{},
 
        &certificateMsg{},
+       &certificateStatusMsg{},
        &clientKeyExchangeMsg{},
        &finishedMsg{},
        &nextProtoMsg{},
@@ -111,6 +112,7 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        if rand.Intn(10) > 5 {
                m.serverName = randomString(rand.Intn(255), rand)
        }
+       m.ocspStapling = rand.Intn(10) > 5
 
        return reflect.NewValue(m)
 }
@@ -146,6 +148,17 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        return reflect.NewValue(m)
 }
 
+func (*certificateStatusMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &certificateStatusMsg{}
+       if rand.Intn(10) > 5 {
+               m.statusType = statusTypeOCSP
+               m.response = randomBytes(rand.Intn(10)+1, rand)
+       } else {
+               m.statusType = 42
+       }
+       return reflect.NewValue(m)
+}
+
 func (*clientKeyExchangeMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &clientKeyExchangeMsg{}
        m.ciphertext = randomBytes(rand.Intn(1000)+1, rand)
index d31dc497e37b50fda460f4d5b6843f524c4eac91..c1a72fce27b65bfbe326da7e9f1664cadf5d7ae8 100644 (file)
@@ -71,13 +71,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
 }
 
 func TestNoSuiteOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
+       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{0xff00}, []uint8{0}, false, "", false}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 
 }
 
 func TestNoCompressionOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
+       clientHello := &clientHelloMsg{nil, 0x0301, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, "", false}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 }