]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/cryptotest: add tests for the cipher.Stream interface
authorManuel Sabin <msabin27@gmail.com>
Thu, 27 Jun 2024 19:06:55 +0000 (15:06 -0400)
committerGopher Robot <gobot@golang.org>
Fri, 2 Aug 2024 14:54:04 +0000 (14:54 +0000)
This CL creates tests for the cipher.Stream interface in the new
cryptotest package.  This set of tests is called from the tests of
implementations of the Stream interface e.g. ctr_test.go, ofb_test.go,
rc4_test.go, etc.

Updates #25309

Change-Id: I57204ef9f4c0ec09b94e88466deb03c6715e411d
Reviewed-on: https://go-review.googlesource.com/c/go/+/595564
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Russell Webb <russell.webb@protonmail.com>
src/crypto/cipher/cfb_test.go
src/crypto/cipher/ctr_test.go
src/crypto/cipher/ofb_test.go
src/crypto/internal/cryptotest/stream.go [new file with mode: 0644]
src/crypto/rc4/rc4_test.go

index 72f62e69d3d39cb9a6c40df6f110ff8a27774f07..67033d9a3bba1a50f2afaf7a393d000a74e0e94b 100644 (file)
@@ -8,8 +8,11 @@ import (
        "bytes"
        "crypto/aes"
        "crypto/cipher"
+       "crypto/des"
+       "crypto/internal/cryptotest"
        "crypto/rand"
        "encoding/hex"
+       "fmt"
        "testing"
 )
 
@@ -111,3 +114,47 @@ func TestCFBInverse(t *testing.T) {
                t.Errorf("got: %x, want: %x", plaintextCopy, plaintext)
        }
 }
+
+func TestCFBStream(t *testing.T) {
+
+       for _, keylen := range []int{128, 192, 256} {
+
+               t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
+                       rng := newRandReader(t)
+
+                       key := make([]byte, keylen/8)
+                       rng.Read(key)
+
+                       block, err := aes.NewCipher(key)
+                       if err != nil {
+                               panic(err)
+                       }
+
+                       t.Run("Encrypter", func(t *testing.T) {
+                               cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBEncrypter)
+                       })
+                       t.Run("Decrypter", func(t *testing.T) {
+                               cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBDecrypter)
+                       })
+               })
+       }
+
+       t.Run("DES", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               key := make([]byte, 8)
+               rng.Read(key)
+
+               block, err := des.NewCipher(key)
+               if err != nil {
+                       panic(err)
+               }
+
+               t.Run("Encrypter", func(t *testing.T) {
+                       cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBEncrypter)
+               })
+               t.Run("Decrypter", func(t *testing.T) {
+                       cryptotest.TestStreamFromBlock(t, block, cipher.NewCFBDecrypter)
+               })
+       })
+}
index e5cce576c79e260750fb47eac6f75b38943a4d95..4bb9deab809658dfc3d5b57cd33e1724e6820aa9 100644 (file)
@@ -6,7 +6,11 @@ package cipher_test
 
 import (
        "bytes"
+       "crypto/aes"
        "crypto/cipher"
+       "crypto/des"
+       "crypto/internal/cryptotest"
+       "fmt"
        "testing"
 )
 
@@ -53,3 +57,37 @@ func TestCTR(t *testing.T) {
                }
        }
 }
+
+func TestCTRStream(t *testing.T) {
+
+       for _, keylen := range []int{128, 192, 256} {
+
+               t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
+                       rng := newRandReader(t)
+
+                       key := make([]byte, keylen/8)
+                       rng.Read(key)
+
+                       block, err := aes.NewCipher(key)
+                       if err != nil {
+                               panic(err)
+                       }
+
+                       cryptotest.TestStreamFromBlock(t, block, cipher.NewCTR)
+               })
+       }
+
+       t.Run("DES", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               key := make([]byte, 8)
+               rng.Read(key)
+
+               block, err := des.NewCipher(key)
+               if err != nil {
+                       panic(err)
+               }
+
+               cryptotest.TestStreamFromBlock(t, block, cipher.NewCTR)
+       })
+}
index 8d3c5d3a3899bdb0510deb5147fda2fd2c48c439..036b76c45c620ac12707e0450ccd621af409d77e 100644 (file)
@@ -14,6 +14,9 @@ import (
        "bytes"
        "crypto/aes"
        "crypto/cipher"
+       "crypto/des"
+       "crypto/internal/cryptotest"
+       "fmt"
        "testing"
 )
 
@@ -100,3 +103,37 @@ func TestOFB(t *testing.T) {
                }
        }
 }
+
+func TestOFBStream(t *testing.T) {
+
+       for _, keylen := range []int{128, 192, 256} {
+
+               t.Run(fmt.Sprintf("AES-%d", keylen), func(t *testing.T) {
+                       rng := newRandReader(t)
+
+                       key := make([]byte, keylen/8)
+                       rng.Read(key)
+
+                       block, err := aes.NewCipher(key)
+                       if err != nil {
+                               panic(err)
+                       }
+
+                       cryptotest.TestStreamFromBlock(t, block, cipher.NewOFB)
+               })
+       }
+
+       t.Run("DES", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               key := make([]byte, 8)
+               rng.Read(key)
+
+               block, err := des.NewCipher(key)
+               if err != nil {
+                       panic(err)
+               }
+
+               cryptotest.TestStreamFromBlock(t, block, cipher.NewOFB)
+       })
+}
diff --git a/src/crypto/internal/cryptotest/stream.go b/src/crypto/internal/cryptotest/stream.go
new file mode 100644 (file)
index 0000000..fb9c553
--- /dev/null
@@ -0,0 +1,241 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package cryptotest
+
+import (
+       "bytes"
+       "crypto/cipher"
+       "crypto/subtle"
+       "fmt"
+       "strings"
+       "testing"
+)
+
+// Each test is executed with each of the buffer lengths in bufLens.
+var (
+       bufLens = []int{0, 1, 3, 4, 8, 10, 15, 16, 20, 32, 50, 4096, 5000}
+       bufCap  = 10000
+)
+
+// MakeStream returns a cipher.Stream instance.
+//
+// Multiple calls to MakeStream must return equivalent instances,
+// so for example the key and/or IV must be fixed.
+type MakeStream func() cipher.Stream
+
+// TestStream performs a set of tests on cipher.Stream implementations,
+// checking the documented requirements of XORKeyStream.
+func TestStream(t *testing.T, ms MakeStream) {
+
+       t.Run("XORSemantics", func(t *testing.T) {
+               if strings.Contains(t.Name(), "TestCFBStream") {
+                       // This is ugly, but so is CFB's abuse of cipher.Stream.
+                       // Don't want to make it easier for anyone else to do that.
+                       t.Skip("CFB implements cipher.Stream but does not follow XOR semantics")
+               }
+
+               // Test that XORKeyStream inverts itself for encryption/decryption.
+               t.Run("Roundtrip", func(t *testing.T) {
+
+                       for _, length := range bufLens {
+                               t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
+                                       rng := newRandReader(t)
+
+                                       plaintext := make([]byte, length)
+                                       rng.Read(plaintext)
+
+                                       ciphertext := make([]byte, length)
+                                       decrypted := make([]byte, length)
+
+                                       ms().XORKeyStream(ciphertext, plaintext) // Encrypt plaintext
+                                       ms().XORKeyStream(decrypted, ciphertext) // Decrypt ciphertext
+                                       if !bytes.Equal(decrypted, plaintext) {
+                                               t.Errorf("plaintext is different after an encrypt/decrypt cycle; got %s, want %s", truncateHex(decrypted), truncateHex(plaintext))
+                                       }
+                               })
+                       }
+               })
+
+               // Test that XORKeyStream behaves the same as directly XORing
+               // plaintext with the stream.
+               t.Run("DirectXOR", func(t *testing.T) {
+
+                       for _, length := range bufLens {
+                               t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
+                                       rng := newRandReader(t)
+
+                                       plaintext := make([]byte, length)
+                                       rng.Read(plaintext)
+
+                                       // Encrypting all zeros should reveal the stream itself
+                                       stream, directXOR := make([]byte, length), make([]byte, length)
+                                       ms().XORKeyStream(stream, stream)
+                                       // Encrypt plaintext by directly XORing the stream
+                                       subtle.XORBytes(directXOR, stream, plaintext)
+
+                                       // Encrypt plaintext with XORKeyStream
+                                       ciphertext := make([]byte, length)
+                                       ms().XORKeyStream(ciphertext, plaintext)
+                                       if !bytes.Equal(ciphertext, directXOR) {
+                                               t.Errorf("xor semantics were not preserved; got %s, want %s", truncateHex(ciphertext), truncateHex(directXOR))
+                                       }
+                               })
+                       }
+               })
+       })
+
+       t.Run("AlterInput", func(t *testing.T) {
+               rng := newRandReader(t)
+               src, dst, before := make([]byte, bufCap), make([]byte, bufCap), make([]byte, bufCap)
+               rng.Read(src)
+
+               for _, length := range bufLens {
+
+                       t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
+                               copy(before, src)
+
+                               ms().XORKeyStream(dst[:length], src[:length])
+                               if !bytes.Equal(src, before) {
+                                       t.Errorf("XORKeyStream modified src; got %s, want %s", truncateHex(src), truncateHex(before))
+                               }
+                       })
+               }
+       })
+
+       t.Run("Aliasing", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               buff, expectedOutput := make([]byte, bufCap), make([]byte, bufCap)
+
+               for _, length := range bufLens {
+                       // Record what output is when src and dst are different
+                       rng.Read(buff)
+                       ms().XORKeyStream(expectedOutput[:length], buff[:length])
+
+                       // Check that the same output is generated when src=dst alias to the same
+                       // memory
+                       ms().XORKeyStream(buff[:length], buff[:length])
+                       if !bytes.Equal(buff[:length], expectedOutput[:length]) {
+                               t.Errorf("block cipher produced different output when dst = src; got %x, want %x", buff[:length], expectedOutput[:length])
+                       }
+               }
+       })
+
+       t.Run("OutOfBoundsWrite", func(t *testing.T) { // Issue 21104
+               rng := newRandReader(t)
+
+               plaintext := make([]byte, bufCap)
+               rng.Read(plaintext)
+               ciphertext := make([]byte, bufCap)
+
+               for _, length := range bufLens {
+                       copy(ciphertext, plaintext) // Reset ciphertext buffer
+
+                       t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
+                               mustPanic(t, "output smaller than input", func() { ms().XORKeyStream(ciphertext[:length], plaintext) })
+
+                               if !bytes.Equal(ciphertext[length:], plaintext[length:]) {
+                                       t.Errorf("XORKeyStream did out of bounds write; got %s, want %s", truncateHex(ciphertext[length:]), truncateHex(plaintext[length:]))
+                               }
+                       })
+               }
+       })
+
+       t.Run("BufferOverlap", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               buff := make([]byte, bufCap)
+               rng.Read(buff)
+
+               for _, length := range bufLens {
+                       if length == 0 || length == 1 {
+                               continue
+                       }
+
+                       t.Run(fmt.Sprintf("BuffLength=%d", length), func(t *testing.T) {
+                               // Make src and dst slices point to same array with inexact overlap
+                               src := buff[:length]
+                               dst := buff[1 : length+1]
+                               mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
+
+                               // Only overlap on one byte
+                               src = buff[:length]
+                               dst = buff[length-1 : 2*length-1]
+                               mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
+
+                               // src comes after dst with one byte overlap
+                               src = buff[length-1 : 2*length-1]
+                               dst = buff[:length]
+                               mustPanic(t, "invalid buffer overlap", func() { ms().XORKeyStream(dst, src) })
+                       })
+               }
+       })
+
+       t.Run("KeepState", func(t *testing.T) {
+               rng := newRandReader(t)
+
+               plaintext := make([]byte, bufCap)
+               rng.Read(plaintext)
+               ciphertext := make([]byte, bufCap)
+
+               // Make one long call to XORKeyStream
+               ms().XORKeyStream(ciphertext, plaintext)
+
+               for _, step := range bufLens {
+                       if step == 0 {
+                               continue
+                       }
+                       stepMsg := fmt.Sprintf("step %d: ", step)
+
+                       dst := make([]byte, bufCap)
+
+                       // Make a bunch of small calls to (stateful) XORKeyStream
+                       stream := ms()
+                       i := 0
+                       for i+step < len(plaintext) {
+                               stream.XORKeyStream(dst[i:], plaintext[i:i+step])
+                               i += step
+                       }
+                       stream.XORKeyStream(dst[i:], plaintext[i:])
+
+                       if !bytes.Equal(dst, ciphertext) {
+                               t.Errorf(stepMsg+"successive XORKeyStream calls returned a different result than a single one; got %s, want %s", truncateHex(dst), truncateHex(ciphertext))
+                       }
+               }
+       })
+}
+
+// TestStreamFromBlock creates a Stream from a cipher.Block used in a
+// cipher.BlockMode. It addresses Issue 68377 by checking for a panic when the
+// BlockMode uses an IV with incorrect length.
+// For a valid IV, it also runs all TestStream tests on the resulting stream.
+func TestStreamFromBlock(t *testing.T, block cipher.Block, blockMode func(b cipher.Block, iv []byte) cipher.Stream) {
+
+       t.Run("WrongIVLen", func(t *testing.T) {
+               t.Skip("see Issue 68377")
+
+               rng := newRandReader(t)
+               iv := make([]byte, block.BlockSize()+1)
+               rng.Read(iv)
+               mustPanic(t, "IV length must equal block size", func() { blockMode(block, iv) })
+       })
+
+       t.Run("BlockModeStream", func(t *testing.T) {
+               rng := newRandReader(t)
+               iv := make([]byte, block.BlockSize())
+               rng.Read(iv)
+
+               TestStream(t, func() cipher.Stream { return blockMode(block, iv) })
+       })
+}
+
+func truncateHex(b []byte) string {
+       numVals := 50
+
+       if len(b) <= numVals {
+               return fmt.Sprintf("%x", b)
+       }
+       return fmt.Sprintf("%x...", b[:numVals])
+}
index e7356aa45de1b2b75c48720bea9b7345bc9a84d0..f092f4f26591e6cf887a7d3dbd4d8432a48713d9 100644 (file)
@@ -6,6 +6,8 @@ package rc4
 
 import (
        "bytes"
+       "crypto/cipher"
+       "crypto/internal/cryptotest"
        "fmt"
        "testing"
 )
@@ -136,6 +138,13 @@ func TestBlock(t *testing.T) {
        }
 }
 
+func TestRC4Stream(t *testing.T) {
+       cryptotest.TestStream(t, func() cipher.Stream {
+               c, _ := NewCipher(golden[0].key)
+               return c
+       })
+}
+
 func benchmark(b *testing.B, size int64) {
        buf := make([]byte, size)
        c, err := NewCipher(golden[0].key)