From: Michael Munday Date: Fri, 30 Sep 2016 18:20:42 +0000 (-0400) Subject: crypto/{aes,cipher}: add optimized implementation of AES-GCM for s390x X-Git-Tag: go1.8beta1~1025 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=fb4f4f4e96058165c0e7be32aa9ce493515c22a3;p=gostls13.git crypto/{aes,cipher}: add optimized implementation of AES-GCM for s390x Also adds two tests: one to exercise the counter incrementing code and one which checks the output of the optimized implementation against that of the generic implementation for large/unaligned data sizes. Uses the KIMD instruction for GHASH and the KMCTR instruction for AES in counter mode. AESGCMSeal1K 75.0MB/s ± 2% 1008.7MB/s ± 1% +1245.71% (p=0.000 n=10+10) AESGCMOpen1K 75.3MB/s ± 1% 1006.0MB/s ± 1% +1235.59% (p=0.000 n=10+9) AESGCMSeal8K 78.5MB/s ± 1% 1748.4MB/s ± 1% +2127.34% (p=0.000 n=9+10) AESGCMOpen8K 78.5MB/s ± 0% 1752.7MB/s ± 0% +2134.07% (p=0.000 n=10+9) Change-Id: I88dbcfcb5988104bfd290ae15a60a2721c1338be Reviewed-on: https://go-review.googlesource.com/30361 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- diff --git a/src/crypto/aes/asm_s390x.s b/src/crypto/aes/asm_s390x.s index e31415a420..5714aee318 100644 --- a/src/crypto/aes/asm_s390x.s +++ b/src/crypto/aes/asm_s390x.s @@ -22,6 +22,19 @@ TEXT ·hasAsm(SB),NOSPLIT,$16-1 AND R3, R2 CMPBNE R2, R3, notfound + // check for KMCTR AES functions + WORD $0xB92D4024 // cipher message with counter (KMCTR) + MOVD mask-16(SP), R2 + AND R3, R2 + CMPBNE R2, R3, notfound + + // check for KIMD GHASH function + WORD $0xB93E0024 // compute intermediate message digest (KIMD) + MOVD mask-8(SP), R2 // bits 64-127 + MOVD $(1<<62), R5 + AND R5, R2 + CMPBNE R2, R5, notfound + MOVB $1, ret+0(FP) RET notfound: @@ -91,3 +104,86 @@ tail: BR tail done: RET + +// func cryptBlocksGCM(fn code, key, dst, src, buf []byte, cnt *[16]byte) +TEXT ·cryptBlocksGCM(SB),NOSPLIT,$0-112 + MOVD src_len+64(FP), R0 + MOVD buf_base+80(FP), R1 + MOVD cnt+104(FP), R12 + LMG (R12), R2, R3 + + // Check that the src size is less than or equal to the buffer size. + MOVD buf_len+88(FP), R4 + CMP R0, R4 + BGT crash + + // Check that the src size is a multiple of 16-bytes. + MOVD R0, R4 + AND $0xf, R4 + BLT crash // non-zero + + // Check that the src size is less than or equal to the dst size. + MOVD dst_len+40(FP), R4 + CMP R0, R4 + BGT crash + + MOVD R2, R4 + MOVD R2, R6 + MOVD R2, R8 + MOVD R3, R5 + MOVD R3, R7 + MOVD R3, R9 + ADDW $1, R5 + ADDW $2, R7 + ADDW $3, R9 +incr: + CMP R0, $64 + BLT tail + STMG R2, R9, (R1) + ADDW $4, R3 + ADDW $4, R5 + ADDW $4, R7 + ADDW $4, R9 + MOVD $64(R1), R1 + SUB $64, R0 + BR incr +tail: + CMP R0, $0 + BEQ crypt + STMG R2, R3, (R1) + ADDW $1, R3 + MOVD $16(R1), R1 + SUB $16, R0 + BR tail +crypt: + STMG R2, R3, (R12) // update next counter value + MOVD fn+0(FP), R0 // function code (encryption) + MOVD key_base+8(FP), R1 // key + MOVD buf_base+80(FP), R2 // counter values + MOVD dst_base+32(FP), R4 // dst + MOVD src_base+56(FP), R6 // src + MOVD src_len+64(FP), R7 // len +loop: + WORD $0xB92D2046 // cipher message with counter (KMCTR) + BVS loop // branch back if interrupted + RET +crash: + MOVD $0, (R0) + RET + +// func ghash(key *gcmHashKey, hash *[16]byte, data []byte) +TEXT ·ghash(SB),NOSPLIT,$32-40 + MOVD $65, R0 // GHASH function code + MOVD key+0(FP), R2 + LMG (R2), R6, R7 + MOVD hash+8(FP), R8 + LMG (R8), R4, R5 + MOVD $params-32(SP), R1 + STMG R4, R7, (R1) + LMG data+16(FP), R2, R3 // R2=base, R3=len +loop: + WORD $0xB93E0002 // compute intermediate message digest (KIMD) + BVS loop // branch back if interrupted + MVC $16, (R1), (R8) + MOVD $0, R0 + RET diff --git a/src/crypto/aes/gcm_s390x.go b/src/crypto/aes/gcm_s390x.go new file mode 100644 index 0000000000..9eaaf7c21e --- /dev/null +++ b/src/crypto/aes/gcm_s390x.go @@ -0,0 +1,270 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package aes + +import ( + "crypto/cipher" + "crypto/subtle" + "errors" +) + +// gcmCount represents a 16-byte big-endian count value. +type gcmCount [16]byte + +// inc increments the rightmost 32-bits of the count value by 1. +func (x *gcmCount) inc() { + // The compiler should optimize this to a 32-bit addition. + n := uint32(x[15]) | uint32(x[14])<<8 | uint32(x[13])<<16 | uint32(x[12])<<24 + n += 1 + x[12] = byte(n >> 24) + x[13] = byte(n >> 16) + x[14] = byte(n >> 8) + x[15] = byte(n) +} + +// gcmLengths writes len0 || len1 as big-endian values to a 16-byte array. +func gcmLengths(len0, len1 uint64) [16]byte { + return [16]byte{ + byte(len0 >> 56), + byte(len0 >> 48), + byte(len0 >> 40), + byte(len0 >> 32), + byte(len0 >> 24), + byte(len0 >> 16), + byte(len0 >> 8), + byte(len0), + byte(len1 >> 56), + byte(len1 >> 48), + byte(len1 >> 40), + byte(len1 >> 32), + byte(len1 >> 24), + byte(len1 >> 16), + byte(len1 >> 8), + byte(len1), + } +} + +// gcmHashKey represents the 16-byte hash key required by the GHASH algorithm. +type gcmHashKey [16]byte + +type gcmAsm struct { + block *aesCipherAsm + hashKey gcmHashKey + nonceSize int +} + +const ( + gcmBlockSize = 16 + gcmTagSize = 16 + gcmStandardNonceSize = 12 +) + +var errOpen = errors.New("cipher: message authentication failed") + +// Assert that aesCipherAsm implements the gcmAble interface. +var _ gcmAble = (*aesCipherAsm)(nil) + +// NewGCM returns the AES cipher wrapped in Galois Counter Mode. This is only +// called by crypto/cipher.NewGCM via the gcmAble interface. +func (c *aesCipherAsm) NewGCM(nonceSize int) (cipher.AEAD, error) { + var hk gcmHashKey + c.Encrypt(hk[:], hk[:]) + g := &gcmAsm{ + block: c, + hashKey: hk, + nonceSize: nonceSize, + } + return g, nil +} + +func (g *gcmAsm) NonceSize() int { + return g.nonceSize +} + +func (*gcmAsm) Overhead() int { + return gcmTagSize +} + +// sliceForAppend takes a slice and a requested number of bytes. It returns a +// slice with the contents of the given slice followed by that many bytes and a +// second slice that aliases into it and contains only the extra bytes. If the +// original slice has sufficient capacity then no allocation is performed. +func sliceForAppend(in []byte, n int) (head, tail []byte) { + if total := len(in) + n; cap(in) >= total { + head = in[:total] + } else { + head = make([]byte, total) + copy(head, in) + } + tail = head[len(in):] + return +} + +// ghash uses the GHASH algorithm to hash data with the given key. The initial +// hash value is given by hash which will be updated with the new hash value. +// The length of data must be a multiple of 16-bytes. +//go:noescape +func ghash(key *gcmHashKey, hash *[16]byte, data []byte) + +// paddedGHASH pads data with zeroes until its length is a multiple of +// 16-bytes. It then calculates a new value for hash using the GHASH algorithm. +func (g *gcmAsm) paddedGHASH(hash *[16]byte, data []byte) { + siz := len(data) &^ 0xf // align size to 16-bytes + if siz > 0 { + ghash(&g.hashKey, hash, data[:siz]) + data = data[siz:] + } + if len(data) > 0 { + var s [16]byte + copy(s[:], data) + ghash(&g.hashKey, hash, s[:]) + } +} + +// cryptBlocksGCM encrypts src using AES in counter mode using the given +// function code and key. The rightmost 32-bits of the counter are incremented +// between each block as required by the GCM spec. The initial counter value +// is given by cnt, which is updated with the value of the next counter value +// to use. +// +// The lengths of both dst and buf must be greater than or equal to the length +// of src. buf may be partially or completely overwritten during the execution +// of the function. +//go:noescape +func cryptBlocksGCM(fn code, key, dst, src, buf []byte, cnt *gcmCount) + +// counterCrypt encrypts src using AES in counter mode and places the result +// into dst. cnt is the initial count value and will be updated with the next +// count value. The length of dst must be greater than or equal to the length +// of src. +func (g *gcmAsm) counterCrypt(dst, src []byte, cnt *gcmCount) { + // Copying src into a buffer improves performance on some models when + // src and dst point to the same underlying array. We also need a + // buffer for counter values. + var ctrbuf, srcbuf [2048]byte + for len(src) >= 16 { + siz := len(src) + if len(src) > len(ctrbuf) { + siz = len(ctrbuf) + } + siz &^= 0xf // align siz to 16-bytes + copy(srcbuf[:], src[:siz]) + cryptBlocksGCM(g.block.function, g.block.key, dst[:siz], srcbuf[:siz], ctrbuf[:], cnt) + src = src[siz:] + dst = dst[siz:] + } + if len(src) > 0 { + var x [16]byte + g.block.Encrypt(x[:], cnt[:]) + for i := range src { + dst[i] = src[i] ^ x[i] + } + cnt.inc() + } +} + +// deriveCounter computes the initial GCM counter state from the given nonce. +// See NIST SP 800-38D, section 7.1. +func (g *gcmAsm) deriveCounter(nonce []byte) gcmCount { + // GCM has two modes of operation with respect to the initial counter + // state: a "fast path" for 96-bit (12-byte) nonces, and a "slow path" + // for nonces of other lengths. For a 96-bit nonce, the nonce, along + // with a four-byte big-endian counter starting at one, is used + // directly as the starting counter. For other nonce sizes, the counter + // is computed by passing it through the GHASH function. + var counter gcmCount + if len(nonce) == gcmStandardNonceSize { + copy(counter[:], nonce) + counter[gcmBlockSize-1] = 1 + } else { + var hash [16]byte + g.paddedGHASH(&hash, nonce) + lens := gcmLengths(0, uint64(len(nonce))*8) + g.paddedGHASH(&hash, lens[:]) + copy(counter[:], hash[:]) + } + return counter +} + +// auth calculates GHASH(ciphertext, additionalData), masks the result with +// tagMask and writes the result to out. +func (g *gcmAsm) auth(out, ciphertext, additionalData []byte, tagMask *[gcmTagSize]byte) { + var hash [16]byte + g.paddedGHASH(&hash, additionalData) + g.paddedGHASH(&hash, ciphertext) + lens := gcmLengths(uint64(len(additionalData))*8, uint64(len(ciphertext))*8) + g.paddedGHASH(&hash, lens[:]) + + copy(out, hash[:]) + for i := range out { + out[i] ^= tagMask[i] + } +} + +// Seal encrypts and authenticates plaintext. See the cipher.AEAD interface for +// details. +func (g *gcmAsm) Seal(dst, nonce, plaintext, data []byte) []byte { + if len(nonce) != g.nonceSize { + panic("cipher: incorrect nonce length given to GCM") + } + if uint64(len(plaintext)) > ((1<<32)-2)*BlockSize { + panic("cipher: message too large for GCM") + } + + ret, out := sliceForAppend(dst, len(plaintext)+gcmTagSize) + + counter := g.deriveCounter(nonce) + + var tagMask [gcmBlockSize]byte + g.block.Encrypt(tagMask[:], counter[:]) + counter.inc() + + g.counterCrypt(out, plaintext, &counter) + g.auth(out[len(plaintext):], out[:len(plaintext)], data, &tagMask) + + return ret +} + +// Open authenticates and decrypts ciphertext. See the cipher.AEAD interface +// for details. +func (g *gcmAsm) Open(dst, nonce, ciphertext, data []byte) ([]byte, error) { + if len(nonce) != g.nonceSize { + panic("cipher: incorrect nonce length given to GCM") + } + if len(ciphertext) < gcmTagSize { + return nil, errOpen + } + if uint64(len(ciphertext)) > ((1<<32)-2)*BlockSize+gcmTagSize { + return nil, errOpen + } + + tag := ciphertext[len(ciphertext)-gcmTagSize:] + ciphertext = ciphertext[:len(ciphertext)-gcmTagSize] + + counter := g.deriveCounter(nonce) + + var tagMask [gcmBlockSize]byte + g.block.Encrypt(tagMask[:], counter[:]) + counter.inc() + + var expectedTag [gcmTagSize]byte + g.auth(expectedTag[:], ciphertext, data, &tagMask) + + ret, out := sliceForAppend(dst, len(ciphertext)) + + if subtle.ConstantTimeCompare(expectedTag[:], tag) != 1 { + // The AESNI code decrypts and authenticates concurrently, and + // so overwrites dst in the event of a tag mismatch. That + // behaviour is mimicked here in order to be consistent across + // platforms. + for i := range out { + out[i] = 0 + } + return nil, errOpen + } + + g.counterCrypt(out, ciphertext, &counter) + return ret, nil +} diff --git a/src/crypto/cipher/gcm_test.go b/src/crypto/cipher/gcm_test.go index bb1ab3c0b0..6878b4cb42 100644 --- a/src/crypto/cipher/gcm_test.go +++ b/src/crypto/cipher/gcm_test.go @@ -8,7 +8,11 @@ import ( "bytes" "crypto/aes" "crypto/cipher" + "crypto/rand" "encoding/hex" + "errors" + "io" + "reflect" "testing" ) @@ -274,3 +278,157 @@ func TestTagFailureOverwrite(t *testing.T) { } } } + +func TestGCMCounterWrap(t *testing.T) { + // Test that the last 32-bits of the counter wrap correctly. + tests := []struct { + nonce, tag string + }{ + {"0fa72e25", "37e1948cdfff09fbde0c40ad99fee4a7"}, // counter: 7eb59e4d961dad0dfdd75aaffffffff0 + {"afe05cc1", "438f3aa9fee5e54903b1927bca26bbdf"}, // counter: 75d492a7e6e6bfc979ad3a8ffffffff4 + {"9ffecbef", "7b88ca424df9703e9e8611071ec7e16e"}, // counter: c8bb108b0ecdc71747b9d57ffffffff5 + {"ffc3e5b3", "38d49c86e0abe853ac250e66da54c01a"}, // counter: 706414d2de9b36ab3b900a9ffffffff6 + {"cfdd729d", "e08402eaac36a1a402e09b1bd56500e8"}, // counter: cd0b96fe36b04e750584e56ffffffff7 + {"010ae3d486", "5405bb490b1f95d01e2ba735687154bc"}, // counter: e36c18e69406c49722808104fffffff8 + {"01b1107a9d", "939a585f342e01e17844627492d44dbf"}, // counter: e6d56eaf9127912b6d62c6dcffffffff + } + key, err := aes.NewCipher(make([]byte, 16)) + if err != nil { + t.Fatal(err) + } + plaintext := make([]byte, 16*17+1) + for i, test := range tests { + nonce, _ := hex.DecodeString(test.nonce) + want, _ := hex.DecodeString(test.tag) + aead, err := cipher.NewGCMWithNonceSize(key, len(nonce)) + if err != nil { + t.Fatal(err) + } + got := aead.Seal(nil, nonce, plaintext, nil) + if !bytes.Equal(got[len(plaintext):], want) { + t.Errorf("test[%v]: got: %x, want: %x", i, got[len(plaintext):], want) + } + _, err = aead.Open(nil, nonce, got, nil) + if err != nil { + t.Errorf("test[%v]: authentication failed", i) + } + } +} + +var _ cipher.Block = (*wrapper)(nil) + +type wrapper struct { + block cipher.Block +} + +func (w *wrapper) BlockSize() int { return w.block.BlockSize() } +func (w *wrapper) Encrypt(dst, src []byte) { w.block.Encrypt(dst, src) } +func (w *wrapper) Decrypt(dst, src []byte) { w.block.Decrypt(dst, src) } + +// wrap wraps the Block interface so that it does not fulfill +// any optimizing interfaces such as gcmAble. +func wrap(b cipher.Block) cipher.Block { + return &wrapper{b} +} + +func TestGCMAsm(t *testing.T) { + // Create a new pair of AEADs, one using the assembly implementation + // and one using the generic Go implementation. + newAESGCM := func(key []byte) (asm, generic cipher.AEAD, err error) { + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, nil, err + } + asm, err = cipher.NewGCM(block) + if err != nil { + return nil, nil, err + } + generic, err = cipher.NewGCM(wrap(block)) + if err != nil { + return nil, nil, err + } + return asm, generic, nil + } + + // check for assembly implementation + var key [16]byte + asm, generic, err := newAESGCM(key[:]) + if err != nil { + t.Fatal(err) + } + if reflect.TypeOf(asm) == reflect.TypeOf(generic) { + t.Skipf("no assembly implementation of GCM") + } + + // generate permutations + type pair struct{ align, length int } + lengths := []int{0, 8192, 8193, 8208} + keySizes := []int{16, 24, 32} + alignments := []int{0, 1, 2, 3} + if testing.Short() { + keySizes = []int{16} + alignments = []int{1} + } + perms := make([]pair, 0) + for _, l := range lengths { + for _, a := range alignments { + if a != 0 && l == 0 { + continue + } + perms = append(perms, pair{align: a, length: l}) + } + } + + // run test for all permutations + test := func(ks int, pt, ad []byte) error { + key := make([]byte, ks) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + return err + } + asm, generic, err := newAESGCM(key) + if err != nil { + return err + } + if _, err := io.ReadFull(rand.Reader, pt); err != nil { + return err + } + if _, err := io.ReadFull(rand.Reader, ad); err != nil { + return err + } + nonce := make([]byte, 12) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return err + } + want := generic.Seal(nil, nonce, pt, ad) + got := asm.Seal(nil, nonce, pt, ad) + if !bytes.Equal(want, got) { + return errors.New("incorrect Seal output") + } + got, err = asm.Open(nil, nonce, want, ad) + if err != nil { + return errors.New("authentication failed") + } + if !bytes.Equal(pt, got) { + return errors.New("incorrect Open output") + } + return nil + } + for _, a := range perms { + ad := make([]byte, a.align+a.length) + ad = ad[a.align:] + for _, p := range perms { + pt := make([]byte, p.align+p.length) + pt = pt[p.align:] + for _, ks := range keySizes { + if err := test(ks, pt, ad); err != nil { + t.Error(err) + t.Errorf(" key size: %v", ks) + t.Errorf(" plaintext alignment: %v", p.align) + t.Errorf(" plaintext length: %v", p.length) + t.Errorf(" additionalData alignment: %v", a.align) + t.Fatalf(" additionalData length: %v", a.length) + } + } + } + } +}