]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: fix handshake message test
authorRuss Cox <rsc@golang.org>
Mon, 14 Nov 2011 20:21:08 +0000 (15:21 -0500)
committerRuss Cox <rsc@golang.org>
Mon, 14 Nov 2011 20:21:08 +0000 (15:21 -0500)
This test breaks when I make reflect.DeepEqual
distinguish empty slices from nil slices.

R=agl
CC=golang-dev
https://golang.org/cl/5369110

src/pkg/crypto/tls/handshake_messages.go
src/pkg/crypto/tls/handshake_messages_test.go

index f11232d8ee57ef268bb96df0cac7c6cfe986f20d..5438e749ce8b8713a7955617a4444e3947835e83 100644 (file)
@@ -4,6 +4,8 @@
 
 package tls
 
+import "bytes"
+
 type clientHelloMsg struct {
        raw                []byte
        vers               uint16
@@ -18,6 +20,25 @@ type clientHelloMsg struct {
        supportedPoints    []uint8
 }
 
+func (m *clientHelloMsg) equal(i interface{}) bool {
+       m1, ok := i.(*clientHelloMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               m.vers == m1.vers &&
+               bytes.Equal(m.random, m1.random) &&
+               bytes.Equal(m.sessionId, m1.sessionId) &&
+               eqUint16s(m.cipherSuites, m1.cipherSuites) &&
+               bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
+               m.nextProtoNeg == m1.nextProtoNeg &&
+               m.serverName == m1.serverName &&
+               m.ocspStapling == m1.ocspStapling &&
+               eqUint16s(m.supportedCurves, m1.supportedCurves) &&
+               bytes.Equal(m.supportedPoints, m1.supportedPoints)
+}
+
 func (m *clientHelloMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -309,6 +330,23 @@ type serverHelloMsg struct {
        ocspStapling      bool
 }
 
+func (m *serverHelloMsg) equal(i interface{}) bool {
+       m1, ok := i.(*serverHelloMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               m.vers == m1.vers &&
+               bytes.Equal(m.random, m1.random) &&
+               bytes.Equal(m.sessionId, m1.sessionId) &&
+               m.cipherSuite == m1.cipherSuite &&
+               m.compressionMethod == m1.compressionMethod &&
+               m.nextProtoNeg == m1.nextProtoNeg &&
+               eqStrings(m.nextProtos, m1.nextProtos) &&
+               m.ocspStapling == m1.ocspStapling
+}
+
 func (m *serverHelloMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -463,6 +501,16 @@ type certificateMsg struct {
        certificates [][]byte
 }
 
+func (m *certificateMsg) equal(i interface{}) bool {
+       m1, ok := i.(*certificateMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               eqByteSlices(m.certificates, m1.certificates)
+}
+
 func (m *certificateMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -540,6 +588,16 @@ type serverKeyExchangeMsg struct {
        key []byte
 }
 
+func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
+       m1, ok := i.(*serverKeyExchangeMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               bytes.Equal(m.key, m1.key)
+}
+
 func (m *serverKeyExchangeMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -571,6 +629,17 @@ type certificateStatusMsg struct {
        response   []byte
 }
 
+func (m *certificateStatusMsg) equal(i interface{}) bool {
+       m1, ok := i.(*certificateStatusMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               m.statusType == m1.statusType &&
+               bytes.Equal(m.response, m1.response)
+}
+
 func (m *certificateStatusMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -622,6 +691,11 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
 
 type serverHelloDoneMsg struct{}
 
+func (m *serverHelloDoneMsg) equal(i interface{}) bool {
+       _, ok := i.(*serverHelloDoneMsg)
+       return ok
+}
+
 func (m *serverHelloDoneMsg) marshal() []byte {
        x := make([]byte, 4)
        x[0] = typeServerHelloDone
@@ -637,6 +711,16 @@ type clientKeyExchangeMsg struct {
        ciphertext []byte
 }
 
+func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
+       m1, ok := i.(*clientKeyExchangeMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               bytes.Equal(m.ciphertext, m1.ciphertext)
+}
+
 func (m *clientKeyExchangeMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -671,6 +755,16 @@ type finishedMsg struct {
        verifyData []byte
 }
 
+func (m *finishedMsg) equal(i interface{}) bool {
+       m1, ok := i.(*finishedMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               bytes.Equal(m.verifyData, m1.verifyData)
+}
+
 func (m *finishedMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -698,6 +792,16 @@ type nextProtoMsg struct {
        proto string
 }
 
+func (m *nextProtoMsg) equal(i interface{}) bool {
+       m1, ok := i.(*nextProtoMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               m.proto == m1.proto
+}
+
 func (m *nextProtoMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -759,6 +863,17 @@ type certificateRequestMsg struct {
        certificateAuthorities [][]byte
 }
 
+func (m *certificateRequestMsg) equal(i interface{}) bool {
+       m1, ok := i.(*certificateRequestMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
+               eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities)
+}
+
 func (m *certificateRequestMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -859,6 +974,16 @@ type certificateVerifyMsg struct {
        signature []byte
 }
 
+func (m *certificateVerifyMsg) equal(i interface{}) bool {
+       m1, ok := i.(*certificateVerifyMsg)
+       if !ok {
+               return false
+       }
+
+       return bytes.Equal(m.raw, m1.raw) &&
+               bytes.Equal(m.signature, m1.signature)
+}
+
 func (m *certificateVerifyMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -902,3 +1027,39 @@ func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
 
        return true
 }
+
+func eqUint16s(x, y []uint16) bool {
+       if len(x) != len(y) {
+               return false
+       }
+       for i, v := range x {
+               if y[i] != v {
+                       return false
+               }
+       }
+       return true
+}
+
+func eqStrings(x, y []string) bool {
+       if len(x) != len(y) {
+               return false
+       }
+       for i, v := range x {
+               if y[i] != v {
+                       return false
+               }
+       }
+       return true
+}
+
+func eqByteSlices(x, y [][]byte) bool {
+       if len(x) != len(y) {
+               return false
+       }
+       for i, v := range x {
+               if !bytes.Equal(v, y[i]) {
+                       return false
+               }
+       }
+       return true
+}
index 87e8f7e428d90e24eb6932fa620233b278bf8b05..e62a9d581b353b438d8289d8d4ce03bd375facad 100644 (file)
@@ -27,10 +27,12 @@ var tests = []interface{}{
 type testMessage interface {
        marshal() []byte
        unmarshal([]byte) bool
+       equal(interface{}) bool
 }
 
 func TestMarshalUnmarshal(t *testing.T) {
        rand := rand.New(rand.NewSource(0))
+
        for i, iface := range tests {
                ty := reflect.ValueOf(iface).Type()
 
@@ -54,7 +56,7 @@ func TestMarshalUnmarshal(t *testing.T) {
                        }
                        m2.marshal() // to fill any marshal cache in the message
 
-                       if !reflect.DeepEqual(m1, m2) {
+                       if !m1.equal(m2) {
                                t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
                                break
                        }