]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/hmac: make Sum idempotent
authorJukka-Pekka Kekkonen <karatepekka@gmail.com>
Thu, 26 Aug 2010 17:32:29 +0000 (13:32 -0400)
committerRuss Cox <rsc@golang.org>
Thu, 26 Aug 2010 17:32:29 +0000 (13:32 -0400)
Fixes #978.

R=rsc
CC=golang-dev
https://golang.org/cl/1967045

src/pkg/crypto/hmac/hmac.go
src/pkg/crypto/hmac/hmac_test.go
src/pkg/crypto/tls/handshake_client.go
src/pkg/crypto/tls/handshake_server.go
src/pkg/crypto/tls/prf.go

index 38d13738dedcd6b716145c74819ff85114dbe1aa..3b5aa138b37a20d3144cdbf4b85c142c8e251bba 100644 (file)
@@ -34,10 +34,9 @@ const (
 )
 
 type hmac struct {
-       size  int
-       key   []byte
-       tmp   []byte
-       inner hash.Hash
+       size         int
+       key, tmp     []byte
+       outer, inner hash.Hash
 }
 
 func (h *hmac) tmpPad(xor byte) {
@@ -50,14 +49,14 @@ func (h *hmac) tmpPad(xor byte) {
 }
 
 func (h *hmac) Sum() []byte {
-       h.tmpPad(0x5c)
        sum := h.inner.Sum()
+       h.tmpPad(0x5c)
        for i, b := range sum {
                h.tmp[padSize+i] = b
        }
-       h.inner.Reset()
-       h.inner.Write(h.tmp)
-       return h.inner.Sum()
+       h.outer.Reset()
+       h.outer.Write(h.tmp)
+       return h.outer.Sum()
 }
 
 func (h *hmac) Write(p []byte) (n int, err os.Error) {
@@ -72,27 +71,26 @@ func (h *hmac) Reset() {
        h.inner.Write(h.tmp[0:padSize])
 }
 
-// New returns a new HMAC hash using the given hash and key.
-func New(h hash.Hash, key []byte) hash.Hash {
+// New returns a new HMAC hash using the given hash generator and key.
+func New(h func() hash.Hash, key []byte) hash.Hash {
+       hm := new(hmac)
+       hm.outer = h()
+       hm.inner = h()
+       hm.size = hm.inner.Size()
+       hm.tmp = make([]byte, padSize+hm.size)
        if len(key) > padSize {
                // If key is too big, hash it.
-               h.Write(key)
-               key = h.Sum()
+               hm.outer.Write(key)
+               key = hm.outer.Sum()
        }
-       hm := new(hmac)
-       hm.inner = h
-       hm.size = h.Size()
        hm.key = make([]byte, len(key))
-       for i, k := range key {
-               hm.key[i] = k
-       }
-       hm.tmp = make([]byte, padSize+hm.size)
+       copy(hm.key, key)
        hm.Reset()
        return hm
 }
 
 // NewMD5 returns a new HMAC-MD5 hash using the given key.
-func NewMD5(key []byte) hash.Hash { return New(md5.New(), key) }
+func NewMD5(key []byte) hash.Hash { return New(md5.New, key) }
 
 // NewSHA1 returns a new HMAC-SHA1 hash using the given key.
-func NewSHA1(key []byte) hash.Hash { return New(sha1.New(), key) }
+func NewSHA1(key []byte) hash.Hash { return New(sha1.New, key) }
index d867c83a9634b89a7e42c6e7a2b013d74a94cfae..6934df236951edea17bcf269539cbb91c1ef4380 100644 (file)
@@ -84,9 +84,13 @@ func TestHMAC(t *testing.T) {
                                t.Errorf("test %d.%d: Write(%d) = %d, %v", i, j, len(tt.in), n, err)
                                continue
                        }
-                       sum := fmt.Sprintf("%x", h.Sum())
-                       if sum != tt.out {
-                               t.Errorf("test %d.%d: have %s want %s\n", i, j, sum, tt.out)
+
+                       // Repetive Sum() calls should return the same value
+                       for k := 0; k < 2; k++ {
+                               sum := fmt.Sprintf("%x", h.Sum())
+                               if sum != tt.out {
+                                       t.Errorf("test %d.%d.%d: have %s want %s\n", i, j, k, sum, tt.out)
+                               }
                        }
 
                        // Second iteration: make sure reset works.
index 4c4626ced89f1edbed7aad2ab475a2514433d517..c629920648e99c71219c4b7d440758079e4f1878 100644 (file)
@@ -8,7 +8,6 @@ import (
        "crypto/hmac"
        "crypto/rc4"
        "crypto/rsa"
-       "crypto/sha1"
        "crypto/subtle"
        "crypto/x509"
        "io"
@@ -226,7 +225,7 @@ func (c *Conn) clientHandshake() os.Error {
 
        cipher, _ := rc4.NewCipher(clientKey)
 
-       c.out.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+       c.out.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
        c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 
        finished := new(finishedMsg)
@@ -235,7 +234,7 @@ func (c *Conn) clientHandshake() os.Error {
        c.writeRecord(recordTypeHandshake, finished.marshal())
 
        cipher2, _ := rc4.NewCipher(serverKey)
-       c.in.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+       c.in.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
        c.readRecord(recordTypeChangeCipherSpec)
        if c.err != nil {
                return c.err
index 734c0fece1c44a73f2142c48d935e491aaaf25dd..118dd4352f31a16d706de986f1a4e9a3e2dab250 100644 (file)
@@ -16,7 +16,6 @@ import (
        "crypto/hmac"
        "crypto/rc4"
        "crypto/rsa"
-       "crypto/sha1"
        "crypto/subtle"
        "crypto/x509"
        "io"
@@ -227,7 +226,7 @@ func (c *Conn) serverHandshake() os.Error {
                keysFromPreMasterSecret11(preMasterSecret, clientHello.random, hello.random, suite.hashLength, suite.cipherKeyLength)
 
        cipher, _ := rc4.NewCipher(clientKey)
-       c.in.prepareCipherSpec(cipher, hmac.New(sha1.New(), clientMAC))
+       c.in.prepareCipherSpec(cipher, hmac.NewSHA1(clientMAC))
        c.readRecord(recordTypeChangeCipherSpec)
        if err := c.error(); err != nil {
                return err
@@ -264,7 +263,7 @@ func (c *Conn) serverHandshake() os.Error {
        finishedHash.Write(clientFinished.marshal())
 
        cipher2, _ := rc4.NewCipher(serverKey)
-       c.out.prepareCipherSpec(cipher2, hmac.New(sha1.New(), serverMAC))
+       c.out.prepareCipherSpec(cipher2, hmac.NewSHA1(serverMAC))
        c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 
        finished := new(finishedMsg)
index ee6cb780b4ebbefcc800254975759313984eb7a9..b206d26a4ab29a0f819507167ac4bb1472df0d49 100644 (file)
@@ -20,7 +20,7 @@ func splitPreMasterSecret(secret []byte) (s1, s2 []byte) {
 }
 
 // pHash implements the P_hash function, as defined in RFC 4346, section 5.
-func pHash(result, secret, seed []byte, hash hash.Hash) {
+func pHash(result, secret, seed []byte, hash func() hash.Hash) {
        h := hmac.New(hash, secret)
        h.Write(seed)
        a := h.Sum()
@@ -46,8 +46,8 @@ func pHash(result, secret, seed []byte, hash hash.Hash) {
 
 // pRF11 implements the TLS 1.1 pseudo-random function, as defined in RFC 4346, section 5.
 func pRF11(result, secret, label, seed []byte) {
-       hashSHA1 := sha1.New()
-       hashMD5 := md5.New()
+       hashSHA1 := sha1.New
+       hashMD5 := md5.New
 
        labelAndSeed := make([]byte, len(label)+len(seed))
        copy(labelAndSeed, label)