]> Cypherpunks repositories - gostls13.git/commitdiff
crypto: allocate less.
authorAdam Langley <agl@golang.org>
Tue, 6 Dec 2011 23:25:14 +0000 (18:25 -0500)
committerAdam Langley <agl@golang.org>
Tue, 6 Dec 2011 23:25:14 +0000 (18:25 -0500)
The code in hash functions themselves could write directly into the
output buffer for a savings of about 50ns. But it's a little ugly so I
wasted a copy.

R=bradfitz
CC=golang-dev
https://golang.org/cl/5440111

12 files changed:
src/pkg/crypto/hmac/hmac.go
src/pkg/crypto/md5/md5.go
src/pkg/crypto/openpgp/s2k/s2k.go
src/pkg/crypto/ripemd160/ripemd160.go
src/pkg/crypto/rsa/rsa.go
src/pkg/crypto/sha1/sha1.go
src/pkg/crypto/sha256/sha256.go
src/pkg/crypto/sha512/sha512.go
src/pkg/crypto/tls/cipher_suites.go
src/pkg/crypto/tls/conn.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_server.go

index deaceafb260a39f4caf04a220dd551d64582cba7..6e7dd8762c814c0fd50ae774801d85163851fc83 100644 (file)
@@ -49,14 +49,13 @@ func (h *hmac) tmpPad(xor byte) {
 }
 
 func (h *hmac) Sum(in []byte) []byte {
-       sum := h.inner.Sum(nil)
+       origLen := len(in)
+       in = h.inner.Sum(in)
        h.tmpPad(0x5c)
-       for i, b := range sum {
-               h.tmp[padSize+i] = b
-       }
+       copy(h.tmp[padSize:], in[origLen:])
        h.outer.Reset()
        h.outer.Write(h.tmp)
-       return h.outer.Sum(in)
+       return h.outer.Sum(in[:origLen])
 }
 
 func (h *hmac) Write(p []byte) (n int, err error) {
index 182cfb8537077383e6bdbc200032ad6262870b28..f4e7b09ebf2c023aefaae417205fbbd945f07df5 100644 (file)
@@ -79,8 +79,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
 
 func (d0 *digest) Sum(in []byte) []byte {
        // Make a copy of d0 so that caller can keep writing and summing.
-       d := new(digest)
-       *d = *d0
+       d := *d0
 
        // Padding.  Add a 1 bit and 0 bits until 56 bytes mod 64.
        len := d.len
@@ -103,11 +102,13 @@ func (d0 *digest) Sum(in []byte) []byte {
                panic("d.nx != 0")
        }
 
-       for _, s := range d.s {
-               in = append(in, byte(s>>0))
-               in = append(in, byte(s>>8))
-               in = append(in, byte(s>>16))
-               in = append(in, byte(s>>24))
+       var digest [Size]byte
+       for i, s := range d.s {
+               digest[i*4] = byte(s)
+               digest[i*4+1] = byte(s >> 8)
+               digest[i*4+2] = byte(s >> 16)
+               digest[i*4+3] = byte(s >> 24)
        }
-       return in
+
+       return append(in, digest[:]...)
 }
index 83673e173353caaf30ff69bae031d947f3c833f5..8bc0bb320bb675d9e6f33e0009aa6bf469c7935e 100644 (file)
@@ -26,6 +26,7 @@ var zero [1]byte
 // 4880, section 3.7.1.2) using the given hash, input passphrase and salt.
 func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
        done := 0
+       var digest []byte
 
        for i := 0; done < len(out); i++ {
                h.Reset()
@@ -34,7 +35,8 @@ func Salted(out []byte, h hash.Hash, in []byte, salt []byte) {
                }
                h.Write(salt)
                h.Write(in)
-               n := copy(out[done:], h.Sum(nil))
+               digest = h.Sum(digest[:0])
+               n := copy(out[done:], digest)
                done += n
        }
 }
@@ -52,6 +54,7 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
        }
 
        done := 0
+       var digest []byte
        for i := 0; done < len(out); i++ {
                h.Reset()
                for j := 0; j < i; j++ {
@@ -68,7 +71,8 @@ func Iterated(out []byte, h hash.Hash, in []byte, salt []byte, count int) {
                                written += len(combined)
                        }
                }
-               n := copy(out[done:], h.Sum(nil))
+               digest = h.Sum(digest[:0])
+               n := copy(out[done:], digest)
                done += n
        }
 }
index c128ee445a5af95eb6369262bebe8e4a693ddb23..cd2cc39dbd170a2a0429164f7a9376cd9217ece3 100644 (file)
@@ -83,8 +83,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
 
 func (d0 *digest) Sum(in []byte) []byte {
        // Make a copy of d0 so that caller can keep writing and summing.
-       d := new(digest)
-       *d = *d0
+       d := *d0
 
        // Padding.  Add a 1 bit and 0 bits until 56 bytes mod 64.
        tc := d.tc
@@ -107,11 +106,13 @@ func (d0 *digest) Sum(in []byte) []byte {
                panic("d.nx != 0")
        }
 
-       for _, s := range d.s {
-               in = append(in, byte(s))
-               in = append(in, byte(s>>8))
-               in = append(in, byte(s>>16))
-               in = append(in, byte(s>>24))
+       var digest [Size]byte
+       for i, s := range d.s {
+               digest[i*4] = byte(s)
+               digest[i*4+1] = byte(s >> 8)
+               digest[i*4+2] = byte(s >> 16)
+               digest[i*4+3] = byte(s >> 24)
        }
-       return in
+
+       return append(in, digest[:]...)
 }
index f74525c103ad453e8e7537b727260862b8f55fcd..c07e8f90db7796f4bf5ae852ab5e9dff6882abf2 100644 (file)
@@ -189,12 +189,13 @@ func incCounter(c *[4]byte) {
 // specified in PKCS#1 v2.1.
 func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
        var counter [4]byte
+       var digest []byte
 
        done := 0
        for done < len(out) {
                hash.Write(seed)
                hash.Write(counter[0:4])
-               digest := hash.Sum(nil)
+               digest = hash.Sum(digest[:0])
                hash.Reset()
 
                for i := 0; i < len(digest) && done < len(out); i++ {
index f41cdb5b0279d8be461559889fc255e226ca247a..7bb68bbdbc8ab5d467436bb9d08216ce13d04f58 100644 (file)
@@ -81,8 +81,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
 
 func (d0 *digest) Sum(in []byte) []byte {
        // Make a copy of d0 so that caller can keep writing and summing.
-       d := new(digest)
-       *d = *d0
+       d := *d0
 
        // Padding.  Add a 1 bit and 0 bits until 56 bytes mod 64.
        len := d.len
@@ -105,11 +104,13 @@ func (d0 *digest) Sum(in []byte) []byte {
                panic("d.nx != 0")
        }
 
-       for _, s := range d.h {
-               in = append(in, byte(s>>24))
-               in = append(in, byte(s>>16))
-               in = append(in, byte(s>>8))
-               in = append(in, byte(s))
+       var digest [Size]byte
+       for i, s := range d.h {
+               digest[i*4] = byte(s >> 24)
+               digest[i*4+1] = byte(s >> 16)
+               digest[i*4+2] = byte(s >> 8)
+               digest[i*4+3] = byte(s)
        }
-       return in
+
+       return append(in, digest[:]...)
 }
index 34861f6cf49522a87de9f69b4df3b5a4b98f59a4..4525541a79caea0c61f53d811c1b6e57fbe84b38 100644 (file)
@@ -125,8 +125,7 @@ func (d *digest) Write(p []byte) (nn int, err error) {
 
 func (d0 *digest) Sum(in []byte) []byte {
        // Make a copy of d0 so that caller can keep writing and summing.
-       d := new(digest)
-       *d = *d0
+       d := *d0
 
        // Padding.  Add a 1 bit and 0 bits until 56 bytes mod 64.
        len := d.len
@@ -150,14 +149,19 @@ func (d0 *digest) Sum(in []byte) []byte {
        }
 
        h := d.h[:]
+       size := Size
        if d.is224 {
                h = d.h[:7]
+               size = Size224
        }
-       for _, s := range h {
-               in = append(in, byte(s>>24))
-               in = append(in, byte(s>>16))
-               in = append(in, byte(s>>8))
-               in = append(in, byte(s))
+
+       var digest [Size]byte
+       for i, s := range h {
+               digest[i*4] = byte(s >> 24)
+               digest[i*4+1] = byte(s >> 16)
+               digest[i*4+2] = byte(s >> 8)
+               digest[i*4+3] = byte(s)
        }
-       return in
+
+       return append(in, digest[:size]...)
 }
index 3cf65cbe7c825e9b76e42549975efa039a3ba682..927f28a28d84c2b9ac3a98f6dfe8590a2fa504c6 100644 (file)
@@ -150,18 +150,23 @@ func (d0 *digest) Sum(in []byte) []byte {
        }
 
        h := d.h[:]
+       size := Size
        if d.is384 {
                h = d.h[:6]
+               size = Size384
        }
-       for _, s := range h {
-               in = append(in, byte(s>>56))
-               in = append(in, byte(s>>48))
-               in = append(in, byte(s>>40))
-               in = append(in, byte(s>>32))
-               in = append(in, byte(s>>24))
-               in = append(in, byte(s>>16))
-               in = append(in, byte(s>>8))
-               in = append(in, byte(s))
+
+       var digest [Size]byte
+       for i, s := range h {
+               digest[i*8] = byte(s >> 56)
+               digest[i*8+1] = byte(s >> 48)
+               digest[i*8+2] = byte(s >> 40)
+               digest[i*8+3] = byte(s >> 32)
+               digest[i*8+4] = byte(s >> 24)
+               digest[i*8+5] = byte(s >> 16)
+               digest[i*8+6] = byte(s >> 8)
+               digest[i*8+7] = byte(s)
        }
-       return in
+
+       return append(in, digest[:size]...)
 }
index afe5129b5c8611c1edfd436ca7109ba21eef62c0..914491d6b4a84f4eb68272213b8b349e69efb6cc 100644 (file)
@@ -96,7 +96,7 @@ func macSHA1(version uint16, key []byte) macFunction {
 
 type macFunction interface {
        Size() int
-       MAC(seq, data []byte) []byte
+       MAC(digestBuf, seq, data []byte) []byte
 }
 
 // ssl30MAC implements the SSLv3 MAC function, as defined in
@@ -114,7 +114,7 @@ var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0
 
 var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c}
 
-func (s ssl30MAC) MAC(seq, record []byte) []byte {
+func (s ssl30MAC) MAC(digestBuf, seq, record []byte) []byte {
        padLength := 48
        if s.h.Size() == 20 {
                padLength = 40
@@ -127,13 +127,13 @@ func (s ssl30MAC) MAC(seq, record []byte) []byte {
        s.h.Write(record[:1])
        s.h.Write(record[3:5])
        s.h.Write(record[recordHeaderLen:])
-       digest := s.h.Sum(nil)
+       digestBuf = s.h.Sum(digestBuf[:0])
 
        s.h.Reset()
        s.h.Write(s.key)
        s.h.Write(ssl30Pad2[:padLength])
-       s.h.Write(digest)
-       return s.h.Sum(nil)
+       s.h.Write(digestBuf)
+       return s.h.Sum(digestBuf[:0])
 }
 
 // tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3.
@@ -145,11 +145,11 @@ func (s tls10MAC) Size() int {
        return s.h.Size()
 }
 
-func (s tls10MAC) MAC(seq, record []byte) []byte {
+func (s tls10MAC) MAC(digestBuf, seq, record []byte) []byte {
        s.h.Reset()
        s.h.Write(seq)
        s.h.Write(record)
-       return s.h.Sum(nil)
+       return s.h.Sum(digestBuf[:0])
 }
 
 func rsaKA() keyAgreement {
index b8fa2737f67e77a9c349ea191ac5cf8b75a922c7..6a03fa8042ae964979af5f8393ca69e635a2a98d 100644 (file)
@@ -118,6 +118,9 @@ type halfConn struct {
 
        nextCipher interface{} // next encryption state
        nextMac    macFunction // next MAC algorithm
+
+       // used to save allocating a new buffer for each MAC.
+       inDigestBuf, outDigestBuf []byte
 }
 
 // prepareCipherSpec sets the encryption and MAC states
@@ -280,12 +283,13 @@ func (hc *halfConn) decrypt(b *block) (bool, alert) {
                b.data[4] = byte(n)
                b.resize(recordHeaderLen + n)
                remoteMAC := payload[n:]
-               localMAC := hc.mac.MAC(hc.seq[0:], b.data)
+               localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data)
                hc.incSeq()
 
                if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
                        return false, alertBadRecordMAC
                }
+               hc.inDigestBuf = localMAC
        }
 
        return true, 0
@@ -312,12 +316,13 @@ func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
 func (hc *halfConn) encrypt(b *block) (bool, alert) {
        // mac
        if hc.mac != nil {
-               mac := hc.mac.MAC(hc.seq[0:], b.data)
+               mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data)
                hc.incSeq()
 
                n := len(b.data)
                b.resize(n + len(mac))
                copy(b.data[n:], mac)
+               hc.outDigestBuf = mac
        }
 
        payload := b.data[recordHeaderLen:]
index b4337f2aac61f915a72daee9f8fc9ab6411cb336..e39e59cd5a1d05505bb92f63b67da4e2ab600466 100644 (file)
@@ -231,10 +231,10 @@ func (c *Conn) clientHandshake() error {
 
        if cert != nil {
                certVerify := new(certificateVerifyMsg)
-               var digest [36]byte
-               copy(digest[0:16], finishedHash.serverMD5.Sum(nil))
-               copy(digest[16:36], finishedHash.serverSHA1.Sum(nil))
-               signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, crypto.MD5SHA1, digest[0:])
+               digest := make([]byte, 0, 36)
+               digest = finishedHash.serverMD5.Sum(digest)
+               digest = finishedHash.serverSHA1.Sum(digest)
+               signed, err := rsa.SignPKCS1v15(c.config.rand(), c.config.Certificates[0].PrivateKey, crypto.MD5SHA1, digest)
                if err != nil {
                        return c.sendAlert(alertInternalError)
                }
index bbb23c0c9f6dacbaf080ee65b27fbaa5ee36a5df..89c000dd6e9b04fed616564dc864843b842d9e92 100644 (file)
@@ -234,9 +234,9 @@ FindCipherSuite:
                        return c.sendAlert(alertUnexpectedMessage)
                }
 
-               digest := make([]byte, 36)
-               copy(digest[0:16], finishedHash.serverMD5.Sum(nil))
-               copy(digest[16:36], finishedHash.serverSHA1.Sum(nil))
+               digest := make([]byte, 0, 36)
+               digest = finishedHash.serverMD5.Sum(digest)
+               digest = finishedHash.serverSHA1.Sum(digest)
                err = rsa.VerifyPKCS1v15(pub, crypto.MD5SHA1, digest, certVerify.signature)
                if err != nil {
                        c.sendAlert(alertBadCertificate)