]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rand: properly handle large Read on windows
authorRoland Shoemaker <roland@golang.org>
Tue, 26 Apr 2022 02:02:35 +0000 (19:02 -0700)
committerRoland Shoemaker <roland@golang.org>
Thu, 5 May 2022 22:44:02 +0000 (22:44 +0000)
Use the batched reader to chunk large Read calls on windows to a max of
1 << 31 - 1 bytes. This prevents an infinite loop when trying to read
more than 1 << 32 -1 bytes, due to how RtlGenRandom works.

This change moves the batched function from rand_unix.go to rand.go,
since it is now needed for both windows and unix implementations.

Fixes #52561

Change-Id: Id98fc4b1427e5cb2132762a445b2aed646a37473
Reviewed-on: https://go-review.googlesource.com/c/go/+/402257
Run-TryBot: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Filippo Valsorda <valsorda@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/crypto/rand/rand.go
src/crypto/rand/rand_batched_test.go
src/crypto/rand/rand_unix.go
src/crypto/rand/rand_windows.go

index b6248a443874df146ac71757fbdeed19660ed66c..af85b966df6e43fe35964b3b79c61fb415ab4e46 100644 (file)
@@ -24,3 +24,21 @@ var Reader io.Reader
 func Read(b []byte) (n int, err error) {
        return io.ReadFull(Reader, b)
 }
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+       return func(out []byte) error {
+               for len(out) > 0 {
+                       read := len(out)
+                       if read > readMax {
+                               read = readMax
+                       }
+                       if err := f(out[:read]); err != nil {
+                               return err
+                       }
+                       out = out[read:]
+               }
+               return nil
+       }
+}
index dfb9517d5e0a26ff101d1ba282b626f188fc7aca..89953776a89ac69e718b0b62c79b6a58d98af1cf 100644 (file)
@@ -23,8 +23,8 @@ func TestBatched(t *testing.T) {
        }, 5)
 
        p := make([]byte, 13)
-       if !fillBatched(p) {
-               t.Fatal("batched function returned false")
+       if err := fillBatched(p); err != nil {
+               t.Fatalf("batched function returned error: %s", err)
        }
        expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
        if !bytes.Equal(expected, p) {
@@ -55,8 +55,8 @@ func TestBatchedBuffering(t *testing.T) {
                        max = len(outputMarker)
                }
                howMuch := prand.Intn(max + 1)
-               if !fillBatched(outputMarker[:howMuch]) {
-                       t.Fatal("batched function returned false")
+               if err := fillBatched(outputMarker[:howMuch]); err != nil {
+                       t.Fatalf("batched function returned error: %s", err)
                }
                outputMarker = outputMarker[howMuch:]
        }
@@ -67,14 +67,14 @@ func TestBatchedBuffering(t *testing.T) {
 
 func TestBatchedError(t *testing.T) {
        b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-       if b(make([]byte, 13)) {
+       if b(make([]byte, 13)) == nil {
                t.Fatal("batched function should have returned an error")
        }
 }
 
 func TestBatchedEmpty(t *testing.T) {
        b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-       if !b(make([]byte, 0)) {
+       if b(make([]byte, 0)) != nil {
                t.Fatal("empty slice should always return successful")
        }
 }
index 87ba9e3af7aaa8bf601fab3f997d54e5beef794b..64b865289d99fba7673a82e8bfb1dd70795ae019 100644 (file)
@@ -40,25 +40,7 @@ type reader struct {
 
 // altGetRandom if non-nil specifies an OS-specific function to get
 // urandom-style randomness.
-var altGetRandom func([]byte) (ok bool)
-
-// batched returns a function that calls f to populate a []byte by chunking it
-// into subslices of, at most, readMax bytes.
-func batched(f func([]byte) error, readMax int) func([]byte) bool {
-       return func(out []byte) bool {
-               for len(out) > 0 {
-                       read := len(out)
-                       if read > readMax {
-                               read = readMax
-                       }
-                       if f(out[:read]) != nil {
-                               return false
-                       }
-                       out = out[read:]
-               }
-               return true
-       }
-}
+var altGetRandom func([]byte) (err error)
 
 func warnBlocked() {
        println("crypto/rand: blocked for 60 seconds waiting to read random data from the kernel")
@@ -72,7 +54,7 @@ func (r *reader) Read(b []byte) (n int, err error) {
                t := time.AfterFunc(time.Minute, warnBlocked)
                defer t.Stop()
        }
-       if altGetRandom != nil && altGetRandom(b) {
+       if altGetRandom != nil && altGetRandom(b) == nil {
                return len(b), nil
        }
        if atomic.LoadUint32(&r.used) != 2 {
index 7379f1489adbf13bb41598f56c745b9e03fccd6d..6c0655c72b692ae444a85025ecc3d27d65107404 100644 (file)
@@ -9,7 +9,6 @@ package rand
 
 import (
        "internal/syscall/windows"
-       "os"
 )
 
 func init() { Reader = &rngReader{} }
@@ -17,16 +16,11 @@ func init() { Reader = &rngReader{} }
 type rngReader struct{}
 
 func (r *rngReader) Read(b []byte) (n int, err error) {
-       // RtlGenRandom only accepts 2**32-1 bytes at a time, so truncate.
-       inputLen := uint32(len(b))
-
-       if inputLen == 0 {
-               return 0, nil
-       }
-
-       err = windows.RtlGenRandom(b)
-       if err != nil {
-               return 0, os.NewSyscallError("RtlGenRandom", err)
+       // RtlGenRandom only returns 1<<32-1 bytes at a time. We only read at
+       // most 1<<31-1 bytes at a time so that  this works the same on 32-bit
+       // and 64-bit systems.
+       if err := batched(windows.RtlGenRandom, 1<<31-1)(b); err != nil {
+               return 0, err
        }
-       return int(inputLen), nil
+       return len(b), nil
 }