package tls
import (
+ "fmt"
+ "golang_org/x/crypto/cryptobyte"
"strings"
)
+// The marshalingFunction type is an adapter to allow the use of ordinary
+// functions as cryptobyte.MarshalingValue.
+type marshalingFunction func(b *cryptobyte.Builder) error
+
+func (f marshalingFunction) Marshal(b *cryptobyte.Builder) error {
+ return f(b)
+}
+
+// addBytesWithLength appends a sequence of bytes to the cryptobyte.Builder. If
+// the length of the sequence is not the value specified, it produces an error.
+func addBytesWithLength(b *cryptobyte.Builder, v []byte, n int) {
+ b.AddValue(marshalingFunction(func(b *cryptobyte.Builder) error {
+ if len(v) != n {
+ return fmt.Errorf("invalid value length: expected %d, got %d", n, len(v))
+ }
+ b.AddBytes(v)
+ return nil
+ }))
+}
+
+// readUint8LengthPrefixed acts like s.ReadUint8LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint8LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+ return s.ReadUint8LengthPrefixed((*cryptobyte.String)(out))
+}
+
+// readUint16LengthPrefixed acts like s.ReadUint16LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint16LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+ return s.ReadUint16LengthPrefixed((*cryptobyte.String)(out))
+}
+
+// readUint24LengthPrefixed acts like s.ReadUint24LengthPrefixed, but targets a
+// []byte instead of a cryptobyte.String.
+func readUint24LengthPrefixed(s *cryptobyte.String, out *[]byte) bool {
+ return s.ReadUint24LengthPrefixed((*cryptobyte.String)(out))
+}
+
type clientHelloMsg struct {
raw []byte
vers uint16
return m.raw
}
- length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods)
- numExtensions := 0
- extensionsLength := 0
- if m.nextProtoNeg {
- numExtensions++
- }
- if m.ocspStapling {
- extensionsLength += 1 + 2 + 2
- numExtensions++
- }
- if len(m.serverName) > 0 {
- extensionsLength += 5 + len(m.serverName)
- numExtensions++
- }
- if len(m.supportedCurves) > 0 {
- extensionsLength += 2 + 2*len(m.supportedCurves)
- numExtensions++
- }
- if len(m.supportedPoints) > 0 {
- extensionsLength += 1 + len(m.supportedPoints)
- numExtensions++
- }
- if m.ticketSupported {
- extensionsLength += len(m.sessionTicket)
- numExtensions++
- }
- if len(m.supportedSignatureAlgorithms) > 0 {
- extensionsLength += 2 + 2*len(m.supportedSignatureAlgorithms)
- numExtensions++
- }
- if m.secureRenegotiationSupported {
- extensionsLength += 1 + len(m.secureRenegotiation)
- numExtensions++
- }
- if len(m.alpnProtocols) > 0 {
- extensionsLength += 2
- for _, s := range m.alpnProtocols {
- if l := len(s); l == 0 || l > 255 {
- panic("invalid ALPN protocol")
+ var b cryptobyte.Builder
+ b.AddUint8(typeClientHello)
+ b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16(m.vers)
+ addBytesWithLength(b, m.random, 32)
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.sessionId)
+ })
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, suite := range m.cipherSuites {
+ b.AddUint16(suite)
+ }
+ })
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.compressionMethods)
+ })
+
+ // If extensions aren't present, omit them.
+ var extensionsPresent bool
+ bWithoutExtensions := *b
+
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ if m.nextProtoNeg {
+ // draft-agl-tls-nextprotoneg-04
+ b.AddUint16(extensionNextProtoNeg)
+ b.AddUint16(0) // empty extension_data
+ }
+ if len(m.serverName) > 0 {
+ // RFC 6066, Section 3
+ b.AddUint16(extensionServerName)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8(0) // name_type = host_name
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes([]byte(m.serverName))
+ })
+ })
+ })
+ }
+ if m.ocspStapling {
+ // RFC 4366, Section 3.6
+ b.AddUint16(extensionStatusRequest)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8(1) // status_type = ocsp
+ b.AddUint16(0) // empty responder_id_list
+ b.AddUint16(0) // empty request_extensions
+ })
+ }
+ if len(m.supportedCurves) > 0 {
+ // RFC 4492, Section 5.1.1
+ b.AddUint16(extensionSupportedCurves)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, curve := range m.supportedCurves {
+ b.AddUint16(uint16(curve))
+ }
+ })
+ })
+ }
+ if len(m.supportedPoints) > 0 {
+ // RFC 4492, Section 5.1.2
+ b.AddUint16(extensionSupportedPoints)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.supportedPoints)
+ })
+ })
+ }
+ if m.ticketSupported {
+ // RFC 5077, Section 3.2
+ b.AddUint16(extensionSessionTicket)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.sessionTicket)
+ })
+ }
+ if len(m.supportedSignatureAlgorithms) > 0 {
+ // RFC 5246, Section 7.4.1.4.1
+ b.AddUint16(extensionSignatureAlgorithms)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, sigAlgo := range m.supportedSignatureAlgorithms {
+ b.AddUint16(uint16(sigAlgo))
+ }
+ })
+ })
+ }
+ if m.secureRenegotiationSupported {
+ // RFC 5746, Section 3.2
+ b.AddUint16(extensionRenegotiationInfo)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocols) > 0 {
+ // RFC 7301, Section 3.1
+ b.AddUint16(extensionALPN)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, proto := range m.alpnProtocols {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes([]byte(proto))
+ })
+ }
+ })
+ })
+ }
+ if m.scts {
+ // RFC 6962, Section 3.3.1
+ b.AddUint16(extensionSCT)
+ b.AddUint16(0) // empty extension_data
}
- extensionsLength++
- extensionsLength += len(s)
- }
- numExtensions++
- }
- if m.scts {
- numExtensions++
- }
- if numExtensions > 0 {
- extensionsLength += 4 * numExtensions
- length += 2 + extensionsLength
- }
-
- x := make([]byte, 4+length)
- x[0] = typeClientHello
- x[1] = uint8(length >> 16)
- x[2] = uint8(length >> 8)
- x[3] = uint8(length)
- x[4] = uint8(m.vers >> 8)
- x[5] = uint8(m.vers)
- copy(x[6:38], m.random)
- x[38] = uint8(len(m.sessionId))
- copy(x[39:39+len(m.sessionId)], m.sessionId)
- y := x[39+len(m.sessionId):]
- y[0] = uint8(len(m.cipherSuites) >> 7)
- y[1] = uint8(len(m.cipherSuites) << 1)
- for i, suite := range m.cipherSuites {
- y[2+i*2] = uint8(suite >> 8)
- y[3+i*2] = uint8(suite)
- }
- z := y[2+len(m.cipherSuites)*2:]
- z[0] = uint8(len(m.compressionMethods))
- copy(z[1:], m.compressionMethods)
-
- z = z[1+len(m.compressionMethods):]
- if numExtensions > 0 {
- z[0] = byte(extensionsLength >> 8)
- z[1] = byte(extensionsLength)
- z = z[2:]
- }
- if m.nextProtoNeg {
- z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg & 0xff)
- // The length is always 0
- z = z[4:]
- }
- if len(m.serverName) > 0 {
- z[0] = byte(extensionServerName >> 8)
- z[1] = byte(extensionServerName & 0xff)
- l := len(m.serverName) + 5
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- z = z[4:]
-
- // RFC 3546, Section 3.1
- //
- // struct {
- // NameType name_type;
- // select (name_type) {
- // case host_name: HostName;
- // } name;
- // } ServerName;
- //
- // enum {
- // host_name(0), (255)
- // } NameType;
- //
- // opaque HostName<1..2^16-1>;
- //
- // struct {
- // ServerName server_name_list<1..2^16-1>
- // } ServerNameList;
-
- z[0] = byte((len(m.serverName) + 3) >> 8)
- z[1] = byte(len(m.serverName) + 3)
- z[3] = byte(len(m.serverName) >> 8)
- z[4] = byte(len(m.serverName))
- copy(z[5:], []byte(m.serverName))
- z = z[l:]
- }
- if m.ocspStapling {
- // RFC 4366, Section 3.6
- z[0] = byte(extensionStatusRequest >> 8)
- z[1] = byte(extensionStatusRequest)
- z[2] = 0
- z[3] = 5
- z[4] = 1 // OCSP type
- // Two zero valued uint16s for the two lengths.
- z = z[9:]
- }
- if len(m.supportedCurves) > 0 {
- // RFC 4492, Section 5.5.1
- z[0] = byte(extensionSupportedCurves >> 8)
- z[1] = byte(extensionSupportedCurves)
- l := 2 + 2*len(m.supportedCurves)
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- l -= 2
- z[4] = byte(l >> 8)
- z[5] = byte(l)
- z = z[6:]
- for _, curve := range m.supportedCurves {
- z[0] = byte(curve >> 8)
- z[1] = byte(curve)
- z = z[2:]
- }
- }
- if len(m.supportedPoints) > 0 {
- // RFC 4492, Section 5.5.2
- z[0] = byte(extensionSupportedPoints >> 8)
- z[1] = byte(extensionSupportedPoints)
- l := 1 + len(m.supportedPoints)
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- l--
- z[4] = byte(l)
- z = z[5:]
- for _, pointFormat := range m.supportedPoints {
- z[0] = pointFormat
- z = z[1:]
- }
- }
- if m.ticketSupported {
- // RFC 5077, Section 3.2
- z[0] = byte(extensionSessionTicket >> 8)
- z[1] = byte(extensionSessionTicket)
- l := len(m.sessionTicket)
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- z = z[4:]
- copy(z, m.sessionTicket)
- z = z[len(m.sessionTicket):]
- }
- if len(m.supportedSignatureAlgorithms) > 0 {
- // RFC 5246, Section 7.4.1.4.1
- z[0] = byte(extensionSignatureAlgorithms >> 8)
- z[1] = byte(extensionSignatureAlgorithms)
- l := 2 + 2*len(m.supportedSignatureAlgorithms)
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- z = z[4:]
-
- l -= 2
- z[0] = byte(l >> 8)
- z[1] = byte(l)
- z = z[2:]
- for _, sigAlgo := range m.supportedSignatureAlgorithms {
- z[0] = byte(sigAlgo >> 8)
- z[1] = byte(sigAlgo)
- z = z[2:]
- }
- }
- if m.secureRenegotiationSupported {
- z[0] = byte(extensionRenegotiationInfo >> 8)
- z[1] = byte(extensionRenegotiationInfo & 0xff)
- z[2] = 0
- z[3] = byte(len(m.secureRenegotiation) + 1)
- z[4] = byte(len(m.secureRenegotiation))
- z = z[5:]
- copy(z, m.secureRenegotiation)
- z = z[len(m.secureRenegotiation):]
- }
- if len(m.alpnProtocols) > 0 {
- z[0] = byte(extensionALPN >> 8)
- z[1] = byte(extensionALPN & 0xff)
- lengths := z[2:]
- z = z[6:]
-
- stringsLength := 0
- for _, s := range m.alpnProtocols {
- l := len(s)
- z[0] = byte(l)
- copy(z[1:], s)
- z = z[1+l:]
- stringsLength += 1 + l
- }
- lengths[2] = byte(stringsLength >> 8)
- lengths[3] = byte(stringsLength)
- stringsLength += 2
- lengths[0] = byte(stringsLength >> 8)
- lengths[1] = byte(stringsLength)
- }
- if m.scts {
- // RFC 6962, Section 3.3.1
- z[0] = byte(extensionSCT >> 8)
- z[1] = byte(extensionSCT)
- // zero uint16 for the zero-length extension_data
- z = z[4:]
- }
+ extensionsPresent = len(b.BytesOrPanic()) > 2
+ })
- m.raw = x
+ if !extensionsPresent {
+ *b = bWithoutExtensions
+ }
+ })
- return x
+ m.raw = b.BytesOrPanic()
+ return m.raw
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
- if len(data) < 42 {
- return false
- }
- m.raw = data
- m.vers = uint16(data[4])<<8 | uint16(data[5])
- m.random = data[6:38]
- sessionIdLen := int(data[38])
- if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
- return false
- }
- m.sessionId = data[39 : 39+sessionIdLen]
- data = data[39+sessionIdLen:]
- if len(data) < 2 {
+ *m = clientHelloMsg{raw: data}
+ s := cryptobyte.String(data)
+
+ if !s.Skip(4) || // message type and uint24 length field
+ !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
+ !readUint8LengthPrefixed(&s, &m.sessionId) {
return false
}
- // cipherSuiteLen is the number of bytes of cipher suite numbers. Since
- // they are uint16s, the number must be even.
- cipherSuiteLen := int(data[0])<<8 | int(data[1])
- if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen {
+
+ var cipherSuites cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&cipherSuites) {
return false
}
- numCipherSuites := cipherSuiteLen / 2
- m.cipherSuites = make([]uint16, numCipherSuites)
- for i := 0; i < numCipherSuites; i++ {
- m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i])
- if m.cipherSuites[i] == scsvRenegotiation {
+ m.cipherSuites = []uint16{}
+ m.secureRenegotiationSupported = false
+ for !cipherSuites.Empty() {
+ var suite uint16
+ if !cipherSuites.ReadUint16(&suite) {
+ return false
+ }
+ if suite == scsvRenegotiation {
m.secureRenegotiationSupported = true
}
+ m.cipherSuites = append(m.cipherSuites, suite)
}
- data = data[2+cipherSuiteLen:]
- if len(data) < 1 {
- return false
- }
- compressionMethodsLen := int(data[0])
- if len(data) < 1+compressionMethodsLen {
+
+ if !readUint8LengthPrefixed(&s, &m.compressionMethods) {
return false
}
- m.compressionMethods = data[1 : 1+compressionMethodsLen]
-
- data = data[1+compressionMethodsLen:]
-
- m.nextProtoNeg = false
- m.serverName = ""
- m.ocspStapling = false
- m.ticketSupported = false
- m.sessionTicket = nil
- m.supportedSignatureAlgorithms = nil
- m.alpnProtocols = nil
- m.scts = false
- if len(data) == 0 {
+ if s.Empty() {
// ClientHello is optionally followed by extension data
return true
}
- if len(data) < 2 {
- return false
- }
- extensionsLength := int(data[0])<<8 | int(data[1])
- data = data[2:]
- if extensionsLength != len(data) {
+ var extensions cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
- for len(data) != 0 {
- if len(data) < 4 {
- return false
- }
- extension := uint16(data[0])<<8 | uint16(data[1])
- length := int(data[2])<<8 | int(data[3])
- data = data[4:]
- if len(data) < length {
+ for !extensions.Empty() {
+ var extension uint16
+ var extData cryptobyte.String
+ if !extensions.ReadUint16(&extension) ||
+ !extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionServerName:
- d := data[:length]
- if len(d) < 2 {
+ // RFC 6066, Section 3
+ var nameList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
return false
}
- namesLen := int(d[0])<<8 | int(d[1])
- d = d[2:]
- if len(d) != namesLen {
- return false
- }
- for len(d) > 0 {
- if len(d) < 3 {
+ for !nameList.Empty() {
+ var nameType uint8
+ var serverName cryptobyte.String
+ if !nameList.ReadUint8(&nameType) ||
+ !nameList.ReadUint16LengthPrefixed(&serverName) ||
+ serverName.Empty() {
return false
}
- nameType := d[0]
- nameLen := int(d[1])<<8 | int(d[2])
- d = d[3:]
- if len(d) < nameLen {
+ if nameType != 0 {
+ continue
+ }
+ if len(m.serverName) != 0 {
+ // Multiple names of the same name_type are prohibited.
return false
}
- if nameType == 0 {
- m.serverName = string(d[:nameLen])
- // An SNI value may not include a trailing dot.
- // See RFC 6066, Section 3.
- if strings.HasSuffix(m.serverName, ".") {
- return false
- }
- break
+ m.serverName = string(serverName)
+ // An SNI value may not include a trailing dot.
+ if strings.HasSuffix(m.serverName, ".") {
+ return false
}
- d = d[nameLen:]
}
case extensionNextProtoNeg:
- if length > 0 {
- return false
- }
+ // draft-agl-tls-nextprotoneg-04
m.nextProtoNeg = true
case extensionStatusRequest:
- m.ocspStapling = length > 0 && data[0] == statusTypeOCSP
- case extensionSupportedCurves:
- // RFC 4492, Section 5.5.1
- if length < 2 {
+ // RFC 4366, Section 3.6
+ var statusType uint8
+ var ignored cryptobyte.String
+ if !extData.ReadUint8(&statusType) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) {
return false
}
- l := int(data[0])<<8 | int(data[1])
- if l%2 == 1 || length != l+2 {
+ m.ocspStapling = statusType == statusTypeOCSP
+ case extensionSupportedCurves:
+ // RFC 4492, Section 5.1.1
+ var curves cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&curves) || curves.Empty() {
return false
}
- numCurves := l / 2
- m.supportedCurves = make([]CurveID, numCurves)
- d := data[2:]
- for i := 0; i < numCurves; i++ {
- m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1])
- d = d[2:]
+ for !curves.Empty() {
+ var curve uint16
+ if !curves.ReadUint16(&curve) {
+ return false
+ }
+ m.supportedCurves = append(m.supportedCurves, CurveID(curve))
}
case extensionSupportedPoints:
- // RFC 4492, Section 5.5.2
- if length < 1 {
- return false
- }
- l := int(data[0])
- if length != l+1 {
+ // RFC 4492, Section 5.1.2
+ if !readUint8LengthPrefixed(&extData, &m.supportedPoints) ||
+ len(m.supportedPoints) == 0 {
return false
}
- m.supportedPoints = make([]uint8, l)
- copy(m.supportedPoints, data[1:])
case extensionSessionTicket:
// RFC 5077, Section 3.2
m.ticketSupported = true
- m.sessionTicket = data[:length]
+ extData.ReadBytes(&m.sessionTicket, len(extData))
case extensionSignatureAlgorithms:
// RFC 5246, Section 7.4.1.4.1
- if length < 2 || length&1 != 0 {
+ var sigAndAlgs cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
return false
}
- l := int(data[0])<<8 | int(data[1])
- if l != length-2 {
- return false
- }
- n := l / 2
- d := data[2:]
- m.supportedSignatureAlgorithms = make([]SignatureScheme, n)
- for i := range m.supportedSignatureAlgorithms {
- m.supportedSignatureAlgorithms[i] = SignatureScheme(d[0])<<8 | SignatureScheme(d[1])
- d = d[2:]
+ for !sigAndAlgs.Empty() {
+ var sigAndAlg uint16
+ if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+ return false
+ }
+ m.supportedSignatureAlgorithms = append(
+ m.supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
}
case extensionRenegotiationInfo:
- if length == 0 {
+ // RFC 5746, Section 3.2
+ if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
- d := data[:length]
- l := int(d[0])
- d = d[1:]
- if l != len(d) {
- return false
- }
-
- m.secureRenegotiation = d
m.secureRenegotiationSupported = true
case extensionALPN:
- if length < 2 {
+ // RFC 7301, Section 3.1
+ var protoList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
- l := int(data[0])<<8 | int(data[1])
- if l != length-2 {
- return false
- }
- d := data[2:length]
- for len(d) != 0 {
- stringLen := int(d[0])
- d = d[1:]
- if stringLen == 0 || stringLen > len(d) {
+ for !protoList.Empty() {
+ var proto cryptobyte.String
+ if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
return false
}
- m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen]))
- d = d[stringLen:]
+ m.alpnProtocols = append(m.alpnProtocols, string(proto))
}
case extensionSCT:
+ // RFC 6962, Section 3.3.1
m.scts = true
- if length != 0 {
- return false
- }
+ default:
+ // Ignore unknown extensions.
+ continue
+ }
+
+ if !extData.Empty() {
+ return false
}
- data = data[length:]
}
return true
return m.raw
}
- length := 38 + len(m.sessionId)
- numExtensions := 0
- extensionsLength := 0
-
- nextProtoLen := 0
- if m.nextProtoNeg {
- numExtensions++
- for _, v := range m.nextProtos {
- nextProtoLen += len(v)
- }
- nextProtoLen += len(m.nextProtos)
- extensionsLength += nextProtoLen
- }
- if m.ocspStapling {
- numExtensions++
- }
- if m.ticketSupported {
- numExtensions++
- }
- if m.secureRenegotiationSupported {
- extensionsLength += 1 + len(m.secureRenegotiation)
- numExtensions++
- }
- if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
- if alpnLen >= 256 {
- panic("invalid ALPN protocol")
- }
- extensionsLength += 2 + 1 + alpnLen
- numExtensions++
- }
- sctLen := 0
- if len(m.scts) > 0 {
- for _, sct := range m.scts {
- sctLen += len(sct) + 2
- }
- extensionsLength += 2 + sctLen
- numExtensions++
- }
+ var b cryptobyte.Builder
+ b.AddUint8(typeServerHello)
+ b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16(m.vers)
+ addBytesWithLength(b, m.random, 32)
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.sessionId)
+ })
+ b.AddUint16(m.cipherSuite)
+ b.AddUint8(m.compressionMethod)
+
+ // If extensions aren't present, omit them.
+ var extensionsPresent bool
+ bWithoutExtensions := *b
+
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ if m.nextProtoNeg {
+ b.AddUint16(extensionNextProtoNeg)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, proto := range m.nextProtos {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes([]byte(proto))
+ })
+ }
+ })
+ }
+ if m.ocspStapling {
+ b.AddUint16(extensionStatusRequest)
+ b.AddUint16(0) // empty extension_data
+ }
+ if m.ticketSupported {
+ b.AddUint16(extensionSessionTicket)
+ b.AddUint16(0) // empty extension_data
+ }
+ if m.secureRenegotiationSupported {
+ b.AddUint16(extensionRenegotiationInfo)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocol) > 0 {
+ b.AddUint16(extensionALPN)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes([]byte(m.alpnProtocol))
+ })
+ })
+ })
+ }
+ if len(m.scts) > 0 {
+ b.AddUint16(extensionSCT)
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ for _, sct := range m.scts {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(sct)
+ })
+ }
+ })
+ })
+ }
- if numExtensions > 0 {
- extensionsLength += 4 * numExtensions
- length += 2 + extensionsLength
- }
+ extensionsPresent = len(b.BytesOrPanic()) > 2
+ })
- x := make([]byte, 4+length)
- x[0] = typeServerHello
- x[1] = uint8(length >> 16)
- x[2] = uint8(length >> 8)
- x[3] = uint8(length)
- x[4] = uint8(m.vers >> 8)
- x[5] = uint8(m.vers)
- copy(x[6:38], m.random)
- x[38] = uint8(len(m.sessionId))
- copy(x[39:39+len(m.sessionId)], m.sessionId)
- z := x[39+len(m.sessionId):]
- z[0] = uint8(m.cipherSuite >> 8)
- z[1] = uint8(m.cipherSuite)
- z[2] = m.compressionMethod
-
- z = z[3:]
- if numExtensions > 0 {
- z[0] = byte(extensionsLength >> 8)
- z[1] = byte(extensionsLength)
- z = z[2:]
- }
- if m.nextProtoNeg {
- z[0] = byte(extensionNextProtoNeg >> 8)
- z[1] = byte(extensionNextProtoNeg & 0xff)
- z[2] = byte(nextProtoLen >> 8)
- z[3] = byte(nextProtoLen)
- z = z[4:]
-
- for _, v := range m.nextProtos {
- l := len(v)
- if l > 255 {
- l = 255
- }
- z[0] = byte(l)
- copy(z[1:], []byte(v[0:l]))
- z = z[1+l:]
+ if !extensionsPresent {
+ *b = bWithoutExtensions
}
- }
- if m.ocspStapling {
- z[0] = byte(extensionStatusRequest >> 8)
- z[1] = byte(extensionStatusRequest)
- z = z[4:]
- }
- if m.ticketSupported {
- z[0] = byte(extensionSessionTicket >> 8)
- z[1] = byte(extensionSessionTicket)
- z = z[4:]
- }
- if m.secureRenegotiationSupported {
- z[0] = byte(extensionRenegotiationInfo >> 8)
- z[1] = byte(extensionRenegotiationInfo & 0xff)
- z[2] = 0
- z[3] = byte(len(m.secureRenegotiation) + 1)
- z[4] = byte(len(m.secureRenegotiation))
- z = z[5:]
- copy(z, m.secureRenegotiation)
- z = z[len(m.secureRenegotiation):]
- }
- if alpnLen := len(m.alpnProtocol); alpnLen > 0 {
- z[0] = byte(extensionALPN >> 8)
- z[1] = byte(extensionALPN & 0xff)
- l := 2 + 1 + alpnLen
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- l -= 2
- z[4] = byte(l >> 8)
- z[5] = byte(l)
- l -= 1
- z[6] = byte(l)
- copy(z[7:], []byte(m.alpnProtocol))
- z = z[7+alpnLen:]
- }
- if sctLen > 0 {
- z[0] = byte(extensionSCT >> 8)
- z[1] = byte(extensionSCT)
- l := sctLen + 2
- z[2] = byte(l >> 8)
- z[3] = byte(l)
- z[4] = byte(sctLen >> 8)
- z[5] = byte(sctLen)
-
- z = z[6:]
- for _, sct := range m.scts {
- z[0] = byte(len(sct) >> 8)
- z[1] = byte(len(sct))
- copy(z[2:], sct)
- z = z[len(sct)+2:]
- }
- }
+ })
- m.raw = x
-
- return x
+ m.raw = b.BytesOrPanic()
+ return m.raw
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
- if len(data) < 42 {
+ *m = serverHelloMsg{raw: data}
+ s := cryptobyte.String(data)
+
+ if !s.Skip(4) || // message type and uint24 length field
+ !s.ReadUint16(&m.vers) || !s.ReadBytes(&m.random, 32) ||
+ !readUint8LengthPrefixed(&s, &m.sessionId) ||
+ !s.ReadUint16(&m.cipherSuite) ||
+ !s.ReadUint8(&m.compressionMethod) {
return false
}
- m.raw = data
- m.vers = uint16(data[4])<<8 | uint16(data[5])
- m.random = data[6:38]
- sessionIdLen := int(data[38])
- if sessionIdLen > 32 || len(data) < 39+sessionIdLen {
- return false
- }
- m.sessionId = data[39 : 39+sessionIdLen]
- data = data[39+sessionIdLen:]
- if len(data) < 3 {
- return false
- }
- m.cipherSuite = uint16(data[0])<<8 | uint16(data[1])
- m.compressionMethod = data[2]
- data = data[3:]
-
- m.nextProtoNeg = false
- m.nextProtos = nil
- m.ocspStapling = false
- m.scts = nil
- m.ticketSupported = false
- m.alpnProtocol = ""
-
- if len(data) == 0 {
+
+ if s.Empty() {
// ServerHello is optionally followed by extension data
return true
}
- if len(data) < 2 {
- return false
- }
- extensionsLength := int(data[0])<<8 | int(data[1])
- data = data[2:]
- if len(data) != extensionsLength {
+ var extensions cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&extensions) || !s.Empty() {
return false
}
- for len(data) != 0 {
- if len(data) < 4 {
- return false
- }
- extension := uint16(data[0])<<8 | uint16(data[1])
- length := int(data[2])<<8 | int(data[3])
- data = data[4:]
- if len(data) < length {
+ for !extensions.Empty() {
+ var extension uint16
+ var extData cryptobyte.String
+ if !extensions.ReadUint16(&extension) ||
+ !extensions.ReadUint16LengthPrefixed(&extData) {
return false
}
switch extension {
case extensionNextProtoNeg:
m.nextProtoNeg = true
- d := data[:length]
- for len(d) > 0 {
- l := int(d[0])
- d = d[1:]
- if l == 0 || l > len(d) {
+ for !extData.Empty() {
+ var proto cryptobyte.String
+ if !extData.ReadUint8LengthPrefixed(&proto) ||
+ proto.Empty() {
return false
}
- m.nextProtos = append(m.nextProtos, string(d[:l]))
- d = d[l:]
+ m.nextProtos = append(m.nextProtos, string(proto))
}
case extensionStatusRequest:
- if length > 0 {
- return false
- }
m.ocspStapling = true
case extensionSessionTicket:
- if length > 0 {
- return false
- }
m.ticketSupported = true
case extensionRenegotiationInfo:
- if length == 0 {
+ if !readUint8LengthPrefixed(&extData, &m.secureRenegotiation) {
return false
}
- d := data[:length]
- l := int(d[0])
- d = d[1:]
- if l != len(d) {
- return false
- }
-
- m.secureRenegotiation = d
m.secureRenegotiationSupported = true
case extensionALPN:
- d := data[:length]
- if len(d) < 3 {
+ var protoList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
return false
}
- l := int(d[0])<<8 | int(d[1])
- if l != len(d)-2 {
+ var proto cryptobyte.String
+ if !protoList.ReadUint8LengthPrefixed(&proto) ||
+ proto.Empty() || !protoList.Empty() {
return false
}
- d = d[2:]
- l = int(d[0])
- if l != len(d)-1 {
- return false
- }
- d = d[1:]
- if len(d) == 0 {
- // ALPN protocols must not be empty.
- return false
- }
- m.alpnProtocol = string(d)
+ m.alpnProtocol = string(proto)
case extensionSCT:
- d := data[:length]
-
- if len(d) < 2 {
- return false
- }
- l := int(d[0])<<8 | int(d[1])
- d = d[2:]
- if len(d) != l || l == 0 {
+ var sctList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&sctList) || sctList.Empty() {
return false
}
-
- m.scts = make([][]byte, 0, 3)
- for len(d) != 0 {
- if len(d) < 2 {
- return false
- }
- sctLen := int(d[0])<<8 | int(d[1])
- d = d[2:]
- if sctLen == 0 || len(d) < sctLen {
+ for !sctList.Empty() {
+ var sct []byte
+ if !readUint16LengthPrefixed(&sctList, &sct) ||
+ len(sct) == 0 {
return false
}
- m.scts = append(m.scts, d[:sctLen])
- d = d[sctLen:]
+ m.scts = append(m.scts, sct)
}
+ default:
+ // Ignore unknown extensions.
+ continue
+ }
+
+ if !extData.Empty() {
+ return false
}
- data = data[length:]
}
return true
verifyData []byte
}
-func (m *finishedMsg) marshal() (x []byte) {
+func (m *finishedMsg) marshal() []byte {
if m.raw != nil {
return m.raw
}
- x = make([]byte, 4+len(m.verifyData))
- x[0] = typeFinished
- x[3] = byte(len(m.verifyData))
- copy(x[4:], m.verifyData)
- m.raw = x
- return
+ var b cryptobyte.Builder
+ b.AddUint8(typeFinished)
+ b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.verifyData)
+ })
+
+ m.raw = b.BytesOrPanic()
+ return m.raw
}
func (m *finishedMsg) unmarshal(data []byte) bool {
m.raw = data
- if len(data) < 4 {
- return false
- }
- m.verifyData = data[4:]
- return true
+ s := cryptobyte.String(data)
+ return s.Skip(1) &&
+ readUint24LengthPrefixed(&s, &m.verifyData) &&
+ s.Empty()
}
type nextProtoMsg struct {
type certificateRequestMsg struct {
raw []byte
- // hasSignatureAndHash indicates whether this message includes a list
- // of signature and hash functions. This change was introduced with TLS
- // 1.2.
- hasSignatureAndHash bool
+ // hasSignatureAlgorithm indicates whether this message includes a list of
+ // supported signature algorithms. This change was introduced with TLS 1.2.
+ hasSignatureAlgorithm bool
certificateTypes []byte
supportedSignatureAlgorithms []SignatureScheme
}
length += casLength
- if m.hasSignatureAndHash {
+ if m.hasSignatureAlgorithm {
length += 2 + 2*len(m.supportedSignatureAlgorithms)
}
copy(x[5:], m.certificateTypes)
y := x[5+len(m.certificateTypes):]
- if m.hasSignatureAndHash {
+ if m.hasSignatureAlgorithm {
n := len(m.supportedSignatureAlgorithms) * 2
y[0] = uint8(n >> 8)
y[1] = uint8(n)
data = data[numCertTypes:]
- if m.hasSignatureAndHash {
+ if m.hasSignatureAlgorithm {
if len(data) < 2 {
return false
}
}
type certificateVerifyMsg struct {
- raw []byte
- hasSignatureAndHash bool
- signatureAlgorithm SignatureScheme
- signature []byte
+ raw []byte
+ hasSignatureAlgorithm bool // format change introduced in TLS 1.2
+ signatureAlgorithm SignatureScheme
+ signature []byte
}
func (m *certificateVerifyMsg) marshal() (x []byte) {
return m.raw
}
- // See RFC 4346, Section 7.4.8.
- siglength := len(m.signature)
- length := 2 + siglength
- if m.hasSignatureAndHash {
- length += 2
- }
- x = make([]byte, 4+length)
- x[0] = typeCertificateVerify
- x[1] = uint8(length >> 16)
- x[2] = uint8(length >> 8)
- x[3] = uint8(length)
- y := x[4:]
- if m.hasSignatureAndHash {
- y[0] = uint8(m.signatureAlgorithm >> 8)
- y[1] = uint8(m.signatureAlgorithm)
- y = y[2:]
- }
- y[0] = uint8(siglength >> 8)
- y[1] = uint8(siglength)
- copy(y[2:], m.signature)
-
- m.raw = x
+ var b cryptobyte.Builder
+ b.AddUint8(typeCertificateVerify)
+ b.AddUint24LengthPrefixed(func(b *cryptobyte.Builder) {
+ if m.hasSignatureAlgorithm {
+ b.AddUint16(uint16(m.signatureAlgorithm))
+ }
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(m.signature)
+ })
+ })
- return
+ m.raw = b.BytesOrPanic()
+ return m.raw
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
m.raw = data
+ s := cryptobyte.String(data)
- if len(data) < 6 {
- return false
- }
-
- length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3])
- if uint32(len(data))-4 != length {
- return false
- }
-
- data = data[4:]
- if m.hasSignatureAndHash {
- m.signatureAlgorithm = SignatureScheme(data[0])<<8 | SignatureScheme(data[1])
- data = data[2:]
- }
-
- if len(data) < 2 {
+ if !s.Skip(4) { // message type and uint24 length field
return false
}
- siglength := int(data[0])<<8 + int(data[1])
- data = data[2:]
- if len(data) != siglength {
- return false
+ if m.hasSignatureAlgorithm {
+ if !s.ReadUint16((*uint16)(&m.signatureAlgorithm)) {
+ return false
+ }
}
-
- m.signature = data
-
- return true
+ return readUint16LengthPrefixed(&s, &m.signature) && s.Empty()
}
type newSessionTicketMsg struct {