]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: extensions and Next Protocol Negotiation
authorAdam Langley <agl@golang.org>
Wed, 23 Dec 2009 19:13:09 +0000 (11:13 -0800)
committerAdam Langley <agl@golang.org>
Wed, 23 Dec 2009 19:13:09 +0000 (11:13 -0800)
Add support for TLS extensions in general and Next Protocol
Negotiation in particular.

R=rsc
CC=golang-dev
https://golang.org/cl/181045

src/pkg/crypto/tls/common.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_messages.go
src/pkg/crypto/tls/handshake_messages_test.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/handshake_server_test.go
src/pkg/crypto/tls/record_process.go
src/pkg/crypto/tls/record_process_test.go
src/pkg/crypto/tls/record_read.go

index 51de53389af8b9f10763461805ae20e3deab4d93..8ef8b09d8b5640ae7bd823aa0a12c21c2844b7d3 100644 (file)
@@ -41,6 +41,7 @@ const (
        typeServerHelloDone   uint8 = 14
        typeClientKeyExchange uint8 = 16
        typeFinished          uint8 = 20
+       typeNextProtocol      uint8 = 67 // Not IANA assigned
 )
 
 // TLS cipher suites.
@@ -53,10 +54,17 @@ var (
        compressionNone uint8 = 0
 )
 
+// TLS extension numbers
+var (
+       extensionServerName   uint16 = 0
+       extensionNextProtoNeg uint16 = 13172 // not IANA assigned
+)
+
 type ConnectionState struct {
-       HandshakeComplete bool
-       CipherSuite       string
-       Error             alertType
+       HandshakeComplete  bool
+       CipherSuite        string
+       Error              alertType
+       NegotiatedProtocol string
 }
 
 // A Config structure is used to configure a TLS client or server. After one
@@ -68,6 +76,9 @@ type Config struct {
        Time         func() int64
        Certificates []Certificate
        RootCAs      *CASet
+       // NextProtos is a list of supported, application level protocols.
+       // Currently only server-side handling is supported.
+       NextProtos []string
 }
 
 type Certificate struct {
index 4e31e70941031d257ad7c844ecc236569d991a66..d07e2d89f6a3e4d53461ea09fb61ef2de936a74c 100644 (file)
@@ -184,7 +184,7 @@ func (h *clientHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
                return
        }
 
-       controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}
+       controlChan <- ConnectionState{HandshakeComplete: true, CipherSuite: "TLS_RSA_WITH_RC4_128_SHA"}
 
        // This should just block forever.
        _ = h.readHandshakeMsg()
@@ -218,7 +218,7 @@ func (h *clientHandshake) error(e alertType) {
                        for _ = range h.msgChan {
                        }
                }()
-               h.controlChan <- ConnectionState{false, "", e}
+               h.controlChan <- ConnectionState{Error: e}
                close(h.controlChan)
                h.writeChan <- alert{alertLevelError, e}
        }
index 2870969eb3241a5e0474e3281c2aabcefc96e3a4..10d2ba3e2675e9932c4b6b25e53c6f3d044e75ec 100644 (file)
@@ -4,6 +4,8 @@
 
 package tls
 
+import "strings"
+
 type clientHelloMsg struct {
        raw                []byte
        major, minor       uint8
@@ -11,6 +13,8 @@ type clientHelloMsg struct {
        sessionId          []byte
        cipherSuites       []uint16
        compressionMethods []uint8
+       nextProtoNeg       bool
+       serverName         string
 }
 
 func (m *clientHelloMsg) marshal() []byte {
@@ -19,6 +23,20 @@ func (m *clientHelloMsg) marshal() []byte {
        }
 
        length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
+       numExtensions := 0
+       extensionsLength := 0
+       if m.nextProtoNeg {
+               numExtensions++
+       }
+       if len(m.serverName) > 0 {
+               extensionsLength += 5 + len(m.serverName)
+               numExtensions++
+       }
+       if numExtensions > 0 {
+               extensionsLength += 4 * numExtensions
+               length += 2 + extensionsLength
+       }
+
        x := make([]byte, 4+length)
        x[0] = typeClientHello
        x[1] = uint8(length >> 16)
@@ -39,6 +57,53 @@ func (m *clientHelloMsg) marshal() []byte {
        z := y[2+len(m.cipherSuites)*2:]
        z[0] = uint8(len(m.compressionMethods))
        copy(z[1:], m.compressionMethods)
+
+       z = z[1+len(m.compressionMethods):]
+       if numExtensions > 0 {
+               z[0] = byte(extensionsLength >> 8)
+               z[1] = byte(extensionsLength)
+               z = z[2:]
+       }
+       if m.nextProtoNeg {
+               z[0] = byte(extensionNextProtoNeg >> 8)
+               z[1] = byte(extensionNextProtoNeg)
+               // The length is always 0
+               z = z[4:]
+       }
+       if len(m.serverName) > 0 {
+               z[0] = byte(extensionServerName >> 8)
+               z[1] = byte(extensionServerName)
+               l := len(m.serverName) + 5
+               z[2] = byte(l >> 8)
+               z[3] = byte(l)
+               z = z[4:]
+
+               // RFC 3546, section 3.1
+               //
+               // struct {
+               //     NameType name_type;
+               //     select (name_type) {
+               //         case host_name: HostName;
+               //     } name;
+               // } ServerName;
+               //
+               // enum {
+               //     host_name(0), (255)
+               // } NameType;
+               //
+               // opaque HostName<1..2^16-1>;
+               //
+               // struct {
+               //     ServerName server_name_list<1..2^16-1>
+               // } ServerNameList;
+
+               z[1] = 1
+               z[3] = byte(len(m.serverName) >> 8)
+               z[4] = byte(len(m.serverName))
+               copy(z[5:], strings.Bytes(m.serverName))
+               z = z[l:]
+       }
+
        m.raw = x
 
        return x
@@ -82,7 +147,68 @@ func (m *clientHelloMsg) unmarshal(data []byte) bool {
        }
        m.compressionMethods = data[1 : 1+compressionMethodsLen]
 
-       // A ClientHello may be following by trailing data: RFC 4346 section 7.4.1.2
+       data = data[1+compressionMethodsLen:]
+
+       m.nextProtoNeg = false
+       m.serverName = ""
+
+       if len(data) == 0 {
+               // ClientHello is optionally followed by extension data
+               return true
+       }
+       if len(data) < 2 {
+               return false
+       }
+
+       extensionsLength := int(data[0])<<8 | int(data[1])
+       data = data[2:]
+       if extensionsLength != len(data) {
+               return false
+       }
+
+       for len(data) != 0 {
+               if len(data) < 4 {
+                       return false
+               }
+               extension := uint16(data[0])<<8 | uint16(data[1])
+               length := int(data[2])<<8 | int(data[3])
+               data = data[4:]
+               if len(data) < length {
+                       return false
+               }
+
+               switch extension {
+               case extensionServerName:
+                       if length < 2 {
+                               return false
+                       }
+                       numNames := int(data[0])<<8 | int(data[1])
+                       d := data[2:]
+                       for i := 0; i < numNames; i++ {
+                               if len(d) < 3 {
+                                       return false
+                               }
+                               nameType := d[0]
+                               nameLen := int(d[1])<<8 | int(d[2])
+                               d = d[3:]
+                               if len(d) < nameLen {
+                                       return false
+                               }
+                               if nameType == 0 {
+                                       m.serverName = string(d[0:nameLen])
+                                       break
+                               }
+                               d = d[nameLen:]
+                       }
+               case extensionNextProtoNeg:
+                       if length > 0 {
+                               return false
+                       }
+                       m.nextProtoNeg = true
+               }
+               data = data[length:]
+       }
+
        return true
 }
 
@@ -93,6 +219,8 @@ type serverHelloMsg struct {
        sessionId         []byte
        cipherSuite       uint16
        compressionMethod uint8
+       nextProtoNeg      bool
+       nextProtos        []string
 }
 
 func (m *serverHelloMsg) marshal() []byte {
@@ -101,6 +229,23 @@ func (m *serverHelloMsg) marshal() []byte {
        }
 
        length := 38 + len(m.sessionId)
+       numExtensions := 0
+       extensionsLength := 0
+
+       nextProtoLen := 0
+       if m.nextProtoNeg {
+               numExtensions++
+               for _, v := range m.nextProtos {
+                       nextProtoLen += len(v)
+               }
+               nextProtoLen += len(m.nextProtos)
+               extensionsLength += nextProtoLen
+       }
+       if numExtensions > 0 {
+               extensionsLength += 4 * numExtensions
+               length += 2 + extensionsLength
+       }
+
        x := make([]byte, 4+length)
        x[0] = typeServerHello
        x[1] = uint8(length >> 16)
@@ -115,11 +260,49 @@ func (m *serverHelloMsg) marshal() []byte {
        z[0] = uint8(m.cipherSuite >> 8)
        z[1] = uint8(m.cipherSuite)
        z[2] = uint8(m.compressionMethod)
+
+       z = z[3:]
+       if numExtensions > 0 {
+               z[0] = byte(extensionsLength >> 8)
+               z[1] = byte(extensionsLength)
+               z = z[2:]
+       }
+       if m.nextProtoNeg {
+               z[0] = byte(extensionNextProtoNeg >> 8)
+               z[1] = byte(extensionNextProtoNeg)
+               z[2] = byte(nextProtoLen >> 8)
+               z[3] = byte(nextProtoLen)
+               z = z[4:]
+
+               for _, v := range m.nextProtos {
+                       l := len(v)
+                       if l > 255 {
+                               l = 255
+                       }
+                       z[0] = byte(l)
+                       copy(z[1:], strings.Bytes(v[0:l]))
+                       z = z[1+l:]
+               }
+       }
+
        m.raw = x
 
        return x
 }
 
+func append(slice []string, elem string) []string {
+       if len(slice) < cap(slice) {
+               slice = slice[0 : len(slice)+1]
+               slice[len(slice)-1] = elem
+               return slice
+       }
+
+       fresh := make([]string, len(slice)+1, cap(slice)*2+1)
+       copy(fresh, slice)
+       fresh[len(slice)] = elem
+       return fresh
+}
+
 func (m *serverHelloMsg) unmarshal(data []byte) bool {
        if len(data) < 42 {
                return false
@@ -139,8 +322,53 @@ func (m *serverHelloMsg) unmarshal(data []byte) bool {
        }
        m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
        m.compressionMethod = data[2]
+       data = data[3:]
+
+       m.nextProtoNeg = false
+       m.nextProtos = nil
+
+       if len(data) == 0 {
+               // ServerHello is optionally followed by extension data
+               return true
+       }
+       if len(data) < 2 {
+               return false
+       }
+
+       extensionsLength := int(data[0])<<8 | int(data[1])
+       data = data[2:]
+       if len(data) != extensionsLength {
+               return false
+       }
+
+       for len(data) != 0 {
+               if len(data) < 4 {
+                       return false
+               }
+               extension := uint16(data[0])<<8 | uint16(data[1])
+               length := int(data[2])<<8 | int(data[3])
+               data = data[4:]
+               if len(data) < length {
+                       return false
+               }
+
+               switch extension {
+               case extensionNextProtoNeg:
+                       m.nextProtoNeg = true
+                       d := data
+                       for len(d) > 0 {
+                               l := int(d[0])
+                               d = d[1:]
+                               if l == 0 || l > len(d) {
+                                       return false
+                               }
+                               m.nextProtos = append(m.nextProtos, string(d[0:l]))
+                               d = d[l:]
+                       }
+               }
+               data = data[length:]
+       }
 
-       // Trailing data is allowed because extensions may be present.
        return true
 }
 
@@ -295,3 +523,63 @@ func (m *finishedMsg) unmarshal(data []byte) bool {
        m.verifyData = data[4:]
        return true
 }
+
+type nextProtoMsg struct {
+       raw   []byte
+       proto string
+}
+
+func (m *nextProtoMsg) marshal() []byte {
+       if m.raw != nil {
+               return m.raw
+       }
+       l := len(m.proto)
+       if l > 255 {
+               l = 255
+       }
+
+       padding := 32 - (l+2)%32
+       length := l + padding + 2
+       x := make([]byte, length+4)
+       x[0] = typeNextProtocol
+       x[1] = uint8(length >> 16)
+       x[2] = uint8(length >> 8)
+       x[3] = uint8(length)
+
+       y := x[4:]
+       y[0] = byte(l)
+       copy(y[1:], strings.Bytes(m.proto[0:l]))
+       y = y[1+l:]
+       y[0] = byte(padding)
+
+       m.raw = x
+
+       return x
+}
+
+func (m *nextProtoMsg) unmarshal(data []byte) bool {
+       m.raw = data
+
+       if len(data) < 5 {
+               return false
+       }
+       data = data[4:]
+       protoLen := int(data[0])
+       data = data[1:]
+       if len(data) < protoLen {
+               return false
+       }
+       m.proto = string(data[0:protoLen])
+       data = data[protoLen:]
+
+       if len(data) < 1 {
+               return false
+       }
+       paddingLen := int(data[0])
+       data = data[1:]
+       if len(data) != paddingLen {
+               return false
+       }
+
+       return true
+}
index 4bfdd6c5f1d215b7285c5f7b15fe394589f0923e..3c5902e2458a65594c4207584cdbe5081a05acf9 100644 (file)
@@ -14,9 +14,11 @@ import (
 var tests = []interface{}{
        &clientHelloMsg{},
        &serverHelloMsg{},
+
        &certificateMsg{},
        &clientKeyExchangeMsg{},
        &finishedMsg{},
+       &nextProtoMsg{},
 }
 
 type testMessage interface {
@@ -40,21 +42,26 @@ func TestMarshalUnmarshal(t *testing.T) {
                        marshaled := m1.marshal()
                        m2 := iface.(testMessage)
                        if !m2.unmarshal(marshaled) {
-                               t.Errorf("#%d failed to unmarshal %#v", i, m1)
+                               t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
                                break
                        }
                        m2.marshal() // to fill any marshal cache in the message
 
                        if !reflect.DeepEqual(m1, m2) {
-                               t.Errorf("#%d got:%#v want:%#v", i, m1, m2)
+                               t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
                                break
                        }
 
-                       // Now check that all prefixes are invalid.
-                       for j := 0; j < len(marshaled); j++ {
-                               if m2.unmarshal(marshaled[0:j]) {
-                                       t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
-                                       break
+                       if i >= 2 {
+                               // The first two message types (ClientHello and
+                               // ServerHello) are allowed to have parsable
+                               // prefixes because the extension data is
+                               // optional.
+                               for j := 0; j < len(marshaled); j++ {
+                                       if m2.unmarshal(marshaled[0:j]) {
+                                               t.Errorf("#%d unmarshaled a prefix of length %d of %#v", i, j, m1)
+                                               break
+                                       }
                                }
                        }
                }
@@ -83,6 +90,11 @@ func randomBytes(n int, rand *rand.Rand) []byte {
        return r
 }
 
+func randomString(n int, rand *rand.Rand) string {
+       b := randomBytes(n, rand)
+       return string(b)
+}
+
 func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &clientHelloMsg{}
        m.major = uint8(rand.Intn(256))
@@ -94,6 +106,12 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
                m.cipherSuites[i] = uint16(rand.Int31())
        }
        m.compressionMethods = randomBytes(rand.Intn(63)+1, rand)
+       if rand.Intn(10) > 5 {
+               m.nextProtoNeg = true
+       }
+       if rand.Intn(10) > 5 {
+               m.serverName = randomString(rand.Intn(255), rand)
+       }
 
        return reflect.NewValue(m)
 }
@@ -106,6 +124,17 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m.sessionId = randomBytes(rand.Intn(32), rand)
        m.cipherSuite = uint16(rand.Int31())
        m.compressionMethod = uint8(rand.Intn(256))
+
+       if rand.Intn(10) > 5 {
+               m.nextProtoNeg = true
+
+               n := rand.Intn(10)
+               m.nextProtos = make([]string, n)
+               for i := 0; i < n; i++ {
+                       m.nextProtos[i] = randomString(20, rand)
+               }
+       }
+
        return reflect.NewValue(m)
 }
 
@@ -130,3 +159,9 @@ func (*finishedMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m.verifyData = randomBytes(12, rand)
        return reflect.NewValue(m)
 }
+
+func (*nextProtoMsg) Generate(rand *rand.Rand, size int) reflect.Value {
+       m := &nextProtoMsg{}
+       m.proto = randomString(rand.Intn(255), rand)
+       return reflect.NewValue(m)
+}
index 5314e5cd197eb7a448d5937f90818874c94bcb76..50854d1543bb0602067e8c787a769db1c8501f0b 100644 (file)
@@ -108,6 +108,10 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
                return
        }
        hello.compressionMethod = compressionNone
+       if clientHello.nextProtoNeg {
+               hello.nextProtoNeg = true
+               hello.nextProtos = config.NextProtos
+       }
 
        finishedHash.Write(hello.marshal())
        writeChan <- writerSetVersion{major, minor}
@@ -165,6 +169,17 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
        cipher, _ := rc4.NewCipher(clientKey)
        controlChan <- &newCipherSpec{cipher, hmac.New(sha1.New(), clientMAC)}
 
+       clientProtocol := ""
+       if hello.nextProtoNeg {
+               nextProto, ok := h.readHandshakeMsg().(*nextProtoMsg)
+               if !ok {
+                       h.error(alertUnexpectedMessage)
+                       return
+               }
+               finishedHash.Write(nextProto.marshal())
+               clientProtocol = nextProto.proto
+       }
+
        clientFinished, ok := h.readHandshakeMsg().(*finishedMsg)
        if !ok {
                h.error(alertUnexpectedMessage)
@@ -178,7 +193,7 @@ func (h *serverHandshake) loop(writeChan chan<- interface{}, controlChan chan<-
                return
        }
 
-       controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}
+       controlChan <- ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, clientProtocol}
 
        finishedHash.Write(clientFinished.marshal())
 
@@ -228,7 +243,7 @@ func (h *serverHandshake) error(e alertType) {
                        for _ = range h.msgChan {
                        }
                }()
-               h.controlChan <- ConnectionState{false, "", e}
+               h.controlChan <- ConnectionState{false, "", e, ""}
                close(h.controlChan)
                h.writeChan <- alert{alertLevelError, e}
        }
index 716098530c3c2f759d115967d5772d7324d13457..a580b14e3c5a8cbf9762e2f08160beac573bbb56 100644 (file)
@@ -51,7 +51,7 @@ func testClientHelloFailure(t *testing.T, clientHello interface{}, expectedAlert
        send := script.NewEvent("send", nil, script.Send{msgChan, clientHello})
        recvAlert := script.NewEvent("recv alert", []*script.Event{send}, script.Recv{writeChan, alert{alertLevelError, expectedAlert}})
        close1 := script.NewEvent("msgChan close", []*script.Event{recvAlert}, script.Closed{writeChan})
-       recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert}})
+       recvState := script.NewEvent("recv state", []*script.Event{send}, script.Recv{controlChan, ConnectionState{false, "", expectedAlert, ""}})
        close2 := script.NewEvent("controlChan close", []*script.Event{recvState}, script.Closed{controlChan})
 
        err := script.Perform(0, []*script.Event{send, recvAlert, close1, recvState, close2})
@@ -78,13 +78,13 @@ func TestRejectBadProtocolVersion(t *testing.T) {
 }
 
 func TestNoSuiteOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}}
+       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{0xff00}, []uint8{0}, false, ""}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 
 }
 
 func TestNoCompressionOverlap(t *testing.T) {
-       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}}
+       clientHello := &clientHelloMsg{nil, 3, 1, nil, nil, []uint16{TLS_RSA_WITH_RC4_128_SHA}, []uint8{0xff}, false, ""}
        testClientHelloFailure(t, clientHello, alertHandshakeFailure)
 }
 
@@ -165,7 +165,7 @@ func TestFullHandshake(t *testing.T) {
        defer close(msgChan)
 
        // The values for this test were obtained from running `gnutls-cli --insecure --debug 9`
-       clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}}
+       clientHello := &clientHelloMsg{fromHex("0100007603024aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b310000340033004500390088001600320044003800870013006600900091008f008e002f004100350084000a00050004008c008d008b008a01000019000900030200010000000e000c0000093132372e302e302e31"), 3, 2, fromHex("4aef7d77e4686d5dfd9d953dfe280788759ffd440867d687670216da45516b31"), nil, []uint16{0x33, 0x45, 0x39, 0x88, 0x16, 0x32, 0x44, 0x38, 0x87, 0x13, 0x66, 0x90, 0x91, 0x8f, 0x8e, 0x2f, 0x41, 0x35, 0x84, 0xa, 0x5, 0x4, 0x8c, 0x8d, 0x8b, 0x8a}, []uint8{0x0}, false, ""}
 
        sendHello := script.NewEvent("send hello", nil, script.Send{msgChan, clientHello})
        setVersion := script.NewEvent("set version", []*script.Event{sendHello}, script.Recv{writeChan, writerSetVersion{3, 2}})
@@ -183,7 +183,7 @@ func TestFullHandshake(t *testing.T) {
        sendFinished := script.NewEvent("send finished", []*script.Event{recvNCS}, script.Send{msgChan, finished})
        recvFinished := script.NewEvent("recv finished", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchFinished})
        setCipher := script.NewEvent("set cipher", []*script.Event{sendFinished}, script.RecvMatch{writeChan, matchSetCipher})
-       recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0}})
+       recvConnectionState := script.NewEvent("recv state", []*script.Event{sendFinished}, script.Recv{controlChan, ConnectionState{true, "TLS_RSA_WITH_RC4_128_SHA", 0, ""}})
 
        err := script.Perform(0, []*script.Event{sendHello, setVersion, recvHello, recvCert, recvDone, sendCKX, sendCCS, recvNCS, sendFinished, setCipher, recvConnectionState, recvFinished})
        if err != nil {
index ddeca0e2b53ae6437f46f230e684cfb28fea2194..77470f04bc5d6ef7c6252566cd711dcd413ae2c5 100644 (file)
@@ -291,6 +291,8 @@ func parseHandshakeMsg(data []byte) (interface{}, bool) {
                m = new(serverHelloDoneMsg)
        case typeClientKeyExchange:
                m = new(clientKeyExchangeMsg)
+       case typeNextProtocol:
+               m = new(nextProtoMsg)
        default:
                return nil, false
        }
index 65ce3eba95ae5ee76d4cb8749b88eb67a46cada7..fe001a2f9a8396a84e5aebb1826754b5ba807dc0 100644 (file)
@@ -36,7 +36,7 @@ func TestNullConnectionState(t *testing.T) {
        // Test a simple request for the connection state.
        replyChan := make(chan ConnectionState)
        sendReq := script.NewEvent("send request", nil, script.Send{requestChan, getConnectionState{replyChan}})
-       getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0}})
+       getReply := script.NewEvent("get reply", []*script.Event{sendReq}, script.Recv{replyChan, ConnectionState{false, "", 0, ""}})
 
        err := script.Perform(0, []*script.Event{sendReq, getReply})
        if err != nil {
@@ -55,9 +55,9 @@ func TestWaitConnectionState(t *testing.T) {
        sendReq := script.NewEvent("send request", nil, script.Send{requestChan, waitConnectionState{replyChan}})
        replyChan2 := make(chan ConnectionState)
        sendReq2 := script.NewEvent("send request 2", []*script.Event{sendReq}, script.Send{requestChan, getConnectionState{replyChan2}})
-       getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0}})
-       sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1}})
-       getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1}})
+       getReply2 := script.NewEvent("get reply 2", []*script.Event{sendReq2}, script.Recv{replyChan2, ConnectionState{false, "", 0, ""}})
+       sendState := script.NewEvent("send state", []*script.Event{getReply2}, script.Send{controlChan, ConnectionState{true, "test", 1, ""}})
+       getReply := script.NewEvent("get reply", []*script.Event{sendState}, script.Recv{replyChan, ConnectionState{true, "test", 1, ""}})
 
        err := script.Perform(0, []*script.Event{sendReq, sendReq2, getReply2, sendState, getReply})
        if err != nil {
@@ -108,7 +108,7 @@ func TestApplicationData(t *testing.T) {
        // Test that the application data is forwarded after a successful Finished message.
        send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeHandshake, 0, 0, fromHex("1400000c000000000000000000000000")}})
        recv1 := script.NewEvent("recv finished", []*script.Event{send1}, script.Recv{handshakeChan, &finishedMsg{fromHex("1400000c000000000000000000000000"), fromHex("000000000000000000000000")}})
-       send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0}})
+       send2 := script.NewEvent("send connState", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{true, "", 0, ""}})
        send3 := script.NewEvent("send 2", []*script.Event{send2}, script.Send{recordChan, &record{recordTypeApplicationData, 0, 0, fromHex("0102")}})
        recv2 := script.NewEvent("recv data", []*script.Event{send3}, script.Recv{appDataChan, []byte{0x01, 0x02}})
 
@@ -126,7 +126,7 @@ func TestInvalidChangeCipherSpec(t *testing.T) {
 
        send1 := script.NewEvent("send 1", nil, script.Send{recordChan, &record{recordTypeChangeCipherSpec, 0, 0, []byte{1}}})
        recv1 := script.NewEvent("recv 1", []*script.Event{send1}, script.Recv{handshakeChan, changeCipherSpec{}})
-       send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42}})
+       send2 := script.NewEvent("send 2", []*script.Event{recv1}, script.Send{controlChan, ConnectionState{false, "", 42, ""}})
        close := script.NewEvent("close 1", []*script.Event{send2}, script.Closed{appDataChan})
        close2 := script.NewEvent("close 2", []*script.Event{send2}, script.Closed{handshakeChan})
 
index 0ddd884a4eb44691c2118422b7a059deb04b4526..682fde8b66cebc63078acbb342454534981d51e1 100644 (file)
@@ -21,7 +21,7 @@ func recordReader(c chan<- *record, source io.Reader) {
 
        for {
                var header [5]byte
-               n, _ := buf.Read(header[0:])
+               n, _ := buf.Read(&header)
                if n != 5 {
                        return
                }