]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/{aes,cipher}: add optimized implementation of AES-GCM for s390x
authorMichael Munday <munday@ca.ibm.com>
Fri, 30 Sep 2016 18:20:42 +0000 (14:20 -0400)
committerMichael Munday <munday@ca.ibm.com>
Wed, 5 Oct 2016 15:37:53 +0000 (15:37 +0000)
Also adds two tests: one to exercise the counter incrementing code
and one which checks the output of the optimized implementation
against that of the generic implementation for large/unaligned data
sizes.

Uses the KIMD instruction for GHASH and the KMCTR instruction for AES
in counter mode.

AESGCMSeal1K  75.0MB/s ± 2%  1008.7MB/s ± 1%  +1245.71%  (p=0.000 n=10+10)
AESGCMOpen1K  75.3MB/s ± 1%  1006.0MB/s ± 1%  +1235.59%   (p=0.000 n=10+9)
AESGCMSeal8K  78.5MB/s ± 1%  1748.4MB/s ± 1%  +2127.34%   (p=0.000 n=9+10)
AESGCMOpen8K  78.5MB/s ± 0%  1752.7MB/s ± 0%  +2134.07%   (p=0.000 n=10+9)

Change-Id: I88dbcfcb5988104bfd290ae15a60a2721c1338be
Reviewed-on: https://go-review.googlesource.com/30361
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/crypto/aes/asm_s390x.s
src/crypto/aes/gcm_s390x.go [new file with mode: 0644]
src/crypto/cipher/gcm_test.go

index e31415a4204d7b6fcee5ccf33cf5ac4fc27d3c77..5714aee3187ce6817b459e9562222f0b3844d9b9 100644 (file)
@@ -22,6 +22,19 @@ TEXT ·hasAsm(SB),NOSPLIT,$16-1
        AND     R3, R2
        CMPBNE  R2, R3, notfound
 
+       // check for KMCTR AES functions
+       WORD    $0xB92D4024 // cipher message with counter (KMCTR)
+       MOVD    mask-16(SP), R2
+       AND     R3, R2
+       CMPBNE  R2, R3, notfound
+
+       // check for KIMD GHASH function
+       WORD    $0xB93E0024    // compute intermediate message digest (KIMD)
+       MOVD    mask-8(SP), R2 // bits 64-127
+       MOVD    $(1<<62), R5
+       AND     R5, R2
+       CMPBNE  R2, R5, notfound
+
        MOVB    $1, ret+0(FP)
        RET
 notfound:
@@ -91,3 +104,86 @@ tail:
        BR      tail
 done:
        RET
+
+// func cryptBlocksGCM(fn code, key, dst, src, buf []byte, cnt *[16]byte)
+TEXT ·cryptBlocksGCM(SB),NOSPLIT,$0-112
+       MOVD    src_len+64(FP), R0
+       MOVD    buf_base+80(FP), R1
+       MOVD    cnt+104(FP), R12
+       LMG     (R12), R2, R3
+
+       // Check that the src size is less than or equal to the buffer size.
+       MOVD    buf_len+88(FP), R4
+       CMP     R0, R4
+       BGT     crash
+
+       // Check that the src size is a multiple of 16-bytes.
+       MOVD    R0, R4
+       AND     $0xf, R4
+       BLT     crash // non-zero
+
+       // Check that the src size is less than or equal to the dst size.
+       MOVD    dst_len+40(FP), R4
+       CMP     R0, R4
+       BGT     crash
+
+       MOVD    R2, R4
+       MOVD    R2, R6
+       MOVD    R2, R8
+       MOVD    R3, R5
+       MOVD    R3, R7
+       MOVD    R3, R9
+       ADDW    $1, R5
+       ADDW    $2, R7
+       ADDW    $3, R9
+incr:
+       CMP     R0, $64
+       BLT     tail
+       STMG    R2, R9, (R1)
+       ADDW    $4, R3
+       ADDW    $4, R5
+       ADDW    $4, R7
+       ADDW    $4, R9
+       MOVD    $64(R1), R1
+       SUB     $64, R0
+       BR      incr
+tail:
+       CMP     R0, $0
+       BEQ     crypt
+       STMG    R2, R3, (R1)
+       ADDW    $1, R3
+       MOVD    $16(R1), R1
+       SUB     $16, R0
+       BR      tail
+crypt:
+       STMG    R2, R3, (R12)       // update next counter value
+       MOVD    fn+0(FP), R0        // function code (encryption)
+       MOVD    key_base+8(FP), R1  // key
+       MOVD    buf_base+80(FP), R2 // counter values
+       MOVD    dst_base+32(FP), R4 // dst
+       MOVD    src_base+56(FP), R6 // src
+       MOVD    src_len+64(FP), R7  // len
+loop:
+       WORD    $0xB92D2046         // cipher message with counter (KMCTR)
+       BVS     loop                // branch back if interrupted
+       RET
+crash:
+       MOVD    $0, (R0)
+       RET
+
+// func ghash(key *gcmHashKey, hash *[16]byte, data []byte)
+TEXT ·ghash(SB),NOSPLIT,$32-40
+       MOVD    $65, R0 // GHASH function code
+       MOVD    key+0(FP), R2
+       LMG     (R2), R6, R7
+       MOVD    hash+8(FP), R8
+       LMG     (R8), R4, R5
+       MOVD    $params-32(SP), R1
+       STMG    R4, R7, (R1)
+       LMG     data+16(FP), R2, R3 // R2=base, R3=len
+loop:
+       WORD    $0xB93E0002 // compute intermediate message digest (KIMD)
+       BVS     loop        // branch back if interrupted
+       MVC     $16, (R1), (R8)
+       MOVD    $0, R0
+       RET
diff --git a/src/crypto/aes/gcm_s390x.go b/src/crypto/aes/gcm_s390x.go
new file mode 100644 (file)
index 0000000..9eaaf7c
--- /dev/null
@@ -0,0 +1,270 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package aes
+
+import (
+       "crypto/cipher"
+       "crypto/subtle"
+       "errors"
+)
+
+// gcmCount represents a 16-byte big-endian count value.
+type gcmCount [16]byte
+
+// inc increments the rightmost 32-bits of the count value by 1.
+func (x *gcmCount) inc() {
+       // The compiler should optimize this to a 32-bit addition.
+       n := uint32(x[15]) | uint32(x[14])<<8 | uint32(x[13])<<16 | uint32(x[12])<<24
+       n += 1
+       x[12] = byte(n >> 24)
+       x[13] = byte(n >> 16)
+       x[14] = byte(n >> 8)
+       x[15] = byte(n)
+}
+
+// gcmLengths writes len0 || len1 as big-endian values to a 16-byte array.
+func gcmLengths(len0, len1 uint64) [16]byte {
+       return [16]byte{
+               byte(len0 >> 56),
+               byte(len0 >> 48),
+               byte(len0 >> 40),
+               byte(len0 >> 32),
+               byte(len0 >> 24),
+               byte(len0 >> 16),
+               byte(len0 >> 8),
+               byte(len0),
+               byte(len1 >> 56),
+               byte(len1 >> 48),
+               byte(len1 >> 40),
+               byte(len1 >> 32),
+               byte(len1 >> 24),
+               byte(len1 >> 16),
+               byte(len1 >> 8),
+               byte(len1),
+       }
+}
+
+// gcmHashKey represents the 16-byte hash key required by the GHASH algorithm.
+type gcmHashKey [16]byte
+
+type gcmAsm struct {
+       block     *aesCipherAsm
+       hashKey   gcmHashKey
+       nonceSize int
+}
+
+const (
+       gcmBlockSize         = 16
+       gcmTagSize           = 16
+       gcmStandardNonceSize = 12
+)
+
+var errOpen = errors.New("cipher: message authentication failed")
+
+// Assert that aesCipherAsm implements the gcmAble interface.
+var _ gcmAble = (*aesCipherAsm)(nil)
+
+// NewGCM returns the AES cipher wrapped in Galois Counter Mode. This is only
+// called by crypto/cipher.NewGCM via the gcmAble interface.
+func (c *aesCipherAsm) NewGCM(nonceSize int) (cipher.AEAD, error) {
+       var hk gcmHashKey
+       c.Encrypt(hk[:], hk[:])
+       g := &gcmAsm{
+               block:     c,
+               hashKey:   hk,
+               nonceSize: nonceSize,
+       }
+       return g, nil
+}
+
+func (g *gcmAsm) NonceSize() int {
+       return g.nonceSize
+}
+
+func (*gcmAsm) Overhead() int {
+       return gcmTagSize
+}
+
+// sliceForAppend takes a slice and a requested number of bytes. It returns a
+// slice with the contents of the given slice followed by that many bytes and a
+// second slice that aliases into it and contains only the extra bytes. If the
+// original slice has sufficient capacity then no allocation is performed.
+func sliceForAppend(in []byte, n int) (head, tail []byte) {
+       if total := len(in) + n; cap(in) >= total {
+               head = in[:total]
+       } else {
+               head = make([]byte, total)
+               copy(head, in)
+       }
+       tail = head[len(in):]
+       return
+}
+
+// ghash uses the GHASH algorithm to hash data with the given key. The initial
+// hash value is given by hash which will be updated with the new hash value.
+// The length of data must be a multiple of 16-bytes.
+//go:noescape
+func ghash(key *gcmHashKey, hash *[16]byte, data []byte)
+
+// paddedGHASH pads data with zeroes until its length is a multiple of
+// 16-bytes. It then calculates a new value for hash using the GHASH algorithm.
+func (g *gcmAsm) paddedGHASH(hash *[16]byte, data []byte) {
+       siz := len(data) &^ 0xf // align size to 16-bytes
+       if siz > 0 {
+               ghash(&g.hashKey, hash, data[:siz])
+               data = data[siz:]
+       }
+       if len(data) > 0 {
+               var s [16]byte
+               copy(s[:], data)
+               ghash(&g.hashKey, hash, s[:])
+       }
+}
+
+// cryptBlocksGCM encrypts src using AES in counter mode using the given
+// function code and key. The rightmost 32-bits of the counter are incremented
+// between each block as required by the GCM spec. The initial counter value
+// is given by cnt, which is updated with the value of the next counter value
+// to use.
+//
+// The lengths of both dst and buf must be greater than or equal to the length
+// of src. buf may be partially or completely overwritten during the execution
+// of the function.
+//go:noescape
+func cryptBlocksGCM(fn code, key, dst, src, buf []byte, cnt *gcmCount)
+
+// counterCrypt encrypts src using AES in counter mode and places the result
+// into dst. cnt is the initial count value and will be updated with the next
+// count value. The length of dst must be greater than or equal to the length
+// of src.
+func (g *gcmAsm) counterCrypt(dst, src []byte, cnt *gcmCount) {
+       // Copying src into a buffer improves performance on some models when
+       // src and dst point to the same underlying array. We also need a
+       // buffer for counter values.
+       var ctrbuf, srcbuf [2048]byte
+       for len(src) >= 16 {
+               siz := len(src)
+               if len(src) > len(ctrbuf) {
+                       siz = len(ctrbuf)
+               }
+               siz &^= 0xf // align siz to 16-bytes
+               copy(srcbuf[:], src[:siz])
+               cryptBlocksGCM(g.block.function, g.block.key, dst[:siz], srcbuf[:siz], ctrbuf[:], cnt)
+               src = src[siz:]
+               dst = dst[siz:]
+       }
+       if len(src) > 0 {
+               var x [16]byte
+               g.block.Encrypt(x[:], cnt[:])
+               for i := range src {
+                       dst[i] = src[i] ^ x[i]
+               }
+               cnt.inc()
+       }
+}
+
+// deriveCounter computes the initial GCM counter state from the given nonce.
+// See NIST SP 800-38D, section 7.1.
+func (g *gcmAsm) deriveCounter(nonce []byte) gcmCount {
+       // GCM has two modes of operation with respect to the initial counter
+       // state: a "fast path" for 96-bit (12-byte) nonces, and a "slow path"
+       // for nonces of other lengths. For a 96-bit nonce, the nonce, along
+       // with a four-byte big-endian counter starting at one, is used
+       // directly as the starting counter. For other nonce sizes, the counter
+       // is computed by passing it through the GHASH function.
+       var counter gcmCount
+       if len(nonce) == gcmStandardNonceSize {
+               copy(counter[:], nonce)
+               counter[gcmBlockSize-1] = 1
+       } else {
+               var hash [16]byte
+               g.paddedGHASH(&hash, nonce)
+               lens := gcmLengths(0, uint64(len(nonce))*8)
+               g.paddedGHASH(&hash, lens[:])
+               copy(counter[:], hash[:])
+       }
+       return counter
+}
+
+// auth calculates GHASH(ciphertext, additionalData), masks the result with
+// tagMask and writes the result to out.
+func (g *gcmAsm) auth(out, ciphertext, additionalData []byte, tagMask *[gcmTagSize]byte) {
+       var hash [16]byte
+       g.paddedGHASH(&hash, additionalData)
+       g.paddedGHASH(&hash, ciphertext)
+       lens := gcmLengths(uint64(len(additionalData))*8, uint64(len(ciphertext))*8)
+       g.paddedGHASH(&hash, lens[:])
+
+       copy(out, hash[:])
+       for i := range out {
+               out[i] ^= tagMask[i]
+       }
+}
+
+// Seal encrypts and authenticates plaintext. See the cipher.AEAD interface for
+// details.
+func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte {
+       if len(nonce) != g.nonceSize {
+               panic("cipher: incorrect nonce length given to GCM")
+       }
+       if uint64(len(plaintext)) > ((1<<32)-2)*BlockSize {
+               panic("cipher: message too large for GCM")
+       }
+
+       ret, out := sliceForAppend(dst, len(plaintext)+gcmTagSize)
+
+       counter := g.deriveCounter(nonce)
+
+       var tagMask [gcmBlockSize]byte
+       g.block.Encrypt(tagMask[:], counter[:])
+       counter.inc()
+
+       g.counterCrypt(out, plaintext, &counter)
+       g.auth(out[len(plaintext):], out[:len(plaintext)], data, &tagMask)
+
+       return ret
+}
+
+// Open authenticates and decrypts ciphertext. See the cipher.AEAD interface
+// for details.
+func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) {
+       if len(nonce) != g.nonceSize {
+               panic("cipher: incorrect nonce length given to GCM")
+       }
+       if len(ciphertext) < gcmTagSize {
+               return nil, errOpen
+       }
+       if uint64(len(ciphertext)) > ((1<<32)-2)*BlockSize+gcmTagSize {
+               return nil, errOpen
+       }
+
+       tag := ciphertext[len(ciphertext)-gcmTagSize:]
+       ciphertext = ciphertext[:len(ciphertext)-gcmTagSize]
+
+       counter := g.deriveCounter(nonce)
+
+       var tagMask [gcmBlockSize]byte
+       g.block.Encrypt(tagMask[:], counter[:])
+       counter.inc()
+
+       var expectedTag [gcmTagSize]byte
+       g.auth(expectedTag[:], ciphertext, data, &tagMask)
+
+       ret, out := sliceForAppend(dst, len(ciphertext))
+
+       if subtle.ConstantTimeCompare(expectedTag[:], tag) != 1 {
+               // The AESNI code decrypts and authenticates concurrently, and
+               // so overwrites dst in the event of a tag mismatch. That
+               // behaviour is mimicked here in order to be consistent across
+               // platforms.
+               for i := range out {
+                       out[i] = 0
+               }
+               return nil, errOpen
+       }
+
+       g.counterCrypt(out, ciphertext, &counter)
+       return ret, nil
+}
index bb1ab3c0b0f27e0d976f5dab6c23bd13875e1a28..6878b4cb42bdf96b1c52bf866a687b1652a4eb0f 100644 (file)
@@ -8,7 +8,11 @@ import (
        "bytes"
        "crypto/aes"
        "crypto/cipher"
+       "crypto/rand"
        "encoding/hex"
+       "errors"
+       "io"
+       "reflect"
        "testing"
 )
 
@@ -274,3 +278,157 @@ func TestTagFailureOverwrite(t *testing.T) {
                }
        }
 }
+
+func TestGCMCounterWrap(t *testing.T) {
+       // Test that the last 32-bits of the counter wrap correctly.
+       tests := []struct {
+               nonce, tag string
+       }{
+               {"0fa72e25", "37e1948cdfff09fbde0c40ad99fee4a7"},   // counter: 7eb59e4d961dad0dfdd75aaffffffff0
+               {"afe05cc1", "438f3aa9fee5e54903b1927bca26bbdf"},   // counter: 75d492a7e6e6bfc979ad3a8ffffffff4
+               {"9ffecbef", "7b88ca424df9703e9e8611071ec7e16e"},   // counter: c8bb108b0ecdc71747b9d57ffffffff5
+               {"ffc3e5b3", "38d49c86e0abe853ac250e66da54c01a"},   // counter: 706414d2de9b36ab3b900a9ffffffff6
+               {"cfdd729d", "e08402eaac36a1a402e09b1bd56500e8"},   // counter: cd0b96fe36b04e750584e56ffffffff7
+               {"010ae3d486", "5405bb490b1f95d01e2ba735687154bc"}, // counter: e36c18e69406c49722808104fffffff8
+               {"01b1107a9d", "939a585f342e01e17844627492d44dbf"}, // counter: e6d56eaf9127912b6d62c6dcffffffff
+       }
+       key, err := aes.NewCipher(make([]byte, 16))
+       if err != nil {
+               t.Fatal(err)
+       }
+       plaintext := make([]byte, 16*17+1)
+       for i, test := range tests {
+               nonce, _ := hex.DecodeString(test.nonce)
+               want, _ := hex.DecodeString(test.tag)
+               aead, err := cipher.NewGCMWithNonceSize(key, len(nonce))
+               if err != nil {
+                       t.Fatal(err)
+               }
+               got := aead.Seal(nil, nonce, plaintext, nil)
+               if !bytes.Equal(got[len(plaintext):], want) {
+                       t.Errorf("test[%v]: got: %x, want: %x", i, got[len(plaintext):], want)
+               }
+               _, err = aead.Open(nil, nonce, got, nil)
+               if err != nil {
+                       t.Errorf("test[%v]: authentication failed", i)
+               }
+       }
+}
+
+var _ cipher.Block = (*wrapper)(nil)
+
+type wrapper struct {
+       block cipher.Block
+}
+
+func (w *wrapper) BlockSize() int          { return w.block.BlockSize() }
+func (w *wrapper) Encrypt(dst, src []byte) { w.block.Encrypt(dst, src) }
+func (w *wrapper) Decrypt(dst, src []byte) { w.block.Decrypt(dst, src) }
+
+// wrap wraps the Block interface so that it does not fulfill
+// any optimizing interfaces such as gcmAble.
+func wrap(b cipher.Block) cipher.Block {
+       return &wrapper{b}
+}
+
+func TestGCMAsm(t *testing.T) {
+       // Create a new pair of AEADs, one using the assembly implementation
+       // and one using the generic Go implementation.
+       newAESGCM := func(key []byte) (asm, generic cipher.AEAD, err error) {
+               block, err := aes.NewCipher(key[:])
+               if err != nil {
+                       return nil, nil, err
+               }
+               asm, err = cipher.NewGCM(block)
+               if err != nil {
+                       return nil, nil, err
+               }
+               generic, err = cipher.NewGCM(wrap(block))
+               if err != nil {
+                       return nil, nil, err
+               }
+               return asm, generic, nil
+       }
+
+       // check for assembly implementation
+       var key [16]byte
+       asm, generic, err := newAESGCM(key[:])
+       if err != nil {
+               t.Fatal(err)
+       }
+       if reflect.TypeOf(asm) == reflect.TypeOf(generic) {
+               t.Skipf("no assembly implementation of GCM")
+       }
+
+       // generate permutations
+       type pair struct{ align, length int }
+       lengths := []int{0, 8192, 8193, 8208}
+       keySizes := []int{16, 24, 32}
+       alignments := []int{0, 1, 2, 3}
+       if testing.Short() {
+               keySizes = []int{16}
+               alignments = []int{1}
+       }
+       perms := make([]pair, 0)
+       for _, l := range lengths {
+               for _, a := range alignments {
+                       if a != 0 && l == 0 {
+                               continue
+                       }
+                       perms = append(perms, pair{align: a, length: l})
+               }
+       }
+
+       // run test for all permutations
+       test := func(ks int, pt, ad []byte) error {
+               key := make([]byte, ks)
+               if _, err := io.ReadFull(rand.Reader, key); err != nil {
+                       return err
+               }
+               asm, generic, err := newAESGCM(key)
+               if err != nil {
+                       return err
+               }
+               if _, err := io.ReadFull(rand.Reader, pt); err != nil {
+                       return err
+               }
+               if _, err := io.ReadFull(rand.Reader, ad); err != nil {
+                       return err
+               }
+               nonce := make([]byte, 12)
+               if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
+                       return err
+               }
+               want := generic.Seal(nil, nonce, pt, ad)
+               got := asm.Seal(nil, nonce, pt, ad)
+               if !bytes.Equal(want, got) {
+                       return errors.New("incorrect Seal output")
+               }
+               got, err = asm.Open(nil, nonce, want, ad)
+               if err != nil {
+                       return errors.New("authentication failed")
+               }
+               if !bytes.Equal(pt, got) {
+                       return errors.New("incorrect Open output")
+               }
+               return nil
+       }
+       for _, a := range perms {
+               ad := make([]byte, a.align+a.length)
+               ad = ad[a.align:]
+               for _, p := range perms {
+                       pt := make([]byte, p.align+p.length)
+                       pt = pt[p.align:]
+                       for _, ks := range keySizes {
+                               if err := test(ks, pt, ad); err != nil {
+                                       t.Error(err)
+                                       t.Errorf("      key size: %v", ks)
+                                       t.Errorf("      plaintext alignment: %v", p.align)
+                                       t.Errorf("      plaintext length: %v", p.length)
+                                       t.Errorf("      additionalData alignment: %v", a.align)
+                                       t.Fatalf("      additionalData length: %v", a.length)
+                               }
+                       }
+               }
+       }
+}