]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/tls: replace custom equal implementations with reflect.DeepEqual
authorFilippo Valsorda <filippo@golang.org>
Wed, 24 Oct 2018 17:16:04 +0000 (13:16 -0400)
committerFilippo Valsorda <filippo@golang.org>
Thu, 25 Oct 2018 19:07:36 +0000 (19:07 +0000)
The equal methods were only there for testing, and I remember regularly
getting them wrong while developing tls-tris. Replace them with simple
reflect.DeepEqual calls.

The only special thing that equal() would do is ignore the difference
between a nil and a zero-length slice. Fixed the Generate methods so
that they create the same value that unmarshal will decode. The
difference is not important: it wasn't tested, all checks are
"len(slice) > 0", and all cases in which presence matters are
accompanied by a boolean.

Change-Id: Iaabf56ea17c2406b5107c808c32f6c85b611aaa8
Reviewed-on: https://go-review.googlesource.com/c/144114
Run-TryBot: Filippo Valsorda <filippo@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Adam Langley <agl@golang.org>
src/crypto/tls/handshake_messages.go
src/crypto/tls/handshake_messages_test.go
src/crypto/tls/ticket.go

index 27004b2d698f937ba0c94789b3f1de4677b1edd7..c5d995060731da52f51ede9958d0cc87a0558586 100644 (file)
@@ -5,7 +5,6 @@
 package tls
 
 import (
-       "bytes"
        "strings"
 )
 
@@ -30,32 +29,6 @@ type clientHelloMsg struct {
        alpnProtocols                []string
 }
 
-func (m *clientHelloMsg) equal(i interface{}) bool {
-       m1, ok := i.(*clientHelloMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               m.vers == m1.vers &&
-               bytes.Equal(m.random, m1.random) &&
-               bytes.Equal(m.sessionId, m1.sessionId) &&
-               eqUint16s(m.cipherSuites, m1.cipherSuites) &&
-               bytes.Equal(m.compressionMethods, m1.compressionMethods) &&
-               m.nextProtoNeg == m1.nextProtoNeg &&
-               m.serverName == m1.serverName &&
-               m.ocspStapling == m1.ocspStapling &&
-               m.scts == m1.scts &&
-               eqCurveIDs(m.supportedCurves, m1.supportedCurves) &&
-               bytes.Equal(m.supportedPoints, m1.supportedPoints) &&
-               m.ticketSupported == m1.ticketSupported &&
-               bytes.Equal(m.sessionTicket, m1.sessionTicket) &&
-               eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms) &&
-               m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
-               bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
-               eqStrings(m.alpnProtocols, m1.alpnProtocols)
-}
-
 func (m *clientHelloMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -519,36 +492,6 @@ type serverHelloMsg struct {
        alpnProtocol                 string
 }
 
-func (m *serverHelloMsg) equal(i interface{}) bool {
-       m1, ok := i.(*serverHelloMsg)
-       if !ok {
-               return false
-       }
-
-       if len(m.scts) != len(m1.scts) {
-               return false
-       }
-       for i, sct := range m.scts {
-               if !bytes.Equal(sct, m1.scts[i]) {
-                       return false
-               }
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               m.vers == m1.vers &&
-               bytes.Equal(m.random, m1.random) &&
-               bytes.Equal(m.sessionId, m1.sessionId) &&
-               m.cipherSuite == m1.cipherSuite &&
-               m.compressionMethod == m1.compressionMethod &&
-               m.nextProtoNeg == m1.nextProtoNeg &&
-               eqStrings(m.nextProtos, m1.nextProtos) &&
-               m.ocspStapling == m1.ocspStapling &&
-               m.ticketSupported == m1.ticketSupported &&
-               m.secureRenegotiationSupported == m1.secureRenegotiationSupported &&
-               bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) &&
-               m.alpnProtocol == m1.alpnProtocol
-}
-
 func (m *serverHelloMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -838,16 +781,6 @@ type certificateMsg struct {
        certificates [][]byte
 }
 
-func (m *certificateMsg) equal(i interface{}) bool {
-       m1, ok := i.(*certificateMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               eqByteSlices(m.certificates, m1.certificates)
-}
-
 func (m *certificateMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -925,16 +858,6 @@ type serverKeyExchangeMsg struct {
        key []byte
 }
 
-func (m *serverKeyExchangeMsg) equal(i interface{}) bool {
-       m1, ok := i.(*serverKeyExchangeMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               bytes.Equal(m.key, m1.key)
-}
-
 func (m *serverKeyExchangeMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -966,17 +889,6 @@ type certificateStatusMsg struct {
        response   []byte
 }
 
-func (m *certificateStatusMsg) equal(i interface{}) bool {
-       m1, ok := i.(*certificateStatusMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               m.statusType == m1.statusType &&
-               bytes.Equal(m.response, m1.response)
-}
-
 func (m *certificateStatusMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -1028,11 +940,6 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
 
 type serverHelloDoneMsg struct{}
 
-func (m *serverHelloDoneMsg) equal(i interface{}) bool {
-       _, ok := i.(*serverHelloDoneMsg)
-       return ok
-}
-
 func (m *serverHelloDoneMsg) marshal() []byte {
        x := make([]byte, 4)
        x[0] = typeServerHelloDone
@@ -1048,16 +955,6 @@ type clientKeyExchangeMsg struct {
        ciphertext []byte
 }
 
-func (m *clientKeyExchangeMsg) equal(i interface{}) bool {
-       m1, ok := i.(*clientKeyExchangeMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               bytes.Equal(m.ciphertext, m1.ciphertext)
-}
-
 func (m *clientKeyExchangeMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -1092,16 +989,6 @@ type finishedMsg struct {
        verifyData []byte
 }
 
-func (m *finishedMsg) equal(i interface{}) bool {
-       m1, ok := i.(*finishedMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               bytes.Equal(m.verifyData, m1.verifyData)
-}
-
 func (m *finishedMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -1129,16 +1016,6 @@ type nextProtoMsg struct {
        proto string
 }
 
-func (m *nextProtoMsg) equal(i interface{}) bool {
-       m1, ok := i.(*nextProtoMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               m.proto == m1.proto
-}
-
 func (m *nextProtoMsg) marshal() []byte {
        if m.raw != nil {
                return m.raw
@@ -1206,18 +1083,6 @@ type certificateRequestMsg struct {
        certificateAuthorities       [][]byte
 }
 
-func (m *certificateRequestMsg) equal(i interface{}) bool {
-       m1, ok := i.(*certificateRequestMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               bytes.Equal(m.certificateTypes, m1.certificateTypes) &&
-               eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) &&
-               eqSignatureAlgorithms(m.supportedSignatureAlgorithms, m1.supportedSignatureAlgorithms)
-}
-
 func (m *certificateRequestMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -1356,18 +1221,6 @@ type certificateVerifyMsg struct {
        signature           []byte
 }
 
-func (m *certificateVerifyMsg) equal(i interface{}) bool {
-       m1, ok := i.(*certificateVerifyMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               m.hasSignatureAndHash == m1.hasSignatureAndHash &&
-               m.signatureAlgorithm == m1.signatureAlgorithm &&
-               bytes.Equal(m.signature, m1.signature)
-}
-
 func (m *certificateVerifyMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -1436,16 +1289,6 @@ type newSessionTicketMsg struct {
        ticket []byte
 }
 
-func (m *newSessionTicketMsg) equal(i interface{}) bool {
-       m1, ok := i.(*newSessionTicketMsg)
-       if !ok {
-               return false
-       }
-
-       return bytes.Equal(m.raw, m1.raw) &&
-               bytes.Equal(m.ticket, m1.ticket)
-}
-
 func (m *newSessionTicketMsg) marshal() (x []byte) {
        if m.raw != nil {
                return m.raw
@@ -1500,63 +1343,3 @@ func (*helloRequestMsg) marshal() []byte {
 func (*helloRequestMsg) unmarshal(data []byte) bool {
        return len(data) == 4
 }
-
-func eqUint16s(x, y []uint16) bool {
-       if len(x) != len(y) {
-               return false
-       }
-       for i, v := range x {
-               if y[i] != v {
-                       return false
-               }
-       }
-       return true
-}
-
-func eqCurveIDs(x, y []CurveID) bool {
-       if len(x) != len(y) {
-               return false
-       }
-       for i, v := range x {
-               if y[i] != v {
-                       return false
-               }
-       }
-       return true
-}
-
-func eqStrings(x, y []string) bool {
-       if len(x) != len(y) {
-               return false
-       }
-       for i, v := range x {
-               if y[i] != v {
-                       return false
-               }
-       }
-       return true
-}
-
-func eqByteSlices(x, y [][]byte) bool {
-       if len(x) != len(y) {
-               return false
-       }
-       for i, v := range x {
-               if !bytes.Equal(v, y[i]) {
-                       return false
-               }
-       }
-       return true
-}
-
-func eqSignatureAlgorithms(x, y []SignatureScheme) bool {
-       if len(x) != len(y) {
-               return false
-       }
-       for i, v := range x {
-               if v != y[i] {
-                       return false
-               }
-       }
-       return true
-}
index 52c5d30e8fa3c222c924e060e7d6e455e44ba1a7..c8cc0d6c5a0f69c39300f5a5c053bf3a33e814bc 100644 (file)
@@ -28,12 +28,6 @@ var tests = []interface{}{
        &sessionState{},
 }
 
-type testMessage interface {
-       marshal() []byte
-       unmarshal([]byte) bool
-       equal(interface{}) bool
-}
-
 func TestMarshalUnmarshal(t *testing.T) {
        rand := rand.New(rand.NewSource(0))
 
@@ -51,16 +45,16 @@ func TestMarshalUnmarshal(t *testing.T) {
                                break
                        }
 
-                       m1 := v.Interface().(testMessage)
+                       m1 := v.Interface().(handshakeMessage)
                        marshaled := m1.marshal()
-                       m2 := iface.(testMessage)
+                       m2 := iface.(handshakeMessage)
                        if !m2.unmarshal(marshaled) {
                                t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
                                break
                        }
                        m2.marshal() // to fill any marshal cache in the message
 
-                       if !m1.equal(m2) {
+                       if !reflect.DeepEqual(m1, m2) {
                                t.Errorf("#%d got:%#v want:%#v %x", i, m2, m1, marshaled)
                                break
                        }
@@ -85,7 +79,7 @@ func TestMarshalUnmarshal(t *testing.T) {
 func TestFuzz(t *testing.T) {
        rand := rand.New(rand.NewSource(0))
        for _, iface := range tests {
-               m := iface.(testMessage)
+               m := iface.(handshakeMessage)
 
                for j := 0; j < 1000; j++ {
                        len := rand.Intn(100)
@@ -142,14 +136,15 @@ func (*clientHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
                m.ticketSupported = true
                if rand.Intn(10) > 5 {
                        m.sessionTicket = randomBytes(rand.Intn(300), rand)
+               } else {
+                       m.sessionTicket = make([]byte, 0)
                }
        }
        if rand.Intn(10) > 5 {
                m.supportedSignatureAlgorithms = supportedSignatureAlgorithms
        }
-       m.alpnProtocols = make([]string, rand.Intn(5))
-       for i := range m.alpnProtocols {
-               m.alpnProtocols[i] = randomString(rand.Intn(20)+1, rand)
+       for i := 0; i < rand.Intn(5); i++ {
+               m.alpnProtocols = append(m.alpnProtocols, randomString(rand.Intn(20)+1, rand))
        }
        if rand.Intn(10) > 5 {
                m.scts = true
@@ -168,11 +163,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
 
        if rand.Intn(10) > 5 {
                m.nextProtoNeg = true
-
-               n := rand.Intn(10)
-               m.nextProtos = make([]string, n)
-               for i := 0; i < n; i++ {
-                       m.nextProtos[i] = randomString(20, rand)
+               for i := 0; i < rand.Intn(10); i++ {
+                       m.nextProtos = append(m.nextProtos, randomString(20, rand))
                }
        }
 
@@ -184,12 +176,8 @@ func (*serverHelloMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        }
        m.alpnProtocol = randomString(rand.Intn(32)+1, rand)
 
-       if rand.Intn(10) > 5 {
-               numSCTs := rand.Intn(4)
-               m.scts = make([][]byte, numSCTs)
-               for i := range m.scts {
-                       m.scts[i] = randomBytes(rand.Intn(500)+1, rand)
-               }
+       for i := 0; i < rand.Intn(4); i++ {
+               m.scts = append(m.scts, randomBytes(rand.Intn(500)+1, rand))
        }
 
        return reflect.ValueOf(m)
@@ -208,10 +196,8 @@ func (*certificateMsg) Generate(rand *rand.Rand, size int) reflect.Value {
 func (*certificateRequestMsg) Generate(rand *rand.Rand, size int) reflect.Value {
        m := &certificateRequestMsg{}
        m.certificateTypes = randomBytes(rand.Intn(5)+1, rand)
-       numCAs := rand.Intn(100)
-       m.certificateAuthorities = make([][]byte, numCAs)
-       for i := 0; i < numCAs; i++ {
-               m.certificateAuthorities[i] = randomBytes(rand.Intn(15)+1, rand)
+       for i := 0; i < rand.Intn(100); i++ {
+               m.certificateAuthorities = append(m.certificateAuthorities, randomBytes(rand.Intn(15)+1, rand))
        }
        return reflect.ValueOf(m)
 }
index 3e7aa93c82a7f1eb0605b02967ad1eda22fa926d..c1077e5ab2368963e17a927b5cc0eab56b0054cc 100644 (file)
@@ -27,31 +27,6 @@ type sessionState struct {
        usedOldKey bool
 }
 
-func (s *sessionState) equal(i interface{}) bool {
-       s1, ok := i.(*sessionState)
-       if !ok {
-               return false
-       }
-
-       if s.vers != s1.vers ||
-               s.cipherSuite != s1.cipherSuite ||
-               !bytes.Equal(s.masterSecret, s1.masterSecret) {
-               return false
-       }
-
-       if len(s.certificates) != len(s1.certificates) {
-               return false
-       }
-
-       for i := range s.certificates {
-               if !bytes.Equal(s.certificates[i], s1.certificates[i]) {
-                       return false
-               }
-       }
-
-       return true
-}
-
 func (s *sessionState) marshal() []byte {
        length := 2 + 2 + 2 + len(s.masterSecret) + 2
        for _, cert := range s.certificates {